mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Update pre-commit to fixed versions, run remod for ck_tile (#2895)
* Fix ruff linter errors
* Fix remod dos2unix command
* Clang format
* Ignore utility in remod
* Run remod
* Specify clang-format version in pre-commit
* Specify ruff version
* Include PoolKernelArgs in reference_pool
* Add calculate_total_elements to reference batched contraction
* Fix calculate_total_elements declaration
* Refactor remod pre-commit hook
* Fix Aquant tests
---------
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
[ROCm/composable_kernel commit: d40b50b9d5]
This commit is contained in:
27
.github/scripts/therock_configure_ci.py
vendored
27
.github/scripts/therock_configure_ci.py
vendored
@@ -6,6 +6,7 @@ import subprocess
|
||||
import sys
|
||||
from typing import Iterable, Optional, Mapping
|
||||
|
||||
|
||||
def gha_set_output(vars: Mapping[str, str | Path]):
|
||||
"""Sets values in a step's output parameters.
|
||||
|
||||
@@ -25,6 +26,7 @@ def gha_set_output(vars: Mapping[str, str | Path]):
|
||||
with open(step_output_file, "a") as f:
|
||||
f.writelines(f"{k}={str(v)}" + "\n" for k, v in vars.items())
|
||||
|
||||
|
||||
def get_modified_paths(base_ref: str) -> Optional[Iterable[str]]:
|
||||
"""Returns the paths of modified files relative to the base reference."""
|
||||
try:
|
||||
@@ -42,11 +44,13 @@ def get_modified_paths(base_ref: str) -> Optional[Iterable[str]]:
|
||||
file=sys.stderr,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
|
||||
GITHUB_WORKFLOWS_CI_PATTERNS = [
|
||||
"therock*",
|
||||
]
|
||||
|
||||
|
||||
def is_path_workflow_file_related_to_ci(path: str) -> bool:
|
||||
return any(
|
||||
fnmatch.fnmatch(path, ".github/workflows/" + pattern)
|
||||
@@ -56,11 +60,13 @@ def is_path_workflow_file_related_to_ci(path: str) -> bool:
|
||||
for pattern in GITHUB_WORKFLOWS_CI_PATTERNS
|
||||
)
|
||||
|
||||
|
||||
def check_for_workflow_file_related_to_ci(paths: Optional[Iterable[str]]) -> bool:
|
||||
if paths is None:
|
||||
return False
|
||||
return any(is_path_workflow_file_related_to_ci(p) for p in paths)
|
||||
|
||||
|
||||
# Paths matching any of these patterns are considered to have no influence over
|
||||
# build or test workflows so any related jobs can be skipped if all paths
|
||||
# modified by a commit/PR match a pattern in this list.
|
||||
@@ -70,23 +76,26 @@ SKIPPABLE_PATH_PATTERNS = [
|
||||
"*.md",
|
||||
"*.pre-commit-config.*",
|
||||
"*LICENSE",
|
||||
'Jenkinsfile',
|
||||
'.github/ISSUE_TEMPLATE/*',
|
||||
'.github/CODEOWNERS',
|
||||
'.github/*.md',
|
||||
'.github/dependabot.yml',
|
||||
"Jenkinsfile",
|
||||
".github/ISSUE_TEMPLATE/*",
|
||||
".github/CODEOWNERS",
|
||||
".github/*.md",
|
||||
".github/dependabot.yml",
|
||||
]
|
||||
|
||||
|
||||
def is_path_skippable(path: str) -> bool:
|
||||
"""Determines if a given relative path to a file matches any skippable patterns."""
|
||||
return any(fnmatch.fnmatch(path, pattern) for pattern in SKIPPABLE_PATH_PATTERNS)
|
||||
|
||||
|
||||
def check_for_non_skippable_path(paths: Optional[Iterable[str]]) -> bool:
|
||||
"""Returns true if at least one path is not in the skippable set."""
|
||||
if paths is None:
|
||||
return False
|
||||
return any(not is_path_skippable(p) for p in paths)
|
||||
|
||||
|
||||
def should_ci_run_given_modified_paths(paths: Optional[Iterable[str]]) -> bool:
|
||||
"""Returns true if CI workflows should run given a list of modified paths."""
|
||||
|
||||
@@ -118,16 +127,16 @@ def should_ci_run_given_modified_paths(paths: Optional[Iterable[str]]) -> bool:
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def main(args):
|
||||
base_ref = args.get("base_ref")
|
||||
modified_paths = get_modified_paths(base_ref)
|
||||
print("modified_paths (max 200):", modified_paths[:200])
|
||||
enable_jobs = should_ci_run_given_modified_paths(modified_paths)
|
||||
output = {
|
||||
'enable_therock_ci': json.dumps(enable_jobs)
|
||||
}
|
||||
output = {"enable_therock_ci": json.dumps(enable_jobs)}
|
||||
gha_set_output(output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = {}
|
||||
args["base_ref"] = os.environ.get("BASE_REF", "HEAD^1")
|
||||
|
||||
@@ -1,11 +1,25 @@
|
||||
repos:
|
||||
- repo: local
|
||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||
rev: v18.1.3
|
||||
hooks:
|
||||
- id: clang-format
|
||||
name: clang-format
|
||||
entry: clang-format-18 -i --style=file
|
||||
language: system
|
||||
types_or: [c++, inc]
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.14.0
|
||||
hooks:
|
||||
- id: ruff-check
|
||||
args: [ --fix ]
|
||||
exclude: |
|
||||
(?x)^(
|
||||
docs/conf.py
|
||||
)$
|
||||
- id: ruff-format
|
||||
exclude: |
|
||||
(?x)^(
|
||||
docs/conf.py
|
||||
)$
|
||||
- repo: local
|
||||
hooks:
|
||||
# - id: copyright-year-checker
|
||||
# name: copyright-year-checker
|
||||
# entry: script/check_copyright_year.sh
|
||||
@@ -18,21 +32,9 @@ repos:
|
||||
language: script
|
||||
types_or: [c++, text]
|
||||
verbose: true
|
||||
- id: ruff-check
|
||||
name: Ruff Linter
|
||||
entry: ruff check --fix
|
||||
language: python
|
||||
types: [python]
|
||||
additional_dependencies: [ruff]
|
||||
- id: ruff-format
|
||||
name: Ruff Formatter
|
||||
entry: ruff format
|
||||
language: python
|
||||
types: [python]
|
||||
additional_dependencies: [ruff]
|
||||
- id: run-remod-if-ck-tile-changed
|
||||
name: Run remod.py if ck_tile files changed
|
||||
entry: script/remod_for_ck_tile.sh
|
||||
language: script
|
||||
always_run: true
|
||||
files: '^(include|example)/ck_tile/.*$'
|
||||
pass_filenames: false
|
||||
|
||||
@@ -2,4 +2,4 @@
|
||||
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
GEN_DIR = "" # in Cmake, have to generate files in same folder
|
||||
GEN_DIR = "" # in Cmake, have to generate files in same folder
|
||||
|
||||
@@ -3,38 +3,35 @@
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
FWD_DTYPE_MAP = {
|
||||
"fp32" : "FmhaFwdFp32",
|
||||
"fp16" : "FmhaFwdFp16",
|
||||
"bf16" : "FmhaFwdBf16",
|
||||
"fp8" : "FmhaFwdFp8",
|
||||
"fp32": "FmhaFwdFp32",
|
||||
"fp16": "FmhaFwdFp16",
|
||||
"bf16": "FmhaFwdBf16",
|
||||
"fp8": "FmhaFwdFp8",
|
||||
"fp8fp16": "FmhaFwdFp8Fp16",
|
||||
"fp8bf16": "FmhaFwdFp8Bf16",
|
||||
"fp8fp32": "FmhaFwdFp8Fp32"
|
||||
"fp8fp32": "FmhaFwdFp8Fp32",
|
||||
}
|
||||
|
||||
BWD_DTYPE_MAP = {
|
||||
"fp32": "FmhaBwdFp32",
|
||||
"fp16": "FmhaBwdFp16",
|
||||
"bf16": "FmhaBwdBf16"
|
||||
}
|
||||
BWD_DTYPE_MAP = {"fp32": "FmhaBwdFp32", "fp16": "FmhaBwdFp16", "bf16": "FmhaBwdBf16"}
|
||||
|
||||
MASK_IMPL = {
|
||||
"generic" : "ck_tile::GenericAttentionMask",
|
||||
"simplified" : "ck_tile::SimplifiedGenericAttentionMask"
|
||||
"generic": "ck_tile::GenericAttentionMask",
|
||||
"simplified": "ck_tile::SimplifiedGenericAttentionMask",
|
||||
}
|
||||
|
||||
_MASK_SIMPLIFIED_MAP = {
|
||||
"s_no" : "ck_tile::SimplifiedGenericAttentionMask<false>",
|
||||
"s_mask" : "ck_tile::SimplifiedGenericAttentionMask<true>",
|
||||
"s_no": "ck_tile::SimplifiedGenericAttentionMask<false>",
|
||||
"s_mask": "ck_tile::SimplifiedGenericAttentionMask<true>",
|
||||
}
|
||||
|
||||
_MASK_MAP = {
|
||||
"no" : "FmhaMasks::NoMask",
|
||||
"causal" : "FmhaMasks::CausalMask",
|
||||
"generic" : "FmhaMasks::GenericMask"
|
||||
"no": "FmhaMasks::NoMask",
|
||||
"causal": "FmhaMasks::CausalMask",
|
||||
"generic": "FmhaMasks::GenericMask",
|
||||
}
|
||||
|
||||
def get_mask_map(mask : str):
|
||||
|
||||
def get_mask_map(mask: str):
|
||||
if mask == "generic":
|
||||
return _MASK_MAP
|
||||
elif mask == "simplified":
|
||||
@@ -43,18 +40,20 @@ def get_mask_map(mask : str):
|
||||
assert False
|
||||
return None
|
||||
|
||||
|
||||
_MASK_CHECK_MAP = {
|
||||
"no" : "t.mask_type == mask_enum::no_mask",
|
||||
"causal" : "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
|
||||
"generic" : "t.mask_type == mask_enum::window_generic",
|
||||
"no": "t.mask_type == mask_enum::no_mask",
|
||||
"causal": "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
|
||||
"generic": "t.mask_type == mask_enum::window_generic",
|
||||
}
|
||||
|
||||
_MASK_SIMPLIFIED_CHECK_MAP = {
|
||||
"s_no" : "t.mask_type == mask_enum::no_mask",
|
||||
"s_mask" : "t.mask_type != mask_enum::no_mask",
|
||||
"s_no": "t.mask_type == mask_enum::no_mask",
|
||||
"s_mask": "t.mask_type != mask_enum::no_mask",
|
||||
}
|
||||
|
||||
def get_mask_check_map(mask : str):
|
||||
|
||||
def get_mask_check_map(mask: str):
|
||||
if mask == "generic":
|
||||
return _MASK_CHECK_MAP
|
||||
elif mask == "simplified":
|
||||
@@ -63,76 +62,71 @@ def get_mask_check_map(mask : str):
|
||||
assert False
|
||||
return None
|
||||
|
||||
|
||||
BIAS_MAP = {
|
||||
"no" : "ck_tile::BlockAttentionBiasEnum::NO_BIAS",
|
||||
"bias" : "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS",
|
||||
"alibi" : "ck_tile::BlockAttentionBiasEnum::ALIBI"
|
||||
"no": "ck_tile::BlockAttentionBiasEnum::NO_BIAS",
|
||||
"bias": "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS",
|
||||
"alibi": "ck_tile::BlockAttentionBiasEnum::ALIBI",
|
||||
}
|
||||
|
||||
# TODO: this is ugly
|
||||
BIAS_CHECK_MAP = {
|
||||
"no" : "bias_enum::no_bias",
|
||||
"bias" : "bias_enum::elementwise_bias",
|
||||
"alibi" : "bias_enum::alibi"
|
||||
"no": "bias_enum::no_bias",
|
||||
"bias": "bias_enum::elementwise_bias",
|
||||
"alibi": "bias_enum::alibi",
|
||||
}
|
||||
|
||||
DROPOUT_MAP = {
|
||||
"no" : "ck_tile::BlockDropoutBwd<false, true, false>",
|
||||
"dropout_wg32" : "ck_tile::BlockDropoutBwd<true, true, false>",
|
||||
"dropout_wg32_storerandval" : "ck_tile::BlockDropoutBwd<true, true, true >",
|
||||
"dropout_wg16" : "ck_tile::BlockDropoutBwd<true, false, false>",
|
||||
"dropout_wg16_storerandval" : "ck_tile::BlockDropoutBwd<true, false, true >"
|
||||
"no": "ck_tile::BlockDropoutBwd<false, true, false>",
|
||||
"dropout_wg32": "ck_tile::BlockDropoutBwd<true, true, false>",
|
||||
"dropout_wg32_storerandval": "ck_tile::BlockDropoutBwd<true, true, true >",
|
||||
"dropout_wg16": "ck_tile::BlockDropoutBwd<true, false, false>",
|
||||
"dropout_wg16_storerandval": "ck_tile::BlockDropoutBwd<true, false, true >",
|
||||
}
|
||||
|
||||
DROPOUT_CHECK_MAP = {
|
||||
"no" : "t.has_dropout == false",
|
||||
"dropout_wg32" : "t.has_dropout == true && t.is_store_randval == false",
|
||||
"dropout_wg32_storerandval" : "t.has_dropout == true && t.is_store_randval == true",
|
||||
"dropout_wg16" : "t.has_dropout == true && t.is_store_randval == false",
|
||||
"dropout_wg16_storerandval" : "t.has_dropout == true && t.is_store_randval == true",
|
||||
"no": "t.has_dropout == false",
|
||||
"dropout_wg32": "t.has_dropout == true && t.is_store_randval == false",
|
||||
"dropout_wg32_storerandval": "t.has_dropout == true && t.is_store_randval == true",
|
||||
"dropout_wg16": "t.has_dropout == true && t.is_store_randval == false",
|
||||
"dropout_wg16_storerandval": "t.has_dropout == true && t.is_store_randval == true",
|
||||
}
|
||||
|
||||
ROPE_MAP = {
|
||||
"no" : "ck_tile::RotaryEmbeddingEnum::NONE",
|
||||
"inter" : "ck_tile::RotaryEmbeddingEnum::INTERLEAVED",
|
||||
"half" : "ck_tile::RotaryEmbeddingEnum::HALF_ROTATED"
|
||||
"no": "ck_tile::RotaryEmbeddingEnum::NONE",
|
||||
"inter": "ck_tile::RotaryEmbeddingEnum::INTERLEAVED",
|
||||
"half": "ck_tile::RotaryEmbeddingEnum::HALF_ROTATED",
|
||||
}
|
||||
|
||||
ROPE_CHECK_MAP = {
|
||||
"no" : "rope_enum::none",
|
||||
"inter" : "rope_enum::interleaved",
|
||||
"half" : "rope_enum::half_rotated"
|
||||
"no": "rope_enum::none",
|
||||
"inter": "rope_enum::interleaved",
|
||||
"half": "rope_enum::half_rotated",
|
||||
}
|
||||
|
||||
MODE_MAP = {
|
||||
"batch" : "false",
|
||||
"group" : "true"
|
||||
}
|
||||
MODE_MAP = {"batch": "false", "group": "true"}
|
||||
|
||||
LAYOUT_MAP = {
|
||||
"row" : "true",
|
||||
"col" : "false"
|
||||
}
|
||||
LAYOUT_MAP = {"row": "true", "col": "false"}
|
||||
|
||||
PIPELINE_MAP = {
|
||||
"qr" : "ck_tile::BlockFmhaPipelineQRKSVS",
|
||||
"qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync",
|
||||
"qs" : "ck_tile::BlockFmhaPipelineQSKSVS",
|
||||
"qr_async_trload" : "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload",
|
||||
"qr": "ck_tile::BlockFmhaPipelineQRKSVS",
|
||||
"qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsync",
|
||||
"qs": "ck_tile::BlockFmhaPipelineQSKSVS",
|
||||
"qr_async_trload": "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload",
|
||||
}
|
||||
|
||||
PIPELINE_ENUM_MAP = {
|
||||
"qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
|
||||
"qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
|
||||
"qr_nwarp_sshuffle" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
|
||||
"qs" : "ck_tile::BlockFmhaPipelineEnum::QSKSVS",
|
||||
"qr_pagedkv" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
|
||||
"qr_async_trload" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD",
|
||||
"qr": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
|
||||
"qr_async": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
|
||||
"qr_nwarp_sshuffle": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
|
||||
"qs": "ck_tile::BlockFmhaPipelineEnum::QSKSVS",
|
||||
"qr_pagedkv": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
|
||||
"qr_async_trload": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD",
|
||||
}
|
||||
|
||||
BOOL_MAP = {
|
||||
"t" : "true",
|
||||
"f" : "false",
|
||||
True : "true",
|
||||
False : "false",
|
||||
"t": "true",
|
||||
"f": "false",
|
||||
True: "true",
|
||||
False: "false",
|
||||
}
|
||||
|
||||
@@ -9,28 +9,26 @@ import itertools
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from codegen.cmake_config import *
|
||||
from codegen.cpp_symbol_map import *
|
||||
from codegen.cmake_config import GEN_DIR
|
||||
from codegen.cpp_symbol_map import (
|
||||
MODE_MAP,
|
||||
LAYOUT_MAP,
|
||||
BIAS_CHECK_MAP,
|
||||
get_mask_check_map,
|
||||
get_mask_map,
|
||||
BIAS_MAP,
|
||||
FWD_DTYPE_MAP,
|
||||
BOOL_MAP,
|
||||
PIPELINE_ENUM_MAP,
|
||||
)
|
||||
|
||||
|
||||
DTYPE_BITS = {
|
||||
"fp32": 32,
|
||||
"fp16": 16,
|
||||
"bf16": 16,
|
||||
"fp8" : 8,
|
||||
"bf8" : 8
|
||||
}
|
||||
DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8}
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {
|
||||
32 : 32,
|
||||
64 : 64,
|
||||
96 : 128,
|
||||
128: 128,
|
||||
256: 256
|
||||
}
|
||||
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}
|
||||
|
||||
FMHA_BATCH_PREFILL_PIPELINE_MAP = {
|
||||
"qr_async" : "ck_tile::BlockFmhaBatchPrefillPipelineQRKSVSAsync",
|
||||
"qr_async": "ck_tile::BlockFmhaBatchPrefillPipelineQRKSVSAsync",
|
||||
}
|
||||
|
||||
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
@@ -40,7 +38,7 @@ FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
#include "fmha_fwd.hpp"
|
||||
"""
|
||||
|
||||
FMHA_FWD_KERNEL_BODY="""
|
||||
FMHA_FWD_KERNEL_BODY = """
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
|
||||
@@ -116,8 +114,8 @@ float fmha_batch_prefill_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_b
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_FILENAME="fmha_batch_prefill_api.cpp"
|
||||
FMHA_FWD_API="""
|
||||
FMHA_FWD_API_FILENAME = "fmha_batch_prefill_api.cpp"
|
||||
FMHA_FWD_API = """
|
||||
#include <cstdio>
|
||||
|
||||
namespace {{
|
||||
@@ -167,173 +165,223 @@ float fmha_batch_prefill(fmha_batch_prefill_traits t, fmha_batch_prefill_args a,
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
|
||||
FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
|
||||
{F_hdim_case}
|
||||
}}
|
||||
"""
|
||||
FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
|
||||
FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
|
||||
{F_inner_dispatch}
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>;
|
||||
return fmha_batch_prefill_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class CppConstraint:
|
||||
bool_expr: str = None
|
||||
|
||||
def __str__(self):
|
||||
if self.bool_expr is None:
|
||||
return 'true'
|
||||
return "true"
|
||||
else:
|
||||
return f'{self.bool_expr}'
|
||||
return f"{self.bool_expr}"
|
||||
|
||||
def __and__(self, other):
|
||||
return CppConstraint(f'({str(self)}) && ({str(other)})')
|
||||
return CppConstraint(f"({str(self)}) && ({str(other)})")
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdApiTrait:
|
||||
pipeline_tag : str
|
||||
pipeline_tag: str
|
||||
# sync with fmha_fwd_traits<>, to generate fallback calls
|
||||
hdim : str
|
||||
dtype : str # data type
|
||||
mode : str # value from MODE_MAP
|
||||
bm0 : int # tile size along q seqlen (block size)
|
||||
bn0 : int # tile size along qk seqlen
|
||||
bk0 : int # tile size along qk gemm unroll
|
||||
bn1 : int # tile size along v head_dim
|
||||
bk1 : int # tile size along kv gemm unroll
|
||||
bk0max : int
|
||||
vlayout : str
|
||||
logits : str
|
||||
mask : str
|
||||
bias : str #
|
||||
lse : str #
|
||||
dropout : str
|
||||
squant : str #
|
||||
spad : str
|
||||
skpad : str
|
||||
dpad : str
|
||||
dvpad : str
|
||||
constraint : CppConstraint
|
||||
hdim: str
|
||||
dtype: str # data type
|
||||
mode: str # value from MODE_MAP
|
||||
bm0: int # tile size along q seqlen (block size)
|
||||
bn0: int # tile size along qk seqlen
|
||||
bk0: int # tile size along qk gemm unroll
|
||||
bn1: int # tile size along v head_dim
|
||||
bk1: int # tile size along kv gemm unroll
|
||||
bk0max: int
|
||||
vlayout: str
|
||||
logits: str
|
||||
mask: str
|
||||
bias: str #
|
||||
lse: str #
|
||||
dropout: str
|
||||
squant: str #
|
||||
spad: str
|
||||
skpad: str
|
||||
dpad: str
|
||||
dvpad: str
|
||||
constraint: CppConstraint
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
|
||||
f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
|
||||
return (
|
||||
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
|
||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}"
|
||||
)
|
||||
|
||||
@property
|
||||
def scheck(self) -> str:
|
||||
if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
if self.spad == 't' : return 'true' # always support
|
||||
else : return 'true'
|
||||
elif self.pipeline_tag in ['qr']:
|
||||
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.seqlen_q % {self.bm0} == 0'
|
||||
else: assert False
|
||||
if self.mode == "group":
|
||||
return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true
|
||||
if self.pipeline_tag == "qr_async":
|
||||
if self.spad == "t":
|
||||
return "true" # always support
|
||||
else:
|
||||
return "true"
|
||||
elif self.pipeline_tag in ["qr"]:
|
||||
if self.spad == "t":
|
||||
return f"true /*a.seqlen_q % {self.bm0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
|
||||
else:
|
||||
return f"a.seqlen_q % {self.bm0} == 0"
|
||||
else:
|
||||
assert False
|
||||
|
||||
@property
|
||||
def skcheck(self) -> str:
|
||||
if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0'
|
||||
else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0'
|
||||
elif self.pipeline_tag in ['qr', 'qr_fp8']:
|
||||
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.seqlen_k % {self.bn0} == 0'
|
||||
else: assert False
|
||||
if self.mode == "group":
|
||||
return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true
|
||||
if self.pipeline_tag == "qr_async":
|
||||
if self.skpad == "t":
|
||||
return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0"
|
||||
else:
|
||||
return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0"
|
||||
elif self.pipeline_tag in ["qr", "qr_fp8"]:
|
||||
if self.skpad == "t":
|
||||
return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
|
||||
else:
|
||||
return f"a.seqlen_k % {self.bn0} == 0"
|
||||
else:
|
||||
assert False
|
||||
|
||||
@property
|
||||
def dcheck(self) -> str:
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
if self.pipeline_tag == "qr_async":
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
|
||||
else : assert False
|
||||
elif self.pipeline_tag in ['qr']:
|
||||
if self.dpad == "t":
|
||||
return f"a.hdim_q % {vec} == 0"
|
||||
else:
|
||||
assert False
|
||||
elif self.pipeline_tag in ["qr"]:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.hdim_q % {bk0submax} == 0'
|
||||
else: assert False
|
||||
if self.dpad == "t":
|
||||
return f"true /*a.hdim_q % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
|
||||
else:
|
||||
return f"a.hdim_q % {bk0submax} == 0"
|
||||
else:
|
||||
assert False
|
||||
|
||||
@property
|
||||
def dvcheck(self) -> str:
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
if self.pipeline_tag == "qr_async":
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
|
||||
else : assert False
|
||||
elif self.pipeline_tag in ['qr']:
|
||||
if self.dvpad == "t":
|
||||
return f"a.hdim_v % {vec} == 0"
|
||||
else:
|
||||
assert False
|
||||
elif self.pipeline_tag in ["qr"]:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.hdim_v % {bk0submax} == 0'
|
||||
else: assert False
|
||||
if self.dvpad == "t":
|
||||
return f"true /*a.hdim_v % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
|
||||
else:
|
||||
return f"a.hdim_v % {bk0submax} == 0"
|
||||
else:
|
||||
assert False
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdPipeline:
|
||||
tag : str
|
||||
tag: str
|
||||
|
||||
F_vlayout : str # row/col
|
||||
F_spad : str # true/false
|
||||
F_skpad : str #
|
||||
F_dpad : str #
|
||||
F_dvpad : str #
|
||||
F_logits : str # t/f
|
||||
F_bias : str # true/false
|
||||
F_lse : str #
|
||||
F_dropout : str #
|
||||
F_squant : str #
|
||||
F_mask : str # value from MASK_MAP
|
||||
F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint())
|
||||
F_vlayout: str # row/col
|
||||
F_spad: str # true/false
|
||||
F_skpad: str #
|
||||
F_dpad: str #
|
||||
F_dvpad: str #
|
||||
F_logits: str # t/f
|
||||
F_bias: str # true/false
|
||||
F_lse: str #
|
||||
F_dropout: str #
|
||||
F_squant: str #
|
||||
F_mask: str # value from MASK_MAP
|
||||
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def pad_name() -> str:
|
||||
n = ''
|
||||
if self.F_spad == 't': n += 's'
|
||||
if self.F_skpad == 't' : n += 'sk'
|
||||
if self.F_dpad == 't' : n += 'd'
|
||||
if self.F_dvpad == 't' : n += 'dv'
|
||||
if n != '' : n = 'p' + n
|
||||
n = ""
|
||||
if self.F_spad == "t":
|
||||
n += "s"
|
||||
if self.F_skpad == "t":
|
||||
n += "sk"
|
||||
if self.F_dpad == "t":
|
||||
n += "d"
|
||||
if self.F_dvpad == "t":
|
||||
n += "dv"
|
||||
if n != "":
|
||||
n = "p" + n
|
||||
return n
|
||||
|
||||
pn = pad_name()
|
||||
n = f'{self.tag}_v{self.F_vlayout[0]}'
|
||||
if pn != '' : n += f'_{pn}'
|
||||
else: n += '_npad'
|
||||
|
||||
if self.F_logits == 't' : n += '_logits'
|
||||
else: n += '_nlogits'
|
||||
|
||||
if self.F_bias != 'no' : n += f'_{self.F_bias}'
|
||||
else: n += '_nbias'
|
||||
|
||||
if self.F_mask[0:2] == 's_':
|
||||
if self.F_mask == 's_mask': n += f'_mask'
|
||||
else: n += '_nmask'
|
||||
n = f"{self.tag}_v{self.F_vlayout[0]}"
|
||||
if pn != "":
|
||||
n += f"_{pn}"
|
||||
else:
|
||||
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
|
||||
else: n += '_nmask'
|
||||
n += "_npad"
|
||||
|
||||
if self.F_lse == 't' : n += '_lse'
|
||||
else: n += '_nlse'
|
||||
if self.F_logits == "t":
|
||||
n += "_logits"
|
||||
else:
|
||||
n += "_nlogits"
|
||||
|
||||
if self.F_dropout == 't' : n += '_dropout'
|
||||
else: n += '_ndropout'
|
||||
if self.F_bias != "no":
|
||||
n += f"_{self.F_bias}"
|
||||
else:
|
||||
n += "_nbias"
|
||||
|
||||
if self.F_squant == 't' : n += '_squant'
|
||||
else: n += '_nsquant'
|
||||
if self.F_mask[0:2] == "s_":
|
||||
if self.F_mask == "s_mask":
|
||||
n += "_mask"
|
||||
else:
|
||||
n += "_nmask"
|
||||
else:
|
||||
if self.F_mask != "no":
|
||||
n += f"_m{self.F_mask[0]}"
|
||||
else:
|
||||
n += "_nmask"
|
||||
|
||||
if self.F_lse == "t":
|
||||
n += "_lse"
|
||||
else:
|
||||
n += "_nlse"
|
||||
|
||||
if self.F_dropout == "t":
|
||||
n += "_dropout"
|
||||
else:
|
||||
n += "_ndropout"
|
||||
|
||||
if self.F_squant == "t":
|
||||
n += "_squant"
|
||||
else:
|
||||
n += "_nsquant"
|
||||
return n
|
||||
|
||||
|
||||
class FmhaFwdApiPool:
|
||||
def __init__(self, mask_impl):
|
||||
self.pool = dict()
|
||||
self.mask_impl = mask_impl
|
||||
|
||||
def register_traits(self, trait : FmhaFwdApiTrait) -> None:
|
||||
def register_traits(self, trait: FmhaFwdApiTrait) -> None:
|
||||
# TODO: do we need to check duplication?
|
||||
if trait.dtype not in self.pool.keys():
|
||||
self.pool[trait.dtype] = dict()
|
||||
@@ -344,118 +392,152 @@ class FmhaFwdApiPool:
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
per_dtypes=str()
|
||||
per_dtypes = str()
|
||||
for i, dtype in enumerate(self.pool.keys()):
|
||||
per_hdim_case=str()
|
||||
per_hdim_case = str()
|
||||
for j, hdim in enumerate(self.pool[dtype].keys()):
|
||||
traits=self.pool[dtype][hdim]
|
||||
inners=str()
|
||||
traits = self.pool[dtype][hdim]
|
||||
inners = str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = 'if' if k == 0 else 'else if'
|
||||
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
|
||||
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_squant=BOOL_MAP[trait.squant],
|
||||
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_constraint=trait.constraint,
|
||||
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
if_k = "if" if k == 0 else "else if"
|
||||
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(
|
||||
F_if=if_k,
|
||||
F_mode=MODE_MAP[trait.mode],
|
||||
F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
|
||||
F_logits=BOOL_MAP[trait.logits],
|
||||
F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
|
||||
F_bias_check=BIAS_CHECK_MAP[trait.bias],
|
||||
F_bias=BIAS_MAP[trait.bias],
|
||||
F_lse=BOOL_MAP[trait.lse],
|
||||
F_dropout=BOOL_MAP[trait.dropout],
|
||||
F_squant=BOOL_MAP[trait.squant],
|
||||
F_scheck=trait.scheck,
|
||||
F_skcheck=trait.skcheck,
|
||||
F_dcheck=trait.dcheck,
|
||||
F_dvcheck=trait.dvcheck,
|
||||
F_constraint=trait.constraint,
|
||||
F_spad=BOOL_MAP[trait.spad],
|
||||
F_skpad=BOOL_MAP[trait.skpad],
|
||||
F_dpad=BOOL_MAP[trait.dpad],
|
||||
F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0,
|
||||
F_bn0=trait.bn0,
|
||||
F_bk0=trait.bk0,
|
||||
F_bn1=trait.bn1,
|
||||
F_bk1=trait.bk1,
|
||||
F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[dtype],
|
||||
)
|
||||
if_j = "if" if j == 0 else "else if"
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
|
||||
F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners
|
||||
)
|
||||
if_i = "if" if i == 0 else "else if"
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(
|
||||
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
|
||||
)
|
||||
if not per_dtypes:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_dtypes += ' (void)t ; (void)s ; (void)a;'
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes)
|
||||
per_dtypes += " (void)t ; (void)s ; (void)a;"
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_dtypes)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdTileSize:
|
||||
F_bm0 : int # tile size along q seqlen (block size)
|
||||
F_bn0 : int # tile size along k seqlen
|
||||
F_bk0 : int # tile size along qk gemm unroll
|
||||
F_bn1 : int # tile size along v head_dim
|
||||
F_bk1 : int # tile size along kv gemm unroll
|
||||
F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
|
||||
F_rm0 : int # number of warps for gemm0 along q seqlen
|
||||
F_rn0 : int # number of warps for gemm0 along k seqlen
|
||||
F_rk0 : int # number of warps for gemm0 along head dim q (not used)
|
||||
F_rm1 : int # number of warps for gemm1 along q seqlen
|
||||
F_rn1 : int # number of warps for gemm1 along head dim v
|
||||
F_rk1 : int # number of warps for gemm1 along k seqlen (not used)
|
||||
F_wm0 : int # gemm0 warp size along m
|
||||
F_wn0 : int # gemm0 warp size along n
|
||||
F_wk0 : int # gemm0 warp size along k
|
||||
F_wm1 : int # gemm1 warp size along m
|
||||
F_wn1 : int # gemm1 warp size along n
|
||||
F_wk1 : int # gemm1 warp size along k
|
||||
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
||||
F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint())
|
||||
F_bm0: int # tile size along q seqlen (block size)
|
||||
F_bn0: int # tile size along k seqlen
|
||||
F_bk0: int # tile size along qk gemm unroll
|
||||
F_bn1: int # tile size along v head_dim
|
||||
F_bk1: int # tile size along kv gemm unroll
|
||||
F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
|
||||
F_rm0: int # number of warps for gemm0 along q seqlen
|
||||
F_rn0: int # number of warps for gemm0 along k seqlen
|
||||
F_rk0: int # number of warps for gemm0 along head dim q (not used)
|
||||
F_rm1: int # number of warps for gemm1 along q seqlen
|
||||
F_rn1: int # number of warps for gemm1 along head dim v
|
||||
F_rk1: int # number of warps for gemm1 along k seqlen (not used)
|
||||
F_wm0: int # gemm0 warp size along m
|
||||
F_wn0: int # gemm0 warp size along n
|
||||
F_wk0: int # gemm0 warp size along k
|
||||
F_wm1: int # gemm1 warp size along m
|
||||
F_wn1: int # gemm1 warp size along n
|
||||
F_wk1: int # gemm1 warp size along k
|
||||
F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
||||
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\
|
||||
f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\
|
||||
f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\
|
||||
("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
|
||||
return (
|
||||
f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}"
|
||||
+ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}"
|
||||
+ f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}"
|
||||
+ ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdKernel:
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim : int # hdim
|
||||
F_dtype : str # data type
|
||||
F_mode : str # value from MODE_MAP
|
||||
F_tile : FmhaFwdTileSize
|
||||
F_pipeline : FmhaFwdPipeline
|
||||
mask_impl : str
|
||||
F_idx: int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim: int # hdim
|
||||
F_dtype: str # data type
|
||||
F_mode: str # value from MODE_MAP
|
||||
F_tile: FmhaFwdTileSize
|
||||
F_pipeline: FmhaFwdPipeline
|
||||
mask_impl: str
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
kernel_body = str()
|
||||
return FMHA_FWD_KERNEL_HEADER + \
|
||||
FMHA_FWD_KERNEL_BODY.format(
|
||||
F_idx = self.F_idx,
|
||||
F_hdim = self.F_hdim,
|
||||
F_dtype = FWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0 = self.F_tile.F_bm0,
|
||||
F_bn0 = self.F_tile.F_bn0,
|
||||
F_bk0 = self.F_tile.F_bk0,
|
||||
F_bn1 = self.F_tile.F_bn1,
|
||||
F_bk1 = self.F_tile.F_bk1,
|
||||
F_bk0max = self.F_tile.F_bk0max,
|
||||
F_rm0 = self.F_tile.F_rm0,
|
||||
F_rn0 = self.F_tile.F_rn0,
|
||||
F_rk0 = self.F_tile.F_rk0,
|
||||
F_rm1 = self.F_tile.F_rm1,
|
||||
F_rn1 = self.F_tile.F_rn1,
|
||||
F_rk1 = self.F_tile.F_rk1,
|
||||
F_wm0 = self.F_tile.F_wm0,
|
||||
F_wn0 = self.F_tile.F_wn0,
|
||||
F_wk0 = self.F_tile.F_wk0,
|
||||
F_wm1 = self.F_tile.F_wm1,
|
||||
F_wn1 = self.F_tile.F_wn1,
|
||||
F_wk1 = self.F_tile.F_wk1,
|
||||
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
|
||||
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
|
||||
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_logits = BOOL_MAP[self.F_pipeline.F_logits],
|
||||
F_bias = BIAS_MAP[self.F_pipeline.F_bias],
|
||||
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
|
||||
F_dropout = BOOL_MAP[self.F_pipeline.F_dropout],
|
||||
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
|
||||
F_occupancy = self.F_tile.F_occupancy,
|
||||
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
F_mode = MODE_MAP[self.F_mode],
|
||||
F_pipeline = FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag])
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format(
|
||||
F_idx=self.F_idx,
|
||||
F_hdim=self.F_hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0=self.F_tile.F_bm0,
|
||||
F_bn0=self.F_tile.F_bn0,
|
||||
F_bk0=self.F_tile.F_bk0,
|
||||
F_bn1=self.F_tile.F_bn1,
|
||||
F_bk1=self.F_tile.F_bk1,
|
||||
F_bk0max=self.F_tile.F_bk0max,
|
||||
F_rm0=self.F_tile.F_rm0,
|
||||
F_rn0=self.F_tile.F_rn0,
|
||||
F_rk0=self.F_tile.F_rk0,
|
||||
F_rm1=self.F_tile.F_rm1,
|
||||
F_rn1=self.F_tile.F_rn1,
|
||||
F_rk1=self.F_tile.F_rk1,
|
||||
F_wm0=self.F_tile.F_wm0,
|
||||
F_wn0=self.F_tile.F_wn0,
|
||||
F_wk0=self.F_tile.F_wk0,
|
||||
F_wm1=self.F_tile.F_wm1,
|
||||
F_wn1=self.F_tile.F_wn1,
|
||||
F_wk1=self.F_tile.F_wk1,
|
||||
F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout],
|
||||
F_spad=BOOL_MAP[self.F_pipeline.F_spad],
|
||||
F_skpad=BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad=BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_logits=BOOL_MAP[self.F_pipeline.F_logits],
|
||||
F_bias=BIAS_MAP[self.F_pipeline.F_bias],
|
||||
F_lse=BOOL_MAP[self.F_pipeline.F_lse],
|
||||
F_dropout=BOOL_MAP[self.F_pipeline.F_dropout],
|
||||
F_squant=BOOL_MAP[self.F_pipeline.F_squant],
|
||||
F_occupancy=self.F_tile.F_occupancy,
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
F_mode=MODE_MAP[self.F_mode],
|
||||
F_pipeline=FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag],
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \
|
||||
self.F_tile.name + '_' + self.F_pipeline.name
|
||||
return (
|
||||
f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_"
|
||||
+ self.F_tile.name
|
||||
+ "_"
|
||||
+ self.F_pipeline.name
|
||||
)
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
@@ -463,35 +545,59 @@ class FmhaFwdKernel:
|
||||
|
||||
def api_trait(self) -> FmhaFwdApiTrait:
|
||||
return FmhaFwdApiTrait(
|
||||
pipeline_tag=self.F_pipeline.tag,
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
mode=self.F_mode,
|
||||
bm0=self.F_tile.F_bm0,
|
||||
bn0=self.F_tile.F_bn0,
|
||||
bk0=self.F_tile.F_bk0,
|
||||
bn1=self.F_tile.F_bn1,
|
||||
bk1=self.F_tile.F_bk1,
|
||||
bk0max=self.F_tile.F_bk0max,
|
||||
vlayout=self.F_pipeline.F_vlayout,
|
||||
mask=self.F_pipeline.F_mask,
|
||||
logits=self.F_pipeline.F_logits,
|
||||
bias=self.F_pipeline.F_bias,
|
||||
lse=self.F_pipeline.F_lse,
|
||||
dropout=self.F_pipeline.F_dropout,
|
||||
squant=self.F_pipeline.F_squant,
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint)
|
||||
pipeline_tag=self.F_pipeline.tag,
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
mode=self.F_mode,
|
||||
bm0=self.F_tile.F_bm0,
|
||||
bn0=self.F_tile.F_bn0,
|
||||
bk0=self.F_tile.F_bk0,
|
||||
bn1=self.F_tile.F_bn1,
|
||||
bk1=self.F_tile.F_bk1,
|
||||
bk0max=self.F_tile.F_bk0max,
|
||||
vlayout=self.F_pipeline.F_vlayout,
|
||||
mask=self.F_pipeline.F_mask,
|
||||
logits=self.F_pipeline.F_logits,
|
||||
bias=self.F_pipeline.F_bias,
|
||||
lse=self.F_pipeline.F_lse,
|
||||
dropout=self.F_pipeline.F_dropout,
|
||||
squant=self.F_pipeline.F_squant,
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
|
||||
)
|
||||
|
||||
|
||||
class KernelComponentFactory:
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
if dtype == "fp16" or dtype == "bf16":
|
||||
return {
|
||||
128 : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
128: [
|
||||
FmhaFwdTileSize(
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
32,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
-1,
|
||||
)
|
||||
],
|
||||
}
|
||||
else:
|
||||
return None
|
||||
@@ -502,28 +608,94 @@ class KernelComponentFactory:
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr pipeline, let 't' padding to appear later!!
|
||||
# TODO: how to design this more generic?
|
||||
squant = 't' if dtype == 'fp8' else 'f'
|
||||
squant = "t" if dtype == "fp8" else "f"
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for logits, mask, bias, lse, dropout in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
# pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
# pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
for logits, mask, bias, lse, dropout in itertools.product(
|
||||
["t", "f"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
BIAS_MAP.keys(),
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
):
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline(
|
||||
"qr_async",
|
||||
"row",
|
||||
"t",
|
||||
"f",
|
||||
"t",
|
||||
"t",
|
||||
logits,
|
||||
bias,
|
||||
lse,
|
||||
dropout,
|
||||
squant,
|
||||
mask,
|
||||
)
|
||||
)
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline(
|
||||
"qr_async",
|
||||
"row",
|
||||
"t",
|
||||
"t",
|
||||
"t",
|
||||
"t",
|
||||
logits,
|
||||
bias,
|
||||
lse,
|
||||
dropout,
|
||||
squant,
|
||||
mask,
|
||||
)
|
||||
)
|
||||
# pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
# pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
|
||||
class CustomFactory(KernelComponentFactory):
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]:
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
if dtype == "fp16" or dtype == "bf16":
|
||||
if 128 in result.keys():
|
||||
result[128].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint('get_num_blocks(128) < num_cus * min_cu_util_rate')))
|
||||
result[128].insert(
|
||||
0,
|
||||
FmhaFwdTileSize(
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
128,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
-1,
|
||||
CppConstraint(
|
||||
"get_num_blocks(128) < num_cus * min_cu_util_rate"
|
||||
),
|
||||
),
|
||||
)
|
||||
return result
|
||||
|
||||
def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
|
||||
def get_fwd_blobs(
|
||||
kernel_filter: Optional[str], receipt, optdim_list, mask_impl
|
||||
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
|
||||
@@ -532,30 +704,41 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
|
||||
for dtype in FWD_DTYPE_MAP.keys():
|
||||
d = CustomFactory.get_hdim_tile_size_dict(dtype)
|
||||
if d == None:
|
||||
if d is None:
|
||||
continue
|
||||
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
# for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for (hdim, tiles), mode in itertools.product(d.items(), MODE_MAP.keys()):
|
||||
for tile, pipeline in itertools.product(tiles, CustomFactory.get_pipelines(dtype, hdim, receipt, mask_impl)):
|
||||
for tile, pipeline in itertools.product(
|
||||
tiles, CustomFactory.get_pipelines(dtype, hdim, receipt, mask_impl)
|
||||
):
|
||||
if mode == "group":
|
||||
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
|
||||
if pipeline.F_spad != "t" or pipeline.F_skpad != "t":
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
continue
|
||||
if hdim == 192 and tile.F_bn1 == 128:
|
||||
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
|
||||
if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't':
|
||||
if (
|
||||
pipeline.F_bias != "no"
|
||||
or pipeline.F_lse == "t"
|
||||
or pipeline.F_dropout == "t"
|
||||
):
|
||||
continue
|
||||
# logits_soft_cap is only allowed if no bias
|
||||
if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'):
|
||||
if not (
|
||||
(pipeline.F_logits == "t" and pipeline.F_bias == "no")
|
||||
or pipeline.F_logits == "f"
|
||||
):
|
||||
continue
|
||||
k = FmhaFwdKernel(F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl)
|
||||
if kernel_filter != '':
|
||||
k = FmhaFwdKernel(
|
||||
F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl,
|
||||
)
|
||||
if kernel_filter != "":
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
@@ -563,48 +746,48 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
continue
|
||||
# 2 - Flash attention integration
|
||||
if receipt in (2, 3):
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'alibi']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_bias in ["no", "alibi"]
|
||||
cond &= pipeline.F_squant == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'bias']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_bias in ["no", "bias"]
|
||||
cond &= pipeline.F_squant == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
elif receipt == 100:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == 'batch'
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_squant == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_batch_prefill) integration
|
||||
elif receipt == 200:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == 'group'
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_squant == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_batch_prefill C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == 'group'
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_squant == "f"
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
# fp32 only
|
||||
if receipt == 800 or receipt == 801:
|
||||
cond = dtype == 'fp32'
|
||||
cond = dtype == "fp32"
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
@@ -613,20 +796,28 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
|
||||
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
|
||||
(autogen_dir / kernel.filename).write_text(kernel.template)
|
||||
|
||||
def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None:
|
||||
|
||||
def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None:
|
||||
(autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api)
|
||||
|
||||
def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None:
|
||||
|
||||
def write_blobs(
|
||||
output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl
|
||||
) -> None:
|
||||
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
for kernel in kernels:
|
||||
write_single_fwd_kernel(kernel, output_dir)
|
||||
write_fwd_api(api_pool, output_dir)
|
||||
|
||||
def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None:
|
||||
with file_path.open('a') as f:
|
||||
|
||||
def list_blobs(
|
||||
file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl
|
||||
) -> None:
|
||||
with file_path.open("a") as f:
|
||||
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -5,23 +5,27 @@
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
import fnmatch
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from codegen.cmake_config import *
|
||||
from codegen.cpp_symbol_map import *
|
||||
from codegen.cmake_config import GEN_DIR
|
||||
from codegen.cpp_symbol_map import (
|
||||
FWD_DTYPE_MAP,
|
||||
BOOL_MAP,
|
||||
ROPE_MAP,
|
||||
LAYOUT_MAP,
|
||||
ROPE_CHECK_MAP,
|
||||
)
|
||||
|
||||
from codegen.ops.fmha_fwd import (
|
||||
FmhaFwdApiTrait,
|
||||
DTYPE_BITS,
|
||||
FMHA_FWD_KERNEL_HEADER,
|
||||
FMHA_FWD_API_PER_DTYPE,
|
||||
FMHA_FWD_API_PER_HDIM_CASE,
|
||||
)
|
||||
|
||||
|
||||
FMHA_FWD_APPENDKV_KERNEL_BODY="""
|
||||
FMHA_FWD_APPENDKV_KERNEL_BODY = """
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdAppendKVTraits<{F_spad},
|
||||
@@ -66,8 +70,8 @@ float fmha_fwd_appendkv_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fw
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_APPENDKV_API_FILENAME="fmha_fwd_appendkv_api.cpp"
|
||||
FMHA_FWD_APPENDKV_API="""
|
||||
FMHA_FWD_APPENDKV_API_FILENAME = "fmha_fwd_appendkv_api.cpp"
|
||||
FMHA_FWD_APPENDKV_API = """
|
||||
float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, const ck_tile::stream_config& s){{
|
||||
float r = -1;
|
||||
{F_dispatch}
|
||||
@@ -75,7 +79,7 @@ float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, co
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_APPENDKV_API_INNER_DISPATCH=""" {F_if}((t.is_v_rowmajor == {F_vlayout}) &&
|
||||
FMHA_FWD_APPENDKV_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.rope_type == {F_rope_check}) &&
|
||||
((a.block_table_ptr != nullptr) == {F_pagedkv})) {{
|
||||
using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
|
||||
@@ -83,81 +87,101 @@ FMHA_FWD_APPENDKV_API_INNER_DISPATCH=""" {F_if}((t.is_v_rowmajor == {
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdAppendKVApiTrait:
|
||||
# sync with fmha_fwd_traits<>, to generate fallback calls
|
||||
hdim : str
|
||||
dtype : str # data type
|
||||
bs : int # tile size along q seqlen
|
||||
bsk : int # tile size along k seqlen
|
||||
bd : int # tile size along qk gemm unroll
|
||||
bdv : int # tile size along kv gemm unroll
|
||||
vlayout : str
|
||||
spad : str
|
||||
skpad : str
|
||||
dpad : str
|
||||
dvpad : str
|
||||
rope : str # key from ROPE_MAP
|
||||
pagedkv : str
|
||||
hdim: str
|
||||
dtype: str # data type
|
||||
bs: int # tile size along q seqlen
|
||||
bsk: int # tile size along k seqlen
|
||||
bd: int # tile size along qk gemm unroll
|
||||
bdv: int # tile size along kv gemm unroll
|
||||
vlayout: str
|
||||
spad: str
|
||||
skpad: str
|
||||
dpad: str
|
||||
dvpad: str
|
||||
rope: str # key from ROPE_MAP
|
||||
pagedkv: str
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f'{self.hdim}-{self.dtype}-{self.bs}-{self.bsk}-{self.bd}-{self.bdv}-{self.vlayout}-'+\
|
||||
f'{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.rope}-{self.pagedkv}'
|
||||
return (
|
||||
f"{self.hdim}-{self.dtype}-{self.bs}-{self.bsk}-{self.bd}-{self.bdv}-{self.vlayout}-"
|
||||
+ f"{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.rope}-{self.pagedkv}"
|
||||
)
|
||||
|
||||
@property
|
||||
def scheck(self) -> str:
|
||||
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bs} != 0*/'
|
||||
else : return f'a.seqlen_q % {self.bs} == 0'
|
||||
if self.spad == "t":
|
||||
return f"true /*a.seqlen_q % {self.bs} != 0*/"
|
||||
else:
|
||||
return f"a.seqlen_q % {self.bs} == 0"
|
||||
|
||||
@property
|
||||
def skcheck(self) -> str:
|
||||
# we do not check all the values in a.seqlen_k_ptr
|
||||
return 'true'
|
||||
return "true"
|
||||
|
||||
@property
|
||||
def dcheck(self) -> str:
|
||||
if self.dpad == 't': return f'true /*a.hdim_q % {self.bd} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.hdim_q % {self.bd} == 0'
|
||||
if self.dpad == "t":
|
||||
return f"true /*a.hdim_q % {self.bd} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
|
||||
else:
|
||||
return f"a.hdim_q % {self.bd} == 0"
|
||||
|
||||
@property
|
||||
def dvcheck(self) -> str:
|
||||
if self.dvpad == 't': return f'true /*a.hdim_v % {self.bdv} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.hdim_v % {self.bdv} == 0'
|
||||
if self.dvpad == "t":
|
||||
return f"true /*a.hdim_v % {self.bdv} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
|
||||
else:
|
||||
return f"a.hdim_v % {self.bdv} == 0"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdAppendKVPipeline:
|
||||
F_vlayout : str # row/col
|
||||
F_spad : str # true/false
|
||||
F_skpad : str #
|
||||
F_dpad : str #
|
||||
F_dvpad : str #
|
||||
F_rope : str # key from ROPE_MAP
|
||||
F_pagedkv : str # t/f
|
||||
F_vlayout: str # row/col
|
||||
F_spad: str # true/false
|
||||
F_skpad: str #
|
||||
F_dpad: str #
|
||||
F_dvpad: str #
|
||||
F_rope: str # key from ROPE_MAP
|
||||
F_pagedkv: str # t/f
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def pad_name() -> str:
|
||||
n = ''
|
||||
if self.F_spad == 't': n += 's'
|
||||
if self.F_skpad == 't' : n += 'sk'
|
||||
if self.F_dpad == 't' : n += 'd'
|
||||
if self.F_dvpad == 't' : n += 'dv'
|
||||
if n != '' : n = 'p' + n
|
||||
n = ""
|
||||
if self.F_spad == "t":
|
||||
n += "s"
|
||||
if self.F_skpad == "t":
|
||||
n += "sk"
|
||||
if self.F_dpad == "t":
|
||||
n += "d"
|
||||
if self.F_dvpad == "t":
|
||||
n += "dv"
|
||||
if n != "":
|
||||
n = "p" + n
|
||||
return n
|
||||
|
||||
pn = pad_name()
|
||||
n = f'v{self.F_vlayout[0]}'
|
||||
if pn != '' : n += f'_{pn}'
|
||||
if self.F_rope != 'no': n += f'_{self.F_rope}'
|
||||
if self.F_pagedkv == 't': n += '_pagedkv'
|
||||
n = f"v{self.F_vlayout[0]}"
|
||||
if pn != "":
|
||||
n += f"_{pn}"
|
||||
if self.F_rope != "no":
|
||||
n += f"_{self.F_rope}"
|
||||
if self.F_pagedkv == "t":
|
||||
n += "_pagedkv"
|
||||
return n
|
||||
|
||||
|
||||
class FmhaFwdAppendKVApiPool:
|
||||
def __init__(self, mask_impl):
|
||||
self.pool = dict()
|
||||
self.mask_impl = mask_impl
|
||||
|
||||
def register_traits(self, trait : FmhaFwdApiTrait) -> None:
|
||||
def register_traits(self, trait: FmhaFwdApiTrait) -> None:
|
||||
# TODO: do we need to check duplication?
|
||||
if trait.dtype not in self.pool.keys():
|
||||
self.pool[trait.dtype] = dict()
|
||||
@@ -168,74 +192,104 @@ class FmhaFwdAppendKVApiPool:
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
per_dtypes=str()
|
||||
per_dtypes = str()
|
||||
for i, dtype in enumerate(self.pool.keys()):
|
||||
per_hdim_case=str()
|
||||
per_hdim_case = str()
|
||||
for j, hdim in enumerate(self.pool[dtype].keys()):
|
||||
traits=self.pool[dtype][hdim]
|
||||
inners=str()
|
||||
traits = self.pool[dtype][hdim]
|
||||
inners = str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = 'if' if k == 0 else 'else if'
|
||||
inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(F_if=if_k, F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_rope_check=ROPE_CHECK_MAP[trait.rope],
|
||||
F_pagedkv=BOOL_MAP[trait.pagedkv], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
if_k = "if" if k == 0 else "else if"
|
||||
inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(
|
||||
F_if=if_k,
|
||||
F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_scheck=trait.scheck,
|
||||
F_skcheck=trait.skcheck,
|
||||
F_dcheck=trait.dcheck,
|
||||
F_dvcheck=trait.dvcheck,
|
||||
F_rope_check=ROPE_CHECK_MAP[trait.rope],
|
||||
F_pagedkv=BOOL_MAP[trait.pagedkv],
|
||||
F_spad=BOOL_MAP[trait.spad],
|
||||
F_skpad=BOOL_MAP[trait.skpad],
|
||||
F_dpad=BOOL_MAP[trait.dpad],
|
||||
F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_rope=ROPE_MAP[trait.rope],
|
||||
F_bs=trait.bs,
|
||||
F_bsk=trait.bsk,
|
||||
F_bd=trait.bd,
|
||||
F_bdv=trait.bdv,
|
||||
F_hdim=hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[dtype],
|
||||
)
|
||||
if_j = "if" if j == 0 else "else if"
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
|
||||
F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners
|
||||
)
|
||||
if_i = "if" if i == 0 else "else if"
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(
|
||||
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
|
||||
)
|
||||
if not per_dtypes:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_dtypes += ' (void)t ; (void)s ; (void)a;'
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_API.format(F_dispatch = per_dtypes)
|
||||
per_dtypes += " (void)t ; (void)s ; (void)a;"
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_API.format(
|
||||
F_dispatch=per_dtypes
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdAppendKVTileSize:
|
||||
F_bs : int # tile size along q seqlen
|
||||
F_bsk : int # tile size along k seqlen
|
||||
F_bd : int # tile size along qk gemm unroll
|
||||
F_bdv : int # tile size along kv gemm unroll
|
||||
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
||||
F_bs: int # tile size along q seqlen
|
||||
F_bsk: int # tile size along k seqlen
|
||||
F_bd: int # tile size along qk gemm unroll
|
||||
F_bdv: int # tile size along kv gemm unroll
|
||||
F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f"b{self.F_bs}x{self.F_bsk}x{self.F_bd}x{self.F_bdv}" +\
|
||||
("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
|
||||
return f"b{self.F_bs}x{self.F_bsk}x{self.F_bd}x{self.F_bdv}" + (
|
||||
"" if self.F_occupancy == -1 else f"_o{self.F_occupancy}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdAppendKVKernel:
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim : int # hdim
|
||||
F_dtype : str # data type
|
||||
F_tile : FmhaFwdAppendKVTileSize
|
||||
F_pipeline : FmhaFwdAppendKVPipeline
|
||||
mask_impl : str
|
||||
F_idx: int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim: int # hdim
|
||||
F_dtype: str # data type
|
||||
F_tile: FmhaFwdAppendKVTileSize
|
||||
F_pipeline: FmhaFwdAppendKVPipeline
|
||||
mask_impl: str
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
kernel_body = str()
|
||||
return FMHA_FWD_KERNEL_HEADER + \
|
||||
FMHA_FWD_APPENDKV_KERNEL_BODY.format(
|
||||
F_idx = self.F_idx,
|
||||
F_hdim = self.F_hdim,
|
||||
F_dtype = FWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bs = self.F_tile.F_bs,
|
||||
F_bsk = self.F_tile.F_bsk,
|
||||
F_bd = self.F_tile.F_bd,
|
||||
F_bdv = self.F_tile.F_bdv,
|
||||
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
|
||||
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
|
||||
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_rope = ROPE_MAP[self.F_pipeline.F_rope],
|
||||
F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv],
|
||||
F_occupancy = self.F_tile.F_occupancy)
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_KERNEL_BODY.format(
|
||||
F_idx=self.F_idx,
|
||||
F_hdim=self.F_hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bs=self.F_tile.F_bs,
|
||||
F_bsk=self.F_tile.F_bsk,
|
||||
F_bd=self.F_tile.F_bd,
|
||||
F_bdv=self.F_tile.F_bdv,
|
||||
F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout],
|
||||
F_spad=BOOL_MAP[self.F_pipeline.F_spad],
|
||||
F_skpad=BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad=BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_rope=ROPE_MAP[self.F_pipeline.F_rope],
|
||||
F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv],
|
||||
F_occupancy=self.F_tile.F_occupancy,
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return f"fmha_fwd_appendkv_d{self.F_hdim}_{self.F_dtype}_" + \
|
||||
self.F_tile.name + '_' + self.F_pipeline.name
|
||||
return (
|
||||
f"fmha_fwd_appendkv_d{self.F_hdim}_{self.F_dtype}_"
|
||||
+ self.F_tile.name
|
||||
+ "_"
|
||||
+ self.F_pipeline.name
|
||||
)
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
@@ -243,40 +297,45 @@ class FmhaFwdAppendKVKernel:
|
||||
|
||||
def api_trait(self) -> FmhaFwdAppendKVApiTrait:
|
||||
return FmhaFwdAppendKVApiTrait(
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
bs=self.F_tile.F_bs,
|
||||
bsk=self.F_tile.F_bsk,
|
||||
bd=self.F_tile.F_bd,
|
||||
bdv=self.F_tile.F_bdv,
|
||||
vlayout=self.F_pipeline.F_vlayout,
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
rope=self.F_pipeline.F_rope,
|
||||
pagedkv=self.F_pipeline.F_pagedkv)
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
bs=self.F_tile.F_bs,
|
||||
bsk=self.F_tile.F_bsk,
|
||||
bd=self.F_tile.F_bd,
|
||||
bdv=self.F_tile.F_bdv,
|
||||
vlayout=self.F_pipeline.F_vlayout,
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
rope=self.F_pipeline.F_rope,
|
||||
pagedkv=self.F_pipeline.F_pagedkv,
|
||||
)
|
||||
|
||||
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size per hdim
|
||||
def get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
def get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype: str) -> Optional[dict]:
|
||||
if dtype == "fp16" or dtype == "bf16":
|
||||
return {
|
||||
'32' : FmhaFwdAppendKVTileSize(64, 64, 32, 32, -1),
|
||||
'64' : FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1),
|
||||
'128' : FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1),
|
||||
'256' : FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1),
|
||||
"32": FmhaFwdAppendKVTileSize(64, 64, 32, 32, -1),
|
||||
"64": FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1),
|
||||
"128": FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1),
|
||||
"256": FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
elif dtype == "fp8" or dtype == "bf8":
|
||||
return {
|
||||
'64' : FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1),
|
||||
'128' : FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1),
|
||||
'256' : FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1)
|
||||
"64": FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1),
|
||||
"128": FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1),
|
||||
"256": FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1),
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, optdim_list) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]:
|
||||
|
||||
def get_fwd_appendkv_blobs(
|
||||
kernel_filter: Optional[str], receipt, mask_impl, optdim_list
|
||||
) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
def get_pipelines(dtype, hdim) -> List[FmhaFwdAppendKVPipeline]:
|
||||
@@ -284,25 +343,50 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, op
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr pipeline, let 't' padding to appear later!!
|
||||
# TODO: how to design this more generic?
|
||||
squant = 't' if dtype == 'fp8' else 'f'
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
# NOTICE: it will be very complicated if we consider all the hdim_q padding cases while
|
||||
# applying rotary embedding, so I just use 't' in inter/half pipelines
|
||||
for vlayout in ['row', 'col']:
|
||||
for vlayout in ["row", "col"]:
|
||||
for pagedkv in ["t", "f"]:
|
||||
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 'f', 'f', 'no', pagedkv))
|
||||
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'no', pagedkv))
|
||||
pipelines.append(
|
||||
FmhaFwdAppendKVPipeline(
|
||||
vlayout, "f", "t", "f", "f", "no", pagedkv
|
||||
)
|
||||
)
|
||||
pipelines.append(
|
||||
FmhaFwdAppendKVPipeline(
|
||||
vlayout, "t", "t", "t", "t", "no", pagedkv
|
||||
)
|
||||
)
|
||||
|
||||
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 't', 'f', 'inter', pagedkv))
|
||||
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'inter', pagedkv))
|
||||
pipelines.append(
|
||||
FmhaFwdAppendKVPipeline(
|
||||
vlayout, "f", "t", "t", "f", "inter", pagedkv
|
||||
)
|
||||
)
|
||||
pipelines.append(
|
||||
FmhaFwdAppendKVPipeline(
|
||||
vlayout, "t", "t", "t", "t", "inter", pagedkv
|
||||
)
|
||||
)
|
||||
|
||||
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 't', 'f', 'half', pagedkv))
|
||||
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'half', pagedkv))
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
pipelines.append(
|
||||
FmhaFwdAppendKVPipeline(
|
||||
vlayout, "f", "t", "t", "f", "half", pagedkv
|
||||
)
|
||||
)
|
||||
pipelines.append(
|
||||
FmhaFwdAppendKVPipeline(
|
||||
vlayout, "t", "t", "t", "t", "half", pagedkv
|
||||
)
|
||||
)
|
||||
elif dtype in ["fp8", "bf8"]:
|
||||
# rope/paged-kv is not supported
|
||||
pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't', 'no', 'f'))
|
||||
elif dtype in ['fp8fp16', 'fp8bf16']:
|
||||
pipelines.append(
|
||||
FmhaFwdAppendKVPipeline("col", "t", "t", "t", "t", "no", "f")
|
||||
)
|
||||
elif dtype in ["fp8fp16", "fp8bf16"]:
|
||||
# TODO
|
||||
None
|
||||
else:
|
||||
@@ -314,19 +398,21 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, op
|
||||
|
||||
for dtype in FWD_DTYPE_MAP.keys():
|
||||
d = get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype)
|
||||
if d == None:
|
||||
if d is None:
|
||||
continue
|
||||
for hdim_str in d.keys():
|
||||
tile = d[hdim_str]
|
||||
hdim = int(hdim_str)
|
||||
for pipeline in get_pipelines(dtype, hdim):
|
||||
k = FmhaFwdAppendKVKernel(F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl)
|
||||
if kernel_filter != '':
|
||||
k = FmhaFwdAppendKVKernel(
|
||||
F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl,
|
||||
)
|
||||
if kernel_filter != "":
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
@@ -334,20 +420,20 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, op
|
||||
continue
|
||||
# 2 - Flash attention integration
|
||||
if receipt == 2:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
# fp32 only
|
||||
if receipt == 800 or receipt == 801:
|
||||
cond = dtype == 'fp32'
|
||||
cond = dtype == "fp32"
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
@@ -356,21 +442,33 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, op
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
|
||||
def write_single_kernel(kernel: FmhaFwdAppendKVKernel, autogen_dir: Path) -> None:
|
||||
(autogen_dir / kernel.filename).write_text(kernel.template)
|
||||
|
||||
def write_fwd_appendkv_api(api_pool : FmhaFwdAppendKVApiPool, autogen_dir: Path) -> None:
|
||||
|
||||
def write_fwd_appendkv_api(api_pool: FmhaFwdAppendKVApiPool, autogen_dir: Path) -> None:
|
||||
(autogen_dir / FMHA_FWD_APPENDKV_API_FILENAME).write_text(api_pool.api)
|
||||
|
||||
def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> None:
|
||||
api_pool, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl, optdim_list)
|
||||
|
||||
def write_blobs(
|
||||
output_dir: Path, kernel_filter: Optional[str], receipt, optdim_list, mask_impl
|
||||
) -> None:
|
||||
api_pool, kernels = get_fwd_appendkv_blobs(
|
||||
kernel_filter, receipt, mask_impl, optdim_list
|
||||
)
|
||||
for kernel in kernels:
|
||||
write_single_kernel(kernel, output_dir)
|
||||
write_fwd_appendkv_api(api_pool, output_dir)
|
||||
|
||||
def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> None:
|
||||
with file_path.open('a') as f:
|
||||
_, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl, optdim_list)
|
||||
|
||||
def list_blobs(
|
||||
file_path: Path, kernel_filter: Optional[str], receipt, optdim_list, mask_impl
|
||||
) -> None:
|
||||
with file_path.open("a") as f:
|
||||
_, kernels = get_fwd_appendkv_blobs(
|
||||
kernel_filter, receipt, mask_impl, optdim_list
|
||||
)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME) + "\n")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -9,28 +9,26 @@ import itertools
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from codegen.cmake_config import *
|
||||
from codegen.cpp_symbol_map import *
|
||||
from codegen.cmake_config import GEN_DIR
|
||||
from codegen.cpp_symbol_map import (
|
||||
LAYOUT_MAP,
|
||||
BIAS_CHECK_MAP,
|
||||
get_mask_check_map,
|
||||
MODE_MAP,
|
||||
get_mask_map,
|
||||
BIAS_MAP,
|
||||
FWD_DTYPE_MAP,
|
||||
BOOL_MAP,
|
||||
PIPELINE_ENUM_MAP,
|
||||
)
|
||||
|
||||
|
||||
DTYPE_BITS = {
|
||||
"fp32": 32,
|
||||
"fp16": 16,
|
||||
"bf16": 16,
|
||||
"fp8" : 8,
|
||||
"bf8" : 8
|
||||
}
|
||||
DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8}
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {
|
||||
32 : 32,
|
||||
64 : 64,
|
||||
96 : 128,
|
||||
128: 128,
|
||||
256: 256
|
||||
}
|
||||
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}
|
||||
|
||||
FMHA_FWD_PAGEDKV_PIPELINE_MAP = {
|
||||
"qr_pagedkv" : "ck_tile::BlockFmhaFwdPagedKVPipelineQRKSVS"
|
||||
"qr_pagedkv": "ck_tile::BlockFmhaFwdPagedKVPipelineQRKSVS"
|
||||
}
|
||||
|
||||
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
@@ -40,7 +38,7 @@ FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
#include "fmha_fwd.hpp"
|
||||
"""
|
||||
|
||||
FMHA_FWD_KERNEL_BODY="""
|
||||
FMHA_FWD_KERNEL_BODY = """
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
|
||||
@@ -115,8 +113,8 @@ float fmha_fwd_pagedkv_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_FILENAME="fmha_fwd_pagedkv_api.cpp"
|
||||
FMHA_FWD_API="""
|
||||
FMHA_FWD_API_FILENAME = "fmha_fwd_pagedkv_api.cpp"
|
||||
FMHA_FWD_API = """
|
||||
float fmha_fwd_pagedkv(fmha_fwd_pagedkv_traits& t, fmha_fwd_pagedkv_args& a, const ck_tile::stream_config& s){{
|
||||
float r = -1;
|
||||
{F_dispatch}
|
||||
@@ -124,164 +122,215 @@ float fmha_fwd_pagedkv(fmha_fwd_pagedkv_traits& t, fmha_fwd_pagedkv_args& a, con
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
|
||||
FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
|
||||
{F_hdim_case}
|
||||
}}
|
||||
"""
|
||||
FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
|
||||
FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
|
||||
{F_inner_dispatch}
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
|
||||
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
|
||||
return fmha_fwd_pagedkv_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdApiTrait:
|
||||
pipeline_tag : str
|
||||
pipeline_tag: str
|
||||
# sync with fmha_fwd_traits<>, to generate fallback calls
|
||||
hdim : str
|
||||
dtype : str # data type
|
||||
mode : str # value from MODE_MAP
|
||||
bm0 : int # tile size along q seqlen (block size)
|
||||
bn0 : int # tile size along qk seqlen
|
||||
bk0 : int # tile size along qk gemm unroll
|
||||
bn1 : int # tile size along v head_dim
|
||||
bk1 : int # tile size along kv gemm unroll
|
||||
bk0max : int
|
||||
vlayout : str
|
||||
logits : str
|
||||
mask : str
|
||||
bias : str #
|
||||
lse : str #
|
||||
pagedkv : str
|
||||
squant : str #
|
||||
spad : str
|
||||
skpad : str
|
||||
dpad : str
|
||||
dvpad : str
|
||||
skip : str
|
||||
hdim: str
|
||||
dtype: str # data type
|
||||
mode: str # value from MODE_MAP
|
||||
bm0: int # tile size along q seqlen (block size)
|
||||
bn0: int # tile size along qk seqlen
|
||||
bk0: int # tile size along qk gemm unroll
|
||||
bn1: int # tile size along v head_dim
|
||||
bk1: int # tile size along kv gemm unroll
|
||||
bk0max: int
|
||||
vlayout: str
|
||||
logits: str
|
||||
mask: str
|
||||
bias: str #
|
||||
lse: str #
|
||||
pagedkv: str
|
||||
squant: str #
|
||||
spad: str
|
||||
skpad: str
|
||||
dpad: str
|
||||
dvpad: str
|
||||
skip: str
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
|
||||
f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}'
|
||||
return (
|
||||
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
|
||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}"
|
||||
)
|
||||
|
||||
@property
|
||||
def scheck(self) -> str:
|
||||
if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
if self.spad == 't' : return 'true' # always support
|
||||
else : return 'true'
|
||||
elif self.pipeline_tag in ['qr_pagedkv', 'qs']:
|
||||
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.seqlen_q % {self.bm0} == 0'
|
||||
else: assert False
|
||||
if self.mode == "group":
|
||||
return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true
|
||||
if self.pipeline_tag == "qr_async":
|
||||
if self.spad == "t":
|
||||
return "true" # always support
|
||||
else:
|
||||
return "true"
|
||||
elif self.pipeline_tag in ["qr_pagedkv", "qs"]:
|
||||
if self.spad == "t":
|
||||
return f"true /*a.seqlen_q % {self.bm0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
|
||||
else:
|
||||
return f"a.seqlen_q % {self.bm0} == 0"
|
||||
else:
|
||||
assert False
|
||||
|
||||
@property
|
||||
def skcheck(self) -> str:
|
||||
if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0'
|
||||
else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0'
|
||||
elif self.pipeline_tag in ['qr_pagedkv', 'qs']:
|
||||
if self.skpad == 't' : return f'true /*a.seqlen_k_ptr != nullptr || a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.seqlen_k_ptr == nullptr && a.seqlen_k % {self.bn0} == 0'
|
||||
else: assert False
|
||||
if self.mode == "group":
|
||||
return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true
|
||||
if self.pipeline_tag == "qr_async":
|
||||
if self.skpad == "t":
|
||||
return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0"
|
||||
else:
|
||||
return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0"
|
||||
elif self.pipeline_tag in ["qr_pagedkv", "qs"]:
|
||||
if self.skpad == "t":
|
||||
return f"true /*a.seqlen_k_ptr != nullptr || a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
|
||||
else:
|
||||
return f"a.seqlen_k_ptr == nullptr && a.seqlen_k % {self.bn0} == 0"
|
||||
else:
|
||||
assert False
|
||||
|
||||
@property
|
||||
def dcheck(self) -> str:
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
if self.pipeline_tag == "qr_async":
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
|
||||
else : assert False
|
||||
elif self.pipeline_tag in ['qr_pagedkv', 'qs']:
|
||||
if self.dpad == "t":
|
||||
return f"a.hdim_q % {vec} == 0"
|
||||
else:
|
||||
assert False
|
||||
elif self.pipeline_tag in ["qr_pagedkv", "qs"]:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.hdim_q % {bk0submax} == 0'
|
||||
else: assert False
|
||||
if self.dpad == "t":
|
||||
return f"true /*a.hdim_q % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
|
||||
else:
|
||||
return f"a.hdim_q % {bk0submax} == 0"
|
||||
else:
|
||||
assert False
|
||||
|
||||
@property
|
||||
def dvcheck(self) -> str:
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
if self.pipeline_tag == "qr_async":
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
|
||||
else : assert False
|
||||
elif self.pipeline_tag in ['qr_pagedkv', 'qs']:
|
||||
if self.dvpad == "t":
|
||||
return f"a.hdim_v % {vec} == 0"
|
||||
else:
|
||||
assert False
|
||||
elif self.pipeline_tag in ["qr_pagedkv", "qs"]:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.hdim_v % {bk0submax} == 0'
|
||||
else: assert False
|
||||
if self.dvpad == "t":
|
||||
return f"true /*a.hdim_v % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
|
||||
else:
|
||||
return f"a.hdim_v % {bk0submax} == 0"
|
||||
else:
|
||||
assert False
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdPipeline:
|
||||
tag : str
|
||||
tag: str
|
||||
|
||||
F_vlayout : str # row/col
|
||||
F_spad : str # true/false
|
||||
F_skpad : str #
|
||||
F_dpad : str #
|
||||
F_dvpad : str #
|
||||
F_logits : str # t/f
|
||||
F_bias : str # true/false
|
||||
F_lse : str #
|
||||
F_pagedkv : str #
|
||||
F_squant : str #
|
||||
F_mask : str # value from MASK_MAP
|
||||
F_skip : str # true/false
|
||||
F_vlayout: str # row/col
|
||||
F_spad: str # true/false
|
||||
F_skpad: str #
|
||||
F_dpad: str #
|
||||
F_dvpad: str #
|
||||
F_logits: str # t/f
|
||||
F_bias: str # true/false
|
||||
F_lse: str #
|
||||
F_pagedkv: str #
|
||||
F_squant: str #
|
||||
F_mask: str # value from MASK_MAP
|
||||
F_skip: str # true/false
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def pad_name() -> str:
|
||||
n = ''
|
||||
if self.F_spad == 't': n += 's'
|
||||
if self.F_skpad == 't' : n += 'sk'
|
||||
if self.F_dpad == 't' : n += 'd'
|
||||
if self.F_dvpad == 't' : n += 'dv'
|
||||
if n != '' : n = 'p' + n
|
||||
n = ""
|
||||
if self.F_spad == "t":
|
||||
n += "s"
|
||||
if self.F_skpad == "t":
|
||||
n += "sk"
|
||||
if self.F_dpad == "t":
|
||||
n += "d"
|
||||
if self.F_dvpad == "t":
|
||||
n += "dv"
|
||||
if n != "":
|
||||
n = "p" + n
|
||||
return n
|
||||
|
||||
pn = pad_name()
|
||||
n = f'{self.tag}_v{self.F_vlayout[0]}'
|
||||
if pn != '' : n += f'_{pn}'
|
||||
else: n += '_npad'
|
||||
|
||||
if self.F_logits == 't' : n += '_logits'
|
||||
else: n += '_nlogits'
|
||||
|
||||
if self.F_bias != 'no' : n += f'_{self.F_bias}'
|
||||
else: n += '_nbias'
|
||||
|
||||
if self.F_mask[0:2] == 's_':
|
||||
if self.F_mask == 's_mask': n += f'_mask'
|
||||
else: n += '_nmask'
|
||||
n = f"{self.tag}_v{self.F_vlayout[0]}"
|
||||
if pn != "":
|
||||
n += f"_{pn}"
|
||||
else:
|
||||
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
|
||||
else: n += '_nmask'
|
||||
n += "_npad"
|
||||
|
||||
if self.F_lse == 't' : n += '_lse'
|
||||
else: n += '_nlse'
|
||||
if self.F_logits == "t":
|
||||
n += "_logits"
|
||||
else:
|
||||
n += "_nlogits"
|
||||
|
||||
if self.F_skip == 't' : n += '_skip'
|
||||
else: n += '_nskip'
|
||||
if self.F_bias != "no":
|
||||
n += f"_{self.F_bias}"
|
||||
else:
|
||||
n += "_nbias"
|
||||
|
||||
if self.F_squant == 't' : n += '_squant'
|
||||
else: n += '_nsquant'
|
||||
if self.F_mask[0:2] == "s_":
|
||||
if self.F_mask == "s_mask":
|
||||
n += "_mask"
|
||||
else:
|
||||
n += "_nmask"
|
||||
else:
|
||||
if self.F_mask != "no":
|
||||
n += f"_m{self.F_mask[0]}"
|
||||
else:
|
||||
n += "_nmask"
|
||||
|
||||
if self.F_pagedkv == 't' : n += '_pagedkv'
|
||||
else: n += '_npagedkv'
|
||||
if self.F_lse == "t":
|
||||
n += "_lse"
|
||||
else:
|
||||
n += "_nlse"
|
||||
|
||||
if self.F_skip == "t":
|
||||
n += "_skip"
|
||||
else:
|
||||
n += "_nskip"
|
||||
|
||||
if self.F_squant == "t":
|
||||
n += "_squant"
|
||||
else:
|
||||
n += "_nsquant"
|
||||
|
||||
if self.F_pagedkv == "t":
|
||||
n += "_pagedkv"
|
||||
else:
|
||||
n += "_npagedkv"
|
||||
|
||||
return n
|
||||
|
||||
|
||||
class FmhaFwdApiPool:
|
||||
def __init__(self, mask_impl):
|
||||
self.pool = dict()
|
||||
self.mask_impl = mask_impl
|
||||
|
||||
def register_traits(self, trait : FmhaFwdApiTrait) -> None:
|
||||
def register_traits(self, trait: FmhaFwdApiTrait) -> None:
|
||||
# TODO: do we need to check duplication?
|
||||
if trait.dtype not in self.pool.keys():
|
||||
self.pool[trait.dtype] = dict()
|
||||
@@ -292,117 +341,152 @@ class FmhaFwdApiPool:
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
per_dtypes=str()
|
||||
per_dtypes = str()
|
||||
for i, dtype in enumerate(self.pool.keys()):
|
||||
per_hdim_case=str()
|
||||
per_hdim_case = str()
|
||||
for j, hdim in enumerate(self.pool[dtype].keys()):
|
||||
traits=self.pool[dtype][hdim]
|
||||
inners=str()
|
||||
traits = self.pool[dtype][hdim]
|
||||
inners = str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = 'if' if k == 0 else 'else if'
|
||||
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
|
||||
F_lse=BOOL_MAP[trait.lse], F_pagedkv=BOOL_MAP[trait.pagedkv], F_skip=BOOL_MAP[trait.skip],
|
||||
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
|
||||
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
if_k = "if" if k == 0 else "else if"
|
||||
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(
|
||||
F_if=if_k,
|
||||
F_mode=MODE_MAP[trait.mode],
|
||||
F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
|
||||
F_logits=BOOL_MAP[trait.logits],
|
||||
F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
|
||||
F_bias_check=BIAS_CHECK_MAP[trait.bias],
|
||||
F_bias=BIAS_MAP[trait.bias],
|
||||
F_lse=BOOL_MAP[trait.lse],
|
||||
F_pagedkv=BOOL_MAP[trait.pagedkv],
|
||||
F_skip=BOOL_MAP[trait.skip],
|
||||
F_squant=BOOL_MAP[trait.squant],
|
||||
F_scheck=trait.scheck,
|
||||
F_skcheck=trait.skcheck,
|
||||
F_dcheck=trait.dcheck,
|
||||
F_dvcheck=trait.dvcheck,
|
||||
F_spad=BOOL_MAP[trait.spad],
|
||||
F_skpad=BOOL_MAP[trait.skpad],
|
||||
F_dpad=BOOL_MAP[trait.dpad],
|
||||
F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0,
|
||||
F_bn0=trait.bn0,
|
||||
F_bk0=trait.bk0,
|
||||
F_bn1=trait.bn1,
|
||||
F_bk1=trait.bk1,
|
||||
F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[dtype],
|
||||
)
|
||||
if_j = "if" if j == 0 else "else if"
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
|
||||
F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners
|
||||
)
|
||||
if_i = "if" if i == 0 else "else if"
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(
|
||||
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
|
||||
)
|
||||
if not per_dtypes:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_dtypes += ' (void)t ; (void)s ; (void)a;'
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes)
|
||||
per_dtypes += " (void)t ; (void)s ; (void)a;"
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_dtypes)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdTileSize:
|
||||
F_bm0 : int # tile size along q seqlen (block size)
|
||||
F_bn0 : int # tile size along k seqlen
|
||||
F_bk0 : int # tile size along qk gemm unroll
|
||||
F_bn1 : int # tile size along v head_dim
|
||||
F_bk1 : int # tile size along kv gemm unroll
|
||||
F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
|
||||
F_rm0 : int # number of warps for gemm0 along q seqlen
|
||||
F_rn0 : int # number of warps for gemm0 along k seqlen
|
||||
F_rk0 : int # number of warps for gemm0 along head dim q (not used)
|
||||
F_rm1 : int # number of warps for gemm1 along q seqlen
|
||||
F_rn1 : int # number of warps for gemm1 along head dim v
|
||||
F_rk1 : int # number of warps for gemm1 along k seqlen (not used)
|
||||
F_wm0 : int # gemm0 warp size along m
|
||||
F_wn0 : int # gemm0 warp size along n
|
||||
F_wk0 : int # gemm0 warp size along k
|
||||
F_wm1 : int # gemm1 warp size along m
|
||||
F_wn1 : int # gemm1 warp size along n
|
||||
F_wk1 : int # gemm1 warp size along k
|
||||
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
||||
F_bm0: int # tile size along q seqlen (block size)
|
||||
F_bn0: int # tile size along k seqlen
|
||||
F_bk0: int # tile size along qk gemm unroll
|
||||
F_bn1: int # tile size along v head_dim
|
||||
F_bk1: int # tile size along kv gemm unroll
|
||||
F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
|
||||
F_rm0: int # number of warps for gemm0 along q seqlen
|
||||
F_rn0: int # number of warps for gemm0 along k seqlen
|
||||
F_rk0: int # number of warps for gemm0 along head dim q (not used)
|
||||
F_rm1: int # number of warps for gemm1 along q seqlen
|
||||
F_rn1: int # number of warps for gemm1 along head dim v
|
||||
F_rk1: int # number of warps for gemm1 along k seqlen (not used)
|
||||
F_wm0: int # gemm0 warp size along m
|
||||
F_wn0: int # gemm0 warp size along n
|
||||
F_wk0: int # gemm0 warp size along k
|
||||
F_wm1: int # gemm1 warp size along m
|
||||
F_wn1: int # gemm1 warp size along n
|
||||
F_wk1: int # gemm1 warp size along k
|
||||
F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\
|
||||
f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\
|
||||
f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\
|
||||
("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
|
||||
return (
|
||||
f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}"
|
||||
+ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}"
|
||||
+ f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}"
|
||||
+ ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdKernel:
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim : int # hdim
|
||||
F_dtype : str # data type
|
||||
F_mode : str # value from MODE_MAP
|
||||
F_tile : FmhaFwdTileSize
|
||||
F_pipeline : FmhaFwdPipeline
|
||||
mask_impl : str
|
||||
F_idx: int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim: int # hdim
|
||||
F_dtype: str # data type
|
||||
F_mode: str # value from MODE_MAP
|
||||
F_tile: FmhaFwdTileSize
|
||||
F_pipeline: FmhaFwdPipeline
|
||||
mask_impl: str
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
kernel_body = str()
|
||||
return FMHA_FWD_KERNEL_HEADER + \
|
||||
FMHA_FWD_KERNEL_BODY.format(
|
||||
F_idx = self.F_idx,
|
||||
F_hdim = self.F_hdim,
|
||||
F_dtype = FWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0 = self.F_tile.F_bm0,
|
||||
F_bn0 = self.F_tile.F_bn0,
|
||||
F_bk0 = self.F_tile.F_bk0,
|
||||
F_bn1 = self.F_tile.F_bn1,
|
||||
F_bk1 = self.F_tile.F_bk1,
|
||||
F_bk0max = self.F_tile.F_bk0max,
|
||||
F_rm0 = self.F_tile.F_rm0,
|
||||
F_rn0 = self.F_tile.F_rn0,
|
||||
F_rk0 = self.F_tile.F_rk0,
|
||||
F_rm1 = self.F_tile.F_rm1,
|
||||
F_rn1 = self.F_tile.F_rn1,
|
||||
F_rk1 = self.F_tile.F_rk1,
|
||||
F_wm0 = self.F_tile.F_wm0,
|
||||
F_wn0 = self.F_tile.F_wn0,
|
||||
F_wk0 = self.F_tile.F_wk0,
|
||||
F_wm1 = self.F_tile.F_wm1,
|
||||
F_wn1 = self.F_tile.F_wn1,
|
||||
F_wk1 = self.F_tile.F_wk1,
|
||||
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
|
||||
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
|
||||
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_logits = BOOL_MAP[self.F_pipeline.F_logits],
|
||||
F_bias = BIAS_MAP[self.F_pipeline.F_bias],
|
||||
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
|
||||
F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv],
|
||||
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
|
||||
F_skip = BOOL_MAP[self.F_pipeline.F_skip],
|
||||
F_occupancy = self.F_tile.F_occupancy,
|
||||
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
F_mode = MODE_MAP[self.F_mode],
|
||||
F_pipeline = FMHA_FWD_PAGEDKV_PIPELINE_MAP[self.F_pipeline.tag])
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format(
|
||||
F_idx=self.F_idx,
|
||||
F_hdim=self.F_hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0=self.F_tile.F_bm0,
|
||||
F_bn0=self.F_tile.F_bn0,
|
||||
F_bk0=self.F_tile.F_bk0,
|
||||
F_bn1=self.F_tile.F_bn1,
|
||||
F_bk1=self.F_tile.F_bk1,
|
||||
F_bk0max=self.F_tile.F_bk0max,
|
||||
F_rm0=self.F_tile.F_rm0,
|
||||
F_rn0=self.F_tile.F_rn0,
|
||||
F_rk0=self.F_tile.F_rk0,
|
||||
F_rm1=self.F_tile.F_rm1,
|
||||
F_rn1=self.F_tile.F_rn1,
|
||||
F_rk1=self.F_tile.F_rk1,
|
||||
F_wm0=self.F_tile.F_wm0,
|
||||
F_wn0=self.F_tile.F_wn0,
|
||||
F_wk0=self.F_tile.F_wk0,
|
||||
F_wm1=self.F_tile.F_wm1,
|
||||
F_wn1=self.F_tile.F_wn1,
|
||||
F_wk1=self.F_tile.F_wk1,
|
||||
F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout],
|
||||
F_spad=BOOL_MAP[self.F_pipeline.F_spad],
|
||||
F_skpad=BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad=BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_logits=BOOL_MAP[self.F_pipeline.F_logits],
|
||||
F_bias=BIAS_MAP[self.F_pipeline.F_bias],
|
||||
F_lse=BOOL_MAP[self.F_pipeline.F_lse],
|
||||
F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv],
|
||||
F_squant=BOOL_MAP[self.F_pipeline.F_squant],
|
||||
F_skip=BOOL_MAP[self.F_pipeline.F_skip],
|
||||
F_occupancy=self.F_tile.F_occupancy,
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
F_mode=MODE_MAP[self.F_mode],
|
||||
F_pipeline=FMHA_FWD_PAGEDKV_PIPELINE_MAP[self.F_pipeline.tag],
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return f"fmha_fwd_pagedkv_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \
|
||||
self.F_tile.name + '_' + self.F_pipeline.name
|
||||
return (
|
||||
f"fmha_fwd_pagedkv_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_"
|
||||
+ self.F_tile.name
|
||||
+ "_"
|
||||
+ self.F_pipeline.name
|
||||
)
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
@@ -410,51 +494,64 @@ class FmhaFwdKernel:
|
||||
|
||||
def api_trait(self) -> FmhaFwdApiTrait:
|
||||
return FmhaFwdApiTrait(
|
||||
pipeline_tag=self.F_pipeline.tag,
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
mode=self.F_mode,
|
||||
bm0=self.F_tile.F_bm0,
|
||||
bn0=self.F_tile.F_bn0,
|
||||
bk0=self.F_tile.F_bk0,
|
||||
bn1=self.F_tile.F_bn1,
|
||||
bk1=self.F_tile.F_bk1,
|
||||
bk0max=self.F_tile.F_bk0max,
|
||||
vlayout=self.F_pipeline.F_vlayout,
|
||||
mask=self.F_pipeline.F_mask,
|
||||
logits=self.F_pipeline.F_logits,
|
||||
bias=self.F_pipeline.F_bias,
|
||||
lse=self.F_pipeline.F_lse,
|
||||
pagedkv=self.F_pipeline.F_pagedkv,
|
||||
squant=self.F_pipeline.F_squant,
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
skip=self.F_pipeline.F_skip)
|
||||
pipeline_tag=self.F_pipeline.tag,
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
mode=self.F_mode,
|
||||
bm0=self.F_tile.F_bm0,
|
||||
bn0=self.F_tile.F_bn0,
|
||||
bk0=self.F_tile.F_bk0,
|
||||
bn1=self.F_tile.F_bn1,
|
||||
bk1=self.F_tile.F_bk1,
|
||||
bk0max=self.F_tile.F_bk0max,
|
||||
vlayout=self.F_pipeline.F_vlayout,
|
||||
mask=self.F_pipeline.F_mask,
|
||||
logits=self.F_pipeline.F_logits,
|
||||
bias=self.F_pipeline.F_bias,
|
||||
lse=self.F_pipeline.F_lse,
|
||||
pagedkv=self.F_pipeline.F_pagedkv,
|
||||
squant=self.F_pipeline.F_squant,
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
skip=self.F_pipeline.F_skip,
|
||||
)
|
||||
|
||||
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size per hdim
|
||||
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
def get_fmha_fwd_tile_dict_from_dtype(dtype: str) -> Optional[dict]:
|
||||
if dtype == "fp16" or dtype == "bf16":
|
||||
return {
|
||||
# '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
"128": FmhaFwdTileSize(
|
||||
128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1
|
||||
),
|
||||
# '192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
elif dtype == "fp8" or dtype == "bf8":
|
||||
return {
|
||||
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
"64": FmhaFwdTileSize(
|
||||
128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1
|
||||
),
|
||||
"128": FmhaFwdTileSize(
|
||||
128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1
|
||||
),
|
||||
"256": FmhaFwdTileSize(
|
||||
128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1
|
||||
),
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
|
||||
def get_fwd_blobs(
|
||||
kernel_filter: Optional[str], receipt, optdim_list, mask_impl
|
||||
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
|
||||
@@ -462,18 +559,90 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr_pagedkv pipeline, let 't' padding to appear later!!
|
||||
# TODO: how to design this more generic?
|
||||
squant = 't' if dtype == 'fp8' else 'f'
|
||||
squant = "t" if dtype == "fp8" else "f"
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for logits, mask, bias, pagedkv, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t"], ["f"]):
|
||||
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 'f', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 't', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip))
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
for logits, mask, bias, pagedkv, skip in itertools.product(
|
||||
["t", "f"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
BIAS_MAP.keys(),
|
||||
["t"],
|
||||
["f"],
|
||||
):
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline(
|
||||
"qr_pagedkv",
|
||||
"row",
|
||||
"t",
|
||||
"f",
|
||||
"f",
|
||||
"f",
|
||||
logits,
|
||||
bias,
|
||||
"f",
|
||||
pagedkv,
|
||||
squant,
|
||||
mask,
|
||||
skip,
|
||||
)
|
||||
)
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline(
|
||||
"qr_pagedkv",
|
||||
"row",
|
||||
"t",
|
||||
"t",
|
||||
"f",
|
||||
"f",
|
||||
logits,
|
||||
bias,
|
||||
"f",
|
||||
pagedkv,
|
||||
squant,
|
||||
mask,
|
||||
skip,
|
||||
)
|
||||
)
|
||||
elif dtype in ["fp8", "bf8"]:
|
||||
# no need lse/dropout kernels
|
||||
for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 't', squant, mask, 'f'))
|
||||
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 't', squant, mask, 'f'))
|
||||
elif dtype in ['fp8fp16', 'fp8bf16']:
|
||||
for logits, mask, bias in itertools.product(
|
||||
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
|
||||
):
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline(
|
||||
"qr_pagedkv",
|
||||
"row",
|
||||
"f",
|
||||
"f",
|
||||
"f",
|
||||
"f",
|
||||
logits,
|
||||
bias,
|
||||
"f",
|
||||
"t",
|
||||
squant,
|
||||
mask,
|
||||
"f",
|
||||
)
|
||||
)
|
||||
pipelines.append(
|
||||
FmhaFwdPipeline(
|
||||
"qr_pagedkv",
|
||||
"row",
|
||||
"t",
|
||||
"t",
|
||||
"f",
|
||||
"f",
|
||||
logits,
|
||||
bias,
|
||||
"f",
|
||||
"t",
|
||||
squant,
|
||||
mask,
|
||||
"f",
|
||||
)
|
||||
)
|
||||
elif dtype in ["fp8fp16", "fp8bf16"]:
|
||||
# TODO
|
||||
None
|
||||
else:
|
||||
@@ -485,9 +654,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
|
||||
for dtype in FWD_DTYPE_MAP.keys():
|
||||
d = get_fmha_fwd_tile_dict_from_dtype(dtype)
|
||||
if d == None:
|
||||
if d is None:
|
||||
continue
|
||||
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
# for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
|
||||
tile = d[hdim_str]
|
||||
hdim = int(hdim_str)
|
||||
@@ -495,24 +664,29 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
# if pipeline.F_pagedkv == 'f':
|
||||
# continue
|
||||
if mode == "group":
|
||||
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
|
||||
if pipeline.F_spad != "t" or pipeline.F_skpad != "t":
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
continue
|
||||
if hdim == 192 and tile.F_bn1 == 128:
|
||||
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
|
||||
if pipeline.F_bias != 'no' or pipeline.F_lse == 't' :
|
||||
if pipeline.F_bias != "no" or pipeline.F_lse == "t":
|
||||
continue
|
||||
# logits_soft_cap is only allowed if no bias
|
||||
if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'):
|
||||
if not (
|
||||
(pipeline.F_logits == "t" and pipeline.F_bias == "no")
|
||||
or pipeline.F_logits == "f"
|
||||
):
|
||||
continue
|
||||
k = FmhaFwdKernel(F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl)
|
||||
if kernel_filter != '':
|
||||
k = FmhaFwdKernel(
|
||||
F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl,
|
||||
)
|
||||
if kernel_filter != "":
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
@@ -520,49 +694,49 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
continue
|
||||
# 2 - Flash attention integration
|
||||
if receipt in (2, 3):
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'alibi']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
cond &= pipeline.F_skip == 'f'
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_bias in ["no", "alibi"]
|
||||
cond &= pipeline.F_squant == "f"
|
||||
cond &= pipeline.F_skip == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'bias']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
cond &= pipeline.F_skip == 'f'
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_bias in ["no", "bias"]
|
||||
cond &= pipeline.F_squant == "f"
|
||||
cond &= pipeline.F_skip == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
elif receipt == 100:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == 'batch'
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_squant == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_varlen_fwd) integration
|
||||
elif receipt == 200:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == 'group'
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_squant == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_fwd C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_squant == "f"
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
# fp32 only
|
||||
if receipt == 800 or receipt == 801:
|
||||
cond = dtype == 'fp32'
|
||||
cond = dtype == "fp32"
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
@@ -571,20 +745,28 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
|
||||
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
|
||||
(autogen_dir / kernel.filename).write_text(kernel.template)
|
||||
|
||||
def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None:
|
||||
|
||||
def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None:
|
||||
(autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api)
|
||||
|
||||
def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None:
|
||||
|
||||
def write_blobs(
|
||||
output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl
|
||||
) -> None:
|
||||
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
for kernel in kernels:
|
||||
write_single_fwd_kernel(kernel, output_dir)
|
||||
write_fwd_api(api_pool, output_dir)
|
||||
|
||||
def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None:
|
||||
with file_path.open('a') as f:
|
||||
|
||||
def list_blobs(
|
||||
file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl
|
||||
) -> None:
|
||||
with file_path.open("a") as f:
|
||||
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
|
||||
@@ -6,30 +6,45 @@ import argparse
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
import pkgutil
|
||||
import sys
|
||||
from typing import List, Optional
|
||||
|
||||
import codegen.ops
|
||||
from codegen.cmake_config import *
|
||||
from codegen.cmake_config import GEN_DIR
|
||||
|
||||
|
||||
class HandlerId(IntEnum):
|
||||
LIST_BLOBS = 0
|
||||
WRITE_BLOBS = 1
|
||||
|
||||
|
||||
# inspect all modules under 'codegen.ops' and register API handlers
|
||||
ops = []
|
||||
for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__):
|
||||
full_module_name = '%s.%s' % (codegen.ops.__name__, module_name)
|
||||
full_module_name = "%s.%s" % (codegen.ops.__name__, module_name)
|
||||
ops.append(importer.find_spec(module_name).loader.load_module(module_name))
|
||||
unwanted_prefix = 'fmha_'
|
||||
unwanted_prefix = "fmha_"
|
||||
handlers = dict(
|
||||
[(op.__name__[len(unwanted_prefix):] if op.__name__.startswith(unwanted_prefix) else op.__name__,
|
||||
(op.list_blobs, op.write_blobs)) for op in ops]
|
||||
[
|
||||
(
|
||||
op.__name__[len(unwanted_prefix) :]
|
||||
if op.__name__.startswith(unwanted_prefix)
|
||||
else op.__name__,
|
||||
(op.list_blobs, op.write_blobs),
|
||||
)
|
||||
for op in ops
|
||||
]
|
||||
)
|
||||
assert 0 < len(handlers)
|
||||
|
||||
def write_blobs(output_dir: Optional[str], api_list : List[str], filters_list : List[str], optdim_list : List[int], receipt, mask_impl) -> None:
|
||||
|
||||
def write_blobs(
|
||||
output_dir: Optional[str],
|
||||
api_list: List[str],
|
||||
filters_list: List[str],
|
||||
optdim_list: List[int],
|
||||
receipt,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
if output_dir is None:
|
||||
output_dir = Path(__file__).parent
|
||||
else:
|
||||
@@ -41,8 +56,16 @@ def write_blobs(output_dir: Optional[str], api_list : List[str], filters_list :
|
||||
handler = handlers[api][HandlerId.WRITE_BLOBS]
|
||||
handler(output_dir, kernel_filter, receipt, optdim_list, mask_impl)
|
||||
|
||||
|
||||
# list all the files that will be generated
|
||||
def list_blobs(output_file : Optional[str], api_list : List[str], filters_list : List[str], optdim_list : List[int], receipt, mask_impl) -> None:
|
||||
def list_blobs(
|
||||
output_file: Optional[str],
|
||||
api_list: List[str],
|
||||
filters_list: List[str],
|
||||
optdim_list: List[int],
|
||||
receipt,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
assert output_file is not None
|
||||
file_path = Path(output_file)
|
||||
|
||||
@@ -53,6 +76,7 @@ def list_blobs(output_file : Optional[str], api_list : List[str], filters_list :
|
||||
handler = handlers[api][HandlerId.LIST_BLOBS]
|
||||
handler(file_path, kernel_filter, receipt, optdim_list, mask_impl)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="generate",
|
||||
@@ -60,32 +84,29 @@ if __name__ == "__main__":
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--direction", # we keep 'direction' option for backward compatibility
|
||||
"--direction", # we keep 'direction' option for backward compatibility
|
||||
"-a",
|
||||
"--api",
|
||||
default='fwd',
|
||||
default="fwd",
|
||||
required=False,
|
||||
help="supply API(s) to generate (default: fwd). separated by comma."
|
||||
help="supply API(s) to generate (default: fwd). separated by comma.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output_dir",
|
||||
required=False,
|
||||
help="write all the blobs into a directory"
|
||||
help="write all the blobs into a directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--list_blobs",
|
||||
required=False,
|
||||
help="list all the kernels to a file"
|
||||
"-l", "--list_blobs", required=False, help="list all the kernels to a file"
|
||||
)
|
||||
# TODO: if using filter, must apply same value to output_dir and list_blobs
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--filter",
|
||||
default='',
|
||||
default="",
|
||||
required=False,
|
||||
help="filter out kernels that need to generate, using fnmatch module"
|
||||
help="filter out kernels that need to generate, using fnmatch module",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@@ -93,7 +114,7 @@ if __name__ == "__main__":
|
||||
"--mask",
|
||||
default="simplified",
|
||||
required=False,
|
||||
help="mask implementation, simplified/generic"
|
||||
help="mask implementation, simplified/generic",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@@ -101,32 +122,46 @@ if __name__ == "__main__":
|
||||
"--receipt",
|
||||
default=0,
|
||||
required=False,
|
||||
help="codegen receipt. 0: generate only 8xhdim coverage\n" + \
|
||||
" 1: generate more instance to cover all hdim\n" + \
|
||||
" 2: Only generate instance for Flash attention integration\n" + \
|
||||
" 4: Only generate instance for PyTorch integration\n" + \
|
||||
" 100-199: Only generate instance for Aiter(mha_fwd) integration\n" + \
|
||||
" 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n" + \
|
||||
" 300-399: Only generate instance for Aiter(mha_bwd) integration\n" + \
|
||||
" 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration\n" + \
|
||||
" 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration"
|
||||
help="codegen receipt. 0: generate only 8xhdim coverage\n"
|
||||
+ " 1: generate more instance to cover all hdim\n"
|
||||
+ " 2: Only generate instance for Flash attention integration\n"
|
||||
+ " 4: Only generate instance for PyTorch integration\n"
|
||||
+ " 100-199: Only generate instance for Aiter(mha_fwd) integration\n"
|
||||
+ " 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n"
|
||||
+ " 300-399: Only generate instance for Aiter(mha_bwd) integration\n"
|
||||
+ " 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration\n"
|
||||
+ " 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--optdim",
|
||||
default='-1',
|
||||
default="-1",
|
||||
required=False,
|
||||
help="only optimize the hdim in the list. separated by comma. -1 is the default choice" + \
|
||||
"eg. --optdim=32,64,128,256"
|
||||
help="only optimize the hdim in the list. separated by comma. -1 is the default choice"
|
||||
+ "eg. --optdim=32,64,128,256",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
api_list = args.direction.split(',')
|
||||
filter_list = args.filter.split(',')
|
||||
filter_list.extend([''] * (len(api_list) - len(filter_list)))
|
||||
optdim_list = [int(hdim) for hdim in args.optdim.split(',')]
|
||||
api_list = args.direction.split(",")
|
||||
filter_list = args.filter.split(",")
|
||||
filter_list.extend([""] * (len(api_list) - len(filter_list)))
|
||||
optdim_list = [int(hdim) for hdim in args.optdim.split(",")]
|
||||
|
||||
if args.list_blobs is not None:
|
||||
list_blobs(args.list_blobs, api_list, filter_list, optdim_list, int(args.receipt), mask_impl=args.mask)
|
||||
list_blobs(
|
||||
args.list_blobs,
|
||||
api_list,
|
||||
filter_list,
|
||||
optdim_list,
|
||||
int(args.receipt),
|
||||
mask_impl=args.mask,
|
||||
)
|
||||
else:
|
||||
write_blobs(args.output_dir, api_list, filter_list, optdim_list, int(args.receipt), mask_impl=args.mask)
|
||||
write_blobs(
|
||||
args.output_dir,
|
||||
api_list,
|
||||
filter_list,
|
||||
optdim_list,
|
||||
int(args.receipt),
|
||||
mask_impl=args.mask,
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -2,7 +2,7 @@
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/pool.hpp"
|
||||
#include "ck_tile/ops/pooling.hpp"
|
||||
#include "ck_tile/host/reference/reference_pool.hpp"
|
||||
#include <cstring>
|
||||
|
||||
|
||||
@@ -1,21 +1,19 @@
|
||||
import pathlib
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
import os
|
||||
import copy
|
||||
|
||||
all_files = []
|
||||
for p in sorted(Path("./").rglob("*")):
|
||||
if p.suffix in ['.hpp', '.cpp']:
|
||||
if p.suffix in [".hpp", ".cpp"]:
|
||||
all_files.append(pathlib.PurePath(p))
|
||||
|
||||
|
||||
|
||||
# formatting
|
||||
for x in all_files:
|
||||
subprocess.Popen(f'dos2unix {str(x)}', shell=True)
|
||||
cmd = f'clang-format-18 -style=file -i {str(x)}'
|
||||
#for xp in x.parents:
|
||||
#print(get_file_base(x))
|
||||
subprocess.Popen(f"dos2unix -n {str(x)}", shell=True)
|
||||
cmd = f"clang-format-18 -style=file -i {str(x)}"
|
||||
# for xp in x.parents:
|
||||
# print(get_file_base(x))
|
||||
subprocess.Popen(cmd, shell=True)
|
||||
|
||||
#print(all_files)
|
||||
# print(all_files)
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/host/permute_pk_int4.hpp"
|
||||
#include "ck_tile/host/ranges.hpp"
|
||||
#include "ck_tile/host/reference/reference_batched_contraction.hpp"
|
||||
#include "ck_tile/host/reference/reference_batched_dropout.hpp"
|
||||
#include "ck_tile/host/reference/reference_batched_dropout_randval.hpp"
|
||||
#include "ck_tile/host/reference/reference_batched_elementwise.hpp"
|
||||
@@ -36,6 +37,7 @@
|
||||
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
|
||||
#include "ck_tile/host/reference/reference_moe_sorting.hpp"
|
||||
#include "ck_tile/host/reference/reference_permute.hpp"
|
||||
#include "ck_tile/host/reference/reference_pool.hpp"
|
||||
#include "ck_tile/host/reference/reference_reduce.hpp"
|
||||
#include "ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp"
|
||||
#include "ck_tile/host/reference/reference_rowwise_quantization2d.hpp"
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
#include <thread>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
@@ -155,6 +157,10 @@ void calculate_reference_multi_dimensional(
|
||||
b_idx.reserve(B_dims.size());
|
||||
e_idx.reserve(E_dims.size());
|
||||
|
||||
auto calculate_total_elements = [](const std::vector<ck_tile::index_t>& dims) {
|
||||
return std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<ck_tile::index_t>());
|
||||
};
|
||||
|
||||
for(ck_tile::index_t g_flat = 0; g_flat < calculate_total_elements(G_dims); ++g_flat)
|
||||
{
|
||||
ck_tile::index_t temp = g_flat;
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include "ck_tile/ops/pooling/kernel/pool_kernel.hpp"
|
||||
#include <thread>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -5,5 +5,9 @@
|
||||
|
||||
#include "ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp"
|
||||
#include "ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp"
|
||||
#include "ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -9,9 +9,9 @@
|
||||
#include "ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp"
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/pooling/kernel/pool_kernel.hpp"
|
||||
#include "ck_tile/ops/pooling/pipeline/pool_default_policy.hpp"
|
||||
#include "ck_tile/ops/pooling/pipeline/pool_problem.hpp"
|
||||
#include "ck_tile/ops/pooling/pipeline/pool_shape.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
@@ -5,39 +5,43 @@ import subprocess
|
||||
import os
|
||||
import copy
|
||||
|
||||
NS = 'ck_tile'
|
||||
OPS = 'ops'
|
||||
REF = 'ref'
|
||||
OPS_COMMON = 'common' #common header will be duplicated into ops/* other module
|
||||
NS = "ck_tile"
|
||||
OPS = "ops"
|
||||
OPS_COMMON = "common" # common header will be duplicated into ops/* other module
|
||||
|
||||
IGNORED_DIRS = ["utility", "ref"]
|
||||
HEADER_COMMON = f"""// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-{datetime.now().year}, Advanced Micro Devices, Inc. All rights reserved.\n
|
||||
"""
|
||||
|
||||
|
||||
# aa/bb/cc/file.hpp -> (aa, bb, cc, file.hpp)
|
||||
def get_module(f, level = 0):
|
||||
def get_module(f, level=0):
|
||||
all_parts = f.parts
|
||||
return str(all_parts[level])
|
||||
|
||||
|
||||
all_files = []
|
||||
for p in sorted(Path("./").rglob("*")):
|
||||
if p.suffix == '.hpp':
|
||||
if p.suffix == ".hpp":
|
||||
all_files.append(pathlib.PurePath(p))
|
||||
|
||||
|
||||
class submodule_t:
|
||||
def __init__(self):
|
||||
self.m = dict()
|
||||
|
||||
def push(self, f):
|
||||
if len(f.parents) != 1: # ignore ./xxx.hpp
|
||||
if len(f.parents) != 1: # ignore ./xxx.hpp
|
||||
mod = get_module(f)
|
||||
# ref is supposed to include one header on demand
|
||||
if mod == REF:
|
||||
# Should only be included by demand
|
||||
if mod in IGNORED_DIRS:
|
||||
return
|
||||
if mod == OPS:
|
||||
if mod not in self.m.keys():
|
||||
self.m[mod] = dict()
|
||||
mod2 = get_module(f, 1)
|
||||
if Path(mod2).suffix != '.hpp':
|
||||
if Path(mod2).suffix != ".hpp":
|
||||
# ignore ops/xxx.hpp
|
||||
if mod2 not in self.m[mod].keys():
|
||||
self.m[mod][mod2] = list()
|
||||
@@ -52,14 +56,15 @@ class submodule_t:
|
||||
# print(hpath)
|
||||
if os.path.exists(str(hpath)):
|
||||
os.remove(str(hpath))
|
||||
with hpath.open('w') as f:
|
||||
with hpath.open("w") as f:
|
||||
f.write(HEADER_COMMON)
|
||||
f.write('#pragma once\n')
|
||||
f.write('\n')
|
||||
f.write("#pragma once\n")
|
||||
f.write("\n")
|
||||
for individual_header in include_list:
|
||||
header_path = NS + '/' + str(individual_header)
|
||||
f.write(f'#include \"{header_path}\"\n')
|
||||
header_path = NS + "/" + str(individual_header)
|
||||
f.write(f'#include "{header_path}"\n')
|
||||
# f.write('\n') # otherwise clang-format will complain
|
||||
|
||||
# print(self.m)
|
||||
# restructure common
|
||||
for k, v in self.m.items():
|
||||
@@ -73,21 +78,21 @@ class submodule_t:
|
||||
for k, v in self.m.items():
|
||||
if k == OPS:
|
||||
for km, kv in v.items():
|
||||
gen_header(Path(k) / (f'{km}.hpp'), kv)
|
||||
gen_header(Path(k) / (f"{km}.hpp"), kv)
|
||||
else:
|
||||
gen_header(Path(f'{k}.hpp'), v)
|
||||
gen_header(Path(f"{k}.hpp"), v)
|
||||
|
||||
|
||||
submodule = submodule_t()
|
||||
# formatting
|
||||
for x in all_files:
|
||||
subprocess.Popen(f'dos2unix {str(x)}', shell=True)
|
||||
cmd = f'clang-format-18 -style=file -i {str(x)}'
|
||||
#for xp in x.parents:
|
||||
#print(get_file_base(x))
|
||||
subprocess.Popen(f"dos2unix -n {str(x)}", shell=True)
|
||||
cmd = f"clang-format-18 -style=file -i {str(x)}"
|
||||
# for xp in x.parents:
|
||||
# print(get_file_base(x))
|
||||
subprocess.Popen(cmd, shell=True)
|
||||
submodule.push(x)
|
||||
|
||||
submodule.gen()
|
||||
|
||||
#print(all_files)
|
||||
# print(all_files)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// Tencent is pleased to support the open source community by making RapidJSON available.
|
||||
//
|
||||
//
|
||||
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
|
||||
//
|
||||
// Licensed under the MIT License (the "License"); you may not use this file except
|
||||
@@ -7,9 +7,9 @@
|
||||
//
|
||||
// http://opensource.org/licenses/MIT
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations under the License.
|
||||
|
||||
#ifndef RAPIDJSON_ALLOCATORS_H_
|
||||
@@ -32,10 +32,10 @@ RAPIDJSON_NAMESPACE_BEGIN
|
||||
|
||||
/*! \class rapidjson::Allocator
|
||||
\brief Concept for allocating, resizing and freeing memory block.
|
||||
|
||||
|
||||
Note that Malloc() and Realloc() are non-static but Free() is static.
|
||||
|
||||
So if an allocator need to support Free(), it needs to put its pointer in
|
||||
|
||||
So if an allocator need to support Free(), it needs to put its pointer in
|
||||
the header of memory block.
|
||||
|
||||
\code
|
||||
@@ -49,7 +49,8 @@ concept Allocator {
|
||||
|
||||
// Resize a memory block.
|
||||
// \param originalPtr The pointer to current memory block. Null pointer is permitted.
|
||||
// \param originalSize The current size in bytes. (Design issue: since some allocator may not book-keep this, explicitly pass to it can save memory.)
|
||||
// \param originalSize The current size in bytes. (Design issue: since some allocator may not
|
||||
book-keep this, explicitly pass to it can save memory.)
|
||||
// \param newSize the new size in bytes.
|
||||
void* Realloc(void* originalPtr, size_t originalSize, size_t newSize);
|
||||
|
||||
@@ -60,7 +61,6 @@ concept Allocator {
|
||||
\endcode
|
||||
*/
|
||||
|
||||
|
||||
/*! \def RAPIDJSON_ALLOCATOR_DEFAULT_CHUNK_CAPACITY
|
||||
\ingroup RAPIDJSON_CONFIG
|
||||
\brief User-defined kDefaultChunkCapacity definition.
|
||||
@@ -72,7 +72,6 @@ concept Allocator {
|
||||
#define RAPIDJSON_ALLOCATOR_DEFAULT_CHUNK_CAPACITY (64 * 1024)
|
||||
#endif
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// CrtAllocator
|
||||
|
||||
@@ -80,38 +79,38 @@ concept Allocator {
|
||||
/*! This class is just wrapper for standard C library memory routines.
|
||||
\note implements Allocator concept
|
||||
*/
|
||||
class CrtAllocator {
|
||||
public:
|
||||
class CrtAllocator
|
||||
{
|
||||
public:
|
||||
static const bool kNeedFree = true;
|
||||
void* Malloc(size_t size) {
|
||||
if (size) // behavior of malloc(0) is implementation defined.
|
||||
void* Malloc(size_t size)
|
||||
{
|
||||
if(size) // behavior of malloc(0) is implementation defined.
|
||||
return RAPIDJSON_MALLOC(size);
|
||||
else
|
||||
return NULL; // standardize to returning NULL.
|
||||
}
|
||||
void* Realloc(void* originalPtr, size_t originalSize, size_t newSize) {
|
||||
void* Realloc(void* originalPtr, size_t originalSize, size_t newSize)
|
||||
{
|
||||
(void)originalSize;
|
||||
if (newSize == 0) {
|
||||
if(newSize == 0)
|
||||
{
|
||||
RAPIDJSON_FREE(originalPtr);
|
||||
return NULL;
|
||||
}
|
||||
return RAPIDJSON_REALLOC(originalPtr, newSize);
|
||||
}
|
||||
static void Free(void *ptr) RAPIDJSON_NOEXCEPT { RAPIDJSON_FREE(ptr); }
|
||||
static void Free(void* ptr) RAPIDJSON_NOEXCEPT { RAPIDJSON_FREE(ptr); }
|
||||
|
||||
bool operator==(const CrtAllocator&) const RAPIDJSON_NOEXCEPT {
|
||||
return true;
|
||||
}
|
||||
bool operator!=(const CrtAllocator&) const RAPIDJSON_NOEXCEPT {
|
||||
return false;
|
||||
}
|
||||
bool operator==(const CrtAllocator&) const RAPIDJSON_NOEXCEPT { return true; }
|
||||
bool operator!=(const CrtAllocator&) const RAPIDJSON_NOEXCEPT { return false; }
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MemoryPoolAllocator
|
||||
|
||||
//! Default memory allocator used by the parser and DOM.
|
||||
/*! This allocator allocate memory blocks from pre-allocated memory chunks.
|
||||
/*! This allocator allocate memory blocks from pre-allocated memory chunks.
|
||||
|
||||
It does not free memory blocks. And Realloc() only allocate new memory.
|
||||
|
||||
@@ -127,69 +126,82 @@ public:
|
||||
\note implements Allocator concept
|
||||
*/
|
||||
template <typename BaseAllocator = CrtAllocator>
|
||||
class MemoryPoolAllocator {
|
||||
class MemoryPoolAllocator
|
||||
{
|
||||
//! Chunk header for perpending to each chunk.
|
||||
/*! Chunks are stored as a singly linked list.
|
||||
*/
|
||||
struct ChunkHeader {
|
||||
size_t capacity; //!< Capacity of the chunk in bytes (excluding the header itself).
|
||||
size_t size; //!< Current size of allocated memory in bytes.
|
||||
ChunkHeader *next; //!< Next chunk in the linked list.
|
||||
*/
|
||||
struct ChunkHeader
|
||||
{
|
||||
size_t capacity; //!< Capacity of the chunk in bytes (excluding the header itself).
|
||||
size_t size; //!< Current size of allocated memory in bytes.
|
||||
ChunkHeader* next; //!< Next chunk in the linked list.
|
||||
};
|
||||
|
||||
struct SharedData {
|
||||
ChunkHeader *chunkHead; //!< Head of the chunk linked-list. Only the head chunk serves allocation.
|
||||
struct SharedData
|
||||
{
|
||||
ChunkHeader*
|
||||
chunkHead; //!< Head of the chunk linked-list. Only the head chunk serves allocation.
|
||||
BaseAllocator* ownBaseAllocator; //!< base allocator created by this object.
|
||||
size_t refcount;
|
||||
bool ownBuffer;
|
||||
};
|
||||
|
||||
static const size_t SIZEOF_SHARED_DATA = RAPIDJSON_ALIGN(sizeof(SharedData));
|
||||
static const size_t SIZEOF_SHARED_DATA = RAPIDJSON_ALIGN(sizeof(SharedData));
|
||||
static const size_t SIZEOF_CHUNK_HEADER = RAPIDJSON_ALIGN(sizeof(ChunkHeader));
|
||||
|
||||
static inline ChunkHeader *GetChunkHead(SharedData *shared)
|
||||
static inline ChunkHeader* GetChunkHead(SharedData* shared)
|
||||
{
|
||||
return reinterpret_cast<ChunkHeader*>(reinterpret_cast<uint8_t*>(shared) + SIZEOF_SHARED_DATA);
|
||||
return reinterpret_cast<ChunkHeader*>(reinterpret_cast<uint8_t*>(shared) +
|
||||
SIZEOF_SHARED_DATA);
|
||||
}
|
||||
static inline uint8_t *GetChunkBuffer(SharedData *shared)
|
||||
static inline uint8_t* GetChunkBuffer(SharedData* shared)
|
||||
{
|
||||
return reinterpret_cast<uint8_t*>(shared->chunkHead) + SIZEOF_CHUNK_HEADER;
|
||||
}
|
||||
|
||||
static const size_t kDefaultChunkCapacity = RAPIDJSON_ALLOCATOR_DEFAULT_CHUNK_CAPACITY; //!< Default chunk capacity.
|
||||
static const size_t kDefaultChunkCapacity =
|
||||
RAPIDJSON_ALLOCATOR_DEFAULT_CHUNK_CAPACITY; //!< Default chunk capacity.
|
||||
|
||||
public:
|
||||
static const bool kNeedFree = false; //!< Tell users that no need to call Free() with this allocator. (concept Allocator)
|
||||
static const bool kRefCounted = true; //!< Tell users that this allocator is reference counted on copy
|
||||
public:
|
||||
static const bool kNeedFree =
|
||||
false; //!< Tell users that no need to call Free() with this allocator. (concept Allocator)
|
||||
static const bool kRefCounted =
|
||||
true; //!< Tell users that this allocator is reference counted on copy
|
||||
|
||||
//! Constructor with chunkSize.
|
||||
/*! \param chunkSize The size of memory chunk. The default is kDefaultChunkSize.
|
||||
\param baseAllocator The allocator for allocating memory chunks.
|
||||
*/
|
||||
explicit
|
||||
MemoryPoolAllocator(size_t chunkSize = kDefaultChunkCapacity, BaseAllocator* baseAllocator = 0) :
|
||||
chunk_capacity_(chunkSize),
|
||||
baseAllocator_(baseAllocator ? baseAllocator : RAPIDJSON_NEW(BaseAllocator)()),
|
||||
shared_(static_cast<SharedData*>(baseAllocator_ ? baseAllocator_->Malloc(SIZEOF_SHARED_DATA + SIZEOF_CHUNK_HEADER) : 0))
|
||||
explicit MemoryPoolAllocator(size_t chunkSize = kDefaultChunkCapacity,
|
||||
BaseAllocator* baseAllocator = 0)
|
||||
: chunk_capacity_(chunkSize),
|
||||
baseAllocator_(baseAllocator ? baseAllocator : RAPIDJSON_NEW(BaseAllocator)()),
|
||||
shared_(static_cast<SharedData*>(
|
||||
baseAllocator_ ? baseAllocator_->Malloc(SIZEOF_SHARED_DATA + SIZEOF_CHUNK_HEADER)
|
||||
: 0))
|
||||
{
|
||||
RAPIDJSON_ASSERT(baseAllocator_ != 0);
|
||||
RAPIDJSON_ASSERT(shared_ != 0);
|
||||
if (baseAllocator) {
|
||||
if(baseAllocator)
|
||||
{
|
||||
shared_->ownBaseAllocator = 0;
|
||||
}
|
||||
else {
|
||||
else
|
||||
{
|
||||
shared_->ownBaseAllocator = baseAllocator_;
|
||||
}
|
||||
shared_->chunkHead = GetChunkHead(shared_);
|
||||
shared_->chunkHead = GetChunkHead(shared_);
|
||||
shared_->chunkHead->capacity = 0;
|
||||
shared_->chunkHead->size = 0;
|
||||
shared_->chunkHead->next = 0;
|
||||
shared_->ownBuffer = true;
|
||||
shared_->refcount = 1;
|
||||
shared_->chunkHead->size = 0;
|
||||
shared_->chunkHead->next = 0;
|
||||
shared_->ownBuffer = true;
|
||||
shared_->refcount = 1;
|
||||
}
|
||||
|
||||
//! Constructor with user-supplied buffer.
|
||||
/*! The user buffer will be used firstly. When it is full, memory pool allocates new chunk with chunk size.
|
||||
/*! The user buffer will be used firstly. When it is full, memory pool allocates new chunk with
|
||||
chunk size.
|
||||
|
||||
The user buffer will not be deallocated when this allocator is destructed.
|
||||
|
||||
@@ -198,25 +210,28 @@ public:
|
||||
\param chunkSize The size of memory chunk. The default is kDefaultChunkSize.
|
||||
\param baseAllocator The allocator for allocating memory chunks.
|
||||
*/
|
||||
MemoryPoolAllocator(void *buffer, size_t size, size_t chunkSize = kDefaultChunkCapacity, BaseAllocator* baseAllocator = 0) :
|
||||
chunk_capacity_(chunkSize),
|
||||
baseAllocator_(baseAllocator),
|
||||
shared_(static_cast<SharedData*>(AlignBuffer(buffer, size)))
|
||||
MemoryPoolAllocator(void* buffer,
|
||||
size_t size,
|
||||
size_t chunkSize = kDefaultChunkCapacity,
|
||||
BaseAllocator* baseAllocator = 0)
|
||||
: chunk_capacity_(chunkSize),
|
||||
baseAllocator_(baseAllocator),
|
||||
shared_(static_cast<SharedData*>(AlignBuffer(buffer, size)))
|
||||
{
|
||||
RAPIDJSON_ASSERT(size >= SIZEOF_SHARED_DATA + SIZEOF_CHUNK_HEADER);
|
||||
shared_->chunkHead = GetChunkHead(shared_);
|
||||
shared_->chunkHead = GetChunkHead(shared_);
|
||||
shared_->chunkHead->capacity = size - SIZEOF_SHARED_DATA - SIZEOF_CHUNK_HEADER;
|
||||
shared_->chunkHead->size = 0;
|
||||
shared_->chunkHead->next = 0;
|
||||
shared_->ownBaseAllocator = 0;
|
||||
shared_->ownBuffer = false;
|
||||
shared_->refcount = 1;
|
||||
shared_->chunkHead->size = 0;
|
||||
shared_->chunkHead->next = 0;
|
||||
shared_->ownBaseAllocator = 0;
|
||||
shared_->ownBuffer = false;
|
||||
shared_->refcount = 1;
|
||||
}
|
||||
|
||||
MemoryPoolAllocator(const MemoryPoolAllocator& rhs) RAPIDJSON_NOEXCEPT :
|
||||
chunk_capacity_(rhs.chunk_capacity_),
|
||||
baseAllocator_(rhs.baseAllocator_),
|
||||
shared_(rhs.shared_)
|
||||
MemoryPoolAllocator(const MemoryPoolAllocator& rhs) RAPIDJSON_NOEXCEPT
|
||||
: chunk_capacity_(rhs.chunk_capacity_),
|
||||
baseAllocator_(rhs.baseAllocator_),
|
||||
shared_(rhs.shared_)
|
||||
{
|
||||
RAPIDJSON_NOEXCEPT_ASSERT(shared_->refcount > 0);
|
||||
++shared_->refcount;
|
||||
@@ -226,17 +241,17 @@ public:
|
||||
RAPIDJSON_NOEXCEPT_ASSERT(rhs.shared_->refcount > 0);
|
||||
++rhs.shared_->refcount;
|
||||
this->~MemoryPoolAllocator();
|
||||
baseAllocator_ = rhs.baseAllocator_;
|
||||
baseAllocator_ = rhs.baseAllocator_;
|
||||
chunk_capacity_ = rhs.chunk_capacity_;
|
||||
shared_ = rhs.shared_;
|
||||
shared_ = rhs.shared_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
#if RAPIDJSON_HAS_CXX11_RVALUE_REFS
|
||||
MemoryPoolAllocator(MemoryPoolAllocator&& rhs) RAPIDJSON_NOEXCEPT :
|
||||
chunk_capacity_(rhs.chunk_capacity_),
|
||||
baseAllocator_(rhs.baseAllocator_),
|
||||
shared_(rhs.shared_)
|
||||
MemoryPoolAllocator(MemoryPoolAllocator&& rhs) RAPIDJSON_NOEXCEPT
|
||||
: chunk_capacity_(rhs.chunk_capacity_),
|
||||
baseAllocator_(rhs.baseAllocator_),
|
||||
shared_(rhs.shared_)
|
||||
{
|
||||
RAPIDJSON_NOEXCEPT_ASSERT(rhs.shared_->refcount > 0);
|
||||
rhs.shared_ = 0;
|
||||
@@ -245,40 +260,47 @@ public:
|
||||
{
|
||||
RAPIDJSON_NOEXCEPT_ASSERT(rhs.shared_->refcount > 0);
|
||||
this->~MemoryPoolAllocator();
|
||||
baseAllocator_ = rhs.baseAllocator_;
|
||||
baseAllocator_ = rhs.baseAllocator_;
|
||||
chunk_capacity_ = rhs.chunk_capacity_;
|
||||
shared_ = rhs.shared_;
|
||||
rhs.shared_ = 0;
|
||||
shared_ = rhs.shared_;
|
||||
rhs.shared_ = 0;
|
||||
return *this;
|
||||
}
|
||||
#endif
|
||||
|
||||
//! Destructor.
|
||||
/*! This deallocates all memory chunks, excluding the user-supplied buffer.
|
||||
*/
|
||||
~MemoryPoolAllocator() RAPIDJSON_NOEXCEPT {
|
||||
if (!shared_) {
|
||||
*/
|
||||
~MemoryPoolAllocator() RAPIDJSON_NOEXCEPT
|
||||
{
|
||||
if(!shared_)
|
||||
{
|
||||
// do nothing if moved
|
||||
return;
|
||||
}
|
||||
if (shared_->refcount > 1) {
|
||||
if(shared_->refcount > 1)
|
||||
{
|
||||
--shared_->refcount;
|
||||
return;
|
||||
}
|
||||
Clear();
|
||||
BaseAllocator *a = shared_->ownBaseAllocator;
|
||||
if (shared_->ownBuffer) {
|
||||
BaseAllocator* a = shared_->ownBaseAllocator;
|
||||
if(shared_->ownBuffer)
|
||||
{
|
||||
baseAllocator_->Free(shared_);
|
||||
}
|
||||
RAPIDJSON_DELETE(a);
|
||||
}
|
||||
|
||||
//! Deallocates all memory chunks, excluding the first/user one.
|
||||
void Clear() RAPIDJSON_NOEXCEPT {
|
||||
void Clear() RAPIDJSON_NOEXCEPT
|
||||
{
|
||||
RAPIDJSON_NOEXCEPT_ASSERT(shared_->refcount > 0);
|
||||
for (;;) {
|
||||
for(;;)
|
||||
{
|
||||
ChunkHeader* c = shared_->chunkHead;
|
||||
if (!c->next) {
|
||||
if(!c->next)
|
||||
{
|
||||
break;
|
||||
}
|
||||
shared_->chunkHead = c->next;
|
||||
@@ -289,78 +311,86 @@ public:
|
||||
|
||||
//! Computes the total capacity of allocated memory chunks.
|
||||
/*! \return total capacity in bytes.
|
||||
*/
|
||||
size_t Capacity() const RAPIDJSON_NOEXCEPT {
|
||||
*/
|
||||
size_t Capacity() const RAPIDJSON_NOEXCEPT
|
||||
{
|
||||
RAPIDJSON_NOEXCEPT_ASSERT(shared_->refcount > 0);
|
||||
size_t capacity = 0;
|
||||
for (ChunkHeader* c = shared_->chunkHead; c != 0; c = c->next)
|
||||
for(ChunkHeader* c = shared_->chunkHead; c != 0; c = c->next)
|
||||
capacity += c->capacity;
|
||||
return capacity;
|
||||
}
|
||||
|
||||
//! Computes the memory blocks allocated.
|
||||
/*! \return total used bytes.
|
||||
*/
|
||||
size_t Size() const RAPIDJSON_NOEXCEPT {
|
||||
*/
|
||||
size_t Size() const RAPIDJSON_NOEXCEPT
|
||||
{
|
||||
RAPIDJSON_NOEXCEPT_ASSERT(shared_->refcount > 0);
|
||||
size_t size = 0;
|
||||
for (ChunkHeader* c = shared_->chunkHead; c != 0; c = c->next)
|
||||
for(ChunkHeader* c = shared_->chunkHead; c != 0; c = c->next)
|
||||
size += c->size;
|
||||
return size;
|
||||
}
|
||||
|
||||
//! Whether the allocator is shared.
|
||||
/*! \return true or false.
|
||||
*/
|
||||
bool Shared() const RAPIDJSON_NOEXCEPT {
|
||||
*/
|
||||
bool Shared() const RAPIDJSON_NOEXCEPT
|
||||
{
|
||||
RAPIDJSON_NOEXCEPT_ASSERT(shared_->refcount > 0);
|
||||
return shared_->refcount > 1;
|
||||
}
|
||||
|
||||
//! Allocates a memory block. (concept Allocator)
|
||||
void* Malloc(size_t size) {
|
||||
void* Malloc(size_t size)
|
||||
{
|
||||
RAPIDJSON_NOEXCEPT_ASSERT(shared_->refcount > 0);
|
||||
if (!size)
|
||||
if(!size)
|
||||
return NULL;
|
||||
|
||||
size = RAPIDJSON_ALIGN(size);
|
||||
if (RAPIDJSON_UNLIKELY(shared_->chunkHead->size + size > shared_->chunkHead->capacity))
|
||||
if (!AddChunk(chunk_capacity_ > size ? chunk_capacity_ : size))
|
||||
if(RAPIDJSON_UNLIKELY(shared_->chunkHead->size + size > shared_->chunkHead->capacity))
|
||||
if(!AddChunk(chunk_capacity_ > size ? chunk_capacity_ : size))
|
||||
return NULL;
|
||||
|
||||
void *buffer = GetChunkBuffer(shared_) + shared_->chunkHead->size;
|
||||
void* buffer = GetChunkBuffer(shared_) + shared_->chunkHead->size;
|
||||
shared_->chunkHead->size += size;
|
||||
return buffer;
|
||||
}
|
||||
|
||||
//! Resizes a memory block (concept Allocator)
|
||||
void* Realloc(void* originalPtr, size_t originalSize, size_t newSize) {
|
||||
if (originalPtr == 0)
|
||||
void* Realloc(void* originalPtr, size_t originalSize, size_t newSize)
|
||||
{
|
||||
if(originalPtr == 0)
|
||||
return Malloc(newSize);
|
||||
|
||||
RAPIDJSON_NOEXCEPT_ASSERT(shared_->refcount > 0);
|
||||
if (newSize == 0)
|
||||
if(newSize == 0)
|
||||
return NULL;
|
||||
|
||||
originalSize = RAPIDJSON_ALIGN(originalSize);
|
||||
newSize = RAPIDJSON_ALIGN(newSize);
|
||||
newSize = RAPIDJSON_ALIGN(newSize);
|
||||
|
||||
// Do not shrink if new size is smaller than original
|
||||
if (originalSize >= newSize)
|
||||
if(originalSize >= newSize)
|
||||
return originalPtr;
|
||||
|
||||
// Simply expand it if it is the last allocation and there is sufficient space
|
||||
if (originalPtr == GetChunkBuffer(shared_) + shared_->chunkHead->size - originalSize) {
|
||||
if(originalPtr == GetChunkBuffer(shared_) + shared_->chunkHead->size - originalSize)
|
||||
{
|
||||
size_t increment = static_cast<size_t>(newSize - originalSize);
|
||||
if (shared_->chunkHead->size + increment <= shared_->chunkHead->capacity) {
|
||||
if(shared_->chunkHead->size + increment <= shared_->chunkHead->capacity)
|
||||
{
|
||||
shared_->chunkHead->size += increment;
|
||||
return originalPtr;
|
||||
}
|
||||
}
|
||||
|
||||
// Realloc process: allocate and copy memory, do not free original buffer.
|
||||
if (void* newBuffer = Malloc(newSize)) {
|
||||
if (originalSize)
|
||||
if(void* newBuffer = Malloc(newSize))
|
||||
{
|
||||
if(originalSize)
|
||||
std::memcpy(newBuffer, originalPtr, originalSize);
|
||||
return newBuffer;
|
||||
}
|
||||
@@ -369,31 +399,36 @@ public:
|
||||
}
|
||||
|
||||
//! Frees a memory block (concept Allocator)
|
||||
static void Free(void *ptr) RAPIDJSON_NOEXCEPT { (void)ptr; } // Do nothing
|
||||
static void Free(void* ptr) RAPIDJSON_NOEXCEPT { (void)ptr; } // Do nothing
|
||||
|
||||
//! Compare (equality) with another MemoryPoolAllocator
|
||||
bool operator==(const MemoryPoolAllocator& rhs) const RAPIDJSON_NOEXCEPT {
|
||||
bool operator==(const MemoryPoolAllocator& rhs) const RAPIDJSON_NOEXCEPT
|
||||
{
|
||||
RAPIDJSON_NOEXCEPT_ASSERT(shared_->refcount > 0);
|
||||
RAPIDJSON_NOEXCEPT_ASSERT(rhs.shared_->refcount > 0);
|
||||
return shared_ == rhs.shared_;
|
||||
}
|
||||
//! Compare (inequality) with another MemoryPoolAllocator
|
||||
bool operator!=(const MemoryPoolAllocator& rhs) const RAPIDJSON_NOEXCEPT {
|
||||
bool operator!=(const MemoryPoolAllocator& rhs) const RAPIDJSON_NOEXCEPT
|
||||
{
|
||||
return !operator==(rhs);
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
//! Creates a new chunk.
|
||||
/*! \param capacity Capacity of the chunk in bytes.
|
||||
\return true if success.
|
||||
*/
|
||||
bool AddChunk(size_t capacity) {
|
||||
if (!baseAllocator_)
|
||||
bool AddChunk(size_t capacity)
|
||||
{
|
||||
if(!baseAllocator_)
|
||||
shared_->ownBaseAllocator = baseAllocator_ = RAPIDJSON_NEW(BaseAllocator)();
|
||||
if (ChunkHeader* chunk = static_cast<ChunkHeader*>(baseAllocator_->Malloc(SIZEOF_CHUNK_HEADER + capacity))) {
|
||||
chunk->capacity = capacity;
|
||||
chunk->size = 0;
|
||||
chunk->next = shared_->chunkHead;
|
||||
if(ChunkHeader* chunk =
|
||||
static_cast<ChunkHeader*>(baseAllocator_->Malloc(SIZEOF_CHUNK_HEADER + capacity)))
|
||||
{
|
||||
chunk->capacity = capacity;
|
||||
chunk->size = 0;
|
||||
chunk->next = shared_->chunkHead;
|
||||
shared_->chunkHead = chunk;
|
||||
return true;
|
||||
}
|
||||
@@ -401,12 +436,13 @@ private:
|
||||
return false;
|
||||
}
|
||||
|
||||
static inline void* AlignBuffer(void* buf, size_t &size)
|
||||
static inline void* AlignBuffer(void* buf, size_t& size)
|
||||
{
|
||||
RAPIDJSON_NOEXCEPT_ASSERT(buf != 0);
|
||||
const uintptr_t mask = sizeof(void*) - 1;
|
||||
const uintptr_t ubuf = reinterpret_cast<uintptr_t>(buf);
|
||||
if (RAPIDJSON_UNLIKELY(ubuf & mask)) {
|
||||
if(RAPIDJSON_UNLIKELY(ubuf & mask))
|
||||
{
|
||||
const uintptr_t abuf = (ubuf + mask) & ~mask;
|
||||
RAPIDJSON_ASSERT(size >= abuf - ubuf);
|
||||
buf = reinterpret_cast<void*>(abuf);
|
||||
@@ -415,37 +451,38 @@ private:
|
||||
return buf;
|
||||
}
|
||||
|
||||
size_t chunk_capacity_; //!< The minimum capacity of chunk when they are allocated.
|
||||
BaseAllocator* baseAllocator_; //!< base allocator for allocating memory chunks.
|
||||
SharedData *shared_; //!< The shared data of the allocator
|
||||
size_t chunk_capacity_; //!< The minimum capacity of chunk when they are allocated.
|
||||
BaseAllocator* baseAllocator_; //!< base allocator for allocating memory chunks.
|
||||
SharedData* shared_; //!< The shared data of the allocator
|
||||
};
|
||||
|
||||
namespace internal {
|
||||
template<typename, typename = void>
|
||||
struct IsRefCounted :
|
||||
public FalseType
|
||||
{ };
|
||||
template<typename T>
|
||||
struct IsRefCounted<T, typename internal::EnableIfCond<T::kRefCounted>::Type> :
|
||||
public TrueType
|
||||
{ };
|
||||
}
|
||||
template <typename, typename = void>
|
||||
struct IsRefCounted : public FalseType
|
||||
{
|
||||
};
|
||||
template <typename T>
|
||||
struct IsRefCounted<T, typename internal::EnableIfCond<T::kRefCounted>::Type> : public TrueType
|
||||
{
|
||||
};
|
||||
} // namespace internal
|
||||
|
||||
template<typename T, typename A>
|
||||
template <typename T, typename A>
|
||||
inline T* Realloc(A& a, T* old_p, size_t old_n, size_t new_n)
|
||||
{
|
||||
RAPIDJSON_NOEXCEPT_ASSERT(old_n <= (std::numeric_limits<size_t>::max)() / sizeof(T) && new_n <= (std::numeric_limits<size_t>::max)() / sizeof(T));
|
||||
RAPIDJSON_NOEXCEPT_ASSERT(old_n <= (std::numeric_limits<size_t>::max)() / sizeof(T) &&
|
||||
new_n <= (std::numeric_limits<size_t>::max)() / sizeof(T));
|
||||
return static_cast<T*>(a.Realloc(old_p, old_n * sizeof(T), new_n * sizeof(T)));
|
||||
}
|
||||
|
||||
template<typename T, typename A>
|
||||
inline T *Malloc(A& a, size_t n = 1)
|
||||
template <typename T, typename A>
|
||||
inline T* Malloc(A& a, size_t n = 1)
|
||||
{
|
||||
return Realloc<T, A>(a, NULL, 0, n);
|
||||
}
|
||||
|
||||
template<typename T, typename A>
|
||||
inline void Free(A& a, T *p, size_t n = 1)
|
||||
template <typename T, typename A>
|
||||
inline void Free(A& a, T* p, size_t n = 1)
|
||||
{
|
||||
static_cast<void>(Realloc<T, A>(a, p, n, 0));
|
||||
}
|
||||
@@ -456,8 +493,7 @@ RAPIDJSON_DIAG_OFF(effc++) // std::allocator can safely be inherited
|
||||
#endif
|
||||
|
||||
template <typename T, typename BaseAllocator = CrtAllocator>
|
||||
class StdAllocator :
|
||||
public std::allocator<T>
|
||||
class StdAllocator : public std::allocator<T>
|
||||
{
|
||||
typedef std::allocator<T> allocator_type;
|
||||
#if RAPIDJSON_HAS_CXX11
|
||||
@@ -466,113 +502,90 @@ class StdAllocator :
|
||||
typedef allocator_type traits_type;
|
||||
#endif
|
||||
|
||||
public:
|
||||
public:
|
||||
typedef BaseAllocator BaseAllocatorType;
|
||||
|
||||
StdAllocator() RAPIDJSON_NOEXCEPT :
|
||||
allocator_type(),
|
||||
baseAllocator_()
|
||||
{ }
|
||||
StdAllocator() RAPIDJSON_NOEXCEPT : allocator_type(), baseAllocator_() {}
|
||||
|
||||
StdAllocator(const StdAllocator& rhs) RAPIDJSON_NOEXCEPT :
|
||||
allocator_type(rhs),
|
||||
baseAllocator_(rhs.baseAllocator_)
|
||||
{ }
|
||||
StdAllocator(const StdAllocator& rhs) RAPIDJSON_NOEXCEPT : allocator_type(rhs),
|
||||
baseAllocator_(rhs.baseAllocator_)
|
||||
{
|
||||
}
|
||||
|
||||
template<typename U>
|
||||
StdAllocator(const StdAllocator<U, BaseAllocator>& rhs) RAPIDJSON_NOEXCEPT :
|
||||
allocator_type(rhs),
|
||||
baseAllocator_(rhs.baseAllocator_)
|
||||
{ }
|
||||
template <typename U>
|
||||
StdAllocator(const StdAllocator<U, BaseAllocator>& rhs) RAPIDJSON_NOEXCEPT
|
||||
: allocator_type(rhs),
|
||||
baseAllocator_(rhs.baseAllocator_)
|
||||
{
|
||||
}
|
||||
|
||||
#if RAPIDJSON_HAS_CXX11_RVALUE_REFS
|
||||
StdAllocator(StdAllocator&& rhs) RAPIDJSON_NOEXCEPT :
|
||||
allocator_type(std::move(rhs)),
|
||||
baseAllocator_(std::move(rhs.baseAllocator_))
|
||||
{ }
|
||||
StdAllocator(StdAllocator&& rhs) RAPIDJSON_NOEXCEPT
|
||||
: allocator_type(std::move(rhs)),
|
||||
baseAllocator_(std::move(rhs.baseAllocator_))
|
||||
{
|
||||
}
|
||||
#endif
|
||||
#if RAPIDJSON_HAS_CXX11
|
||||
using propagate_on_container_move_assignment = std::true_type;
|
||||
using propagate_on_container_swap = std::true_type;
|
||||
using propagate_on_container_swap = std::true_type;
|
||||
#endif
|
||||
|
||||
/* implicit */
|
||||
StdAllocator(const BaseAllocator& baseAllocator) RAPIDJSON_NOEXCEPT :
|
||||
allocator_type(),
|
||||
baseAllocator_(baseAllocator)
|
||||
{ }
|
||||
StdAllocator(const BaseAllocator& baseAllocator) RAPIDJSON_NOEXCEPT
|
||||
: allocator_type(),
|
||||
baseAllocator_(baseAllocator)
|
||||
{
|
||||
}
|
||||
|
||||
~StdAllocator() RAPIDJSON_NOEXCEPT
|
||||
{ }
|
||||
~StdAllocator() RAPIDJSON_NOEXCEPT {}
|
||||
|
||||
template<typename U>
|
||||
struct rebind {
|
||||
template <typename U>
|
||||
struct rebind
|
||||
{
|
||||
typedef StdAllocator<U, BaseAllocator> other;
|
||||
};
|
||||
|
||||
typedef typename traits_type::size_type size_type;
|
||||
typedef typename traits_type::difference_type difference_type;
|
||||
typedef typename traits_type::size_type size_type;
|
||||
typedef typename traits_type::difference_type difference_type;
|
||||
|
||||
typedef typename traits_type::value_type value_type;
|
||||
typedef typename traits_type::pointer pointer;
|
||||
typedef typename traits_type::const_pointer const_pointer;
|
||||
typedef typename traits_type::value_type value_type;
|
||||
typedef typename traits_type::pointer pointer;
|
||||
typedef typename traits_type::const_pointer const_pointer;
|
||||
|
||||
#if RAPIDJSON_HAS_CXX11
|
||||
|
||||
typedef typename std::add_lvalue_reference<value_type>::type &reference;
|
||||
typedef typename std::add_lvalue_reference<typename std::add_const<value_type>::type>::type &const_reference;
|
||||
typedef typename std::add_lvalue_reference<value_type>::type& reference;
|
||||
typedef typename std::add_lvalue_reference<typename std::add_const<value_type>::type>::type&
|
||||
const_reference;
|
||||
|
||||
pointer address(reference r) const RAPIDJSON_NOEXCEPT
|
||||
{
|
||||
return std::addressof(r);
|
||||
}
|
||||
const_pointer address(const_reference r) const RAPIDJSON_NOEXCEPT
|
||||
{
|
||||
return std::addressof(r);
|
||||
}
|
||||
pointer address(reference r) const RAPIDJSON_NOEXCEPT { return std::addressof(r); }
|
||||
const_pointer address(const_reference r) const RAPIDJSON_NOEXCEPT { return std::addressof(r); }
|
||||
|
||||
size_type max_size() const RAPIDJSON_NOEXCEPT
|
||||
{
|
||||
return traits_type::max_size(*this);
|
||||
}
|
||||
size_type max_size() const RAPIDJSON_NOEXCEPT { return traits_type::max_size(*this); }
|
||||
|
||||
template <typename ...Args>
|
||||
template <typename... Args>
|
||||
void construct(pointer p, Args&&... args)
|
||||
{
|
||||
traits_type::construct(*this, p, std::forward<Args>(args)...);
|
||||
}
|
||||
void destroy(pointer p)
|
||||
{
|
||||
traits_type::destroy(*this, p);
|
||||
}
|
||||
void destroy(pointer p) { traits_type::destroy(*this, p); }
|
||||
|
||||
#else // !RAPIDJSON_HAS_CXX11
|
||||
|
||||
typedef typename allocator_type::reference reference;
|
||||
typedef typename allocator_type::reference reference;
|
||||
typedef typename allocator_type::const_reference const_reference;
|
||||
|
||||
pointer address(reference r) const RAPIDJSON_NOEXCEPT
|
||||
{
|
||||
return allocator_type::address(r);
|
||||
}
|
||||
pointer address(reference r) const RAPIDJSON_NOEXCEPT { return allocator_type::address(r); }
|
||||
const_pointer address(const_reference r) const RAPIDJSON_NOEXCEPT
|
||||
{
|
||||
return allocator_type::address(r);
|
||||
}
|
||||
|
||||
size_type max_size() const RAPIDJSON_NOEXCEPT
|
||||
{
|
||||
return allocator_type::max_size();
|
||||
}
|
||||
size_type max_size() const RAPIDJSON_NOEXCEPT { return allocator_type::max_size(); }
|
||||
|
||||
void construct(pointer p, const_reference r)
|
||||
{
|
||||
allocator_type::construct(p, r);
|
||||
}
|
||||
void destroy(pointer p)
|
||||
{
|
||||
allocator_type::destroy(p);
|
||||
}
|
||||
void construct(pointer p, const_reference r) { allocator_type::construct(p, r); }
|
||||
void destroy(pointer p) { allocator_type::destroy(p); }
|
||||
|
||||
#endif // !RAPIDJSON_HAS_CXX11
|
||||
|
||||
@@ -587,47 +600,35 @@ public:
|
||||
RAPIDJSON_NAMESPACE::Free<U>(baseAllocator_, p, n);
|
||||
}
|
||||
|
||||
pointer allocate(size_type n = 1, const void* = 0)
|
||||
{
|
||||
return allocate<value_type>(n);
|
||||
}
|
||||
void deallocate(pointer p, size_type n = 1)
|
||||
{
|
||||
deallocate<value_type>(p, n);
|
||||
}
|
||||
pointer allocate(size_type n = 1, const void* = 0) { return allocate<value_type>(n); }
|
||||
void deallocate(pointer p, size_type n = 1) { deallocate<value_type>(p, n); }
|
||||
|
||||
#if RAPIDJSON_HAS_CXX11
|
||||
using is_always_equal = std::is_empty<BaseAllocator>;
|
||||
#endif
|
||||
|
||||
template<typename U>
|
||||
template <typename U>
|
||||
bool operator==(const StdAllocator<U, BaseAllocator>& rhs) const RAPIDJSON_NOEXCEPT
|
||||
{
|
||||
return baseAllocator_ == rhs.baseAllocator_;
|
||||
}
|
||||
template<typename U>
|
||||
template <typename U>
|
||||
bool operator!=(const StdAllocator<U, BaseAllocator>& rhs) const RAPIDJSON_NOEXCEPT
|
||||
{
|
||||
return !operator==(rhs);
|
||||
}
|
||||
|
||||
//! rapidjson Allocator concept
|
||||
static const bool kNeedFree = BaseAllocator::kNeedFree;
|
||||
static const bool kNeedFree = BaseAllocator::kNeedFree;
|
||||
static const bool kRefCounted = internal::IsRefCounted<BaseAllocator>::Value;
|
||||
void* Malloc(size_t size)
|
||||
{
|
||||
return baseAllocator_.Malloc(size);
|
||||
}
|
||||
void* Malloc(size_t size) { return baseAllocator_.Malloc(size); }
|
||||
void* Realloc(void* originalPtr, size_t originalSize, size_t newSize)
|
||||
{
|
||||
return baseAllocator_.Realloc(originalPtr, originalSize, newSize);
|
||||
}
|
||||
static void Free(void *ptr) RAPIDJSON_NOEXCEPT
|
||||
{
|
||||
BaseAllocator::Free(ptr);
|
||||
}
|
||||
static void Free(void* ptr) RAPIDJSON_NOEXCEPT { BaseAllocator::Free(ptr); }
|
||||
|
||||
private:
|
||||
private:
|
||||
template <typename, typename>
|
||||
friend class StdAllocator; // access to StdAllocator<!T>.*
|
||||
|
||||
@@ -636,47 +637,45 @@ private:
|
||||
|
||||
#if !RAPIDJSON_HAS_CXX17 // std::allocator<void> deprecated in C++17
|
||||
template <typename BaseAllocator>
|
||||
class StdAllocator<void, BaseAllocator> :
|
||||
public std::allocator<void>
|
||||
class StdAllocator<void, BaseAllocator> : public std::allocator<void>
|
||||
{
|
||||
typedef std::allocator<void> allocator_type;
|
||||
|
||||
public:
|
||||
public:
|
||||
typedef BaseAllocator BaseAllocatorType;
|
||||
|
||||
StdAllocator() RAPIDJSON_NOEXCEPT :
|
||||
allocator_type(),
|
||||
baseAllocator_()
|
||||
{ }
|
||||
StdAllocator() RAPIDJSON_NOEXCEPT : allocator_type(), baseAllocator_() {}
|
||||
|
||||
StdAllocator(const StdAllocator& rhs) RAPIDJSON_NOEXCEPT :
|
||||
allocator_type(rhs),
|
||||
baseAllocator_(rhs.baseAllocator_)
|
||||
{ }
|
||||
StdAllocator(const StdAllocator& rhs) RAPIDJSON_NOEXCEPT : allocator_type(rhs),
|
||||
baseAllocator_(rhs.baseAllocator_)
|
||||
{
|
||||
}
|
||||
|
||||
template<typename U>
|
||||
StdAllocator(const StdAllocator<U, BaseAllocator>& rhs) RAPIDJSON_NOEXCEPT :
|
||||
allocator_type(rhs),
|
||||
baseAllocator_(rhs.baseAllocator_)
|
||||
{ }
|
||||
template <typename U>
|
||||
StdAllocator(const StdAllocator<U, BaseAllocator>& rhs) RAPIDJSON_NOEXCEPT
|
||||
: allocator_type(rhs),
|
||||
baseAllocator_(rhs.baseAllocator_)
|
||||
{
|
||||
}
|
||||
|
||||
/* implicit */
|
||||
StdAllocator(const BaseAllocator& baseAllocator) RAPIDJSON_NOEXCEPT :
|
||||
allocator_type(),
|
||||
baseAllocator_(baseAllocator)
|
||||
{ }
|
||||
StdAllocator(const BaseAllocator& baseAllocator) RAPIDJSON_NOEXCEPT
|
||||
: allocator_type(),
|
||||
baseAllocator_(baseAllocator)
|
||||
{
|
||||
}
|
||||
|
||||
~StdAllocator() RAPIDJSON_NOEXCEPT
|
||||
{ }
|
||||
~StdAllocator() RAPIDJSON_NOEXCEPT {}
|
||||
|
||||
template<typename U>
|
||||
struct rebind {
|
||||
template <typename U>
|
||||
struct rebind
|
||||
{
|
||||
typedef StdAllocator<U, BaseAllocator> other;
|
||||
};
|
||||
|
||||
typedef typename allocator_type::value_type value_type;
|
||||
|
||||
private:
|
||||
private:
|
||||
template <typename, typename>
|
||||
friend class StdAllocator; // access to StdAllocator<!T>.*
|
||||
|
||||
|
||||
@@ -24,33 +24,39 @@ RAPIDJSON_DIAG_OFF(effc++)
|
||||
|
||||
#if defined(_MSC_VER) && _MSC_VER <= 1800
|
||||
RAPIDJSON_DIAG_PUSH
|
||||
RAPIDJSON_DIAG_OFF(4702) // unreachable code
|
||||
RAPIDJSON_DIAG_OFF(4512) // assignment operator could not be generated
|
||||
RAPIDJSON_DIAG_OFF(4702) // unreachable code
|
||||
RAPIDJSON_DIAG_OFF(4512) // assignment operator could not be generated
|
||||
#endif
|
||||
|
||||
RAPIDJSON_NAMESPACE_BEGIN
|
||||
|
||||
|
||||
//! Cursor stream wrapper for counting line and column number if error exists.
|
||||
/*!
|
||||
\tparam InputStream Any stream that implements Stream Concept
|
||||
*/
|
||||
template <typename InputStream, typename Encoding = UTF8<> >
|
||||
class CursorStreamWrapper : public GenericStreamWrapper<InputStream, Encoding> {
|
||||
public:
|
||||
template <typename InputStream, typename Encoding = UTF8<>>
|
||||
class CursorStreamWrapper : public GenericStreamWrapper<InputStream, Encoding>
|
||||
{
|
||||
public:
|
||||
typedef typename Encoding::Ch Ch;
|
||||
|
||||
CursorStreamWrapper(InputStream& is):
|
||||
GenericStreamWrapper<InputStream, Encoding>(is), line_(1), col_(0) {}
|
||||
CursorStreamWrapper(InputStream& is)
|
||||
: GenericStreamWrapper<InputStream, Encoding>(is), line_(1), col_(0)
|
||||
{
|
||||
}
|
||||
|
||||
// counting line and column number
|
||||
Ch Take() {
|
||||
Ch Take()
|
||||
{
|
||||
Ch ch = this->is_.Take();
|
||||
if(ch == '\n') {
|
||||
line_ ++;
|
||||
if(ch == '\n')
|
||||
{
|
||||
line_++;
|
||||
col_ = 0;
|
||||
} else {
|
||||
col_ ++;
|
||||
}
|
||||
else
|
||||
{
|
||||
col_++;
|
||||
}
|
||||
return ch;
|
||||
}
|
||||
@@ -60,9 +66,9 @@ public:
|
||||
//! Get the error column number, if error exists.
|
||||
size_t GetColumn() const { return col_; }
|
||||
|
||||
private:
|
||||
size_t line_; //!< Current Line
|
||||
size_t col_; //!< Current Column
|
||||
private:
|
||||
size_t line_; //!< Current Line
|
||||
size_t col_; //!< Current Column
|
||||
};
|
||||
|
||||
#if defined(_MSC_VER) && _MSC_VER <= 1800
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,5 @@
|
||||
// Tencent is pleased to support the open source community by making RapidJSON available.
|
||||
//
|
||||
//
|
||||
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
|
||||
//
|
||||
// Licensed under the MIT License (the "License"); you may not use this file except
|
||||
@@ -7,9 +7,9 @@
|
||||
//
|
||||
// http://opensource.org/licenses/MIT
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations under the License.
|
||||
|
||||
#ifndef RAPIDJSON_ENCODEDSTREAM_H_
|
||||
@@ -32,30 +32,43 @@ RAPIDJSON_NAMESPACE_BEGIN
|
||||
|
||||
//! Input byte stream wrapper with a statically bound encoding.
|
||||
/*!
|
||||
\tparam Encoding The interpretation of encoding of the stream. Either UTF8, UTF16LE, UTF16BE, UTF32LE, UTF32BE.
|
||||
\tparam InputByteStream Type of input byte stream. For example, FileReadStream.
|
||||
\tparam Encoding The interpretation of encoding of the stream. Either UTF8, UTF16LE, UTF16BE,
|
||||
UTF32LE, UTF32BE. \tparam InputByteStream Type of input byte stream. For example, FileReadStream.
|
||||
*/
|
||||
template <typename Encoding, typename InputByteStream>
|
||||
class EncodedInputStream {
|
||||
class EncodedInputStream
|
||||
{
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1);
|
||||
public:
|
||||
|
||||
public:
|
||||
typedef typename Encoding::Ch Ch;
|
||||
|
||||
EncodedInputStream(InputByteStream& is) : is_(is) {
|
||||
current_ = Encoding::TakeBOM(is_);
|
||||
}
|
||||
EncodedInputStream(InputByteStream& is) : is_(is) { current_ = Encoding::TakeBOM(is_); }
|
||||
|
||||
Ch Peek() const { return current_; }
|
||||
Ch Take() { Ch c = current_; current_ = Encoding::Take(is_); return c; }
|
||||
Ch Take()
|
||||
{
|
||||
Ch c = current_;
|
||||
current_ = Encoding::Take(is_);
|
||||
return c;
|
||||
}
|
||||
size_t Tell() const { return is_.Tell(); }
|
||||
|
||||
// Not implemented
|
||||
void Put(Ch) { RAPIDJSON_ASSERT(false); }
|
||||
void Flush() { RAPIDJSON_ASSERT(false); }
|
||||
Ch* PutBegin() { RAPIDJSON_ASSERT(false); return 0; }
|
||||
size_t PutEnd(Ch*) { RAPIDJSON_ASSERT(false); return 0; }
|
||||
void Flush() { RAPIDJSON_ASSERT(false); }
|
||||
Ch* PutBegin()
|
||||
{
|
||||
RAPIDJSON_ASSERT(false);
|
||||
return 0;
|
||||
}
|
||||
size_t PutEnd(Ch*)
|
||||
{
|
||||
RAPIDJSON_ASSERT(false);
|
||||
return 0;
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
EncodedInputStream(const EncodedInputStream&);
|
||||
EncodedInputStream& operator=(const EncodedInputStream&);
|
||||
|
||||
@@ -65,14 +78,19 @@ private:
|
||||
|
||||
//! Specialized for UTF8 MemoryStream.
|
||||
template <>
|
||||
class EncodedInputStream<UTF8<>, MemoryStream> {
|
||||
public:
|
||||
class EncodedInputStream<UTF8<>, MemoryStream>
|
||||
{
|
||||
public:
|
||||
typedef UTF8<>::Ch Ch;
|
||||
|
||||
EncodedInputStream(MemoryStream& is) : is_(is) {
|
||||
if (static_cast<unsigned char>(is_.Peek()) == 0xEFu) is_.Take();
|
||||
if (static_cast<unsigned char>(is_.Peek()) == 0xBBu) is_.Take();
|
||||
if (static_cast<unsigned char>(is_.Peek()) == 0xBFu) is_.Take();
|
||||
EncodedInputStream(MemoryStream& is) : is_(is)
|
||||
{
|
||||
if(static_cast<unsigned char>(is_.Peek()) == 0xEFu)
|
||||
is_.Take();
|
||||
if(static_cast<unsigned char>(is_.Peek()) == 0xBBu)
|
||||
is_.Take();
|
||||
if(static_cast<unsigned char>(is_.Peek()) == 0xBFu)
|
||||
is_.Take();
|
||||
}
|
||||
Ch Peek() const { return is_.Peek(); }
|
||||
Ch Take() { return is_.Take(); }
|
||||
@@ -80,51 +98,76 @@ public:
|
||||
|
||||
// Not implemented
|
||||
void Put(Ch) {}
|
||||
void Flush() {}
|
||||
void Flush() {}
|
||||
Ch* PutBegin() { return 0; }
|
||||
size_t PutEnd(Ch*) { return 0; }
|
||||
|
||||
MemoryStream& is_;
|
||||
|
||||
private:
|
||||
private:
|
||||
EncodedInputStream(const EncodedInputStream&);
|
||||
EncodedInputStream& operator=(const EncodedInputStream&);
|
||||
};
|
||||
|
||||
//! Output byte stream wrapper with statically bound encoding.
|
||||
/*!
|
||||
\tparam Encoding The interpretation of encoding of the stream. Either UTF8, UTF16LE, UTF16BE, UTF32LE, UTF32BE.
|
||||
\tparam OutputByteStream Type of input byte stream. For example, FileWriteStream.
|
||||
\tparam Encoding The interpretation of encoding of the stream. Either UTF8, UTF16LE, UTF16BE,
|
||||
UTF32LE, UTF32BE. \tparam OutputByteStream Type of input byte stream. For example,
|
||||
FileWriteStream.
|
||||
*/
|
||||
template <typename Encoding, typename OutputByteStream>
|
||||
class EncodedOutputStream {
|
||||
class EncodedOutputStream
|
||||
{
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1);
|
||||
public:
|
||||
|
||||
public:
|
||||
typedef typename Encoding::Ch Ch;
|
||||
|
||||
EncodedOutputStream(OutputByteStream& os, bool putBOM = true) : os_(os) {
|
||||
if (putBOM)
|
||||
EncodedOutputStream(OutputByteStream& os, bool putBOM = true) : os_(os)
|
||||
{
|
||||
if(putBOM)
|
||||
Encoding::PutBOM(os_);
|
||||
}
|
||||
|
||||
void Put(Ch c) { Encoding::Put(os_, c); }
|
||||
void Put(Ch c) { Encoding::Put(os_, c); }
|
||||
void Flush() { os_.Flush(); }
|
||||
|
||||
// Not implemented
|
||||
Ch Peek() const { RAPIDJSON_ASSERT(false); return 0;}
|
||||
Ch Take() { RAPIDJSON_ASSERT(false); return 0;}
|
||||
size_t Tell() const { RAPIDJSON_ASSERT(false); return 0; }
|
||||
Ch* PutBegin() { RAPIDJSON_ASSERT(false); return 0; }
|
||||
size_t PutEnd(Ch*) { RAPIDJSON_ASSERT(false); return 0; }
|
||||
Ch Peek() const
|
||||
{
|
||||
RAPIDJSON_ASSERT(false);
|
||||
return 0;
|
||||
}
|
||||
Ch Take()
|
||||
{
|
||||
RAPIDJSON_ASSERT(false);
|
||||
return 0;
|
||||
}
|
||||
size_t Tell() const
|
||||
{
|
||||
RAPIDJSON_ASSERT(false);
|
||||
return 0;
|
||||
}
|
||||
Ch* PutBegin()
|
||||
{
|
||||
RAPIDJSON_ASSERT(false);
|
||||
return 0;
|
||||
}
|
||||
size_t PutEnd(Ch*)
|
||||
{
|
||||
RAPIDJSON_ASSERT(false);
|
||||
return 0;
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
EncodedOutputStream(const EncodedOutputStream&);
|
||||
EncodedOutputStream& operator=(const EncodedOutputStream&);
|
||||
|
||||
OutputByteStream& os_;
|
||||
};
|
||||
|
||||
#define RAPIDJSON_ENCODINGS_FUNC(x) UTF8<Ch>::x, UTF16LE<Ch>::x, UTF16BE<Ch>::x, UTF32LE<Ch>::x, UTF32BE<Ch>::x
|
||||
#define RAPIDJSON_ENCODINGS_FUNC(x) \
|
||||
UTF8<Ch>::x, UTF16LE<Ch>::x, UTF16BE<Ch>::x, UTF32LE<Ch>::x, UTF32BE<Ch>::x
|
||||
|
||||
//! Input stream wrapper with dynamically bound encoding and automatic encoding detection.
|
||||
/*!
|
||||
@@ -132,9 +175,11 @@ private:
|
||||
\tparam InputByteStream type of input byte stream to be wrapped.
|
||||
*/
|
||||
template <typename CharType, typename InputByteStream>
|
||||
class AutoUTFInputStream {
|
||||
class AutoUTFInputStream
|
||||
{
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1);
|
||||
public:
|
||||
|
||||
public:
|
||||
typedef CharType Ch;
|
||||
|
||||
//! Constructor.
|
||||
@@ -142,33 +187,49 @@ public:
|
||||
\param is input stream to be wrapped.
|
||||
\param type UTF encoding type if it is not detected from the stream.
|
||||
*/
|
||||
AutoUTFInputStream(InputByteStream& is, UTFType type = kUTF8) : is_(&is), type_(type), hasBOM_(false) {
|
||||
RAPIDJSON_ASSERT(type >= kUTF8 && type <= kUTF32BE);
|
||||
AutoUTFInputStream(InputByteStream& is, UTFType type = kUTF8)
|
||||
: is_(&is), type_(type), hasBOM_(false)
|
||||
{
|
||||
RAPIDJSON_ASSERT(type >= kUTF8 && type <= kUTF32BE);
|
||||
DetectType();
|
||||
static const TakeFunc f[] = { RAPIDJSON_ENCODINGS_FUNC(Take) };
|
||||
takeFunc_ = f[type_];
|
||||
current_ = takeFunc_(*is_);
|
||||
static const TakeFunc f[] = {RAPIDJSON_ENCODINGS_FUNC(Take)};
|
||||
takeFunc_ = f[type_];
|
||||
current_ = takeFunc_(*is_);
|
||||
}
|
||||
|
||||
UTFType GetType() const { return type_; }
|
||||
bool HasBOM() const { return hasBOM_; }
|
||||
|
||||
Ch Peek() const { return current_; }
|
||||
Ch Take() { Ch c = current_; current_ = takeFunc_(*is_); return c; }
|
||||
Ch Take()
|
||||
{
|
||||
Ch c = current_;
|
||||
current_ = takeFunc_(*is_);
|
||||
return c;
|
||||
}
|
||||
size_t Tell() const { return is_->Tell(); }
|
||||
|
||||
// Not implemented
|
||||
void Put(Ch) { RAPIDJSON_ASSERT(false); }
|
||||
void Flush() { RAPIDJSON_ASSERT(false); }
|
||||
Ch* PutBegin() { RAPIDJSON_ASSERT(false); return 0; }
|
||||
size_t PutEnd(Ch*) { RAPIDJSON_ASSERT(false); return 0; }
|
||||
void Flush() { RAPIDJSON_ASSERT(false); }
|
||||
Ch* PutBegin()
|
||||
{
|
||||
RAPIDJSON_ASSERT(false);
|
||||
return 0;
|
||||
}
|
||||
size_t PutEnd(Ch*)
|
||||
{
|
||||
RAPIDJSON_ASSERT(false);
|
||||
return 0;
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
AutoUTFInputStream(const AutoUTFInputStream&);
|
||||
AutoUTFInputStream& operator=(const AutoUTFInputStream&);
|
||||
|
||||
// Detect encoding type with BOM or RFC 4627
|
||||
void DetectType() {
|
||||
void DetectType()
|
||||
{
|
||||
// BOM (Byte Order Mark):
|
||||
// 00 00 FE FF UTF-32BE
|
||||
// FF FE 00 00 UTF-32LE
|
||||
@@ -176,17 +237,52 @@ private:
|
||||
// FF FE UTF-16LE
|
||||
// EF BB BF UTF-8
|
||||
|
||||
const unsigned char* c = reinterpret_cast<const unsigned char *>(is_->Peek4());
|
||||
if (!c)
|
||||
const unsigned char* c = reinterpret_cast<const unsigned char*>(is_->Peek4());
|
||||
if(!c)
|
||||
return;
|
||||
|
||||
unsigned bom = static_cast<unsigned>(c[0] | (c[1] << 8) | (c[2] << 16) | (c[3] << 24));
|
||||
hasBOM_ = false;
|
||||
if (bom == 0xFFFE0000) { type_ = kUTF32BE; hasBOM_ = true; is_->Take(); is_->Take(); is_->Take(); is_->Take(); }
|
||||
else if (bom == 0x0000FEFF) { type_ = kUTF32LE; hasBOM_ = true; is_->Take(); is_->Take(); is_->Take(); is_->Take(); }
|
||||
else if ((bom & 0xFFFF) == 0xFFFE) { type_ = kUTF16BE; hasBOM_ = true; is_->Take(); is_->Take(); }
|
||||
else if ((bom & 0xFFFF) == 0xFEFF) { type_ = kUTF16LE; hasBOM_ = true; is_->Take(); is_->Take(); }
|
||||
else if ((bom & 0xFFFFFF) == 0xBFBBEF) { type_ = kUTF8; hasBOM_ = true; is_->Take(); is_->Take(); is_->Take(); }
|
||||
hasBOM_ = false;
|
||||
if(bom == 0xFFFE0000)
|
||||
{
|
||||
type_ = kUTF32BE;
|
||||
hasBOM_ = true;
|
||||
is_->Take();
|
||||
is_->Take();
|
||||
is_->Take();
|
||||
is_->Take();
|
||||
}
|
||||
else if(bom == 0x0000FEFF)
|
||||
{
|
||||
type_ = kUTF32LE;
|
||||
hasBOM_ = true;
|
||||
is_->Take();
|
||||
is_->Take();
|
||||
is_->Take();
|
||||
is_->Take();
|
||||
}
|
||||
else if((bom & 0xFFFF) == 0xFFFE)
|
||||
{
|
||||
type_ = kUTF16BE;
|
||||
hasBOM_ = true;
|
||||
is_->Take();
|
||||
is_->Take();
|
||||
}
|
||||
else if((bom & 0xFFFF) == 0xFEFF)
|
||||
{
|
||||
type_ = kUTF16LE;
|
||||
hasBOM_ = true;
|
||||
is_->Take();
|
||||
is_->Take();
|
||||
}
|
||||
else if((bom & 0xFFFFFF) == 0xBFBBEF)
|
||||
{
|
||||
type_ = kUTF8;
|
||||
hasBOM_ = true;
|
||||
is_->Take();
|
||||
is_->Take();
|
||||
is_->Take();
|
||||
}
|
||||
|
||||
// RFC 4627: Section 3
|
||||
// "Since the first two characters of a JSON text will always be ASCII
|
||||
@@ -199,21 +295,26 @@ private:
|
||||
// xx 00 xx 00 UTF-16LE
|
||||
// xx xx xx xx UTF-8
|
||||
|
||||
if (!hasBOM_) {
|
||||
if(!hasBOM_)
|
||||
{
|
||||
int pattern = (c[0] ? 1 : 0) | (c[1] ? 2 : 0) | (c[2] ? 4 : 0) | (c[3] ? 8 : 0);
|
||||
switch (pattern) {
|
||||
switch(pattern)
|
||||
{
|
||||
case 0x08: type_ = kUTF32BE; break;
|
||||
case 0x0A: type_ = kUTF16BE; break;
|
||||
case 0x01: type_ = kUTF32LE; break;
|
||||
case 0x05: type_ = kUTF16LE; break;
|
||||
case 0x0F: type_ = kUTF8; break;
|
||||
case 0x0F: type_ = kUTF8; break;
|
||||
default: break; // Use type defined by user.
|
||||
}
|
||||
}
|
||||
|
||||
// Runtime check whether the size of character type is sufficient. It only perform checks with assertion.
|
||||
if (type_ == kUTF16LE || type_ == kUTF16BE) RAPIDJSON_ASSERT(sizeof(Ch) >= 2);
|
||||
if (type_ == kUTF32LE || type_ == kUTF32BE) RAPIDJSON_ASSERT(sizeof(Ch) >= 4);
|
||||
// Runtime check whether the size of character type is sufficient. It only perform checks
|
||||
// with assertion.
|
||||
if(type_ == kUTF16LE || type_ == kUTF16BE)
|
||||
RAPIDJSON_ASSERT(sizeof(Ch) >= 2);
|
||||
if(type_ == kUTF32LE || type_ == kUTF32BE)
|
||||
RAPIDJSON_ASSERT(sizeof(Ch) >= 4);
|
||||
}
|
||||
|
||||
typedef Ch (*TakeFunc)(InputByteStream& is);
|
||||
@@ -230,9 +331,11 @@ private:
|
||||
\tparam OutputByteStream type of output byte stream to be wrapped.
|
||||
*/
|
||||
template <typename CharType, typename OutputByteStream>
|
||||
class AutoUTFOutputStream {
|
||||
class AutoUTFOutputStream
|
||||
{
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1);
|
||||
public:
|
||||
|
||||
public:
|
||||
typedef CharType Ch;
|
||||
|
||||
//! Constructor.
|
||||
@@ -241,39 +344,64 @@ public:
|
||||
\param type UTF encoding type.
|
||||
\param putBOM Whether to write BOM at the beginning of the stream.
|
||||
*/
|
||||
AutoUTFOutputStream(OutputByteStream& os, UTFType type, bool putBOM) : os_(&os), type_(type) {
|
||||
AutoUTFOutputStream(OutputByteStream& os, UTFType type, bool putBOM) : os_(&os), type_(type)
|
||||
{
|
||||
RAPIDJSON_ASSERT(type >= kUTF8 && type <= kUTF32BE);
|
||||
|
||||
// Runtime check whether the size of character type is sufficient. It only perform checks with assertion.
|
||||
if (type_ == kUTF16LE || type_ == kUTF16BE) RAPIDJSON_ASSERT(sizeof(Ch) >= 2);
|
||||
if (type_ == kUTF32LE || type_ == kUTF32BE) RAPIDJSON_ASSERT(sizeof(Ch) >= 4);
|
||||
// Runtime check whether the size of character type is sufficient. It only perform checks
|
||||
// with assertion.
|
||||
if(type_ == kUTF16LE || type_ == kUTF16BE)
|
||||
RAPIDJSON_ASSERT(sizeof(Ch) >= 2);
|
||||
if(type_ == kUTF32LE || type_ == kUTF32BE)
|
||||
RAPIDJSON_ASSERT(sizeof(Ch) >= 4);
|
||||
|
||||
static const PutFunc f[] = { RAPIDJSON_ENCODINGS_FUNC(Put) };
|
||||
putFunc_ = f[type_];
|
||||
static const PutFunc f[] = {RAPIDJSON_ENCODINGS_FUNC(Put)};
|
||||
putFunc_ = f[type_];
|
||||
|
||||
if (putBOM)
|
||||
if(putBOM)
|
||||
PutBOM();
|
||||
}
|
||||
|
||||
UTFType GetType() const { return type_; }
|
||||
|
||||
void Put(Ch c) { putFunc_(*os_, c); }
|
||||
void Flush() { os_->Flush(); }
|
||||
void Flush() { os_->Flush(); }
|
||||
|
||||
// Not implemented
|
||||
Ch Peek() const { RAPIDJSON_ASSERT(false); return 0;}
|
||||
Ch Take() { RAPIDJSON_ASSERT(false); return 0;}
|
||||
size_t Tell() const { RAPIDJSON_ASSERT(false); return 0; }
|
||||
Ch* PutBegin() { RAPIDJSON_ASSERT(false); return 0; }
|
||||
size_t PutEnd(Ch*) { RAPIDJSON_ASSERT(false); return 0; }
|
||||
Ch Peek() const
|
||||
{
|
||||
RAPIDJSON_ASSERT(false);
|
||||
return 0;
|
||||
}
|
||||
Ch Take()
|
||||
{
|
||||
RAPIDJSON_ASSERT(false);
|
||||
return 0;
|
||||
}
|
||||
size_t Tell() const
|
||||
{
|
||||
RAPIDJSON_ASSERT(false);
|
||||
return 0;
|
||||
}
|
||||
Ch* PutBegin()
|
||||
{
|
||||
RAPIDJSON_ASSERT(false);
|
||||
return 0;
|
||||
}
|
||||
size_t PutEnd(Ch*)
|
||||
{
|
||||
RAPIDJSON_ASSERT(false);
|
||||
return 0;
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
AutoUTFOutputStream(const AutoUTFOutputStream&);
|
||||
AutoUTFOutputStream& operator=(const AutoUTFOutputStream&);
|
||||
|
||||
void PutBOM() {
|
||||
void PutBOM()
|
||||
{
|
||||
typedef void (*PutBOMFunc)(OutputByteStream&);
|
||||
static const PutBOMFunc f[] = { RAPIDJSON_ENCODINGS_FUNC(PutBOM) };
|
||||
static const PutBOMFunc f[] = {RAPIDJSON_ENCODINGS_FUNC(PutBOM)};
|
||||
f[type_](*os_);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// Tencent is pleased to support the open source community by making RapidJSON available.
|
||||
//
|
||||
//
|
||||
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
|
||||
//
|
||||
// Licensed under the MIT License (the "License"); you may not use this file except
|
||||
@@ -7,9 +7,9 @@
|
||||
//
|
||||
// http://opensource.org/licenses/MIT
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations under the License.
|
||||
|
||||
#ifndef RAPIDJSON_ENCODINGS_H_
|
||||
@@ -20,7 +20,7 @@
|
||||
#if defined(_MSC_VER) && !defined(__clang__)
|
||||
RAPIDJSON_DIAG_PUSH
|
||||
RAPIDJSON_DIAG_OFF(4244) // conversion from 'type1' to 'type2', possible loss of data
|
||||
RAPIDJSON_DIAG_OFF(4702) // unreachable code
|
||||
RAPIDJSON_DIAG_OFF(4702) // unreachable code
|
||||
#elif defined(__GNUC__)
|
||||
RAPIDJSON_DIAG_PUSH
|
||||
RAPIDJSON_DIAG_OFF(effc++)
|
||||
@@ -37,7 +37,8 @@ RAPIDJSON_NAMESPACE_BEGIN
|
||||
|
||||
\code
|
||||
concept Encoding {
|
||||
typename Ch; //! Type of character. A "character" is actually a code unit in unicode's definition.
|
||||
typename Ch; //! Type of character. A "character" is actually a code unit in unicode's
|
||||
definition.
|
||||
|
||||
enum { supportUnicode = 1 }; // or 0 if not supporting unicode
|
||||
|
||||
@@ -92,26 +93,34 @@ concept Encoding {
|
||||
\tparam CharType Code unit for storing 8-bit UTF-8 data. Default is char.
|
||||
\note implements Encoding concept
|
||||
*/
|
||||
template<typename CharType = char>
|
||||
struct UTF8 {
|
||||
template <typename CharType = char>
|
||||
struct UTF8
|
||||
{
|
||||
typedef CharType Ch;
|
||||
|
||||
enum { supportUnicode = 1 };
|
||||
enum
|
||||
{
|
||||
supportUnicode = 1
|
||||
};
|
||||
|
||||
template<typename OutputStream>
|
||||
static void Encode(OutputStream& os, unsigned codepoint) {
|
||||
if (codepoint <= 0x7F)
|
||||
template <typename OutputStream>
|
||||
static void Encode(OutputStream& os, unsigned codepoint)
|
||||
{
|
||||
if(codepoint <= 0x7F)
|
||||
os.Put(static_cast<Ch>(codepoint & 0xFF));
|
||||
else if (codepoint <= 0x7FF) {
|
||||
else if(codepoint <= 0x7FF)
|
||||
{
|
||||
os.Put(static_cast<Ch>(0xC0 | ((codepoint >> 6) & 0xFF)));
|
||||
os.Put(static_cast<Ch>(0x80 | ((codepoint & 0x3F))));
|
||||
}
|
||||
else if (codepoint <= 0xFFFF) {
|
||||
else if(codepoint <= 0xFFFF)
|
||||
{
|
||||
os.Put(static_cast<Ch>(0xE0 | ((codepoint >> 12) & 0xFF)));
|
||||
os.Put(static_cast<Ch>(0x80 | ((codepoint >> 6) & 0x3F)));
|
||||
os.Put(static_cast<Ch>(0x80 | (codepoint & 0x3F)));
|
||||
}
|
||||
else {
|
||||
else
|
||||
{
|
||||
RAPIDJSON_ASSERT(codepoint <= 0x10FFFF);
|
||||
os.Put(static_cast<Ch>(0xF0 | ((codepoint >> 18) & 0xFF)));
|
||||
os.Put(static_cast<Ch>(0x80 | ((codepoint >> 12) & 0x3F)));
|
||||
@@ -120,20 +129,24 @@ struct UTF8 {
|
||||
}
|
||||
}
|
||||
|
||||
template<typename OutputStream>
|
||||
static void EncodeUnsafe(OutputStream& os, unsigned codepoint) {
|
||||
if (codepoint <= 0x7F)
|
||||
template <typename OutputStream>
|
||||
static void EncodeUnsafe(OutputStream& os, unsigned codepoint)
|
||||
{
|
||||
if(codepoint <= 0x7F)
|
||||
PutUnsafe(os, static_cast<Ch>(codepoint & 0xFF));
|
||||
else if (codepoint <= 0x7FF) {
|
||||
else if(codepoint <= 0x7FF)
|
||||
{
|
||||
PutUnsafe(os, static_cast<Ch>(0xC0 | ((codepoint >> 6) & 0xFF)));
|
||||
PutUnsafe(os, static_cast<Ch>(0x80 | ((codepoint & 0x3F))));
|
||||
}
|
||||
else if (codepoint <= 0xFFFF) {
|
||||
else if(codepoint <= 0xFFFF)
|
||||
{
|
||||
PutUnsafe(os, static_cast<Ch>(0xE0 | ((codepoint >> 12) & 0xFF)));
|
||||
PutUnsafe(os, static_cast<Ch>(0x80 | ((codepoint >> 6) & 0x3F)));
|
||||
PutUnsafe(os, static_cast<Ch>(0x80 | (codepoint & 0x3F)));
|
||||
}
|
||||
else {
|
||||
else
|
||||
{
|
||||
RAPIDJSON_ASSERT(codepoint <= 0x10FFFF);
|
||||
PutUnsafe(os, static_cast<Ch>(0xF0 | ((codepoint >> 18) & 0xFF)));
|
||||
PutUnsafe(os, static_cast<Ch>(0x80 | ((codepoint >> 12) & 0x3F)));
|
||||
@@ -143,31 +156,66 @@ struct UTF8 {
|
||||
}
|
||||
|
||||
template <typename InputStream>
|
||||
static bool Decode(InputStream& is, unsigned* codepoint) {
|
||||
#define RAPIDJSON_COPY() c = is.Take(); *codepoint = (*codepoint << 6) | (static_cast<unsigned char>(c) & 0x3Fu)
|
||||
static bool Decode(InputStream& is, unsigned* codepoint)
|
||||
{
|
||||
#define RAPIDJSON_COPY() \
|
||||
c = is.Take(); \
|
||||
*codepoint = (*codepoint << 6) | (static_cast<unsigned char>(c) & 0x3Fu)
|
||||
#define RAPIDJSON_TRANS(mask) result &= ((GetRange(static_cast<unsigned char>(c)) & mask) != 0)
|
||||
#define RAPIDJSON_TAIL() RAPIDJSON_COPY(); RAPIDJSON_TRANS(0x70)
|
||||
#define RAPIDJSON_TAIL() \
|
||||
RAPIDJSON_COPY(); \
|
||||
RAPIDJSON_TRANS(0x70)
|
||||
typename InputStream::Ch c = is.Take();
|
||||
if (!(c & 0x80)) {
|
||||
if(!(c & 0x80))
|
||||
{
|
||||
*codepoint = static_cast<unsigned char>(c);
|
||||
return true;
|
||||
}
|
||||
|
||||
unsigned char type = GetRange(static_cast<unsigned char>(c));
|
||||
if (type >= 32) {
|
||||
if(type >= 32)
|
||||
{
|
||||
*codepoint = 0;
|
||||
} else {
|
||||
}
|
||||
else
|
||||
{
|
||||
*codepoint = (0xFFu >> type) & static_cast<unsigned char>(c);
|
||||
}
|
||||
bool result = true;
|
||||
switch (type) {
|
||||
switch(type)
|
||||
{
|
||||
case 2: RAPIDJSON_TAIL(); return result;
|
||||
case 3: RAPIDJSON_TAIL(); RAPIDJSON_TAIL(); return result;
|
||||
case 4: RAPIDJSON_COPY(); RAPIDJSON_TRANS(0x50); RAPIDJSON_TAIL(); return result;
|
||||
case 5: RAPIDJSON_COPY(); RAPIDJSON_TRANS(0x10); RAPIDJSON_TAIL(); RAPIDJSON_TAIL(); return result;
|
||||
case 6: RAPIDJSON_TAIL(); RAPIDJSON_TAIL(); RAPIDJSON_TAIL(); return result;
|
||||
case 10: RAPIDJSON_COPY(); RAPIDJSON_TRANS(0x20); RAPIDJSON_TAIL(); return result;
|
||||
case 11: RAPIDJSON_COPY(); RAPIDJSON_TRANS(0x60); RAPIDJSON_TAIL(); RAPIDJSON_TAIL(); return result;
|
||||
case 3:
|
||||
RAPIDJSON_TAIL();
|
||||
RAPIDJSON_TAIL();
|
||||
return result;
|
||||
case 4:
|
||||
RAPIDJSON_COPY();
|
||||
RAPIDJSON_TRANS(0x50);
|
||||
RAPIDJSON_TAIL();
|
||||
return result;
|
||||
case 5:
|
||||
RAPIDJSON_COPY();
|
||||
RAPIDJSON_TRANS(0x10);
|
||||
RAPIDJSON_TAIL();
|
||||
RAPIDJSON_TAIL();
|
||||
return result;
|
||||
case 6:
|
||||
RAPIDJSON_TAIL();
|
||||
RAPIDJSON_TAIL();
|
||||
RAPIDJSON_TAIL();
|
||||
return result;
|
||||
case 10:
|
||||
RAPIDJSON_COPY();
|
||||
RAPIDJSON_TRANS(0x20);
|
||||
RAPIDJSON_TAIL();
|
||||
return result;
|
||||
case 11:
|
||||
RAPIDJSON_COPY();
|
||||
RAPIDJSON_TRANS(0x60);
|
||||
RAPIDJSON_TAIL();
|
||||
RAPIDJSON_TAIL();
|
||||
return result;
|
||||
default: return false;
|
||||
}
|
||||
#undef RAPIDJSON_COPY
|
||||
@@ -176,24 +224,55 @@ struct UTF8 {
|
||||
}
|
||||
|
||||
template <typename InputStream, typename OutputStream>
|
||||
static bool Validate(InputStream& is, OutputStream& os) {
|
||||
#define RAPIDJSON_COPY() if (c != '\0') os.Put(c = is.Take())
|
||||
static bool Validate(InputStream& is, OutputStream& os)
|
||||
{
|
||||
#define RAPIDJSON_COPY() \
|
||||
if(c != '\0') \
|
||||
os.Put(c = is.Take())
|
||||
#define RAPIDJSON_TRANS(mask) result &= ((GetRange(static_cast<unsigned char>(c)) & mask) != 0)
|
||||
#define RAPIDJSON_TAIL() RAPIDJSON_COPY(); RAPIDJSON_TRANS(0x70)
|
||||
#define RAPIDJSON_TAIL() \
|
||||
RAPIDJSON_COPY(); \
|
||||
RAPIDJSON_TRANS(0x70)
|
||||
Ch c = static_cast<Ch>(-1);
|
||||
RAPIDJSON_COPY();
|
||||
if (!(c & 0x80))
|
||||
if(!(c & 0x80))
|
||||
return true;
|
||||
|
||||
bool result = true;
|
||||
switch (GetRange(static_cast<unsigned char>(c))) {
|
||||
switch(GetRange(static_cast<unsigned char>(c)))
|
||||
{
|
||||
case 2: RAPIDJSON_TAIL(); return result;
|
||||
case 3: RAPIDJSON_TAIL(); RAPIDJSON_TAIL(); return result;
|
||||
case 4: RAPIDJSON_COPY(); RAPIDJSON_TRANS(0x50); RAPIDJSON_TAIL(); return result;
|
||||
case 5: RAPIDJSON_COPY(); RAPIDJSON_TRANS(0x10); RAPIDJSON_TAIL(); RAPIDJSON_TAIL(); return result;
|
||||
case 6: RAPIDJSON_TAIL(); RAPIDJSON_TAIL(); RAPIDJSON_TAIL(); return result;
|
||||
case 10: RAPIDJSON_COPY(); RAPIDJSON_TRANS(0x20); RAPIDJSON_TAIL(); return result;
|
||||
case 11: RAPIDJSON_COPY(); RAPIDJSON_TRANS(0x60); RAPIDJSON_TAIL(); RAPIDJSON_TAIL(); return result;
|
||||
case 3:
|
||||
RAPIDJSON_TAIL();
|
||||
RAPIDJSON_TAIL();
|
||||
return result;
|
||||
case 4:
|
||||
RAPIDJSON_COPY();
|
||||
RAPIDJSON_TRANS(0x50);
|
||||
RAPIDJSON_TAIL();
|
||||
return result;
|
||||
case 5:
|
||||
RAPIDJSON_COPY();
|
||||
RAPIDJSON_TRANS(0x10);
|
||||
RAPIDJSON_TAIL();
|
||||
RAPIDJSON_TAIL();
|
||||
return result;
|
||||
case 6:
|
||||
RAPIDJSON_TAIL();
|
||||
RAPIDJSON_TAIL();
|
||||
RAPIDJSON_TAIL();
|
||||
return result;
|
||||
case 10:
|
||||
RAPIDJSON_COPY();
|
||||
RAPIDJSON_TRANS(0x20);
|
||||
RAPIDJSON_TAIL();
|
||||
return result;
|
||||
case 11:
|
||||
RAPIDJSON_COPY();
|
||||
RAPIDJSON_TRANS(0x60);
|
||||
RAPIDJSON_TAIL();
|
||||
RAPIDJSON_TAIL();
|
||||
return result;
|
||||
default: return false;
|
||||
}
|
||||
#undef RAPIDJSON_COPY
|
||||
@@ -201,45 +280,62 @@ struct UTF8 {
|
||||
#undef RAPIDJSON_TAIL
|
||||
}
|
||||
|
||||
static unsigned char GetRange(unsigned char c) {
|
||||
static unsigned char GetRange(unsigned char c)
|
||||
{
|
||||
// Referring to DFA of http://bjoern.hoehrmann.de/utf-8/decoder/dfa/
|
||||
// With new mapping 1 -> 0x10, 7 -> 0x20, 9 -> 0x40, such that AND operation can test multiple types.
|
||||
// With new mapping 1 -> 0x10, 7 -> 0x20, 9 -> 0x40, such that AND operation can test
|
||||
// multiple types.
|
||||
static const unsigned char type[] = {
|
||||
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
|
||||
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
|
||||
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
|
||||
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
|
||||
0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,
|
||||
0x40,0x40,0x40,0x40,0x40,0x40,0x40,0x40,0x40,0x40,0x40,0x40,0x40,0x40,0x40,0x40,
|
||||
0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,
|
||||
0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,
|
||||
8,8,2,2,2,2,2,2,2,2,2,2,2,2,2,2, 2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,
|
||||
10,3,3,3,3,3,3,3,3,3,3,3,3,4,3,3, 11,6,6,6,5,8,8,8,8,8,8,8,8,8,8,8,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10,
|
||||
0x10, 0x10, 0x10, 0x10, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40,
|
||||
0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
|
||||
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
|
||||
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 8, 8, 2, 2,
|
||||
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
|
||||
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
|
||||
10, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4,
|
||||
3, 3, 11, 6, 6, 6, 5, 8, 8, 8, 8, 8, 8, 8,
|
||||
8, 8, 8, 8,
|
||||
};
|
||||
return type[c];
|
||||
}
|
||||
|
||||
template <typename InputByteStream>
|
||||
static CharType TakeBOM(InputByteStream& is) {
|
||||
static CharType TakeBOM(InputByteStream& is)
|
||||
{
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1);
|
||||
typename InputByteStream::Ch c = Take(is);
|
||||
if (static_cast<unsigned char>(c) != 0xEFu) return c;
|
||||
if(static_cast<unsigned char>(c) != 0xEFu)
|
||||
return c;
|
||||
c = is.Take();
|
||||
if (static_cast<unsigned char>(c) != 0xBBu) return c;
|
||||
if(static_cast<unsigned char>(c) != 0xBBu)
|
||||
return c;
|
||||
c = is.Take();
|
||||
if (static_cast<unsigned char>(c) != 0xBFu) return c;
|
||||
if(static_cast<unsigned char>(c) != 0xBFu)
|
||||
return c;
|
||||
c = is.Take();
|
||||
return c;
|
||||
}
|
||||
|
||||
template <typename InputByteStream>
|
||||
static Ch Take(InputByteStream& is) {
|
||||
static Ch Take(InputByteStream& is)
|
||||
{
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1);
|
||||
return static_cast<Ch>(is.Take());
|
||||
}
|
||||
|
||||
template <typename OutputByteStream>
|
||||
static void PutBOM(OutputByteStream& os) {
|
||||
static void PutBOM(OutputByteStream& os)
|
||||
{
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1);
|
||||
os.Put(static_cast<typename OutputByteStream::Ch>(0xEFu));
|
||||
os.Put(static_cast<typename OutputByteStream::Ch>(0xBBu));
|
||||
@@ -247,7 +343,8 @@ struct UTF8 {
|
||||
}
|
||||
|
||||
template <typename OutputByteStream>
|
||||
static void Put(OutputByteStream& os, Ch c) {
|
||||
static void Put(OutputByteStream& os, Ch c)
|
||||
{
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1);
|
||||
os.Put(static_cast<typename OutputByteStream::Ch>(c));
|
||||
}
|
||||
@@ -259,27 +356,35 @@ struct UTF8 {
|
||||
//! UTF-16 encoding.
|
||||
/*! http://en.wikipedia.org/wiki/UTF-16
|
||||
http://tools.ietf.org/html/rfc2781
|
||||
\tparam CharType Type for storing 16-bit UTF-16 data. Default is wchar_t. C++11 may use char16_t instead.
|
||||
\note implements Encoding concept
|
||||
\tparam CharType Type for storing 16-bit UTF-16 data. Default is wchar_t. C++11 may use char16_t
|
||||
instead. \note implements Encoding concept
|
||||
|
||||
\note For in-memory access, no need to concern endianness. The code units and code points are represented by CPU's endianness.
|
||||
For streaming, use UTF16LE and UTF16BE, which handle endianness.
|
||||
\note For in-memory access, no need to concern endianness. The code units and code points are
|
||||
represented by CPU's endianness. For streaming, use UTF16LE and UTF16BE, which handle endianness.
|
||||
*/
|
||||
template<typename CharType = wchar_t>
|
||||
struct UTF16 {
|
||||
template <typename CharType = wchar_t>
|
||||
struct UTF16
|
||||
{
|
||||
typedef CharType Ch;
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(Ch) >= 2);
|
||||
|
||||
enum { supportUnicode = 1 };
|
||||
enum
|
||||
{
|
||||
supportUnicode = 1
|
||||
};
|
||||
|
||||
template<typename OutputStream>
|
||||
static void Encode(OutputStream& os, unsigned codepoint) {
|
||||
template <typename OutputStream>
|
||||
static void Encode(OutputStream& os, unsigned codepoint)
|
||||
{
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputStream::Ch) >= 2);
|
||||
if (codepoint <= 0xFFFF) {
|
||||
RAPIDJSON_ASSERT(codepoint < 0xD800 || codepoint > 0xDFFF); // Code point itself cannot be surrogate pair
|
||||
if(codepoint <= 0xFFFF)
|
||||
{
|
||||
RAPIDJSON_ASSERT(codepoint < 0xD800 ||
|
||||
codepoint > 0xDFFF); // Code point itself cannot be surrogate pair
|
||||
os.Put(static_cast<typename OutputStream::Ch>(codepoint));
|
||||
}
|
||||
else {
|
||||
else
|
||||
{
|
||||
RAPIDJSON_ASSERT(codepoint <= 0x10FFFF);
|
||||
unsigned v = codepoint - 0x10000;
|
||||
os.Put(static_cast<typename OutputStream::Ch>((v >> 10) | 0xD800));
|
||||
@@ -287,15 +392,18 @@ struct UTF16 {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<typename OutputStream>
|
||||
static void EncodeUnsafe(OutputStream& os, unsigned codepoint) {
|
||||
template <typename OutputStream>
|
||||
static void EncodeUnsafe(OutputStream& os, unsigned codepoint)
|
||||
{
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputStream::Ch) >= 2);
|
||||
if (codepoint <= 0xFFFF) {
|
||||
RAPIDJSON_ASSERT(codepoint < 0xD800 || codepoint > 0xDFFF); // Code point itself cannot be surrogate pair
|
||||
if(codepoint <= 0xFFFF)
|
||||
{
|
||||
RAPIDJSON_ASSERT(codepoint < 0xD800 ||
|
||||
codepoint > 0xDFFF); // Code point itself cannot be surrogate pair
|
||||
PutUnsafe(os, static_cast<typename OutputStream::Ch>(codepoint));
|
||||
}
|
||||
else {
|
||||
else
|
||||
{
|
||||
RAPIDJSON_ASSERT(codepoint <= 0x10FFFF);
|
||||
unsigned v = codepoint - 0x10000;
|
||||
PutUnsafe(os, static_cast<typename OutputStream::Ch>((v >> 10) | 0xD800));
|
||||
@@ -304,16 +412,19 @@ struct UTF16 {
|
||||
}
|
||||
|
||||
template <typename InputStream>
|
||||
static bool Decode(InputStream& is, unsigned* codepoint) {
|
||||
static bool Decode(InputStream& is, unsigned* codepoint)
|
||||
{
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(typename InputStream::Ch) >= 2);
|
||||
typename InputStream::Ch c = is.Take();
|
||||
if (c < 0xD800 || c > 0xDFFF) {
|
||||
if(c < 0xD800 || c > 0xDFFF)
|
||||
{
|
||||
*codepoint = static_cast<unsigned>(c);
|
||||
return true;
|
||||
}
|
||||
else if (c <= 0xDBFF) {
|
||||
else if(c <= 0xDBFF)
|
||||
{
|
||||
*codepoint = (static_cast<unsigned>(c) & 0x3FF) << 10;
|
||||
c = is.Take();
|
||||
c = is.Take();
|
||||
*codepoint |= (static_cast<unsigned>(c) & 0x3FF);
|
||||
*codepoint += 0x10000;
|
||||
return c >= 0xDC00 && c <= 0xDFFF;
|
||||
@@ -322,14 +433,16 @@ struct UTF16 {
|
||||
}
|
||||
|
||||
template <typename InputStream, typename OutputStream>
|
||||
static bool Validate(InputStream& is, OutputStream& os) {
|
||||
static bool Validate(InputStream& is, OutputStream& os)
|
||||
{
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(typename InputStream::Ch) >= 2);
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputStream::Ch) >= 2);
|
||||
typename InputStream::Ch c;
|
||||
os.Put(static_cast<typename OutputStream::Ch>(c = is.Take()));
|
||||
if (c < 0xD800 || c > 0xDFFF)
|
||||
if(c < 0xD800 || c > 0xDFFF)
|
||||
return true;
|
||||
else if (c <= 0xDBFF) {
|
||||
else if(c <= 0xDBFF)
|
||||
{
|
||||
os.Put(c = is.Take());
|
||||
return c >= 0xDC00 && c <= 0xDFFF;
|
||||
}
|
||||
@@ -338,17 +451,20 @@ struct UTF16 {
|
||||
};
|
||||
|
||||
//! UTF-16 little endian encoding.
|
||||
template<typename CharType = wchar_t>
|
||||
struct UTF16LE : UTF16<CharType> {
|
||||
template <typename CharType = wchar_t>
|
||||
struct UTF16LE : UTF16<CharType>
|
||||
{
|
||||
template <typename InputByteStream>
|
||||
static CharType TakeBOM(InputByteStream& is) {
|
||||
static CharType TakeBOM(InputByteStream& is)
|
||||
{
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1);
|
||||
CharType c = Take(is);
|
||||
return static_cast<uint16_t>(c) == 0xFEFFu ? Take(is) : c;
|
||||
}
|
||||
|
||||
template <typename InputByteStream>
|
||||
static CharType Take(InputByteStream& is) {
|
||||
static CharType Take(InputByteStream& is)
|
||||
{
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1);
|
||||
unsigned c = static_cast<uint8_t>(is.Take());
|
||||
c |= static_cast<unsigned>(static_cast<uint8_t>(is.Take())) << 8;
|
||||
@@ -356,14 +472,16 @@ struct UTF16LE : UTF16<CharType> {
|
||||
}
|
||||
|
||||
template <typename OutputByteStream>
|
||||
static void PutBOM(OutputByteStream& os) {
|
||||
static void PutBOM(OutputByteStream& os)
|
||||
{
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1);
|
||||
os.Put(static_cast<typename OutputByteStream::Ch>(0xFFu));
|
||||
os.Put(static_cast<typename OutputByteStream::Ch>(0xFEu));
|
||||
}
|
||||
|
||||
template <typename OutputByteStream>
|
||||
static void Put(OutputByteStream& os, CharType c) {
|
||||
static void Put(OutputByteStream& os, CharType c)
|
||||
{
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1);
|
||||
os.Put(static_cast<typename OutputByteStream::Ch>(static_cast<unsigned>(c) & 0xFFu));
|
||||
os.Put(static_cast<typename OutputByteStream::Ch>((static_cast<unsigned>(c) >> 8) & 0xFFu));
|
||||
@@ -371,17 +489,20 @@ struct UTF16LE : UTF16<CharType> {
|
||||
};
|
||||
|
||||
//! UTF-16 big endian encoding.
|
||||
template<typename CharType = wchar_t>
|
||||
struct UTF16BE : UTF16<CharType> {
|
||||
template <typename CharType = wchar_t>
|
||||
struct UTF16BE : UTF16<CharType>
|
||||
{
|
||||
template <typename InputByteStream>
|
||||
static CharType TakeBOM(InputByteStream& is) {
|
||||
static CharType TakeBOM(InputByteStream& is)
|
||||
{
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1);
|
||||
CharType c = Take(is);
|
||||
return static_cast<uint16_t>(c) == 0xFEFFu ? Take(is) : c;
|
||||
}
|
||||
|
||||
template <typename InputByteStream>
|
||||
static CharType Take(InputByteStream& is) {
|
||||
static CharType Take(InputByteStream& is)
|
||||
{
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1);
|
||||
unsigned c = static_cast<unsigned>(static_cast<uint8_t>(is.Take())) << 8;
|
||||
c |= static_cast<unsigned>(static_cast<uint8_t>(is.Take()));
|
||||
@@ -389,14 +510,16 @@ struct UTF16BE : UTF16<CharType> {
|
||||
}
|
||||
|
||||
template <typename OutputByteStream>
|
||||
static void PutBOM(OutputByteStream& os) {
|
||||
static void PutBOM(OutputByteStream& os)
|
||||
{
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1);
|
||||
os.Put(static_cast<typename OutputByteStream::Ch>(0xFEu));
|
||||
os.Put(static_cast<typename OutputByteStream::Ch>(0xFFu));
|
||||
}
|
||||
|
||||
template <typename OutputByteStream>
|
||||
static void Put(OutputByteStream& os, CharType c) {
|
||||
static void Put(OutputByteStream& os, CharType c)
|
||||
{
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1);
|
||||
os.Put(static_cast<typename OutputByteStream::Ch>((static_cast<unsigned>(c) >> 8) & 0xFFu));
|
||||
os.Put(static_cast<typename OutputByteStream::Ch>(static_cast<unsigned>(c) & 0xFFu));
|
||||
@@ -406,45 +529,53 @@ struct UTF16BE : UTF16<CharType> {
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// UTF32
|
||||
|
||||
//! UTF-32 encoding.
|
||||
//! UTF-32 encoding.
|
||||
/*! http://en.wikipedia.org/wiki/UTF-32
|
||||
\tparam CharType Type for storing 32-bit UTF-32 data. Default is unsigned. C++11 may use char32_t instead.
|
||||
\note implements Encoding concept
|
||||
\tparam CharType Type for storing 32-bit UTF-32 data. Default is unsigned. C++11 may use
|
||||
char32_t instead. \note implements Encoding concept
|
||||
|
||||
\note For in-memory access, no need to concern endianness. The code units and code points are represented by CPU's endianness.
|
||||
For streaming, use UTF32LE and UTF32BE, which handle endianness.
|
||||
\note For in-memory access, no need to concern endianness. The code units and code points are
|
||||
represented by CPU's endianness. For streaming, use UTF32LE and UTF32BE, which handle endianness.
|
||||
*/
|
||||
template<typename CharType = unsigned>
|
||||
struct UTF32 {
|
||||
template <typename CharType = unsigned>
|
||||
struct UTF32
|
||||
{
|
||||
typedef CharType Ch;
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(Ch) >= 4);
|
||||
|
||||
enum { supportUnicode = 1 };
|
||||
enum
|
||||
{
|
||||
supportUnicode = 1
|
||||
};
|
||||
|
||||
template<typename OutputStream>
|
||||
static void Encode(OutputStream& os, unsigned codepoint) {
|
||||
template <typename OutputStream>
|
||||
static void Encode(OutputStream& os, unsigned codepoint)
|
||||
{
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputStream::Ch) >= 4);
|
||||
RAPIDJSON_ASSERT(codepoint <= 0x10FFFF);
|
||||
os.Put(codepoint);
|
||||
}
|
||||
|
||||
template<typename OutputStream>
|
||||
static void EncodeUnsafe(OutputStream& os, unsigned codepoint) {
|
||||
template <typename OutputStream>
|
||||
static void EncodeUnsafe(OutputStream& os, unsigned codepoint)
|
||||
{
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputStream::Ch) >= 4);
|
||||
RAPIDJSON_ASSERT(codepoint <= 0x10FFFF);
|
||||
PutUnsafe(os, codepoint);
|
||||
}
|
||||
|
||||
template <typename InputStream>
|
||||
static bool Decode(InputStream& is, unsigned* codepoint) {
|
||||
static bool Decode(InputStream& is, unsigned* codepoint)
|
||||
{
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(typename InputStream::Ch) >= 4);
|
||||
Ch c = is.Take();
|
||||
Ch c = is.Take();
|
||||
*codepoint = c;
|
||||
return c <= 0x10FFFF;
|
||||
}
|
||||
|
||||
template <typename InputStream, typename OutputStream>
|
||||
static bool Validate(InputStream& is, OutputStream& os) {
|
||||
static bool Validate(InputStream& is, OutputStream& os)
|
||||
{
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(typename InputStream::Ch) >= 4);
|
||||
Ch c;
|
||||
os.Put(c = is.Take());
|
||||
@@ -453,17 +584,20 @@ struct UTF32 {
|
||||
};
|
||||
|
||||
//! UTF-32 little endian enocoding.
|
||||
template<typename CharType = unsigned>
|
||||
struct UTF32LE : UTF32<CharType> {
|
||||
template <typename CharType = unsigned>
|
||||
struct UTF32LE : UTF32<CharType>
|
||||
{
|
||||
template <typename InputByteStream>
|
||||
static CharType TakeBOM(InputByteStream& is) {
|
||||
static CharType TakeBOM(InputByteStream& is)
|
||||
{
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1);
|
||||
CharType c = Take(is);
|
||||
return static_cast<uint32_t>(c) == 0x0000FEFFu ? Take(is) : c;
|
||||
}
|
||||
|
||||
template <typename InputByteStream>
|
||||
static CharType Take(InputByteStream& is) {
|
||||
static CharType Take(InputByteStream& is)
|
||||
{
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1);
|
||||
unsigned c = static_cast<uint8_t>(is.Take());
|
||||
c |= static_cast<unsigned>(static_cast<uint8_t>(is.Take())) << 8;
|
||||
@@ -473,7 +607,8 @@ struct UTF32LE : UTF32<CharType> {
|
||||
}
|
||||
|
||||
template <typename OutputByteStream>
|
||||
static void PutBOM(OutputByteStream& os) {
|
||||
static void PutBOM(OutputByteStream& os)
|
||||
{
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1);
|
||||
os.Put(static_cast<typename OutputByteStream::Ch>(0xFFu));
|
||||
os.Put(static_cast<typename OutputByteStream::Ch>(0xFEu));
|
||||
@@ -482,7 +617,8 @@ struct UTF32LE : UTF32<CharType> {
|
||||
}
|
||||
|
||||
template <typename OutputByteStream>
|
||||
static void Put(OutputByteStream& os, CharType c) {
|
||||
static void Put(OutputByteStream& os, CharType c)
|
||||
{
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1);
|
||||
os.Put(static_cast<typename OutputByteStream::Ch>(c & 0xFFu));
|
||||
os.Put(static_cast<typename OutputByteStream::Ch>((c >> 8) & 0xFFu));
|
||||
@@ -492,17 +628,20 @@ struct UTF32LE : UTF32<CharType> {
|
||||
};
|
||||
|
||||
//! UTF-32 big endian encoding.
|
||||
template<typename CharType = unsigned>
|
||||
struct UTF32BE : UTF32<CharType> {
|
||||
template <typename CharType = unsigned>
|
||||
struct UTF32BE : UTF32<CharType>
|
||||
{
|
||||
template <typename InputByteStream>
|
||||
static CharType TakeBOM(InputByteStream& is) {
|
||||
static CharType TakeBOM(InputByteStream& is)
|
||||
{
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1);
|
||||
CharType c = Take(is);
|
||||
return static_cast<uint32_t>(c) == 0x0000FEFFu ? Take(is) : c;
|
||||
return static_cast<uint32_t>(c) == 0x0000FEFFu ? Take(is) : c;
|
||||
}
|
||||
|
||||
template <typename InputByteStream>
|
||||
static CharType Take(InputByteStream& is) {
|
||||
static CharType Take(InputByteStream& is)
|
||||
{
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1);
|
||||
unsigned c = static_cast<unsigned>(static_cast<uint8_t>(is.Take())) << 24;
|
||||
c |= static_cast<unsigned>(static_cast<uint8_t>(is.Take())) << 16;
|
||||
@@ -512,7 +651,8 @@ struct UTF32BE : UTF32<CharType> {
|
||||
}
|
||||
|
||||
template <typename OutputByteStream>
|
||||
static void PutBOM(OutputByteStream& os) {
|
||||
static void PutBOM(OutputByteStream& os)
|
||||
{
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1);
|
||||
os.Put(static_cast<typename OutputByteStream::Ch>(0x00u));
|
||||
os.Put(static_cast<typename OutputByteStream::Ch>(0x00u));
|
||||
@@ -521,7 +661,8 @@ struct UTF32BE : UTF32<CharType> {
|
||||
}
|
||||
|
||||
template <typename OutputByteStream>
|
||||
static void Put(OutputByteStream& os, CharType c) {
|
||||
static void Put(OutputByteStream& os, CharType c)
|
||||
{
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1);
|
||||
os.Put(static_cast<typename OutputByteStream::Ch>((c >> 24) & 0xFFu));
|
||||
os.Put(static_cast<typename OutputByteStream::Ch>((c >> 16) & 0xFFu));
|
||||
@@ -538,59 +679,71 @@ struct UTF32BE : UTF32<CharType> {
|
||||
\tparam CharType Code unit for storing 7-bit ASCII data. Default is char.
|
||||
\note implements Encoding concept
|
||||
*/
|
||||
template<typename CharType = char>
|
||||
struct ASCII {
|
||||
template <typename CharType = char>
|
||||
struct ASCII
|
||||
{
|
||||
typedef CharType Ch;
|
||||
|
||||
enum { supportUnicode = 0 };
|
||||
enum
|
||||
{
|
||||
supportUnicode = 0
|
||||
};
|
||||
|
||||
template<typename OutputStream>
|
||||
static void Encode(OutputStream& os, unsigned codepoint) {
|
||||
template <typename OutputStream>
|
||||
static void Encode(OutputStream& os, unsigned codepoint)
|
||||
{
|
||||
RAPIDJSON_ASSERT(codepoint <= 0x7F);
|
||||
os.Put(static_cast<Ch>(codepoint & 0xFF));
|
||||
}
|
||||
|
||||
template<typename OutputStream>
|
||||
static void EncodeUnsafe(OutputStream& os, unsigned codepoint) {
|
||||
template <typename OutputStream>
|
||||
static void EncodeUnsafe(OutputStream& os, unsigned codepoint)
|
||||
{
|
||||
RAPIDJSON_ASSERT(codepoint <= 0x7F);
|
||||
PutUnsafe(os, static_cast<Ch>(codepoint & 0xFF));
|
||||
}
|
||||
|
||||
template <typename InputStream>
|
||||
static bool Decode(InputStream& is, unsigned* codepoint) {
|
||||
uint8_t c = static_cast<uint8_t>(is.Take());
|
||||
static bool Decode(InputStream& is, unsigned* codepoint)
|
||||
{
|
||||
uint8_t c = static_cast<uint8_t>(is.Take());
|
||||
*codepoint = c;
|
||||
return c <= 0X7F;
|
||||
}
|
||||
|
||||
template <typename InputStream, typename OutputStream>
|
||||
static bool Validate(InputStream& is, OutputStream& os) {
|
||||
static bool Validate(InputStream& is, OutputStream& os)
|
||||
{
|
||||
uint8_t c = static_cast<uint8_t>(is.Take());
|
||||
os.Put(static_cast<typename OutputStream::Ch>(c));
|
||||
return c <= 0x7F;
|
||||
}
|
||||
|
||||
template <typename InputByteStream>
|
||||
static CharType TakeBOM(InputByteStream& is) {
|
||||
static CharType TakeBOM(InputByteStream& is)
|
||||
{
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1);
|
||||
uint8_t c = static_cast<uint8_t>(Take(is));
|
||||
return static_cast<Ch>(c);
|
||||
}
|
||||
|
||||
template <typename InputByteStream>
|
||||
static Ch Take(InputByteStream& is) {
|
||||
static Ch Take(InputByteStream& is)
|
||||
{
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1);
|
||||
return static_cast<Ch>(is.Take());
|
||||
}
|
||||
|
||||
template <typename OutputByteStream>
|
||||
static void PutBOM(OutputByteStream& os) {
|
||||
static void PutBOM(OutputByteStream& os)
|
||||
{
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1);
|
||||
(void)os;
|
||||
}
|
||||
|
||||
template <typename OutputByteStream>
|
||||
static void Put(OutputByteStream& os, Ch c) {
|
||||
static void Put(OutputByteStream& os, Ch c)
|
||||
{
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1);
|
||||
os.Put(static_cast<typename OutputByteStream::Ch>(c));
|
||||
}
|
||||
@@ -600,50 +753,61 @@ struct ASCII {
|
||||
// AutoUTF
|
||||
|
||||
//! Runtime-specified UTF encoding type of a stream.
|
||||
enum UTFType {
|
||||
kUTF8 = 0, //!< UTF-8.
|
||||
kUTF16LE = 1, //!< UTF-16 little endian.
|
||||
kUTF16BE = 2, //!< UTF-16 big endian.
|
||||
kUTF32LE = 3, //!< UTF-32 little endian.
|
||||
kUTF32BE = 4 //!< UTF-32 big endian.
|
||||
enum UTFType
|
||||
{
|
||||
kUTF8 = 0, //!< UTF-8.
|
||||
kUTF16LE = 1, //!< UTF-16 little endian.
|
||||
kUTF16BE = 2, //!< UTF-16 big endian.
|
||||
kUTF32LE = 3, //!< UTF-32 little endian.
|
||||
kUTF32BE = 4 //!< UTF-32 big endian.
|
||||
};
|
||||
|
||||
//! Dynamically select encoding according to stream's runtime-specified UTF encoding type.
|
||||
/*! \note This class can be used with AutoUTFInputtStream and AutoUTFOutputStream, which provides GetType().
|
||||
*/
|
||||
template<typename CharType>
|
||||
struct AutoUTF {
|
||||
/*! \note This class can be used with AutoUTFInputtStream and AutoUTFOutputStream, which provides
|
||||
* GetType().
|
||||
*/
|
||||
template <typename CharType>
|
||||
struct AutoUTF
|
||||
{
|
||||
typedef CharType Ch;
|
||||
|
||||
enum { supportUnicode = 1 };
|
||||
enum
|
||||
{
|
||||
supportUnicode = 1
|
||||
};
|
||||
|
||||
#define RAPIDJSON_ENCODINGS_FUNC(x) UTF8<Ch>::x, UTF16LE<Ch>::x, UTF16BE<Ch>::x, UTF32LE<Ch>::x, UTF32BE<Ch>::x
|
||||
#define RAPIDJSON_ENCODINGS_FUNC(x) \
|
||||
UTF8<Ch>::x, UTF16LE<Ch>::x, UTF16BE<Ch>::x, UTF32LE<Ch>::x, UTF32BE<Ch>::x
|
||||
|
||||
template<typename OutputStream>
|
||||
static RAPIDJSON_FORCEINLINE void Encode(OutputStream& os, unsigned codepoint) {
|
||||
template <typename OutputStream>
|
||||
static RAPIDJSON_FORCEINLINE void Encode(OutputStream& os, unsigned codepoint)
|
||||
{
|
||||
typedef void (*EncodeFunc)(OutputStream&, unsigned);
|
||||
static const EncodeFunc f[] = { RAPIDJSON_ENCODINGS_FUNC(Encode) };
|
||||
static const EncodeFunc f[] = {RAPIDJSON_ENCODINGS_FUNC(Encode)};
|
||||
(*f[os.GetType()])(os, codepoint);
|
||||
}
|
||||
|
||||
template<typename OutputStream>
|
||||
static RAPIDJSON_FORCEINLINE void EncodeUnsafe(OutputStream& os, unsigned codepoint) {
|
||||
template <typename OutputStream>
|
||||
static RAPIDJSON_FORCEINLINE void EncodeUnsafe(OutputStream& os, unsigned codepoint)
|
||||
{
|
||||
typedef void (*EncodeFunc)(OutputStream&, unsigned);
|
||||
static const EncodeFunc f[] = { RAPIDJSON_ENCODINGS_FUNC(EncodeUnsafe) };
|
||||
static const EncodeFunc f[] = {RAPIDJSON_ENCODINGS_FUNC(EncodeUnsafe)};
|
||||
(*f[os.GetType()])(os, codepoint);
|
||||
}
|
||||
|
||||
template <typename InputStream>
|
||||
static RAPIDJSON_FORCEINLINE bool Decode(InputStream& is, unsigned* codepoint) {
|
||||
static RAPIDJSON_FORCEINLINE bool Decode(InputStream& is, unsigned* codepoint)
|
||||
{
|
||||
typedef bool (*DecodeFunc)(InputStream&, unsigned*);
|
||||
static const DecodeFunc f[] = { RAPIDJSON_ENCODINGS_FUNC(Decode) };
|
||||
static const DecodeFunc f[] = {RAPIDJSON_ENCODINGS_FUNC(Decode)};
|
||||
return (*f[is.GetType()])(is, codepoint);
|
||||
}
|
||||
|
||||
template <typename InputStream, typename OutputStream>
|
||||
static RAPIDJSON_FORCEINLINE bool Validate(InputStream& is, OutputStream& os) {
|
||||
static RAPIDJSON_FORCEINLINE bool Validate(InputStream& is, OutputStream& os)
|
||||
{
|
||||
typedef bool (*ValidateFunc)(InputStream&, OutputStream&);
|
||||
static const ValidateFunc f[] = { RAPIDJSON_ENCODINGS_FUNC(Validate) };
|
||||
static const ValidateFunc f[] = {RAPIDJSON_ENCODINGS_FUNC(Validate)};
|
||||
return (*f[is.GetType()])(is, os);
|
||||
}
|
||||
|
||||
@@ -654,56 +818,67 @@ struct AutoUTF {
|
||||
// Transcoder
|
||||
|
||||
//! Encoding conversion.
|
||||
template<typename SourceEncoding, typename TargetEncoding>
|
||||
struct Transcoder {
|
||||
//! Take one Unicode codepoint from source encoding, convert it to target encoding and put it to the output stream.
|
||||
template<typename InputStream, typename OutputStream>
|
||||
static RAPIDJSON_FORCEINLINE bool Transcode(InputStream& is, OutputStream& os) {
|
||||
template <typename SourceEncoding, typename TargetEncoding>
|
||||
struct Transcoder
|
||||
{
|
||||
//! Take one Unicode codepoint from source encoding, convert it to target encoding and put it to
|
||||
//! the output stream.
|
||||
template <typename InputStream, typename OutputStream>
|
||||
static RAPIDJSON_FORCEINLINE bool Transcode(InputStream& is, OutputStream& os)
|
||||
{
|
||||
unsigned codepoint;
|
||||
if (!SourceEncoding::Decode(is, &codepoint))
|
||||
if(!SourceEncoding::Decode(is, &codepoint))
|
||||
return false;
|
||||
TargetEncoding::Encode(os, codepoint);
|
||||
return true;
|
||||
}
|
||||
|
||||
template<typename InputStream, typename OutputStream>
|
||||
static RAPIDJSON_FORCEINLINE bool TranscodeUnsafe(InputStream& is, OutputStream& os) {
|
||||
template <typename InputStream, typename OutputStream>
|
||||
static RAPIDJSON_FORCEINLINE bool TranscodeUnsafe(InputStream& is, OutputStream& os)
|
||||
{
|
||||
unsigned codepoint;
|
||||
if (!SourceEncoding::Decode(is, &codepoint))
|
||||
if(!SourceEncoding::Decode(is, &codepoint))
|
||||
return false;
|
||||
TargetEncoding::EncodeUnsafe(os, codepoint);
|
||||
return true;
|
||||
}
|
||||
|
||||
//! Validate one Unicode codepoint from an encoded stream.
|
||||
template<typename InputStream, typename OutputStream>
|
||||
static RAPIDJSON_FORCEINLINE bool Validate(InputStream& is, OutputStream& os) {
|
||||
return Transcode(is, os); // Since source/target encoding is different, must transcode.
|
||||
template <typename InputStream, typename OutputStream>
|
||||
static RAPIDJSON_FORCEINLINE bool Validate(InputStream& is, OutputStream& os)
|
||||
{
|
||||
return Transcode(is, os); // Since source/target encoding is different, must transcode.
|
||||
}
|
||||
};
|
||||
|
||||
// Forward declaration.
|
||||
template<typename Stream>
|
||||
template <typename Stream>
|
||||
inline void PutUnsafe(Stream& stream, typename Stream::Ch c);
|
||||
|
||||
//! Specialization of Transcoder with same source and target encoding.
|
||||
template<typename Encoding>
|
||||
struct Transcoder<Encoding, Encoding> {
|
||||
template<typename InputStream, typename OutputStream>
|
||||
static RAPIDJSON_FORCEINLINE bool Transcode(InputStream& is, OutputStream& os) {
|
||||
os.Put(is.Take()); // Just copy one code unit. This semantic is different from primary template class.
|
||||
template <typename Encoding>
|
||||
struct Transcoder<Encoding, Encoding>
|
||||
{
|
||||
template <typename InputStream, typename OutputStream>
|
||||
static RAPIDJSON_FORCEINLINE bool Transcode(InputStream& is, OutputStream& os)
|
||||
{
|
||||
os.Put(is.Take()); // Just copy one code unit. This semantic is different from primary
|
||||
// template class.
|
||||
return true;
|
||||
}
|
||||
|
||||
template<typename InputStream, typename OutputStream>
|
||||
static RAPIDJSON_FORCEINLINE bool TranscodeUnsafe(InputStream& is, OutputStream& os) {
|
||||
PutUnsafe(os, is.Take()); // Just copy one code unit. This semantic is different from primary template class.
|
||||
|
||||
template <typename InputStream, typename OutputStream>
|
||||
static RAPIDJSON_FORCEINLINE bool TranscodeUnsafe(InputStream& is, OutputStream& os)
|
||||
{
|
||||
PutUnsafe(os, is.Take()); // Just copy one code unit. This semantic is different from
|
||||
// primary template class.
|
||||
return true;
|
||||
}
|
||||
|
||||
template<typename InputStream, typename OutputStream>
|
||||
static RAPIDJSON_FORCEINLINE bool Validate(InputStream& is, OutputStream& os) {
|
||||
return Encoding::Validate(is, os); // source/target encoding are the same
|
||||
|
||||
template <typename InputStream, typename OutputStream>
|
||||
static RAPIDJSON_FORCEINLINE bool Validate(InputStream& is, OutputStream& os)
|
||||
{
|
||||
return Encoding::Validate(is, os); // source/target encoding are the same
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -19,8 +19,8 @@
|
||||
|
||||
#ifdef __clang__
|
||||
RAPIDJSON_DIAG_PUSH
|
||||
RAPIDJSON_DIAG_OFF(switch-enum)
|
||||
RAPIDJSON_DIAG_OFF(covered-switch-default)
|
||||
RAPIDJSON_DIAG_OFF(switch - enum)
|
||||
RAPIDJSON_DIAG_OFF(covered - switch - default)
|
||||
#endif
|
||||
|
||||
RAPIDJSON_NAMESPACE_BEGIN
|
||||
@@ -33,35 +33,51 @@ RAPIDJSON_NAMESPACE_BEGIN
|
||||
\note User can make a copy of this function for localization.
|
||||
Using switch-case is safer for future modification of error codes.
|
||||
*/
|
||||
inline const RAPIDJSON_ERROR_CHARTYPE* GetParseError_En(ParseErrorCode parseErrorCode) {
|
||||
switch (parseErrorCode) {
|
||||
case kParseErrorNone: return RAPIDJSON_ERROR_STRING("No error.");
|
||||
inline const RAPIDJSON_ERROR_CHARTYPE* GetParseError_En(ParseErrorCode parseErrorCode)
|
||||
{
|
||||
switch(parseErrorCode)
|
||||
{
|
||||
case kParseErrorNone: return RAPIDJSON_ERROR_STRING("No error.");
|
||||
|
||||
case kParseErrorDocumentEmpty: return RAPIDJSON_ERROR_STRING("The document is empty.");
|
||||
case kParseErrorDocumentRootNotSingular: return RAPIDJSON_ERROR_STRING("The document root must not be followed by other values.");
|
||||
case kParseErrorDocumentEmpty: return RAPIDJSON_ERROR_STRING("The document is empty.");
|
||||
case kParseErrorDocumentRootNotSingular:
|
||||
return RAPIDJSON_ERROR_STRING("The document root must not be followed by other values.");
|
||||
|
||||
case kParseErrorValueInvalid: return RAPIDJSON_ERROR_STRING("Invalid value.");
|
||||
case kParseErrorValueInvalid: return RAPIDJSON_ERROR_STRING("Invalid value.");
|
||||
|
||||
case kParseErrorObjectMissName: return RAPIDJSON_ERROR_STRING("Missing a name for object member.");
|
||||
case kParseErrorObjectMissColon: return RAPIDJSON_ERROR_STRING("Missing a colon after a name of object member.");
|
||||
case kParseErrorObjectMissCommaOrCurlyBracket: return RAPIDJSON_ERROR_STRING("Missing a comma or '}' after an object member.");
|
||||
case kParseErrorObjectMissName:
|
||||
return RAPIDJSON_ERROR_STRING("Missing a name for object member.");
|
||||
case kParseErrorObjectMissColon:
|
||||
return RAPIDJSON_ERROR_STRING("Missing a colon after a name of object member.");
|
||||
case kParseErrorObjectMissCommaOrCurlyBracket:
|
||||
return RAPIDJSON_ERROR_STRING("Missing a comma or '}' after an object member.");
|
||||
|
||||
case kParseErrorArrayMissCommaOrSquareBracket: return RAPIDJSON_ERROR_STRING("Missing a comma or ']' after an array element.");
|
||||
case kParseErrorArrayMissCommaOrSquareBracket:
|
||||
return RAPIDJSON_ERROR_STRING("Missing a comma or ']' after an array element.");
|
||||
|
||||
case kParseErrorStringUnicodeEscapeInvalidHex: return RAPIDJSON_ERROR_STRING("Incorrect hex digit after \\u escape in string.");
|
||||
case kParseErrorStringUnicodeSurrogateInvalid: return RAPIDJSON_ERROR_STRING("The surrogate pair in string is invalid.");
|
||||
case kParseErrorStringEscapeInvalid: return RAPIDJSON_ERROR_STRING("Invalid escape character in string.");
|
||||
case kParseErrorStringMissQuotationMark: return RAPIDJSON_ERROR_STRING("Missing a closing quotation mark in string.");
|
||||
case kParseErrorStringInvalidEncoding: return RAPIDJSON_ERROR_STRING("Invalid encoding in string.");
|
||||
case kParseErrorStringUnicodeEscapeInvalidHex:
|
||||
return RAPIDJSON_ERROR_STRING("Incorrect hex digit after \\u escape in string.");
|
||||
case kParseErrorStringUnicodeSurrogateInvalid:
|
||||
return RAPIDJSON_ERROR_STRING("The surrogate pair in string is invalid.");
|
||||
case kParseErrorStringEscapeInvalid:
|
||||
return RAPIDJSON_ERROR_STRING("Invalid escape character in string.");
|
||||
case kParseErrorStringMissQuotationMark:
|
||||
return RAPIDJSON_ERROR_STRING("Missing a closing quotation mark in string.");
|
||||
case kParseErrorStringInvalidEncoding:
|
||||
return RAPIDJSON_ERROR_STRING("Invalid encoding in string.");
|
||||
|
||||
case kParseErrorNumberTooBig: return RAPIDJSON_ERROR_STRING("Number too big to be stored in double.");
|
||||
case kParseErrorNumberMissFraction: return RAPIDJSON_ERROR_STRING("Miss fraction part in number.");
|
||||
case kParseErrorNumberMissExponent: return RAPIDJSON_ERROR_STRING("Miss exponent in number.");
|
||||
case kParseErrorNumberTooBig:
|
||||
return RAPIDJSON_ERROR_STRING("Number too big to be stored in double.");
|
||||
case kParseErrorNumberMissFraction:
|
||||
return RAPIDJSON_ERROR_STRING("Miss fraction part in number.");
|
||||
case kParseErrorNumberMissExponent: return RAPIDJSON_ERROR_STRING("Miss exponent in number.");
|
||||
|
||||
case kParseErrorTermination: return RAPIDJSON_ERROR_STRING("Terminate parsing due to Handler error.");
|
||||
case kParseErrorUnspecificSyntaxError: return RAPIDJSON_ERROR_STRING("Unspecific syntax error.");
|
||||
case kParseErrorTermination:
|
||||
return RAPIDJSON_ERROR_STRING("Terminate parsing due to Handler error.");
|
||||
case kParseErrorUnspecificSyntaxError:
|
||||
return RAPIDJSON_ERROR_STRING("Unspecific syntax error.");
|
||||
|
||||
default: return RAPIDJSON_ERROR_STRING("Unknown error.");
|
||||
default: return RAPIDJSON_ERROR_STRING("Unknown error.");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -73,46 +89,102 @@ inline const RAPIDJSON_ERROR_CHARTYPE* GetParseError_En(ParseErrorCode parseErro
|
||||
\note User can make a copy of this function for localization.
|
||||
Using switch-case is safer for future modification of error codes.
|
||||
*/
|
||||
inline const RAPIDJSON_ERROR_CHARTYPE* GetValidateError_En(ValidateErrorCode validateErrorCode) {
|
||||
switch (validateErrorCode) {
|
||||
case kValidateErrors: return RAPIDJSON_ERROR_STRING("One or more validation errors have occurred");
|
||||
case kValidateErrorNone: return RAPIDJSON_ERROR_STRING("No error.");
|
||||
inline const RAPIDJSON_ERROR_CHARTYPE* GetValidateError_En(ValidateErrorCode validateErrorCode)
|
||||
{
|
||||
switch(validateErrorCode)
|
||||
{
|
||||
case kValidateErrors:
|
||||
return RAPIDJSON_ERROR_STRING("One or more validation errors have occurred");
|
||||
case kValidateErrorNone: return RAPIDJSON_ERROR_STRING("No error.");
|
||||
|
||||
case kValidateErrorMultipleOf: return RAPIDJSON_ERROR_STRING("Number '%actual' is not a multiple of the 'multipleOf' value '%expected'.");
|
||||
case kValidateErrorMaximum: return RAPIDJSON_ERROR_STRING("Number '%actual' is greater than the 'maximum' value '%expected'.");
|
||||
case kValidateErrorExclusiveMaximum: return RAPIDJSON_ERROR_STRING("Number '%actual' is greater than or equal to the 'exclusiveMaximum' value '%expected'.");
|
||||
case kValidateErrorMinimum: return RAPIDJSON_ERROR_STRING("Number '%actual' is less than the 'minimum' value '%expected'.");
|
||||
case kValidateErrorExclusiveMinimum: return RAPIDJSON_ERROR_STRING("Number '%actual' is less than or equal to the 'exclusiveMinimum' value '%expected'.");
|
||||
case kValidateErrorMultipleOf:
|
||||
return RAPIDJSON_ERROR_STRING(
|
||||
"Number '%actual' is not a multiple of the 'multipleOf' value '%expected'.");
|
||||
case kValidateErrorMaximum:
|
||||
return RAPIDJSON_ERROR_STRING(
|
||||
"Number '%actual' is greater than the 'maximum' value '%expected'.");
|
||||
case kValidateErrorExclusiveMaximum:
|
||||
return RAPIDJSON_ERROR_STRING("Number '%actual' is greater than or equal to the "
|
||||
"'exclusiveMaximum' value '%expected'.");
|
||||
case kValidateErrorMinimum:
|
||||
return RAPIDJSON_ERROR_STRING(
|
||||
"Number '%actual' is less than the 'minimum' value '%expected'.");
|
||||
case kValidateErrorExclusiveMinimum:
|
||||
return RAPIDJSON_ERROR_STRING(
|
||||
"Number '%actual' is less than or equal to the 'exclusiveMinimum' value '%expected'.");
|
||||
|
||||
case kValidateErrorMaxLength: return RAPIDJSON_ERROR_STRING("String '%actual' is longer than the 'maxLength' value '%expected'.");
|
||||
case kValidateErrorMinLength: return RAPIDJSON_ERROR_STRING("String '%actual' is shorter than the 'minLength' value '%expected'.");
|
||||
case kValidateErrorPattern: return RAPIDJSON_ERROR_STRING("String '%actual' does not match the 'pattern' regular expression.");
|
||||
case kValidateErrorMaxLength:
|
||||
return RAPIDJSON_ERROR_STRING(
|
||||
"String '%actual' is longer than the 'maxLength' value '%expected'.");
|
||||
case kValidateErrorMinLength:
|
||||
return RAPIDJSON_ERROR_STRING(
|
||||
"String '%actual' is shorter than the 'minLength' value '%expected'.");
|
||||
case kValidateErrorPattern:
|
||||
return RAPIDJSON_ERROR_STRING(
|
||||
"String '%actual' does not match the 'pattern' regular expression.");
|
||||
|
||||
case kValidateErrorMaxItems: return RAPIDJSON_ERROR_STRING("Array of length '%actual' is longer than the 'maxItems' value '%expected'.");
|
||||
case kValidateErrorMinItems: return RAPIDJSON_ERROR_STRING("Array of length '%actual' is shorter than the 'minItems' value '%expected'.");
|
||||
case kValidateErrorUniqueItems: return RAPIDJSON_ERROR_STRING("Array has duplicate items at indices '%duplicates' but 'uniqueItems' is true.");
|
||||
case kValidateErrorAdditionalItems: return RAPIDJSON_ERROR_STRING("Array has an additional item at index '%disallowed' that is not allowed by the schema.");
|
||||
case kValidateErrorMaxItems:
|
||||
return RAPIDJSON_ERROR_STRING(
|
||||
"Array of length '%actual' is longer than the 'maxItems' value '%expected'.");
|
||||
case kValidateErrorMinItems:
|
||||
return RAPIDJSON_ERROR_STRING(
|
||||
"Array of length '%actual' is shorter than the 'minItems' value '%expected'.");
|
||||
case kValidateErrorUniqueItems:
|
||||
return RAPIDJSON_ERROR_STRING(
|
||||
"Array has duplicate items at indices '%duplicates' but 'uniqueItems' is true.");
|
||||
case kValidateErrorAdditionalItems:
|
||||
return RAPIDJSON_ERROR_STRING("Array has an additional item at index '%disallowed' that is "
|
||||
"not allowed by the schema.");
|
||||
|
||||
case kValidateErrorMaxProperties: return RAPIDJSON_ERROR_STRING("Object has '%actual' members which is more than 'maxProperties' value '%expected'.");
|
||||
case kValidateErrorMinProperties: return RAPIDJSON_ERROR_STRING("Object has '%actual' members which is less than 'minProperties' value '%expected'.");
|
||||
case kValidateErrorRequired: return RAPIDJSON_ERROR_STRING("Object is missing the following members required by the schema: '%missing'.");
|
||||
case kValidateErrorAdditionalProperties: return RAPIDJSON_ERROR_STRING("Object has an additional member '%disallowed' that is not allowed by the schema.");
|
||||
case kValidateErrorPatternProperties: return RAPIDJSON_ERROR_STRING("Object has 'patternProperties' that are not allowed by the schema.");
|
||||
case kValidateErrorDependencies: return RAPIDJSON_ERROR_STRING("Object has missing property or schema dependencies, refer to following errors.");
|
||||
case kValidateErrorMaxProperties:
|
||||
return RAPIDJSON_ERROR_STRING(
|
||||
"Object has '%actual' members which is more than 'maxProperties' value '%expected'.");
|
||||
case kValidateErrorMinProperties:
|
||||
return RAPIDJSON_ERROR_STRING(
|
||||
"Object has '%actual' members which is less than 'minProperties' value '%expected'.");
|
||||
case kValidateErrorRequired:
|
||||
return RAPIDJSON_ERROR_STRING(
|
||||
"Object is missing the following members required by the schema: '%missing'.");
|
||||
case kValidateErrorAdditionalProperties:
|
||||
return RAPIDJSON_ERROR_STRING(
|
||||
"Object has an additional member '%disallowed' that is not allowed by the schema.");
|
||||
case kValidateErrorPatternProperties:
|
||||
return RAPIDJSON_ERROR_STRING(
|
||||
"Object has 'patternProperties' that are not allowed by the schema.");
|
||||
case kValidateErrorDependencies:
|
||||
return RAPIDJSON_ERROR_STRING(
|
||||
"Object has missing property or schema dependencies, refer to following errors.");
|
||||
|
||||
case kValidateErrorEnum: return RAPIDJSON_ERROR_STRING("Property has a value that is not one of its allowed enumerated values.");
|
||||
case kValidateErrorType: return RAPIDJSON_ERROR_STRING("Property has a type '%actual' that is not in the following list: '%expected'.");
|
||||
case kValidateErrorEnum:
|
||||
return RAPIDJSON_ERROR_STRING(
|
||||
"Property has a value that is not one of its allowed enumerated values.");
|
||||
case kValidateErrorType:
|
||||
return RAPIDJSON_ERROR_STRING(
|
||||
"Property has a type '%actual' that is not in the following list: '%expected'.");
|
||||
|
||||
case kValidateErrorOneOf: return RAPIDJSON_ERROR_STRING("Property did not match any of the sub-schemas specified by 'oneOf', refer to following errors.");
|
||||
case kValidateErrorOneOfMatch: return RAPIDJSON_ERROR_STRING("Property matched more than one of the sub-schemas specified by 'oneOf', indices '%matches'.");
|
||||
case kValidateErrorAllOf: return RAPIDJSON_ERROR_STRING("Property did not match all of the sub-schemas specified by 'allOf', refer to following errors.");
|
||||
case kValidateErrorAnyOf: return RAPIDJSON_ERROR_STRING("Property did not match any of the sub-schemas specified by 'anyOf', refer to following errors.");
|
||||
case kValidateErrorNot: return RAPIDJSON_ERROR_STRING("Property matched the sub-schema specified by 'not'.");
|
||||
case kValidateErrorOneOf:
|
||||
return RAPIDJSON_ERROR_STRING("Property did not match any of the sub-schemas specified by "
|
||||
"'oneOf', refer to following errors.");
|
||||
case kValidateErrorOneOfMatch:
|
||||
return RAPIDJSON_ERROR_STRING("Property matched more than one of the sub-schemas specified "
|
||||
"by 'oneOf', indices '%matches'.");
|
||||
case kValidateErrorAllOf:
|
||||
return RAPIDJSON_ERROR_STRING("Property did not match all of the sub-schemas specified by "
|
||||
"'allOf', refer to following errors.");
|
||||
case kValidateErrorAnyOf:
|
||||
return RAPIDJSON_ERROR_STRING("Property did not match any of the sub-schemas specified by "
|
||||
"'anyOf', refer to following errors.");
|
||||
case kValidateErrorNot:
|
||||
return RAPIDJSON_ERROR_STRING("Property matched the sub-schema specified by 'not'.");
|
||||
|
||||
case kValidateErrorReadOnly: return RAPIDJSON_ERROR_STRING("Property is read-only but has been provided when validation is for writing.");
|
||||
case kValidateErrorWriteOnly: return RAPIDJSON_ERROR_STRING("Property is write-only but has been provided when validation is for reading.");
|
||||
case kValidateErrorReadOnly:
|
||||
return RAPIDJSON_ERROR_STRING(
|
||||
"Property is read-only but has been provided when validation is for writing.");
|
||||
case kValidateErrorWriteOnly:
|
||||
return RAPIDJSON_ERROR_STRING(
|
||||
"Property is write-only but has been provided when validation is for reading.");
|
||||
|
||||
default: return RAPIDJSON_ERROR_STRING("Unknown error.");
|
||||
default: return RAPIDJSON_ERROR_STRING("Unknown error.");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -124,27 +196,46 @@ inline const RAPIDJSON_ERROR_CHARTYPE* GetValidateError_En(ValidateErrorCode val
|
||||
\note User can make a copy of this function for localization.
|
||||
Using switch-case is safer for future modification of error codes.
|
||||
*/
|
||||
inline const RAPIDJSON_ERROR_CHARTYPE* GetSchemaError_En(SchemaErrorCode schemaErrorCode) {
|
||||
switch (schemaErrorCode) {
|
||||
case kSchemaErrorNone: return RAPIDJSON_ERROR_STRING("No error.");
|
||||
inline const RAPIDJSON_ERROR_CHARTYPE* GetSchemaError_En(SchemaErrorCode schemaErrorCode)
|
||||
{
|
||||
switch(schemaErrorCode)
|
||||
{
|
||||
case kSchemaErrorNone: return RAPIDJSON_ERROR_STRING("No error.");
|
||||
|
||||
case kSchemaErrorStartUnknown: return RAPIDJSON_ERROR_STRING("Pointer '%value' to start of schema does not resolve to a location in the document.");
|
||||
case kSchemaErrorRefPlainName: return RAPIDJSON_ERROR_STRING("$ref fragment '%value' must be a JSON pointer.");
|
||||
case kSchemaErrorRefInvalid: return RAPIDJSON_ERROR_STRING("$ref must not be an empty string.");
|
||||
case kSchemaErrorRefPointerInvalid: return RAPIDJSON_ERROR_STRING("$ref fragment '%value' is not a valid JSON pointer at offset '%offset'.");
|
||||
case kSchemaErrorRefUnknown: return RAPIDJSON_ERROR_STRING("$ref '%value' does not resolve to a location in the target document.");
|
||||
case kSchemaErrorRefCyclical: return RAPIDJSON_ERROR_STRING("$ref '%value' is cyclical.");
|
||||
case kSchemaErrorRefNoRemoteProvider: return RAPIDJSON_ERROR_STRING("$ref is remote but there is no remote provider.");
|
||||
case kSchemaErrorRefNoRemoteSchema: return RAPIDJSON_ERROR_STRING("$ref '%value' is remote but the remote provider did not return a schema.");
|
||||
case kSchemaErrorRegexInvalid: return RAPIDJSON_ERROR_STRING("Invalid regular expression '%value' in 'pattern' or 'patternProperties'.");
|
||||
case kSchemaErrorSpecUnknown: return RAPIDJSON_ERROR_STRING("JSON schema draft or OpenAPI version is not recognized.");
|
||||
case kSchemaErrorSpecUnsupported: return RAPIDJSON_ERROR_STRING("JSON schema draft or OpenAPI version is not supported.");
|
||||
case kSchemaErrorSpecIllegal: return RAPIDJSON_ERROR_STRING("Both JSON schema draft and OpenAPI version found in document.");
|
||||
case kSchemaErrorReadOnlyAndWriteOnly: return RAPIDJSON_ERROR_STRING("Property must not be both 'readOnly' and 'writeOnly'.");
|
||||
case kSchemaErrorStartUnknown:
|
||||
return RAPIDJSON_ERROR_STRING(
|
||||
"Pointer '%value' to start of schema does not resolve to a location in the document.");
|
||||
case kSchemaErrorRefPlainName:
|
||||
return RAPIDJSON_ERROR_STRING("$ref fragment '%value' must be a JSON pointer.");
|
||||
case kSchemaErrorRefInvalid: return RAPIDJSON_ERROR_STRING("$ref must not be an empty string.");
|
||||
case kSchemaErrorRefPointerInvalid:
|
||||
return RAPIDJSON_ERROR_STRING(
|
||||
"$ref fragment '%value' is not a valid JSON pointer at offset '%offset'.");
|
||||
case kSchemaErrorRefUnknown:
|
||||
return RAPIDJSON_ERROR_STRING(
|
||||
"$ref '%value' does not resolve to a location in the target document.");
|
||||
case kSchemaErrorRefCyclical: return RAPIDJSON_ERROR_STRING("$ref '%value' is cyclical.");
|
||||
case kSchemaErrorRefNoRemoteProvider:
|
||||
return RAPIDJSON_ERROR_STRING("$ref is remote but there is no remote provider.");
|
||||
case kSchemaErrorRefNoRemoteSchema:
|
||||
return RAPIDJSON_ERROR_STRING(
|
||||
"$ref '%value' is remote but the remote provider did not return a schema.");
|
||||
case kSchemaErrorRegexInvalid:
|
||||
return RAPIDJSON_ERROR_STRING(
|
||||
"Invalid regular expression '%value' in 'pattern' or 'patternProperties'.");
|
||||
case kSchemaErrorSpecUnknown:
|
||||
return RAPIDJSON_ERROR_STRING("JSON schema draft or OpenAPI version is not recognized.");
|
||||
case kSchemaErrorSpecUnsupported:
|
||||
return RAPIDJSON_ERROR_STRING("JSON schema draft or OpenAPI version is not supported.");
|
||||
case kSchemaErrorSpecIllegal:
|
||||
return RAPIDJSON_ERROR_STRING(
|
||||
"Both JSON schema draft and OpenAPI version found in document.");
|
||||
case kSchemaErrorReadOnlyAndWriteOnly:
|
||||
return RAPIDJSON_ERROR_STRING("Property must not be both 'readOnly' and 'writeOnly'.");
|
||||
|
||||
default: return RAPIDJSON_ERROR_STRING("Unknown error.");
|
||||
default: return RAPIDJSON_ERROR_STRING("Unknown error.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//! Maps error code of pointer parse into error message.
|
||||
/*!
|
||||
@@ -154,16 +245,22 @@ inline const RAPIDJSON_ERROR_CHARTYPE* GetValidateError_En(ValidateErrorCode val
|
||||
\note User can make a copy of this function for localization.
|
||||
Using switch-case is safer for future modification of error codes.
|
||||
*/
|
||||
inline const RAPIDJSON_ERROR_CHARTYPE* GetPointerParseError_En(PointerParseErrorCode pointerParseErrorCode) {
|
||||
switch (pointerParseErrorCode) {
|
||||
case kPointerParseErrorNone: return RAPIDJSON_ERROR_STRING("No error.");
|
||||
inline const RAPIDJSON_ERROR_CHARTYPE*
|
||||
GetPointerParseError_En(PointerParseErrorCode pointerParseErrorCode)
|
||||
{
|
||||
switch(pointerParseErrorCode)
|
||||
{
|
||||
case kPointerParseErrorNone: return RAPIDJSON_ERROR_STRING("No error.");
|
||||
|
||||
case kPointerParseErrorTokenMustBeginWithSolidus: return RAPIDJSON_ERROR_STRING("A token must begin with a '/'.");
|
||||
case kPointerParseErrorInvalidEscape: return RAPIDJSON_ERROR_STRING("Invalid escape.");
|
||||
case kPointerParseErrorInvalidPercentEncoding: return RAPIDJSON_ERROR_STRING("Invalid percent encoding in URI fragment.");
|
||||
case kPointerParseErrorCharacterMustPercentEncode: return RAPIDJSON_ERROR_STRING("A character must be percent encoded in a URI fragment.");
|
||||
case kPointerParseErrorTokenMustBeginWithSolidus:
|
||||
return RAPIDJSON_ERROR_STRING("A token must begin with a '/'.");
|
||||
case kPointerParseErrorInvalidEscape: return RAPIDJSON_ERROR_STRING("Invalid escape.");
|
||||
case kPointerParseErrorInvalidPercentEncoding:
|
||||
return RAPIDJSON_ERROR_STRING("Invalid percent encoding in URI fragment.");
|
||||
case kPointerParseErrorCharacterMustPercentEncode:
|
||||
return RAPIDJSON_ERROR_STRING("A character must be percent encoded in a URI fragment.");
|
||||
|
||||
default: return RAPIDJSON_ERROR_STRING("Unknown error.");
|
||||
default: return RAPIDJSON_ERROR_STRING("Unknown error.");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -61,32 +61,33 @@ RAPIDJSON_NAMESPACE_BEGIN
|
||||
/*! \ingroup RAPIDJSON_ERRORS
|
||||
\see GenericReader::Parse, GenericReader::GetParseErrorCode
|
||||
*/
|
||||
enum ParseErrorCode {
|
||||
kParseErrorNone = 0, //!< No error.
|
||||
enum ParseErrorCode
|
||||
{
|
||||
kParseErrorNone = 0, //!< No error.
|
||||
|
||||
kParseErrorDocumentEmpty, //!< The document is empty.
|
||||
kParseErrorDocumentRootNotSingular, //!< The document root must not follow by other values.
|
||||
kParseErrorDocumentEmpty, //!< The document is empty.
|
||||
kParseErrorDocumentRootNotSingular, //!< The document root must not follow by other values.
|
||||
|
||||
kParseErrorValueInvalid, //!< Invalid value.
|
||||
kParseErrorValueInvalid, //!< Invalid value.
|
||||
|
||||
kParseErrorObjectMissName, //!< Missing a name for object member.
|
||||
kParseErrorObjectMissColon, //!< Missing a colon after a name of object member.
|
||||
kParseErrorObjectMissCommaOrCurlyBracket, //!< Missing a comma or '}' after an object member.
|
||||
kParseErrorObjectMissName, //!< Missing a name for object member.
|
||||
kParseErrorObjectMissColon, //!< Missing a colon after a name of object member.
|
||||
kParseErrorObjectMissCommaOrCurlyBracket, //!< Missing a comma or '}' after an object member.
|
||||
|
||||
kParseErrorArrayMissCommaOrSquareBracket, //!< Missing a comma or ']' after an array element.
|
||||
kParseErrorArrayMissCommaOrSquareBracket, //!< Missing a comma or ']' after an array element.
|
||||
|
||||
kParseErrorStringUnicodeEscapeInvalidHex, //!< Incorrect hex digit after \\u escape in string.
|
||||
kParseErrorStringUnicodeSurrogateInvalid, //!< The surrogate pair in string is invalid.
|
||||
kParseErrorStringEscapeInvalid, //!< Invalid escape character in string.
|
||||
kParseErrorStringMissQuotationMark, //!< Missing a closing quotation mark in string.
|
||||
kParseErrorStringInvalidEncoding, //!< Invalid encoding in string.
|
||||
kParseErrorStringUnicodeEscapeInvalidHex, //!< Incorrect hex digit after \\u escape in string.
|
||||
kParseErrorStringUnicodeSurrogateInvalid, //!< The surrogate pair in string is invalid.
|
||||
kParseErrorStringEscapeInvalid, //!< Invalid escape character in string.
|
||||
kParseErrorStringMissQuotationMark, //!< Missing a closing quotation mark in string.
|
||||
kParseErrorStringInvalidEncoding, //!< Invalid encoding in string.
|
||||
|
||||
kParseErrorNumberTooBig, //!< Number too big to be stored in double.
|
||||
kParseErrorNumberMissFraction, //!< Miss fraction part in number.
|
||||
kParseErrorNumberMissExponent, //!< Miss exponent in number.
|
||||
kParseErrorNumberTooBig, //!< Number too big to be stored in double.
|
||||
kParseErrorNumberMissFraction, //!< Miss fraction part in number.
|
||||
kParseErrorNumberMissExponent, //!< Miss exponent in number.
|
||||
|
||||
kParseErrorTermination, //!< Parsing was terminated.
|
||||
kParseErrorUnspecificSyntaxError //!< Unspecific syntax error.
|
||||
kParseErrorTermination, //!< Parsing was terminated.
|
||||
kParseErrorUnspecificSyntaxError //!< Unspecific syntax error.
|
||||
};
|
||||
|
||||
//! Result of parsing (wraps ParseErrorCode)
|
||||
@@ -103,10 +104,12 @@ enum ParseErrorCode {
|
||||
\endcode
|
||||
\see GenericReader::Parse, GenericDocument::Parse
|
||||
*/
|
||||
struct ParseResult {
|
||||
struct ParseResult
|
||||
{
|
||||
//!! Unspecified boolean type
|
||||
typedef bool (ParseResult::*BooleanType)() const;
|
||||
public:
|
||||
|
||||
public:
|
||||
//! Default constructor, no error.
|
||||
ParseResult() : code_(kParseErrorNone), offset_(0) {}
|
||||
//! Constructor to set an error.
|
||||
@@ -124,18 +127,25 @@ public:
|
||||
|
||||
bool operator==(const ParseResult& that) const { return code_ == that.code_; }
|
||||
bool operator==(ParseErrorCode code) const { return code_ == code; }
|
||||
friend bool operator==(ParseErrorCode code, const ParseResult & err) { return code == err.code_; }
|
||||
friend bool operator==(ParseErrorCode code, const ParseResult& err)
|
||||
{
|
||||
return code == err.code_;
|
||||
}
|
||||
|
||||
bool operator!=(const ParseResult& that) const { return !(*this == that); }
|
||||
bool operator!=(ParseErrorCode code) const { return !(*this == code); }
|
||||
friend bool operator!=(ParseErrorCode code, const ParseResult & err) { return err != code; }
|
||||
friend bool operator!=(ParseErrorCode code, const ParseResult& err) { return err != code; }
|
||||
|
||||
//! Reset error code.
|
||||
void Clear() { Set(kParseErrorNone); }
|
||||
//! Update error code and offset.
|
||||
void Set(ParseErrorCode code, size_t offset = 0) { code_ = code; offset_ = offset; }
|
||||
void Set(ParseErrorCode code, size_t offset = 0)
|
||||
{
|
||||
code_ = code;
|
||||
offset_ = offset;
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
ParseErrorCode code_;
|
||||
size_t offset_;
|
||||
};
|
||||
@@ -159,43 +169,49 @@ typedef const RAPIDJSON_ERROR_CHARTYPE* (*GetParseErrorFunc)(ParseErrorCode);
|
||||
/*! \ingroup RAPIDJSON_ERRORS
|
||||
\see GenericSchemaValidator
|
||||
*/
|
||||
enum ValidateErrorCode {
|
||||
kValidateErrors = -1, //!< Top level error code when kValidateContinueOnErrorsFlag set.
|
||||
kValidateErrorNone = 0, //!< No error.
|
||||
enum ValidateErrorCode
|
||||
{
|
||||
kValidateErrors = -1, //!< Top level error code when kValidateContinueOnErrorsFlag set.
|
||||
kValidateErrorNone = 0, //!< No error.
|
||||
|
||||
kValidateErrorMultipleOf, //!< Number is not a multiple of the 'multipleOf' value.
|
||||
kValidateErrorMaximum, //!< Number is greater than the 'maximum' value.
|
||||
kValidateErrorExclusiveMaximum, //!< Number is greater than or equal to the 'maximum' value.
|
||||
kValidateErrorMinimum, //!< Number is less than the 'minimum' value.
|
||||
kValidateErrorExclusiveMinimum, //!< Number is less than or equal to the 'minimum' value.
|
||||
kValidateErrorMultipleOf, //!< Number is not a multiple of the 'multipleOf' value.
|
||||
kValidateErrorMaximum, //!< Number is greater than the 'maximum' value.
|
||||
kValidateErrorExclusiveMaximum, //!< Number is greater than or equal to the 'maximum' value.
|
||||
kValidateErrorMinimum, //!< Number is less than the 'minimum' value.
|
||||
kValidateErrorExclusiveMinimum, //!< Number is less than or equal to the 'minimum' value.
|
||||
|
||||
kValidateErrorMaxLength, //!< String is longer than the 'maxLength' value.
|
||||
kValidateErrorMinLength, //!< String is longer than the 'maxLength' value.
|
||||
kValidateErrorPattern, //!< String does not match the 'pattern' regular expression.
|
||||
kValidateErrorMaxLength, //!< String is longer than the 'maxLength' value.
|
||||
kValidateErrorMinLength, //!< String is longer than the 'maxLength' value.
|
||||
kValidateErrorPattern, //!< String does not match the 'pattern' regular expression.
|
||||
|
||||
kValidateErrorMaxItems, //!< Array is longer than the 'maxItems' value.
|
||||
kValidateErrorMinItems, //!< Array is shorter than the 'minItems' value.
|
||||
kValidateErrorUniqueItems, //!< Array has duplicate items but 'uniqueItems' is true.
|
||||
kValidateErrorAdditionalItems, //!< Array has additional items that are not allowed by the schema.
|
||||
kValidateErrorMaxItems, //!< Array is longer than the 'maxItems' value.
|
||||
kValidateErrorMinItems, //!< Array is shorter than the 'minItems' value.
|
||||
kValidateErrorUniqueItems, //!< Array has duplicate items but 'uniqueItems' is true.
|
||||
kValidateErrorAdditionalItems, //!< Array has additional items that are not allowed by the
|
||||
//!< schema.
|
||||
|
||||
kValidateErrorMaxProperties, //!< Object has more members than 'maxProperties' value.
|
||||
kValidateErrorMinProperties, //!< Object has less members than 'minProperties' value.
|
||||
kValidateErrorRequired, //!< Object is missing one or more members required by the schema.
|
||||
kValidateErrorAdditionalProperties, //!< Object has additional members that are not allowed by the schema.
|
||||
kValidateErrorPatternProperties, //!< See other errors.
|
||||
kValidateErrorDependencies, //!< Object has missing property or schema dependencies.
|
||||
kValidateErrorMaxProperties, //!< Object has more members than 'maxProperties' value.
|
||||
kValidateErrorMinProperties, //!< Object has less members than 'minProperties' value.
|
||||
kValidateErrorRequired, //!< Object is missing one or more members required by the schema.
|
||||
kValidateErrorAdditionalProperties, //!< Object has additional members that are not allowed by
|
||||
//!< the schema.
|
||||
kValidateErrorPatternProperties, //!< See other errors.
|
||||
kValidateErrorDependencies, //!< Object has missing property or schema dependencies.
|
||||
|
||||
kValidateErrorEnum, //!< Property has a value that is not one of its allowed enumerated values.
|
||||
kValidateErrorType, //!< Property has a type that is not allowed by the schema.
|
||||
kValidateErrorEnum, //!< Property has a value that is not one of its allowed enumerated values.
|
||||
kValidateErrorType, //!< Property has a type that is not allowed by the schema.
|
||||
|
||||
kValidateErrorOneOf, //!< Property did not match any of the sub-schemas specified by 'oneOf'.
|
||||
kValidateErrorOneOfMatch, //!< Property matched more than one of the sub-schemas specified by 'oneOf'.
|
||||
kValidateErrorAllOf, //!< Property did not match all of the sub-schemas specified by 'allOf'.
|
||||
kValidateErrorAnyOf, //!< Property did not match any of the sub-schemas specified by 'anyOf'.
|
||||
kValidateErrorNot, //!< Property matched the sub-schema specified by 'not'.
|
||||
kValidateErrorOneOf, //!< Property did not match any of the sub-schemas specified by 'oneOf'.
|
||||
kValidateErrorOneOfMatch, //!< Property matched more than one of the sub-schemas specified by
|
||||
//!< 'oneOf'.
|
||||
kValidateErrorAllOf, //!< Property did not match all of the sub-schemas specified by 'allOf'.
|
||||
kValidateErrorAnyOf, //!< Property did not match any of the sub-schemas specified by 'anyOf'.
|
||||
kValidateErrorNot, //!< Property matched the sub-schema specified by 'not'.
|
||||
|
||||
kValidateErrorReadOnly, //!< Property is read-only but has been provided when validation is for writing
|
||||
kValidateErrorWriteOnly //!< Property is write-only but has been provided when validation is for reading
|
||||
kValidateErrorReadOnly, //!< Property is read-only but has been provided when validation is for
|
||||
//!< writing
|
||||
kValidateErrorWriteOnly //!< Property is write-only but has been provided when validation is for
|
||||
//!< reading
|
||||
};
|
||||
|
||||
//! Function pointer type of GetValidateError().
|
||||
@@ -217,22 +233,25 @@ typedef const RAPIDJSON_ERROR_CHARTYPE* (*GetValidateErrorFunc)(ValidateErrorCod
|
||||
/*! \ingroup RAPIDJSON_ERRORS
|
||||
\see GenericSchemaValidator
|
||||
*/
|
||||
enum SchemaErrorCode {
|
||||
kSchemaErrorNone = 0, //!< No error.
|
||||
enum SchemaErrorCode
|
||||
{
|
||||
kSchemaErrorNone = 0, //!< No error.
|
||||
|
||||
kSchemaErrorStartUnknown, //!< Pointer to start of schema does not resolve to a location in the document
|
||||
kSchemaErrorRefPlainName, //!< $ref fragment must be a JSON pointer
|
||||
kSchemaErrorRefInvalid, //!< $ref must not be an empty string
|
||||
kSchemaErrorRefPointerInvalid, //!< $ref fragment is not a valid JSON pointer at offset
|
||||
kSchemaErrorRefUnknown, //!< $ref does not resolve to a location in the target document
|
||||
kSchemaErrorRefCyclical, //!< $ref is cyclical
|
||||
kSchemaErrorRefNoRemoteProvider, //!< $ref is remote but there is no remote provider
|
||||
kSchemaErrorRefNoRemoteSchema, //!< $ref is remote but the remote provider did not return a schema
|
||||
kSchemaErrorRegexInvalid, //!< Invalid regular expression in 'pattern' or 'patternProperties'
|
||||
kSchemaErrorSpecUnknown, //!< JSON schema draft or OpenAPI version is not recognized
|
||||
kSchemaErrorSpecUnsupported, //!< JSON schema draft or OpenAPI version is not supported
|
||||
kSchemaErrorSpecIllegal, //!< Both JSON schema draft and OpenAPI version found in document
|
||||
kSchemaErrorReadOnlyAndWriteOnly //!< Property must not be both 'readOnly' and 'writeOnly'
|
||||
kSchemaErrorStartUnknown, //!< Pointer to start of schema does not resolve to a location in the
|
||||
//!< document
|
||||
kSchemaErrorRefPlainName, //!< $ref fragment must be a JSON pointer
|
||||
kSchemaErrorRefInvalid, //!< $ref must not be an empty string
|
||||
kSchemaErrorRefPointerInvalid, //!< $ref fragment is not a valid JSON pointer at offset
|
||||
kSchemaErrorRefUnknown, //!< $ref does not resolve to a location in the target document
|
||||
kSchemaErrorRefCyclical, //!< $ref is cyclical
|
||||
kSchemaErrorRefNoRemoteProvider, //!< $ref is remote but there is no remote provider
|
||||
kSchemaErrorRefNoRemoteSchema, //!< $ref is remote but the remote provider did not return a
|
||||
//!< schema
|
||||
kSchemaErrorRegexInvalid, //!< Invalid regular expression in 'pattern' or 'patternProperties'
|
||||
kSchemaErrorSpecUnknown, //!< JSON schema draft or OpenAPI version is not recognized
|
||||
kSchemaErrorSpecUnsupported, //!< JSON schema draft or OpenAPI version is not supported
|
||||
kSchemaErrorSpecIllegal, //!< Both JSON schema draft and OpenAPI version found in document
|
||||
kSchemaErrorReadOnlyAndWriteOnly //!< Property must not be both 'readOnly' and 'writeOnly'
|
||||
};
|
||||
|
||||
//! Function pointer type of GetSchemaError().
|
||||
@@ -254,13 +273,15 @@ typedef const RAPIDJSON_ERROR_CHARTYPE* (*GetSchemaErrorFunc)(SchemaErrorCode);
|
||||
/*! \ingroup RAPIDJSON_ERRORS
|
||||
\see GenericPointer::GenericPointer, GenericPointer::GetParseErrorCode
|
||||
*/
|
||||
enum PointerParseErrorCode {
|
||||
kPointerParseErrorNone = 0, //!< The parse is successful
|
||||
enum PointerParseErrorCode
|
||||
{
|
||||
kPointerParseErrorNone = 0, //!< The parse is successful
|
||||
|
||||
kPointerParseErrorTokenMustBeginWithSolidus, //!< A token must begin with a '/'
|
||||
kPointerParseErrorInvalidEscape, //!< Invalid escape
|
||||
kPointerParseErrorInvalidPercentEncoding, //!< Invalid percent encoding in URI fragment
|
||||
kPointerParseErrorCharacterMustPercentEncode //!< A character must percent encoded in URI fragment
|
||||
kPointerParseErrorTokenMustBeginWithSolidus, //!< A token must begin with a '/'
|
||||
kPointerParseErrorInvalidEscape, //!< Invalid escape
|
||||
kPointerParseErrorInvalidPercentEncoding, //!< Invalid percent encoding in URI fragment
|
||||
kPointerParseErrorCharacterMustPercentEncode //!< A character must percent encoded in URI
|
||||
//!< fragment
|
||||
};
|
||||
|
||||
//! Function pointer type of GetPointerParseError().
|
||||
@@ -275,7 +296,6 @@ enum PointerParseErrorCode {
|
||||
*/
|
||||
typedef const RAPIDJSON_ERROR_CHARTYPE* (*GetPointerParseErrorFunc)(PointerParseErrorCode);
|
||||
|
||||
|
||||
RAPIDJSON_NAMESPACE_END
|
||||
|
||||
#ifdef __clang__
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// Tencent is pleased to support the open source community by making RapidJSON available.
|
||||
//
|
||||
//
|
||||
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
|
||||
//
|
||||
// Licensed under the MIT License (the "License"); you may not use this file except
|
||||
@@ -7,9 +7,9 @@
|
||||
//
|
||||
// http://opensource.org/licenses/MIT
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations under the License.
|
||||
|
||||
#ifndef RAPIDJSON_FILEREADSTREAM_H_
|
||||
@@ -21,8 +21,8 @@
|
||||
#ifdef __clang__
|
||||
RAPIDJSON_DIAG_PUSH
|
||||
RAPIDJSON_DIAG_OFF(padded)
|
||||
RAPIDJSON_DIAG_OFF(unreachable-code)
|
||||
RAPIDJSON_DIAG_OFF(missing-noreturn)
|
||||
RAPIDJSON_DIAG_OFF(unreachable - code)
|
||||
RAPIDJSON_DIAG_OFF(missing - noreturn)
|
||||
#endif
|
||||
|
||||
RAPIDJSON_NAMESPACE_BEGIN
|
||||
@@ -31,9 +31,10 @@ RAPIDJSON_NAMESPACE_BEGIN
|
||||
/*!
|
||||
\note implements Stream concept
|
||||
*/
|
||||
class FileReadStream {
|
||||
public:
|
||||
typedef char Ch; //!< Character type (byte).
|
||||
class FileReadStream
|
||||
{
|
||||
public:
|
||||
typedef char Ch; //!< Character type (byte).
|
||||
|
||||
//! Constructor.
|
||||
/*!
|
||||
@@ -41,38 +42,61 @@ public:
|
||||
\param buffer user-supplied buffer.
|
||||
\param bufferSize size of buffer in bytes. Must >=4 bytes.
|
||||
*/
|
||||
FileReadStream(std::FILE* fp, char* buffer, size_t bufferSize) : fp_(fp), buffer_(buffer), bufferSize_(bufferSize), bufferLast_(0), current_(buffer_), readCount_(0), count_(0), eof_(false) {
|
||||
FileReadStream(std::FILE* fp, char* buffer, size_t bufferSize)
|
||||
: fp_(fp),
|
||||
buffer_(buffer),
|
||||
bufferSize_(bufferSize),
|
||||
bufferLast_(0),
|
||||
current_(buffer_),
|
||||
readCount_(0),
|
||||
count_(0),
|
||||
eof_(false)
|
||||
{
|
||||
RAPIDJSON_ASSERT(fp_ != 0);
|
||||
RAPIDJSON_ASSERT(bufferSize >= 4);
|
||||
Read();
|
||||
}
|
||||
|
||||
Ch Peek() const { return *current_; }
|
||||
Ch Take() { Ch c = *current_; Read(); return c; }
|
||||
Ch Take()
|
||||
{
|
||||
Ch c = *current_;
|
||||
Read();
|
||||
return c;
|
||||
}
|
||||
size_t Tell() const { return count_ + static_cast<size_t>(current_ - buffer_); }
|
||||
|
||||
// Not implemented
|
||||
void Put(Ch) { RAPIDJSON_ASSERT(false); }
|
||||
void Flush() { RAPIDJSON_ASSERT(false); }
|
||||
Ch* PutBegin() { RAPIDJSON_ASSERT(false); return 0; }
|
||||
size_t PutEnd(Ch*) { RAPIDJSON_ASSERT(false); return 0; }
|
||||
|
||||
// For encoding detection only.
|
||||
const Ch* Peek4() const {
|
||||
return (current_ + 4 - !eof_ <= bufferLast_) ? current_ : 0;
|
||||
void Flush() { RAPIDJSON_ASSERT(false); }
|
||||
Ch* PutBegin()
|
||||
{
|
||||
RAPIDJSON_ASSERT(false);
|
||||
return 0;
|
||||
}
|
||||
size_t PutEnd(Ch*)
|
||||
{
|
||||
RAPIDJSON_ASSERT(false);
|
||||
return 0;
|
||||
}
|
||||
|
||||
private:
|
||||
void Read() {
|
||||
if (current_ < bufferLast_)
|
||||
++current_;
|
||||
else if (!eof_) {
|
||||
count_ += readCount_;
|
||||
readCount_ = std::fread(buffer_, 1, bufferSize_, fp_);
|
||||
bufferLast_ = buffer_ + readCount_ - 1;
|
||||
current_ = buffer_;
|
||||
// For encoding detection only.
|
||||
const Ch* Peek4() const { return (current_ + 4 - !eof_ <= bufferLast_) ? current_ : 0; }
|
||||
|
||||
if (readCount_ < bufferSize_) {
|
||||
private:
|
||||
void Read()
|
||||
{
|
||||
if(current_ < bufferLast_)
|
||||
++current_;
|
||||
else if(!eof_)
|
||||
{
|
||||
count_ += readCount_;
|
||||
readCount_ = std::fread(buffer_, 1, bufferSize_, fp_);
|
||||
bufferLast_ = buffer_ + readCount_ - 1;
|
||||
current_ = buffer_;
|
||||
|
||||
if(readCount_ < bufferSize_)
|
||||
{
|
||||
buffer_[readCount_] = '\0';
|
||||
++bufferLast_;
|
||||
eof_ = true;
|
||||
@@ -81,12 +105,12 @@ private:
|
||||
}
|
||||
|
||||
std::FILE* fp_;
|
||||
Ch *buffer_;
|
||||
Ch* buffer_;
|
||||
size_t bufferSize_;
|
||||
Ch *bufferLast_;
|
||||
Ch *current_;
|
||||
Ch* bufferLast_;
|
||||
Ch* current_;
|
||||
size_t readCount_;
|
||||
size_t count_; //!< Number of characters read
|
||||
size_t count_; //!< Number of characters read
|
||||
bool eof_;
|
||||
};
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// Tencent is pleased to support the open source community by making RapidJSON available.
|
||||
//
|
||||
//
|
||||
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
|
||||
//
|
||||
// Licensed under the MIT License (the "License"); you may not use this file except
|
||||
@@ -7,9 +7,9 @@
|
||||
//
|
||||
// http://opensource.org/licenses/MIT
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations under the License.
|
||||
|
||||
#ifndef RAPIDJSON_FILEWRITESTREAM_H_
|
||||
@@ -20,7 +20,7 @@
|
||||
|
||||
#ifdef __clang__
|
||||
RAPIDJSON_DIAG_PUSH
|
||||
RAPIDJSON_DIAG_OFF(unreachable-code)
|
||||
RAPIDJSON_DIAG_OFF(unreachable - code)
|
||||
#endif
|
||||
|
||||
RAPIDJSON_NAMESPACE_BEGIN
|
||||
@@ -29,24 +29,30 @@ RAPIDJSON_NAMESPACE_BEGIN
|
||||
/*!
|
||||
\note implements Stream concept
|
||||
*/
|
||||
class FileWriteStream {
|
||||
public:
|
||||
typedef char Ch; //!< Character type. Only support char.
|
||||
class FileWriteStream
|
||||
{
|
||||
public:
|
||||
typedef char Ch; //!< Character type. Only support char.
|
||||
|
||||
FileWriteStream(std::FILE* fp, char* buffer, size_t bufferSize) : fp_(fp), buffer_(buffer), bufferEnd_(buffer + bufferSize), current_(buffer_) {
|
||||
FileWriteStream(std::FILE* fp, char* buffer, size_t bufferSize)
|
||||
: fp_(fp), buffer_(buffer), bufferEnd_(buffer + bufferSize), current_(buffer_)
|
||||
{
|
||||
RAPIDJSON_ASSERT(fp_ != 0);
|
||||
}
|
||||
|
||||
void Put(char c) {
|
||||
if (current_ >= bufferEnd_)
|
||||
void Put(char c)
|
||||
{
|
||||
if(current_ >= bufferEnd_)
|
||||
Flush();
|
||||
|
||||
*current_++ = c;
|
||||
}
|
||||
|
||||
void PutN(char c, size_t n) {
|
||||
void PutN(char c, size_t n)
|
||||
{
|
||||
size_t avail = static_cast<size_t>(bufferEnd_ - current_);
|
||||
while (n > avail) {
|
||||
while(n > avail)
|
||||
{
|
||||
std::memset(current_, c, avail);
|
||||
current_ += avail;
|
||||
Flush();
|
||||
@@ -54,16 +60,20 @@ public:
|
||||
avail = static_cast<size_t>(bufferEnd_ - current_);
|
||||
}
|
||||
|
||||
if (n > 0) {
|
||||
if(n > 0)
|
||||
{
|
||||
std::memset(current_, c, n);
|
||||
current_ += n;
|
||||
}
|
||||
}
|
||||
|
||||
void Flush() {
|
||||
if (current_ != buffer_) {
|
||||
void Flush()
|
||||
{
|
||||
if(current_ != buffer_)
|
||||
{
|
||||
size_t result = std::fwrite(buffer_, 1, static_cast<size_t>(current_ - buffer_), fp_);
|
||||
if (result < static_cast<size_t>(current_ - buffer_)) {
|
||||
if(result < static_cast<size_t>(current_ - buffer_))
|
||||
{
|
||||
// failure deliberately ignored at this time
|
||||
// added to avoid warn_unused_result build errors
|
||||
}
|
||||
@@ -72,26 +82,47 @@ public:
|
||||
}
|
||||
|
||||
// Not implemented
|
||||
char Peek() const { RAPIDJSON_ASSERT(false); return 0; }
|
||||
char Take() { RAPIDJSON_ASSERT(false); return 0; }
|
||||
size_t Tell() const { RAPIDJSON_ASSERT(false); return 0; }
|
||||
char* PutBegin() { RAPIDJSON_ASSERT(false); return 0; }
|
||||
size_t PutEnd(char*) { RAPIDJSON_ASSERT(false); return 0; }
|
||||
char Peek() const
|
||||
{
|
||||
RAPIDJSON_ASSERT(false);
|
||||
return 0;
|
||||
}
|
||||
char Take()
|
||||
{
|
||||
RAPIDJSON_ASSERT(false);
|
||||
return 0;
|
||||
}
|
||||
size_t Tell() const
|
||||
{
|
||||
RAPIDJSON_ASSERT(false);
|
||||
return 0;
|
||||
}
|
||||
char* PutBegin()
|
||||
{
|
||||
RAPIDJSON_ASSERT(false);
|
||||
return 0;
|
||||
}
|
||||
size_t PutEnd(char*)
|
||||
{
|
||||
RAPIDJSON_ASSERT(false);
|
||||
return 0;
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
// Prohibit copy constructor & assignment operator.
|
||||
FileWriteStream(const FileWriteStream&);
|
||||
FileWriteStream& operator=(const FileWriteStream&);
|
||||
|
||||
std::FILE* fp_;
|
||||
char *buffer_;
|
||||
char *bufferEnd_;
|
||||
char *current_;
|
||||
char* buffer_;
|
||||
char* bufferEnd_;
|
||||
char* current_;
|
||||
};
|
||||
|
||||
//! Implement specialized version of PutN() with memset() for better performance.
|
||||
template<>
|
||||
inline void PutN(FileWriteStream& stream, char c, size_t n) {
|
||||
template <>
|
||||
inline void PutN(FileWriteStream& stream, char c, size_t n)
|
||||
{
|
||||
stream.PutN(c, n);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// Tencent is pleased to support the open source community by making RapidJSON available.
|
||||
//
|
||||
//
|
||||
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
|
||||
//
|
||||
// Licensed under the MIT License (the "License"); you may not use this file except
|
||||
@@ -7,9 +7,9 @@
|
||||
//
|
||||
// http://opensource.org/licenses/MIT
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations under the License.
|
||||
|
||||
#ifndef RAPIDJSON_FWD_H_
|
||||
@@ -21,17 +21,26 @@ RAPIDJSON_NAMESPACE_BEGIN
|
||||
|
||||
// encodings.h
|
||||
|
||||
template<typename CharType> struct UTF8;
|
||||
template<typename CharType> struct UTF16;
|
||||
template<typename CharType> struct UTF16BE;
|
||||
template<typename CharType> struct UTF16LE;
|
||||
template<typename CharType> struct UTF32;
|
||||
template<typename CharType> struct UTF32BE;
|
||||
template<typename CharType> struct UTF32LE;
|
||||
template<typename CharType> struct ASCII;
|
||||
template<typename CharType> struct AutoUTF;
|
||||
template <typename CharType>
|
||||
struct UTF8;
|
||||
template <typename CharType>
|
||||
struct UTF16;
|
||||
template <typename CharType>
|
||||
struct UTF16BE;
|
||||
template <typename CharType>
|
||||
struct UTF16LE;
|
||||
template <typename CharType>
|
||||
struct UTF32;
|
||||
template <typename CharType>
|
||||
struct UTF32BE;
|
||||
template <typename CharType>
|
||||
struct UTF32LE;
|
||||
template <typename CharType>
|
||||
struct ASCII;
|
||||
template <typename CharType>
|
||||
struct AutoUTF;
|
||||
|
||||
template<typename SourceEncoding, typename TargetEncoding>
|
||||
template <typename SourceEncoding, typename TargetEncoding>
|
||||
struct Transcoder;
|
||||
|
||||
// allocators.h
|
||||
@@ -46,12 +55,12 @@ class MemoryPoolAllocator;
|
||||
template <typename Encoding>
|
||||
struct GenericStringStream;
|
||||
|
||||
typedef GenericStringStream<UTF8<char> > StringStream;
|
||||
typedef GenericStringStream<UTF8<char>> StringStream;
|
||||
|
||||
template <typename Encoding>
|
||||
struct GenericInsituStringStream;
|
||||
|
||||
typedef GenericInsituStringStream<UTF8<char> > InsituStringStream;
|
||||
typedef GenericInsituStringStream<UTF8<char>> InsituStringStream;
|
||||
|
||||
// stringbuffer.h
|
||||
|
||||
@@ -81,7 +90,7 @@ struct MemoryStream;
|
||||
|
||||
// reader.h
|
||||
|
||||
template<typename Encoding, typename Derived>
|
||||
template <typename Encoding, typename Derived>
|
||||
struct BaseReaderHandler;
|
||||
|
||||
template <typename SourceEncoding, typename TargetEncoding, typename StackAllocator>
|
||||
@@ -91,29 +100,37 @@ typedef GenericReader<UTF8<char>, UTF8<char>, CrtAllocator> Reader;
|
||||
|
||||
// writer.h
|
||||
|
||||
template<typename OutputStream, typename SourceEncoding, typename TargetEncoding, typename StackAllocator, unsigned writeFlags>
|
||||
template <typename OutputStream,
|
||||
typename SourceEncoding,
|
||||
typename TargetEncoding,
|
||||
typename StackAllocator,
|
||||
unsigned writeFlags>
|
||||
class Writer;
|
||||
|
||||
// prettywriter.h
|
||||
|
||||
template<typename OutputStream, typename SourceEncoding, typename TargetEncoding, typename StackAllocator, unsigned writeFlags>
|
||||
template <typename OutputStream,
|
||||
typename SourceEncoding,
|
||||
typename TargetEncoding,
|
||||
typename StackAllocator,
|
||||
unsigned writeFlags>
|
||||
class PrettyWriter;
|
||||
|
||||
// document.h
|
||||
|
||||
template <typename Encoding, typename Allocator>
|
||||
template <typename Encoding, typename Allocator>
|
||||
class GenericMember;
|
||||
|
||||
template <bool Const, typename Encoding, typename Allocator>
|
||||
class GenericMemberIterator;
|
||||
|
||||
template<typename CharType>
|
||||
template <typename CharType>
|
||||
struct GenericStringRef;
|
||||
|
||||
template <typename Encoding, typename Allocator>
|
||||
template <typename Encoding, typename Allocator>
|
||||
class GenericValue;
|
||||
|
||||
typedef GenericValue<UTF8<char>, MemoryPoolAllocator<CrtAllocator> > Value;
|
||||
typedef GenericValue<UTF8<char>, MemoryPoolAllocator<CrtAllocator>> Value;
|
||||
|
||||
template <typename Encoding, typename Allocator, typename StackAllocator>
|
||||
class GenericDocument;
|
||||
@@ -138,13 +155,11 @@ class GenericSchemaDocument;
|
||||
typedef GenericSchemaDocument<Value, CrtAllocator> SchemaDocument;
|
||||
typedef IGenericRemoteSchemaDocumentProvider<SchemaDocument> IRemoteSchemaDocumentProvider;
|
||||
|
||||
template <
|
||||
typename SchemaDocumentType,
|
||||
typename OutputHandler,
|
||||
typename StateAllocator>
|
||||
template <typename SchemaDocumentType, typename OutputHandler, typename StateAllocator>
|
||||
class GenericSchemaValidator;
|
||||
|
||||
typedef GenericSchemaValidator<SchemaDocument, BaseReaderHandler<UTF8<char>, void>, CrtAllocator> SchemaValidator;
|
||||
typedef GenericSchemaValidator<SchemaDocument, BaseReaderHandler<UTF8<char>, void>, CrtAllocator>
|
||||
SchemaValidator;
|
||||
|
||||
RAPIDJSON_NAMESPACE_END
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// Tencent is pleased to support the open source community by making RapidJSON available.
|
||||
//
|
||||
//
|
||||
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
|
||||
//
|
||||
// Licensed under the MIT License (the "License"); you may not use this file except
|
||||
@@ -7,9 +7,9 @@
|
||||
//
|
||||
// http://opensource.org/licenses/MIT
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations under the License.
|
||||
|
||||
#ifndef RAPIDJSON_BIGINTEGER_H_
|
||||
@@ -22,132 +22,153 @@
|
||||
#if !defined(_ARM64EC_)
|
||||
#pragma intrinsic(_umul128)
|
||||
#else
|
||||
#pragma comment(lib,"softintrin")
|
||||
#pragma comment(lib, "softintrin")
|
||||
#endif
|
||||
#endif
|
||||
|
||||
RAPIDJSON_NAMESPACE_BEGIN
|
||||
namespace internal {
|
||||
|
||||
class BigInteger {
|
||||
public:
|
||||
class BigInteger
|
||||
{
|
||||
public:
|
||||
typedef uint64_t Type;
|
||||
|
||||
BigInteger(const BigInteger& rhs) : count_(rhs.count_) {
|
||||
BigInteger(const BigInteger& rhs) : count_(rhs.count_)
|
||||
{
|
||||
std::memcpy(digits_, rhs.digits_, count_ * sizeof(Type));
|
||||
}
|
||||
|
||||
explicit BigInteger(uint64_t u) : count_(1) {
|
||||
digits_[0] = u;
|
||||
}
|
||||
explicit BigInteger(uint64_t u) : count_(1) { digits_[0] = u; }
|
||||
|
||||
template<typename Ch>
|
||||
BigInteger(const Ch* decimals, size_t length) : count_(1) {
|
||||
template <typename Ch>
|
||||
BigInteger(const Ch* decimals, size_t length) : count_(1)
|
||||
{
|
||||
RAPIDJSON_ASSERT(length > 0);
|
||||
digits_[0] = 0;
|
||||
size_t i = 0;
|
||||
const size_t kMaxDigitPerIteration = 19; // 2^64 = 18446744073709551616 > 10^19
|
||||
while (length >= kMaxDigitPerIteration) {
|
||||
digits_[0] = 0;
|
||||
size_t i = 0;
|
||||
const size_t kMaxDigitPerIteration = 19; // 2^64 = 18446744073709551616 > 10^19
|
||||
while(length >= kMaxDigitPerIteration)
|
||||
{
|
||||
AppendDecimal64(decimals + i, decimals + i + kMaxDigitPerIteration);
|
||||
length -= kMaxDigitPerIteration;
|
||||
i += kMaxDigitPerIteration;
|
||||
}
|
||||
|
||||
if (length > 0)
|
||||
if(length > 0)
|
||||
AppendDecimal64(decimals + i, decimals + i + length);
|
||||
}
|
||||
|
||||
BigInteger& operator=(const BigInteger &rhs)
|
||||
|
||||
BigInteger& operator=(const BigInteger& rhs)
|
||||
{
|
||||
if (this != &rhs) {
|
||||
if(this != &rhs)
|
||||
{
|
||||
count_ = rhs.count_;
|
||||
std::memcpy(digits_, rhs.digits_, count_ * sizeof(Type));
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
BigInteger& operator=(uint64_t u) {
|
||||
digits_[0] = u;
|
||||
count_ = 1;
|
||||
|
||||
BigInteger& operator=(uint64_t u)
|
||||
{
|
||||
digits_[0] = u;
|
||||
count_ = 1;
|
||||
return *this;
|
||||
}
|
||||
|
||||
BigInteger& operator+=(uint64_t u) {
|
||||
BigInteger& operator+=(uint64_t u)
|
||||
{
|
||||
Type backup = digits_[0];
|
||||
digits_[0] += u;
|
||||
for (size_t i = 0; i < count_ - 1; i++) {
|
||||
if (digits_[i] >= backup)
|
||||
for(size_t i = 0; i < count_ - 1; i++)
|
||||
{
|
||||
if(digits_[i] >= backup)
|
||||
return *this; // no carry
|
||||
backup = digits_[i + 1];
|
||||
digits_[i + 1] += 1;
|
||||
}
|
||||
|
||||
// Last carry
|
||||
if (digits_[count_ - 1] < backup)
|
||||
if(digits_[count_ - 1] < backup)
|
||||
PushBack(1);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
BigInteger& operator*=(uint64_t u) {
|
||||
if (u == 0) return *this = 0;
|
||||
if (u == 1) return *this;
|
||||
if (*this == 1) return *this = u;
|
||||
BigInteger& operator*=(uint64_t u)
|
||||
{
|
||||
if(u == 0)
|
||||
return *this = 0;
|
||||
if(u == 1)
|
||||
return *this;
|
||||
if(*this == 1)
|
||||
return *this = u;
|
||||
|
||||
uint64_t k = 0;
|
||||
for (size_t i = 0; i < count_; i++) {
|
||||
for(size_t i = 0; i < count_; i++)
|
||||
{
|
||||
uint64_t hi;
|
||||
digits_[i] = MulAdd64(digits_[i], u, k, &hi);
|
||||
k = hi;
|
||||
k = hi;
|
||||
}
|
||||
|
||||
if (k > 0)
|
||||
|
||||
if(k > 0)
|
||||
PushBack(k);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
BigInteger& operator*=(uint32_t u) {
|
||||
if (u == 0) return *this = 0;
|
||||
if (u == 1) return *this;
|
||||
if (*this == 1) return *this = u;
|
||||
BigInteger& operator*=(uint32_t u)
|
||||
{
|
||||
if(u == 0)
|
||||
return *this = 0;
|
||||
if(u == 1)
|
||||
return *this;
|
||||
if(*this == 1)
|
||||
return *this = u;
|
||||
|
||||
uint64_t k = 0;
|
||||
for (size_t i = 0; i < count_; i++) {
|
||||
const uint64_t c = digits_[i] >> 32;
|
||||
const uint64_t d = digits_[i] & 0xFFFFFFFF;
|
||||
for(size_t i = 0; i < count_; i++)
|
||||
{
|
||||
const uint64_t c = digits_[i] >> 32;
|
||||
const uint64_t d = digits_[i] & 0xFFFFFFFF;
|
||||
const uint64_t uc = u * c;
|
||||
const uint64_t ud = u * d;
|
||||
const uint64_t p0 = ud + k;
|
||||
const uint64_t p1 = uc + (p0 >> 32);
|
||||
digits_[i] = (p0 & 0xFFFFFFFF) | (p1 << 32);
|
||||
k = p1 >> 32;
|
||||
digits_[i] = (p0 & 0xFFFFFFFF) | (p1 << 32);
|
||||
k = p1 >> 32;
|
||||
}
|
||||
|
||||
if (k > 0)
|
||||
|
||||
if(k > 0)
|
||||
PushBack(k);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
BigInteger& operator<<=(size_t shift) {
|
||||
if (IsZero() || shift == 0) return *this;
|
||||
BigInteger& operator<<=(size_t shift)
|
||||
{
|
||||
if(IsZero() || shift == 0)
|
||||
return *this;
|
||||
|
||||
size_t offset = shift / kTypeBit;
|
||||
size_t offset = shift / kTypeBit;
|
||||
size_t interShift = shift % kTypeBit;
|
||||
RAPIDJSON_ASSERT(count_ + offset <= kCapacity);
|
||||
|
||||
if (interShift == 0) {
|
||||
if(interShift == 0)
|
||||
{
|
||||
std::memmove(digits_ + offset, digits_, count_ * sizeof(Type));
|
||||
count_ += offset;
|
||||
}
|
||||
else {
|
||||
else
|
||||
{
|
||||
digits_[count_] = 0;
|
||||
for (size_t i = count_; i > 0; i--)
|
||||
digits_[i + offset] = (digits_[i] << interShift) | (digits_[i - 1] >> (kTypeBit - interShift));
|
||||
for(size_t i = count_; i > 0; i--)
|
||||
digits_[i + offset] =
|
||||
(digits_[i] << interShift) | (digits_[i - 1] >> (kTypeBit - interShift));
|
||||
digits_[offset] = digits_[0] << interShift;
|
||||
count_ += offset;
|
||||
if (digits_[count_])
|
||||
if(digits_[count_])
|
||||
count_++;
|
||||
}
|
||||
|
||||
@@ -156,96 +177,121 @@ public:
|
||||
return *this;
|
||||
}
|
||||
|
||||
bool operator==(const BigInteger& rhs) const {
|
||||
return count_ == rhs.count_ && std::memcmp(digits_, rhs.digits_, count_ * sizeof(Type)) == 0;
|
||||
bool operator==(const BigInteger& rhs) const
|
||||
{
|
||||
return count_ == rhs.count_ &&
|
||||
std::memcmp(digits_, rhs.digits_, count_ * sizeof(Type)) == 0;
|
||||
}
|
||||
|
||||
bool operator==(const Type rhs) const {
|
||||
return count_ == 1 && digits_[0] == rhs;
|
||||
}
|
||||
bool operator==(const Type rhs) const { return count_ == 1 && digits_[0] == rhs; }
|
||||
|
||||
BigInteger& MultiplyPow5(unsigned exp) {
|
||||
static const uint32_t kPow5[12] = {
|
||||
5,
|
||||
5 * 5,
|
||||
5 * 5 * 5,
|
||||
5 * 5 * 5 * 5,
|
||||
5 * 5 * 5 * 5 * 5,
|
||||
5 * 5 * 5 * 5 * 5 * 5,
|
||||
5 * 5 * 5 * 5 * 5 * 5 * 5,
|
||||
5 * 5 * 5 * 5 * 5 * 5 * 5 * 5,
|
||||
5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5,
|
||||
5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5,
|
||||
5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5,
|
||||
5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5
|
||||
};
|
||||
if (exp == 0) return *this;
|
||||
for (; exp >= 27; exp -= 27) *this *= RAPIDJSON_UINT64_C2(0X6765C793, 0XFA10079D); // 5^27
|
||||
for (; exp >= 13; exp -= 13) *this *= static_cast<uint32_t>(1220703125u); // 5^13
|
||||
if (exp > 0) *this *= kPow5[exp - 1];
|
||||
BigInteger& MultiplyPow5(unsigned exp)
|
||||
{
|
||||
static const uint32_t kPow5[12] = {5,
|
||||
5 * 5,
|
||||
5 * 5 * 5,
|
||||
5 * 5 * 5 * 5,
|
||||
5 * 5 * 5 * 5 * 5,
|
||||
5 * 5 * 5 * 5 * 5 * 5,
|
||||
5 * 5 * 5 * 5 * 5 * 5 * 5,
|
||||
5 * 5 * 5 * 5 * 5 * 5 * 5 * 5,
|
||||
5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5,
|
||||
5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5,
|
||||
5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5,
|
||||
5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5};
|
||||
if(exp == 0)
|
||||
return *this;
|
||||
for(; exp >= 27; exp -= 27)
|
||||
*this *= RAPIDJSON_UINT64_C2(0X6765C793, 0XFA10079D); // 5^27
|
||||
for(; exp >= 13; exp -= 13)
|
||||
*this *= static_cast<uint32_t>(1220703125u); // 5^13
|
||||
if(exp > 0)
|
||||
*this *= kPow5[exp - 1];
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Compute absolute difference of this and rhs.
|
||||
// Assume this != rhs
|
||||
bool Difference(const BigInteger& rhs, BigInteger* out) const {
|
||||
bool Difference(const BigInteger& rhs, BigInteger* out) const
|
||||
{
|
||||
int cmp = Compare(rhs);
|
||||
RAPIDJSON_ASSERT(cmp != 0);
|
||||
const BigInteger *a, *b; // Makes a > b
|
||||
const BigInteger *a, *b; // Makes a > b
|
||||
bool ret;
|
||||
if (cmp < 0) { a = &rhs; b = this; ret = true; }
|
||||
else { a = this; b = &rhs; ret = false; }
|
||||
if(cmp < 0)
|
||||
{
|
||||
a = &rhs;
|
||||
b = this;
|
||||
ret = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
a = this;
|
||||
b = &rhs;
|
||||
ret = false;
|
||||
}
|
||||
|
||||
Type borrow = 0;
|
||||
for (size_t i = 0; i < a->count_; i++) {
|
||||
for(size_t i = 0; i < a->count_; i++)
|
||||
{
|
||||
Type d = a->digits_[i] - borrow;
|
||||
if (i < b->count_)
|
||||
if(i < b->count_)
|
||||
d -= b->digits_[i];
|
||||
borrow = (d > a->digits_[i]) ? 1 : 0;
|
||||
borrow = (d > a->digits_[i]) ? 1 : 0;
|
||||
out->digits_[i] = d;
|
||||
if (d != 0)
|
||||
if(d != 0)
|
||||
out->count_ = i + 1;
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
int Compare(const BigInteger& rhs) const {
|
||||
if (count_ != rhs.count_)
|
||||
int Compare(const BigInteger& rhs) const
|
||||
{
|
||||
if(count_ != rhs.count_)
|
||||
return count_ < rhs.count_ ? -1 : 1;
|
||||
|
||||
for (size_t i = count_; i-- > 0;)
|
||||
if (digits_[i] != rhs.digits_[i])
|
||||
for(size_t i = count_; i-- > 0;)
|
||||
if(digits_[i] != rhs.digits_[i])
|
||||
return digits_[i] < rhs.digits_[i] ? -1 : 1;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
size_t GetCount() const { return count_; }
|
||||
Type GetDigit(size_t index) const { RAPIDJSON_ASSERT(index < count_); return digits_[index]; }
|
||||
Type GetDigit(size_t index) const
|
||||
{
|
||||
RAPIDJSON_ASSERT(index < count_);
|
||||
return digits_[index];
|
||||
}
|
||||
bool IsZero() const { return count_ == 1 && digits_[0] == 0; }
|
||||
|
||||
private:
|
||||
template<typename Ch>
|
||||
void AppendDecimal64(const Ch* begin, const Ch* end) {
|
||||
private:
|
||||
template <typename Ch>
|
||||
void AppendDecimal64(const Ch* begin, const Ch* end)
|
||||
{
|
||||
uint64_t u = ParseUint64(begin, end);
|
||||
if (IsZero())
|
||||
if(IsZero())
|
||||
*this = u;
|
||||
else {
|
||||
else
|
||||
{
|
||||
unsigned exp = static_cast<unsigned>(end - begin);
|
||||
(MultiplyPow5(exp) <<= exp) += u; // *this = *this * 10^exp + u
|
||||
(MultiplyPow5(exp) <<= exp) += u; // *this = *this * 10^exp + u
|
||||
}
|
||||
}
|
||||
|
||||
void PushBack(Type digit) {
|
||||
void PushBack(Type digit)
|
||||
{
|
||||
RAPIDJSON_ASSERT(count_ < kCapacity);
|
||||
digits_[count_++] = digit;
|
||||
}
|
||||
|
||||
template<typename Ch>
|
||||
static uint64_t ParseUint64(const Ch* begin, const Ch* end) {
|
||||
template <typename Ch>
|
||||
static uint64_t ParseUint64(const Ch* begin, const Ch* end)
|
||||
{
|
||||
uint64_t r = 0;
|
||||
for (const Ch* p = begin; p != end; ++p) {
|
||||
for(const Ch* p = begin; p != end; ++p)
|
||||
{
|
||||
RAPIDJSON_ASSERT(*p >= Ch('0') && *p <= Ch('9'));
|
||||
r = r * 10u + static_cast<unsigned>(*p - Ch('0'));
|
||||
}
|
||||
@@ -253,13 +299,15 @@ private:
|
||||
}
|
||||
|
||||
// Assume a * b + k < 2^128
|
||||
static uint64_t MulAdd64(uint64_t a, uint64_t b, uint64_t k, uint64_t* outHigh) {
|
||||
static uint64_t MulAdd64(uint64_t a, uint64_t b, uint64_t k, uint64_t* outHigh)
|
||||
{
|
||||
#if defined(_MSC_VER) && defined(_M_AMD64)
|
||||
uint64_t low = _umul128(a, b, outHigh) + k;
|
||||
if (low < k)
|
||||
if(low < k)
|
||||
(*outHigh)++;
|
||||
return low;
|
||||
#elif defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6)) && defined(__x86_64__)
|
||||
#elif defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6)) && \
|
||||
defined(__x86_64__)
|
||||
__extension__ typedef unsigned __int128 uint128;
|
||||
uint128 p = static_cast<uint128>(a) * static_cast<uint128>(b);
|
||||
p += k;
|
||||
@@ -270,22 +318,22 @@ private:
|
||||
uint64_t x0 = a0 * b0, x1 = a0 * b1, x2 = a1 * b0, x3 = a1 * b1;
|
||||
x1 += (x0 >> 32); // can't give carry
|
||||
x1 += x2;
|
||||
if (x1 < x2)
|
||||
if(x1 < x2)
|
||||
x3 += (static_cast<uint64_t>(1) << 32);
|
||||
uint64_t lo = (x1 << 32) + (x0 & 0xFFFFFFFF);
|
||||
uint64_t hi = x3 + (x1 >> 32);
|
||||
|
||||
lo += k;
|
||||
if (lo < k)
|
||||
if(lo < k)
|
||||
hi++;
|
||||
*outHigh = hi;
|
||||
return lo;
|
||||
#endif
|
||||
}
|
||||
|
||||
static const size_t kBitCount = 3328; // 64bit * 54 > 10^1000
|
||||
static const size_t kBitCount = 3328; // 64bit * 54 > 10^1000
|
||||
static const size_t kCapacity = kBitCount / sizeof(Type);
|
||||
static const size_t kTypeBit = sizeof(Type) * 8;
|
||||
static const size_t kTypeBit = sizeof(Type) * 8;
|
||||
|
||||
Type digits_[kCapacity];
|
||||
size_t count_;
|
||||
|
||||
@@ -29,7 +29,8 @@
|
||||
RAPIDJSON_NAMESPACE_BEGIN
|
||||
namespace internal {
|
||||
|
||||
inline uint32_t clzll(uint64_t x) {
|
||||
inline uint32_t clzll(uint64_t x)
|
||||
{
|
||||
// Passing 0 to __builtin_clzll is UB in GCC and results in an
|
||||
// infinite loop in the software implementation.
|
||||
RAPIDJSON_ASSERT(x != 0);
|
||||
@@ -40,7 +41,7 @@ inline uint32_t clzll(uint64_t x) {
|
||||
_BitScanReverse64(&r, x);
|
||||
#else
|
||||
// Scan the high 32 bits.
|
||||
if (_BitScanReverse(&r, static_cast<uint32_t>(x >> 32)))
|
||||
if(_BitScanReverse(&r, static_cast<uint32_t>(x >> 32)))
|
||||
return 63 - (r + 32);
|
||||
|
||||
// Scan the low 32 bits.
|
||||
@@ -48,13 +49,14 @@ inline uint32_t clzll(uint64_t x) {
|
||||
#endif // _WIN64
|
||||
|
||||
return 63 - r;
|
||||
#elif (defined(__GNUC__) && __GNUC__ >= 4) || RAPIDJSON_HAS_BUILTIN(__builtin_clzll)
|
||||
#elif(defined(__GNUC__) && __GNUC__ >= 4) || RAPIDJSON_HAS_BUILTIN(__builtin_clzll)
|
||||
// __builtin_clzll wrapper
|
||||
return static_cast<uint32_t>(__builtin_clzll(x));
|
||||
#else
|
||||
// naive version
|
||||
uint32_t r = 0;
|
||||
while (!(x & (static_cast<uint64_t>(1) << 63))) {
|
||||
while(!(x & (static_cast<uint64_t>(1) << 63)))
|
||||
{
|
||||
x <<= 1;
|
||||
++r;
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@
|
||||
#if !defined(_ARM64EC_)
|
||||
#pragma intrinsic(_umul128)
|
||||
#else
|
||||
#pragma comment(lib,"softintrin")
|
||||
#pragma comment(lib, "softintrin")
|
||||
#endif
|
||||
#endif
|
||||
|
||||
@@ -45,72 +45,80 @@ RAPIDJSON_DIAG_PUSH
|
||||
RAPIDJSON_DIAG_OFF(padded)
|
||||
#endif
|
||||
|
||||
struct DiyFp {
|
||||
struct DiyFp
|
||||
{
|
||||
DiyFp() : f(), e() {}
|
||||
|
||||
DiyFp(uint64_t fp, int exp) : f(fp), e(exp) {}
|
||||
|
||||
explicit DiyFp(double d) {
|
||||
union {
|
||||
explicit DiyFp(double d)
|
||||
{
|
||||
union
|
||||
{
|
||||
double d;
|
||||
uint64_t u64;
|
||||
} u = { d };
|
||||
} u = {d};
|
||||
|
||||
int biased_e = static_cast<int>((u.u64 & kDpExponentMask) >> kDpSignificandSize);
|
||||
int biased_e = static_cast<int>((u.u64 & kDpExponentMask) >> kDpSignificandSize);
|
||||
uint64_t significand = (u.u64 & kDpSignificandMask);
|
||||
if (biased_e != 0) {
|
||||
if(biased_e != 0)
|
||||
{
|
||||
f = significand + kDpHiddenBit;
|
||||
e = biased_e - kDpExponentBias;
|
||||
}
|
||||
else {
|
||||
else
|
||||
{
|
||||
f = significand;
|
||||
e = kDpMinExponent + 1;
|
||||
}
|
||||
}
|
||||
|
||||
DiyFp operator-(const DiyFp& rhs) const {
|
||||
return DiyFp(f - rhs.f, e);
|
||||
}
|
||||
DiyFp operator-(const DiyFp& rhs) const { return DiyFp(f - rhs.f, e); }
|
||||
|
||||
DiyFp operator*(const DiyFp& rhs) const {
|
||||
DiyFp operator*(const DiyFp& rhs) const
|
||||
{
|
||||
#if defined(_MSC_VER) && defined(_M_AMD64)
|
||||
uint64_t h;
|
||||
uint64_t l = _umul128(f, rhs.f, &h);
|
||||
if (l & (uint64_t(1) << 63)) // rounding
|
||||
if(l & (uint64_t(1) << 63)) // rounding
|
||||
h++;
|
||||
return DiyFp(h, e + rhs.e + 64);
|
||||
#elif defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6)) && defined(__x86_64__)
|
||||
#elif defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6)) && \
|
||||
defined(__x86_64__)
|
||||
__extension__ typedef unsigned __int128 uint128;
|
||||
uint128 p = static_cast<uint128>(f) * static_cast<uint128>(rhs.f);
|
||||
uint128 p = static_cast<uint128>(f) * static_cast<uint128>(rhs.f);
|
||||
uint64_t h = static_cast<uint64_t>(p >> 64);
|
||||
uint64_t l = static_cast<uint64_t>(p);
|
||||
if (l & (uint64_t(1) << 63)) // rounding
|
||||
if(l & (uint64_t(1) << 63)) // rounding
|
||||
h++;
|
||||
return DiyFp(h, e + rhs.e + 64);
|
||||
#else
|
||||
const uint64_t M32 = 0xFFFFFFFF;
|
||||
const uint64_t a = f >> 32;
|
||||
const uint64_t b = f & M32;
|
||||
const uint64_t c = rhs.f >> 32;
|
||||
const uint64_t d = rhs.f & M32;
|
||||
const uint64_t ac = a * c;
|
||||
const uint64_t bc = b * c;
|
||||
const uint64_t ad = a * d;
|
||||
const uint64_t bd = b * d;
|
||||
uint64_t tmp = (bd >> 32) + (ad & M32) + (bc & M32);
|
||||
tmp += 1U << 31; /// mult_round
|
||||
const uint64_t a = f >> 32;
|
||||
const uint64_t b = f & M32;
|
||||
const uint64_t c = rhs.f >> 32;
|
||||
const uint64_t d = rhs.f & M32;
|
||||
const uint64_t ac = a * c;
|
||||
const uint64_t bc = b * c;
|
||||
const uint64_t ad = a * d;
|
||||
const uint64_t bd = b * d;
|
||||
uint64_t tmp = (bd >> 32) + (ad & M32) + (bc & M32);
|
||||
tmp += 1U << 31; /// mult_round
|
||||
return DiyFp(ac + (ad >> 32) + (bc >> 32) + (tmp >> 32), e + rhs.e + 64);
|
||||
#endif
|
||||
}
|
||||
|
||||
DiyFp Normalize() const {
|
||||
DiyFp Normalize() const
|
||||
{
|
||||
int s = static_cast<int>(clzll(f));
|
||||
return DiyFp(f << s, e - s);
|
||||
}
|
||||
|
||||
DiyFp NormalizeBoundary() const {
|
||||
DiyFp NormalizeBoundary() const
|
||||
{
|
||||
DiyFp res = *this;
|
||||
while (!(res.f & (kDpHiddenBit << 1))) {
|
||||
while(!(res.f & (kDpHiddenBit << 1)))
|
||||
{
|
||||
res.f <<= 1;
|
||||
res.e--;
|
||||
}
|
||||
@@ -119,50 +127,57 @@ struct DiyFp {
|
||||
return res;
|
||||
}
|
||||
|
||||
void NormalizedBoundaries(DiyFp* minus, DiyFp* plus) const {
|
||||
void NormalizedBoundaries(DiyFp* minus, DiyFp* plus) const
|
||||
{
|
||||
DiyFp pl = DiyFp((f << 1) + 1, e - 1).NormalizeBoundary();
|
||||
DiyFp mi = (f == kDpHiddenBit) ? DiyFp((f << 2) - 1, e - 2) : DiyFp((f << 1) - 1, e - 1);
|
||||
mi.f <<= mi.e - pl.e;
|
||||
mi.e = pl.e;
|
||||
*plus = pl;
|
||||
mi.e = pl.e;
|
||||
*plus = pl;
|
||||
*minus = mi;
|
||||
}
|
||||
|
||||
double ToDouble() const {
|
||||
union {
|
||||
double ToDouble() const
|
||||
{
|
||||
union
|
||||
{
|
||||
double d;
|
||||
uint64_t u64;
|
||||
}u;
|
||||
} u;
|
||||
RAPIDJSON_ASSERT(f <= kDpHiddenBit + kDpSignificandMask);
|
||||
if (e < kDpDenormalExponent) {
|
||||
if(e < kDpDenormalExponent)
|
||||
{
|
||||
// Underflow.
|
||||
return 0.0;
|
||||
}
|
||||
if (e >= kDpMaxExponent) {
|
||||
if(e >= kDpMaxExponent)
|
||||
{
|
||||
// Overflow.
|
||||
return std::numeric_limits<double>::infinity();
|
||||
}
|
||||
const uint64_t be = (e == kDpDenormalExponent && (f & kDpHiddenBit) == 0) ? 0 :
|
||||
static_cast<uint64_t>(e + kDpExponentBias);
|
||||
u.u64 = (f & kDpSignificandMask) | (be << kDpSignificandSize);
|
||||
const uint64_t be = (e == kDpDenormalExponent && (f & kDpHiddenBit) == 0)
|
||||
? 0
|
||||
: static_cast<uint64_t>(e + kDpExponentBias);
|
||||
u.u64 = (f & kDpSignificandMask) | (be << kDpSignificandSize);
|
||||
return u.d;
|
||||
}
|
||||
|
||||
static const int kDiySignificandSize = 64;
|
||||
static const int kDpSignificandSize = 52;
|
||||
static const int kDpExponentBias = 0x3FF + kDpSignificandSize;
|
||||
static const int kDpMaxExponent = 0x7FF - kDpExponentBias;
|
||||
static const int kDpMinExponent = -kDpExponentBias;
|
||||
static const int kDpDenormalExponent = -kDpExponentBias + 1;
|
||||
static const uint64_t kDpExponentMask = RAPIDJSON_UINT64_C2(0x7FF00000, 0x00000000);
|
||||
static const int kDiySignificandSize = 64;
|
||||
static const int kDpSignificandSize = 52;
|
||||
static const int kDpExponentBias = 0x3FF + kDpSignificandSize;
|
||||
static const int kDpMaxExponent = 0x7FF - kDpExponentBias;
|
||||
static const int kDpMinExponent = -kDpExponentBias;
|
||||
static const int kDpDenormalExponent = -kDpExponentBias + 1;
|
||||
static const uint64_t kDpExponentMask = RAPIDJSON_UINT64_C2(0x7FF00000, 0x00000000);
|
||||
static const uint64_t kDpSignificandMask = RAPIDJSON_UINT64_C2(0x000FFFFF, 0xFFFFFFFF);
|
||||
static const uint64_t kDpHiddenBit = RAPIDJSON_UINT64_C2(0x00100000, 0x00000000);
|
||||
static const uint64_t kDpHiddenBit = RAPIDJSON_UINT64_C2(0x00100000, 0x00000000);
|
||||
|
||||
uint64_t f;
|
||||
int e;
|
||||
};
|
||||
|
||||
inline DiyFp GetCachedPowerByIndex(size_t index) {
|
||||
inline DiyFp GetCachedPowerByIndex(size_t index)
|
||||
{
|
||||
// 10^-348, 10^-340, ..., 10^340
|
||||
static const uint64_t kCachedPowers_F[] = {
|
||||
RAPIDJSON_UINT64_C2(0xfa8fd5a0, 0x081c0288), RAPIDJSON_UINT64_C2(0xbaaee17f, 0xa23ebf76),
|
||||
@@ -208,41 +223,40 @@ inline DiyFp GetCachedPowerByIndex(size_t index) {
|
||||
RAPIDJSON_UINT64_C2(0x80444b5e, 0x7aa7cf85), RAPIDJSON_UINT64_C2(0xbf21e440, 0x03acdd2d),
|
||||
RAPIDJSON_UINT64_C2(0x8e679c2f, 0x5e44ff8f), RAPIDJSON_UINT64_C2(0xd433179d, 0x9c8cb841),
|
||||
RAPIDJSON_UINT64_C2(0x9e19db92, 0xb4e31ba9), RAPIDJSON_UINT64_C2(0xeb96bf6e, 0xbadf77d9),
|
||||
RAPIDJSON_UINT64_C2(0xaf87023b, 0x9bf0ee6b)
|
||||
};
|
||||
RAPIDJSON_UINT64_C2(0xaf87023b, 0x9bf0ee6b)};
|
||||
static const int16_t kCachedPowers_E[] = {
|
||||
-1220, -1193, -1166, -1140, -1113, -1087, -1060, -1034, -1007, -980,
|
||||
-954, -927, -901, -874, -847, -821, -794, -768, -741, -715,
|
||||
-688, -661, -635, -608, -582, -555, -529, -502, -475, -449,
|
||||
-422, -396, -369, -343, -316, -289, -263, -236, -210, -183,
|
||||
-157, -130, -103, -77, -50, -24, 3, 30, 56, 83,
|
||||
109, 136, 162, 189, 216, 242, 269, 295, 322, 348,
|
||||
375, 402, 428, 455, 481, 508, 534, 561, 588, 614,
|
||||
641, 667, 694, 720, 747, 774, 800, 827, 853, 880,
|
||||
907, 933, 960, 986, 1013, 1039, 1066
|
||||
};
|
||||
-1220, -1193, -1166, -1140, -1113, -1087, -1060, -1034, -1007, -980, -954, -927, -901,
|
||||
-874, -847, -821, -794, -768, -741, -715, -688, -661, -635, -608, -582, -555,
|
||||
-529, -502, -475, -449, -422, -396, -369, -343, -316, -289, -263, -236, -210,
|
||||
-183, -157, -130, -103, -77, -50, -24, 3, 30, 56, 83, 109, 136,
|
||||
162, 189, 216, 242, 269, 295, 322, 348, 375, 402, 428, 455, 481,
|
||||
508, 534, 561, 588, 614, 641, 667, 694, 720, 747, 774, 800, 827,
|
||||
853, 880, 907, 933, 960, 986, 1013, 1039, 1066};
|
||||
RAPIDJSON_ASSERT(index < 87);
|
||||
return DiyFp(kCachedPowers_F[index], kCachedPowers_E[index]);
|
||||
}
|
||||
|
||||
inline DiyFp GetCachedPower(int e, int* K) {
|
||||
inline DiyFp GetCachedPower(int e, int* K)
|
||||
{
|
||||
|
||||
//int k = static_cast<int>(ceil((-61 - e) * 0.30102999566398114)) + 374;
|
||||
double dk = (-61 - e) * 0.30102999566398114 + 347; // dk must be positive, so can do ceiling in positive
|
||||
// int k = static_cast<int>(ceil((-61 - e) * 0.30102999566398114)) + 374;
|
||||
double dk =
|
||||
(-61 - e) * 0.30102999566398114 + 347; // dk must be positive, so can do ceiling in positive
|
||||
int k = static_cast<int>(dk);
|
||||
if (dk - k > 0.0)
|
||||
if(dk - k > 0.0)
|
||||
k++;
|
||||
|
||||
unsigned index = static_cast<unsigned>((k >> 3) + 1);
|
||||
*K = -(-348 + static_cast<int>(index << 3)); // decimal exponent no need lookup table
|
||||
*K = -(-348 + static_cast<int>(index << 3)); // decimal exponent no need lookup table
|
||||
|
||||
return GetCachedPowerByIndex(index);
|
||||
}
|
||||
|
||||
inline DiyFp GetCachedPower10(int exp, int *outExp) {
|
||||
inline DiyFp GetCachedPower10(int exp, int* outExp)
|
||||
{
|
||||
RAPIDJSON_ASSERT(exp >= -348);
|
||||
unsigned index = static_cast<unsigned>(exp + 348) / 8u;
|
||||
*outExp = -348 + static_cast<int>(index) * 8;
|
||||
*outExp = -348 + static_cast<int>(index) * 8;
|
||||
return GetCachedPowerByIndex(index);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// Tencent is pleased to support the open source community by making RapidJSON available.
|
||||
//
|
||||
//
|
||||
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
|
||||
//
|
||||
// Licensed under the MIT License (the "License"); you may not use this file except
|
||||
@@ -7,9 +7,9 @@
|
||||
//
|
||||
// http://opensource.org/licenses/MIT
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations under the License.
|
||||
|
||||
// This is a C++ header-only implementation of Grisu2 algorithm from the publication:
|
||||
@@ -29,66 +29,126 @@ namespace internal {
|
||||
#ifdef __GNUC__
|
||||
RAPIDJSON_DIAG_PUSH
|
||||
RAPIDJSON_DIAG_OFF(effc++)
|
||||
RAPIDJSON_DIAG_OFF(array-bounds) // some gcc versions generate wrong warnings https://gcc.gnu.org/bugzilla/show_bug.cgi?id=59124
|
||||
RAPIDJSON_DIAG_OFF(array - bounds) // some gcc versions generate wrong warnings
|
||||
// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=59124
|
||||
#endif
|
||||
|
||||
inline void GrisuRound(char* buffer, int len, uint64_t delta, uint64_t rest, uint64_t ten_kappa, uint64_t wp_w) {
|
||||
while (rest < wp_w && delta - rest >= ten_kappa &&
|
||||
(rest + ten_kappa < wp_w || /// closer
|
||||
wp_w - rest > rest + ten_kappa - wp_w)) {
|
||||
inline void
|
||||
GrisuRound(char* buffer, int len, uint64_t delta, uint64_t rest, uint64_t ten_kappa, uint64_t wp_w)
|
||||
{
|
||||
while(rest < wp_w && delta - rest >= ten_kappa &&
|
||||
(rest + ten_kappa < wp_w || /// closer
|
||||
wp_w - rest > rest + ten_kappa - wp_w))
|
||||
{
|
||||
buffer[len - 1]--;
|
||||
rest += ten_kappa;
|
||||
}
|
||||
}
|
||||
|
||||
inline int CountDecimalDigit32(uint32_t n) {
|
||||
inline int CountDecimalDigit32(uint32_t n)
|
||||
{
|
||||
// Simple pure C++ implementation was faster than __builtin_clz version in this situation.
|
||||
if (n < 10) return 1;
|
||||
if (n < 100) return 2;
|
||||
if (n < 1000) return 3;
|
||||
if (n < 10000) return 4;
|
||||
if (n < 100000) return 5;
|
||||
if (n < 1000000) return 6;
|
||||
if (n < 10000000) return 7;
|
||||
if (n < 100000000) return 8;
|
||||
if(n < 10)
|
||||
return 1;
|
||||
if(n < 100)
|
||||
return 2;
|
||||
if(n < 1000)
|
||||
return 3;
|
||||
if(n < 10000)
|
||||
return 4;
|
||||
if(n < 100000)
|
||||
return 5;
|
||||
if(n < 1000000)
|
||||
return 6;
|
||||
if(n < 10000000)
|
||||
return 7;
|
||||
if(n < 100000000)
|
||||
return 8;
|
||||
// Will not reach 10 digits in DigitGen()
|
||||
//if (n < 1000000000) return 9;
|
||||
//return 10;
|
||||
// if (n < 1000000000) return 9;
|
||||
// return 10;
|
||||
return 9;
|
||||
}
|
||||
|
||||
inline void DigitGen(const DiyFp& W, const DiyFp& Mp, uint64_t delta, char* buffer, int* len, int* K) {
|
||||
static const uint64_t kPow10[] = { 1ULL, 10ULL, 100ULL, 1000ULL, 10000ULL, 100000ULL, 1000000ULL, 10000000ULL, 100000000ULL,
|
||||
1000000000ULL, 10000000000ULL, 100000000000ULL, 1000000000000ULL,
|
||||
10000000000000ULL, 100000000000000ULL, 1000000000000000ULL,
|
||||
10000000000000000ULL, 100000000000000000ULL, 1000000000000000000ULL,
|
||||
10000000000000000000ULL };
|
||||
inline void
|
||||
DigitGen(const DiyFp& W, const DiyFp& Mp, uint64_t delta, char* buffer, int* len, int* K)
|
||||
{
|
||||
static const uint64_t kPow10[] = {1ULL,
|
||||
10ULL,
|
||||
100ULL,
|
||||
1000ULL,
|
||||
10000ULL,
|
||||
100000ULL,
|
||||
1000000ULL,
|
||||
10000000ULL,
|
||||
100000000ULL,
|
||||
1000000000ULL,
|
||||
10000000000ULL,
|
||||
100000000000ULL,
|
||||
1000000000000ULL,
|
||||
10000000000000ULL,
|
||||
100000000000000ULL,
|
||||
1000000000000000ULL,
|
||||
10000000000000000ULL,
|
||||
100000000000000000ULL,
|
||||
1000000000000000000ULL,
|
||||
10000000000000000000ULL};
|
||||
const DiyFp one(uint64_t(1) << -Mp.e, Mp.e);
|
||||
const DiyFp wp_w = Mp - W;
|
||||
uint32_t p1 = static_cast<uint32_t>(Mp.f >> -one.e);
|
||||
uint64_t p2 = Mp.f & (one.f - 1);
|
||||
int kappa = CountDecimalDigit32(p1); // kappa in [0, 9]
|
||||
*len = 0;
|
||||
uint32_t p1 = static_cast<uint32_t>(Mp.f >> -one.e);
|
||||
uint64_t p2 = Mp.f & (one.f - 1);
|
||||
int kappa = CountDecimalDigit32(p1); // kappa in [0, 9]
|
||||
*len = 0;
|
||||
|
||||
while (kappa > 0) {
|
||||
while(kappa > 0)
|
||||
{
|
||||
uint32_t d = 0;
|
||||
switch (kappa) {
|
||||
case 9: d = p1 / 100000000; p1 %= 100000000; break;
|
||||
case 8: d = p1 / 10000000; p1 %= 10000000; break;
|
||||
case 7: d = p1 / 1000000; p1 %= 1000000; break;
|
||||
case 6: d = p1 / 100000; p1 %= 100000; break;
|
||||
case 5: d = p1 / 10000; p1 %= 10000; break;
|
||||
case 4: d = p1 / 1000; p1 %= 1000; break;
|
||||
case 3: d = p1 / 100; p1 %= 100; break;
|
||||
case 2: d = p1 / 10; p1 %= 10; break;
|
||||
case 1: d = p1; p1 = 0; break;
|
||||
default:;
|
||||
switch(kappa)
|
||||
{
|
||||
case 9:
|
||||
d = p1 / 100000000;
|
||||
p1 %= 100000000;
|
||||
break;
|
||||
case 8:
|
||||
d = p1 / 10000000;
|
||||
p1 %= 10000000;
|
||||
break;
|
||||
case 7:
|
||||
d = p1 / 1000000;
|
||||
p1 %= 1000000;
|
||||
break;
|
||||
case 6:
|
||||
d = p1 / 100000;
|
||||
p1 %= 100000;
|
||||
break;
|
||||
case 5:
|
||||
d = p1 / 10000;
|
||||
p1 %= 10000;
|
||||
break;
|
||||
case 4:
|
||||
d = p1 / 1000;
|
||||
p1 %= 1000;
|
||||
break;
|
||||
case 3:
|
||||
d = p1 / 100;
|
||||
p1 %= 100;
|
||||
break;
|
||||
case 2:
|
||||
d = p1 / 10;
|
||||
p1 %= 10;
|
||||
break;
|
||||
case 1:
|
||||
d = p1;
|
||||
p1 = 0;
|
||||
break;
|
||||
default:;
|
||||
}
|
||||
if (d || *len)
|
||||
if(d || *len)
|
||||
buffer[(*len)++] = static_cast<char>('0' + static_cast<char>(d));
|
||||
kappa--;
|
||||
uint64_t tmp = (static_cast<uint64_t>(p1) << -one.e) + p2;
|
||||
if (tmp <= delta) {
|
||||
if(tmp <= delta)
|
||||
{
|
||||
*K += kappa;
|
||||
GrisuRound(buffer, *len, delta, tmp, kPow10[kappa] << -one.e, wp_w.f);
|
||||
return;
|
||||
@@ -96,15 +156,17 @@ inline void DigitGen(const DiyFp& W, const DiyFp& Mp, uint64_t delta, char* buff
|
||||
}
|
||||
|
||||
// kappa = 0
|
||||
for (;;) {
|
||||
for(;;)
|
||||
{
|
||||
p2 *= 10;
|
||||
delta *= 10;
|
||||
char d = static_cast<char>(p2 >> -one.e);
|
||||
if (d || *len)
|
||||
if(d || *len)
|
||||
buffer[(*len)++] = static_cast<char>('0' + d);
|
||||
p2 &= one.f - 1;
|
||||
kappa--;
|
||||
if (p2 < delta) {
|
||||
if(p2 < delta)
|
||||
{
|
||||
*K += kappa;
|
||||
int index = -kappa;
|
||||
GrisuRound(buffer, *len, delta, p2, one.f, wp_w.f * (index < 20 ? kPow10[index] : 0));
|
||||
@@ -113,37 +175,42 @@ inline void DigitGen(const DiyFp& W, const DiyFp& Mp, uint64_t delta, char* buff
|
||||
}
|
||||
}
|
||||
|
||||
inline void Grisu2(double value, char* buffer, int* length, int* K) {
|
||||
inline void Grisu2(double value, char* buffer, int* length, int* K)
|
||||
{
|
||||
const DiyFp v(value);
|
||||
DiyFp w_m, w_p;
|
||||
v.NormalizedBoundaries(&w_m, &w_p);
|
||||
|
||||
const DiyFp c_mk = GetCachedPower(w_p.e, K);
|
||||
const DiyFp W = v.Normalize() * c_mk;
|
||||
DiyFp Wp = w_p * c_mk;
|
||||
DiyFp Wm = w_m * c_mk;
|
||||
const DiyFp W = v.Normalize() * c_mk;
|
||||
DiyFp Wp = w_p * c_mk;
|
||||
DiyFp Wm = w_m * c_mk;
|
||||
Wm.f++;
|
||||
Wp.f--;
|
||||
DigitGen(W, Wp, Wp.f - Wm.f, buffer, length, K);
|
||||
}
|
||||
|
||||
inline char* WriteExponent(int K, char* buffer) {
|
||||
if (K < 0) {
|
||||
inline char* WriteExponent(int K, char* buffer)
|
||||
{
|
||||
if(K < 0)
|
||||
{
|
||||
*buffer++ = '-';
|
||||
K = -K;
|
||||
K = -K;
|
||||
}
|
||||
|
||||
if (K >= 100) {
|
||||
if(K >= 100)
|
||||
{
|
||||
*buffer++ = static_cast<char>('0' + static_cast<char>(K / 100));
|
||||
K %= 100;
|
||||
const char* d = GetDigitsLut() + K * 2;
|
||||
*buffer++ = d[0];
|
||||
*buffer++ = d[1];
|
||||
*buffer++ = d[0];
|
||||
*buffer++ = d[1];
|
||||
}
|
||||
else if (K >= 10) {
|
||||
else if(K >= 10)
|
||||
{
|
||||
const char* d = GetDigitsLut() + K * 2;
|
||||
*buffer++ = d[0];
|
||||
*buffer++ = d[1];
|
||||
*buffer++ = d[0];
|
||||
*buffer++ = d[1];
|
||||
}
|
||||
else
|
||||
*buffer++ = static_cast<char>('0' + static_cast<char>(K));
|
||||
@@ -151,87 +218,100 @@ inline char* WriteExponent(int K, char* buffer) {
|
||||
return buffer;
|
||||
}
|
||||
|
||||
inline char* Prettify(char* buffer, int length, int k, int maxDecimalPlaces) {
|
||||
const int kk = length + k; // 10^(kk-1) <= v < 10^kk
|
||||
inline char* Prettify(char* buffer, int length, int k, int maxDecimalPlaces)
|
||||
{
|
||||
const int kk = length + k; // 10^(kk-1) <= v < 10^kk
|
||||
|
||||
if (0 <= k && kk <= 21) {
|
||||
if(0 <= k && kk <= 21)
|
||||
{
|
||||
// 1234e7 -> 12340000000
|
||||
for (int i = length; i < kk; i++)
|
||||
for(int i = length; i < kk; i++)
|
||||
buffer[i] = '0';
|
||||
buffer[kk] = '.';
|
||||
buffer[kk] = '.';
|
||||
buffer[kk + 1] = '0';
|
||||
return &buffer[kk + 2];
|
||||
}
|
||||
else if (0 < kk && kk <= 21) {
|
||||
else if(0 < kk && kk <= 21)
|
||||
{
|
||||
// 1234e-2 -> 12.34
|
||||
std::memmove(&buffer[kk + 1], &buffer[kk], static_cast<size_t>(length - kk));
|
||||
buffer[kk] = '.';
|
||||
if (0 > k + maxDecimalPlaces) {
|
||||
if(0 > k + maxDecimalPlaces)
|
||||
{
|
||||
// When maxDecimalPlaces = 2, 1.2345 -> 1.23, 1.102 -> 1.1
|
||||
// Remove extra trailing zeros (at least one) after truncation.
|
||||
for (int i = kk + maxDecimalPlaces; i > kk + 1; i--)
|
||||
if (buffer[i] != '0')
|
||||
for(int i = kk + maxDecimalPlaces; i > kk + 1; i--)
|
||||
if(buffer[i] != '0')
|
||||
return &buffer[i + 1];
|
||||
return &buffer[kk + 2]; // Reserve one zero
|
||||
}
|
||||
else
|
||||
return &buffer[length + 1];
|
||||
}
|
||||
else if (-6 < kk && kk <= 0) {
|
||||
else if(-6 < kk && kk <= 0)
|
||||
{
|
||||
// 1234e-6 -> 0.001234
|
||||
const int offset = 2 - kk;
|
||||
std::memmove(&buffer[offset], &buffer[0], static_cast<size_t>(length));
|
||||
buffer[0] = '0';
|
||||
buffer[1] = '.';
|
||||
for (int i = 2; i < offset; i++)
|
||||
for(int i = 2; i < offset; i++)
|
||||
buffer[i] = '0';
|
||||
if (length - kk > maxDecimalPlaces) {
|
||||
if(length - kk > maxDecimalPlaces)
|
||||
{
|
||||
// When maxDecimalPlaces = 2, 0.123 -> 0.12, 0.102 -> 0.1
|
||||
// Remove extra trailing zeros (at least one) after truncation.
|
||||
for (int i = maxDecimalPlaces + 1; i > 2; i--)
|
||||
if (buffer[i] != '0')
|
||||
for(int i = maxDecimalPlaces + 1; i > 2; i--)
|
||||
if(buffer[i] != '0')
|
||||
return &buffer[i + 1];
|
||||
return &buffer[3]; // Reserve one zero
|
||||
}
|
||||
else
|
||||
return &buffer[length + offset];
|
||||
}
|
||||
else if (kk < -maxDecimalPlaces) {
|
||||
else if(kk < -maxDecimalPlaces)
|
||||
{
|
||||
// Truncate to zero
|
||||
buffer[0] = '0';
|
||||
buffer[1] = '.';
|
||||
buffer[2] = '0';
|
||||
return &buffer[3];
|
||||
}
|
||||
else if (length == 1) {
|
||||
else if(length == 1)
|
||||
{
|
||||
// 1e30
|
||||
buffer[1] = 'e';
|
||||
return WriteExponent(kk - 1, &buffer[2]);
|
||||
}
|
||||
else {
|
||||
else
|
||||
{
|
||||
// 1234e30 -> 1.234e33
|
||||
std::memmove(&buffer[2], &buffer[1], static_cast<size_t>(length - 1));
|
||||
buffer[1] = '.';
|
||||
buffer[1] = '.';
|
||||
buffer[length + 1] = 'e';
|
||||
return WriteExponent(kk - 1, &buffer[0 + length + 2]);
|
||||
}
|
||||
}
|
||||
|
||||
inline char* dtoa(double value, char* buffer, int maxDecimalPlaces = 324) {
|
||||
inline char* dtoa(double value, char* buffer, int maxDecimalPlaces = 324)
|
||||
{
|
||||
RAPIDJSON_ASSERT(maxDecimalPlaces >= 1);
|
||||
Double d(value);
|
||||
if (d.IsZero()) {
|
||||
if (d.Sign())
|
||||
*buffer++ = '-'; // -0.0, Issue #289
|
||||
if(d.IsZero())
|
||||
{
|
||||
if(d.Sign())
|
||||
*buffer++ = '-'; // -0.0, Issue #289
|
||||
buffer[0] = '0';
|
||||
buffer[1] = '.';
|
||||
buffer[2] = '0';
|
||||
return &buffer[3];
|
||||
}
|
||||
else {
|
||||
if (value < 0) {
|
||||
else
|
||||
{
|
||||
if(value < 0)
|
||||
{
|
||||
*buffer++ = '-';
|
||||
value = -value;
|
||||
value = -value;
|
||||
}
|
||||
int length, K;
|
||||
Grisu2(value, buffer, &length, &K);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// Tencent is pleased to support the open source community by making RapidJSON available.
|
||||
//
|
||||
//
|
||||
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
|
||||
//
|
||||
// Licensed under the MIT License (the "License"); you may not use this file except
|
||||
@@ -7,9 +7,9 @@
|
||||
//
|
||||
// http://opensource.org/licenses/MIT
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations under the License.
|
||||
|
||||
#ifndef RAPIDJSON_IEEE754_
|
||||
@@ -20,8 +20,9 @@
|
||||
RAPIDJSON_NAMESPACE_BEGIN
|
||||
namespace internal {
|
||||
|
||||
class Double {
|
||||
public:
|
||||
class Double
|
||||
{
|
||||
public:
|
||||
Double() {}
|
||||
Double(double d) : d_(d) {}
|
||||
Double(uint64_t u) : u_(u) {}
|
||||
@@ -29,14 +30,18 @@ public:
|
||||
double Value() const { return d_; }
|
||||
uint64_t Uint64Value() const { return u_; }
|
||||
|
||||
double NextPositiveDouble() const {
|
||||
double NextPositiveDouble() const
|
||||
{
|
||||
RAPIDJSON_ASSERT(!Sign());
|
||||
return Double(u_ + 1).Value();
|
||||
}
|
||||
|
||||
bool Sign() const { return (u_ & kSignMask) != 0; }
|
||||
uint64_t Significand() const { return u_ & kSignificandMask; }
|
||||
int Exponent() const { return static_cast<int>(((u_ & kExponentMask) >> kSignificandSize) - kExponentBias); }
|
||||
int Exponent() const
|
||||
{
|
||||
return static_cast<int>(((u_ & kExponentMask) >> kSignificandSize) - kExponentBias);
|
||||
}
|
||||
|
||||
bool IsNan() const { return (u_ & kExponentMask) == kExponentMask && Significand() != 0; }
|
||||
bool IsInf() const { return (u_ & kExponentMask) == kExponentMask && Significand() == 0; }
|
||||
@@ -44,29 +49,37 @@ public:
|
||||
bool IsNormal() const { return (u_ & kExponentMask) != 0 || Significand() == 0; }
|
||||
bool IsZero() const { return (u_ & (kExponentMask | kSignificandMask)) == 0; }
|
||||
|
||||
uint64_t IntegerSignificand() const { return IsNormal() ? Significand() | kHiddenBit : Significand(); }
|
||||
int IntegerExponent() const { return (IsNormal() ? Exponent() : kDenormalExponent) - kSignificandSize; }
|
||||
uint64_t IntegerSignificand() const
|
||||
{
|
||||
return IsNormal() ? Significand() | kHiddenBit : Significand();
|
||||
}
|
||||
int IntegerExponent() const
|
||||
{
|
||||
return (IsNormal() ? Exponent() : kDenormalExponent) - kSignificandSize;
|
||||
}
|
||||
uint64_t ToBias() const { return (u_ & kSignMask) ? ~u_ + 1 : u_ | kSignMask; }
|
||||
|
||||
static int EffectiveSignificandSize(int order) {
|
||||
if (order >= -1021)
|
||||
static int EffectiveSignificandSize(int order)
|
||||
{
|
||||
if(order >= -1021)
|
||||
return 53;
|
||||
else if (order <= -1074)
|
||||
else if(order <= -1074)
|
||||
return 0;
|
||||
else
|
||||
return order + 1074;
|
||||
}
|
||||
|
||||
private:
|
||||
static const int kSignificandSize = 52;
|
||||
static const int kExponentBias = 0x3FF;
|
||||
static const int kDenormalExponent = 1 - kExponentBias;
|
||||
static const uint64_t kSignMask = RAPIDJSON_UINT64_C2(0x80000000, 0x00000000);
|
||||
static const uint64_t kExponentMask = RAPIDJSON_UINT64_C2(0x7FF00000, 0x00000000);
|
||||
private:
|
||||
static const int kSignificandSize = 52;
|
||||
static const int kExponentBias = 0x3FF;
|
||||
static const int kDenormalExponent = 1 - kExponentBias;
|
||||
static const uint64_t kSignMask = RAPIDJSON_UINT64_C2(0x80000000, 0x00000000);
|
||||
static const uint64_t kExponentMask = RAPIDJSON_UINT64_C2(0x7FF00000, 0x00000000);
|
||||
static const uint64_t kSignificandMask = RAPIDJSON_UINT64_C2(0x000FFFFF, 0xFFFFFFFF);
|
||||
static const uint64_t kHiddenBit = RAPIDJSON_UINT64_C2(0x00100000, 0x00000000);
|
||||
static const uint64_t kHiddenBit = RAPIDJSON_UINT64_C2(0x00100000, 0x00000000);
|
||||
|
||||
union {
|
||||
union
|
||||
{
|
||||
double d_;
|
||||
uint64_t u_;
|
||||
};
|
||||
|
||||
@@ -20,40 +20,45 @@
|
||||
RAPIDJSON_NAMESPACE_BEGIN
|
||||
namespace internal {
|
||||
|
||||
inline const char* GetDigitsLut() {
|
||||
inline const char* GetDigitsLut()
|
||||
{
|
||||
static const char cDigitsLut[200] = {
|
||||
'0','0','0','1','0','2','0','3','0','4','0','5','0','6','0','7','0','8','0','9',
|
||||
'1','0','1','1','1','2','1','3','1','4','1','5','1','6','1','7','1','8','1','9',
|
||||
'2','0','2','1','2','2','2','3','2','4','2','5','2','6','2','7','2','8','2','9',
|
||||
'3','0','3','1','3','2','3','3','3','4','3','5','3','6','3','7','3','8','3','9',
|
||||
'4','0','4','1','4','2','4','3','4','4','4','5','4','6','4','7','4','8','4','9',
|
||||
'5','0','5','1','5','2','5','3','5','4','5','5','5','6','5','7','5','8','5','9',
|
||||
'6','0','6','1','6','2','6','3','6','4','6','5','6','6','6','7','6','8','6','9',
|
||||
'7','0','7','1','7','2','7','3','7','4','7','5','7','6','7','7','7','8','7','9',
|
||||
'8','0','8','1','8','2','8','3','8','4','8','5','8','6','8','7','8','8','8','9',
|
||||
'9','0','9','1','9','2','9','3','9','4','9','5','9','6','9','7','9','8','9','9'
|
||||
};
|
||||
'0', '0', '0', '1', '0', '2', '0', '3', '0', '4', '0', '5', '0', '6', '0', '7', '0',
|
||||
'8', '0', '9', '1', '0', '1', '1', '1', '2', '1', '3', '1', '4', '1', '5', '1', '6',
|
||||
'1', '7', '1', '8', '1', '9', '2', '0', '2', '1', '2', '2', '2', '3', '2', '4', '2',
|
||||
'5', '2', '6', '2', '7', '2', '8', '2', '9', '3', '0', '3', '1', '3', '2', '3', '3',
|
||||
'3', '4', '3', '5', '3', '6', '3', '7', '3', '8', '3', '9', '4', '0', '4', '1', '4',
|
||||
'2', '4', '3', '4', '4', '4', '5', '4', '6', '4', '7', '4', '8', '4', '9', '5', '0',
|
||||
'5', '1', '5', '2', '5', '3', '5', '4', '5', '5', '5', '6', '5', '7', '5', '8', '5',
|
||||
'9', '6', '0', '6', '1', '6', '2', '6', '3', '6', '4', '6', '5', '6', '6', '6', '7',
|
||||
'6', '8', '6', '9', '7', '0', '7', '1', '7', '2', '7', '3', '7', '4', '7', '5', '7',
|
||||
'6', '7', '7', '7', '8', '7', '9', '8', '0', '8', '1', '8', '2', '8', '3', '8', '4',
|
||||
'8', '5', '8', '6', '8', '7', '8', '8', '8', '9', '9', '0', '9', '1', '9', '2', '9',
|
||||
'3', '9', '4', '9', '5', '9', '6', '9', '7', '9', '8', '9', '9'};
|
||||
return cDigitsLut;
|
||||
}
|
||||
|
||||
inline char* u32toa(uint32_t value, char* buffer) {
|
||||
inline char* u32toa(uint32_t value, char* buffer)
|
||||
{
|
||||
RAPIDJSON_ASSERT(buffer != 0);
|
||||
|
||||
const char* cDigitsLut = GetDigitsLut();
|
||||
|
||||
if (value < 10000) {
|
||||
if(value < 10000)
|
||||
{
|
||||
const uint32_t d1 = (value / 100) << 1;
|
||||
const uint32_t d2 = (value % 100) << 1;
|
||||
|
||||
if (value >= 1000)
|
||||
if(value >= 1000)
|
||||
*buffer++ = cDigitsLut[d1];
|
||||
if (value >= 100)
|
||||
if(value >= 100)
|
||||
*buffer++ = cDigitsLut[d1 + 1];
|
||||
if (value >= 10)
|
||||
if(value >= 10)
|
||||
*buffer++ = cDigitsLut[d2];
|
||||
*buffer++ = cDigitsLut[d2 + 1];
|
||||
}
|
||||
else if (value < 100000000) {
|
||||
else if(value < 100000000)
|
||||
{
|
||||
// value = bbbbcccc
|
||||
const uint32_t b = value / 10000;
|
||||
const uint32_t c = value % 10000;
|
||||
@@ -64,11 +69,11 @@ inline char* u32toa(uint32_t value, char* buffer) {
|
||||
const uint32_t d3 = (c / 100) << 1;
|
||||
const uint32_t d4 = (c % 100) << 1;
|
||||
|
||||
if (value >= 10000000)
|
||||
if(value >= 10000000)
|
||||
*buffer++ = cDigitsLut[d1];
|
||||
if (value >= 1000000)
|
||||
if(value >= 1000000)
|
||||
*buffer++ = cDigitsLut[d1 + 1];
|
||||
if (value >= 100000)
|
||||
if(value >= 100000)
|
||||
*buffer++ = cDigitsLut[d2];
|
||||
*buffer++ = cDigitsLut[d2 + 1];
|
||||
|
||||
@@ -77,16 +82,18 @@ inline char* u32toa(uint32_t value, char* buffer) {
|
||||
*buffer++ = cDigitsLut[d4];
|
||||
*buffer++ = cDigitsLut[d4 + 1];
|
||||
}
|
||||
else {
|
||||
else
|
||||
{
|
||||
// value = aabbbbcccc in decimal
|
||||
|
||||
const uint32_t a = value / 100000000; // 1 to 42
|
||||
value %= 100000000;
|
||||
|
||||
if (a >= 10) {
|
||||
if(a >= 10)
|
||||
{
|
||||
const unsigned i = a << 1;
|
||||
*buffer++ = cDigitsLut[i];
|
||||
*buffer++ = cDigitsLut[i + 1];
|
||||
*buffer++ = cDigitsLut[i];
|
||||
*buffer++ = cDigitsLut[i + 1];
|
||||
}
|
||||
else
|
||||
*buffer++ = static_cast<char>('0' + static_cast<char>(a));
|
||||
@@ -112,45 +119,51 @@ inline char* u32toa(uint32_t value, char* buffer) {
|
||||
return buffer;
|
||||
}
|
||||
|
||||
inline char* i32toa(int32_t value, char* buffer) {
|
||||
inline char* i32toa(int32_t value, char* buffer)
|
||||
{
|
||||
RAPIDJSON_ASSERT(buffer != 0);
|
||||
uint32_t u = static_cast<uint32_t>(value);
|
||||
if (value < 0) {
|
||||
if(value < 0)
|
||||
{
|
||||
*buffer++ = '-';
|
||||
u = ~u + 1;
|
||||
u = ~u + 1;
|
||||
}
|
||||
|
||||
return u32toa(u, buffer);
|
||||
}
|
||||
|
||||
inline char* u64toa(uint64_t value, char* buffer) {
|
||||
inline char* u64toa(uint64_t value, char* buffer)
|
||||
{
|
||||
RAPIDJSON_ASSERT(buffer != 0);
|
||||
const char* cDigitsLut = GetDigitsLut();
|
||||
const uint64_t kTen8 = 100000000;
|
||||
const uint64_t kTen9 = kTen8 * 10;
|
||||
const uint64_t kTen10 = kTen8 * 100;
|
||||
const uint64_t kTen11 = kTen8 * 1000;
|
||||
const uint64_t kTen12 = kTen8 * 10000;
|
||||
const uint64_t kTen13 = kTen8 * 100000;
|
||||
const uint64_t kTen14 = kTen8 * 1000000;
|
||||
const uint64_t kTen15 = kTen8 * 10000000;
|
||||
const uint64_t kTen16 = kTen8 * kTen8;
|
||||
const uint64_t kTen8 = 100000000;
|
||||
const uint64_t kTen9 = kTen8 * 10;
|
||||
const uint64_t kTen10 = kTen8 * 100;
|
||||
const uint64_t kTen11 = kTen8 * 1000;
|
||||
const uint64_t kTen12 = kTen8 * 10000;
|
||||
const uint64_t kTen13 = kTen8 * 100000;
|
||||
const uint64_t kTen14 = kTen8 * 1000000;
|
||||
const uint64_t kTen15 = kTen8 * 10000000;
|
||||
const uint64_t kTen16 = kTen8 * kTen8;
|
||||
|
||||
if (value < kTen8) {
|
||||
if(value < kTen8)
|
||||
{
|
||||
uint32_t v = static_cast<uint32_t>(value);
|
||||
if (v < 10000) {
|
||||
if(v < 10000)
|
||||
{
|
||||
const uint32_t d1 = (v / 100) << 1;
|
||||
const uint32_t d2 = (v % 100) << 1;
|
||||
|
||||
if (v >= 1000)
|
||||
if(v >= 1000)
|
||||
*buffer++ = cDigitsLut[d1];
|
||||
if (v >= 100)
|
||||
if(v >= 100)
|
||||
*buffer++ = cDigitsLut[d1 + 1];
|
||||
if (v >= 10)
|
||||
if(v >= 10)
|
||||
*buffer++ = cDigitsLut[d2];
|
||||
*buffer++ = cDigitsLut[d2 + 1];
|
||||
}
|
||||
else {
|
||||
else
|
||||
{
|
||||
// value = bbbbcccc
|
||||
const uint32_t b = v / 10000;
|
||||
const uint32_t c = v % 10000;
|
||||
@@ -161,11 +174,11 @@ inline char* u64toa(uint64_t value, char* buffer) {
|
||||
const uint32_t d3 = (c / 100) << 1;
|
||||
const uint32_t d4 = (c % 100) << 1;
|
||||
|
||||
if (value >= 10000000)
|
||||
if(value >= 10000000)
|
||||
*buffer++ = cDigitsLut[d1];
|
||||
if (value >= 1000000)
|
||||
if(value >= 1000000)
|
||||
*buffer++ = cDigitsLut[d1 + 1];
|
||||
if (value >= 100000)
|
||||
if(value >= 100000)
|
||||
*buffer++ = cDigitsLut[d2];
|
||||
*buffer++ = cDigitsLut[d2 + 1];
|
||||
|
||||
@@ -175,7 +188,8 @@ inline char* u64toa(uint64_t value, char* buffer) {
|
||||
*buffer++ = cDigitsLut[d4 + 1];
|
||||
}
|
||||
}
|
||||
else if (value < kTen16) {
|
||||
else if(value < kTen16)
|
||||
{
|
||||
const uint32_t v0 = static_cast<uint32_t>(value / kTen8);
|
||||
const uint32_t v1 = static_cast<uint32_t>(value % kTen8);
|
||||
|
||||
@@ -197,19 +211,19 @@ inline char* u64toa(uint64_t value, char* buffer) {
|
||||
const uint32_t d7 = (c1 / 100) << 1;
|
||||
const uint32_t d8 = (c1 % 100) << 1;
|
||||
|
||||
if (value >= kTen15)
|
||||
if(value >= kTen15)
|
||||
*buffer++ = cDigitsLut[d1];
|
||||
if (value >= kTen14)
|
||||
if(value >= kTen14)
|
||||
*buffer++ = cDigitsLut[d1 + 1];
|
||||
if (value >= kTen13)
|
||||
if(value >= kTen13)
|
||||
*buffer++ = cDigitsLut[d2];
|
||||
if (value >= kTen12)
|
||||
if(value >= kTen12)
|
||||
*buffer++ = cDigitsLut[d2 + 1];
|
||||
if (value >= kTen11)
|
||||
if(value >= kTen11)
|
||||
*buffer++ = cDigitsLut[d3];
|
||||
if (value >= kTen10)
|
||||
if(value >= kTen10)
|
||||
*buffer++ = cDigitsLut[d3 + 1];
|
||||
if (value >= kTen9)
|
||||
if(value >= kTen9)
|
||||
*buffer++ = cDigitsLut[d4];
|
||||
|
||||
*buffer++ = cDigitsLut[d4 + 1];
|
||||
@@ -222,31 +236,35 @@ inline char* u64toa(uint64_t value, char* buffer) {
|
||||
*buffer++ = cDigitsLut[d8];
|
||||
*buffer++ = cDigitsLut[d8 + 1];
|
||||
}
|
||||
else {
|
||||
else
|
||||
{
|
||||
const uint32_t a = static_cast<uint32_t>(value / kTen16); // 1 to 1844
|
||||
value %= kTen16;
|
||||
|
||||
if (a < 10)
|
||||
if(a < 10)
|
||||
*buffer++ = static_cast<char>('0' + static_cast<char>(a));
|
||||
else if (a < 100) {
|
||||
else if(a < 100)
|
||||
{
|
||||
const uint32_t i = a << 1;
|
||||
*buffer++ = cDigitsLut[i];
|
||||
*buffer++ = cDigitsLut[i + 1];
|
||||
*buffer++ = cDigitsLut[i];
|
||||
*buffer++ = cDigitsLut[i + 1];
|
||||
}
|
||||
else if (a < 1000) {
|
||||
else if(a < 1000)
|
||||
{
|
||||
*buffer++ = static_cast<char>('0' + static_cast<char>(a / 100));
|
||||
|
||||
const uint32_t i = (a % 100) << 1;
|
||||
*buffer++ = cDigitsLut[i];
|
||||
*buffer++ = cDigitsLut[i + 1];
|
||||
*buffer++ = cDigitsLut[i];
|
||||
*buffer++ = cDigitsLut[i + 1];
|
||||
}
|
||||
else {
|
||||
else
|
||||
{
|
||||
const uint32_t i = (a / 100) << 1;
|
||||
const uint32_t j = (a % 100) << 1;
|
||||
*buffer++ = cDigitsLut[i];
|
||||
*buffer++ = cDigitsLut[i + 1];
|
||||
*buffer++ = cDigitsLut[j];
|
||||
*buffer++ = cDigitsLut[j + 1];
|
||||
*buffer++ = cDigitsLut[i];
|
||||
*buffer++ = cDigitsLut[i + 1];
|
||||
*buffer++ = cDigitsLut[j];
|
||||
*buffer++ = cDigitsLut[j + 1];
|
||||
}
|
||||
|
||||
const uint32_t v0 = static_cast<uint32_t>(value / kTen8);
|
||||
@@ -291,12 +309,14 @@ inline char* u64toa(uint64_t value, char* buffer) {
|
||||
return buffer;
|
||||
}
|
||||
|
||||
inline char* i64toa(int64_t value, char* buffer) {
|
||||
inline char* i64toa(int64_t value, char* buffer)
|
||||
{
|
||||
RAPIDJSON_ASSERT(buffer != 0);
|
||||
uint64_t u = static_cast<uint64_t>(value);
|
||||
if (value < 0) {
|
||||
if(value < 0)
|
||||
{
|
||||
*buffer++ = '-';
|
||||
u = ~u + 1;
|
||||
u = ~u + 1;
|
||||
}
|
||||
|
||||
return u64toa(u, buffer);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// Tencent is pleased to support the open source community by making RapidJSON available.
|
||||
//
|
||||
//
|
||||
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
|
||||
//
|
||||
// Licensed under the MIT License (the "License"); you may not use this file except
|
||||
@@ -7,9 +7,9 @@
|
||||
//
|
||||
// http://opensource.org/licenses/MIT
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations under the License.
|
||||
|
||||
#ifndef RAPIDJSON_INTERNAL_META_H_
|
||||
@@ -36,140 +36,253 @@ RAPIDJSON_NAMESPACE_BEGIN
|
||||
namespace internal {
|
||||
|
||||
// Helper to wrap/convert arbitrary types to void, useful for arbitrary type matching
|
||||
template <typename T> struct Void { typedef void Type; };
|
||||
template <typename T>
|
||||
struct Void
|
||||
{
|
||||
typedef void Type;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// BoolType, TrueType, FalseType
|
||||
//
|
||||
template <bool Cond> struct BoolType {
|
||||
template <bool Cond>
|
||||
struct BoolType
|
||||
{
|
||||
static const bool Value = Cond;
|
||||
typedef BoolType Type;
|
||||
};
|
||||
typedef BoolType<true> TrueType;
|
||||
typedef BoolType<false> FalseType;
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// SelectIf, BoolExpr, NotExpr, AndExpr, OrExpr
|
||||
//
|
||||
|
||||
template <bool C> struct SelectIfImpl { template <typename T1, typename T2> struct Apply { typedef T1 Type; }; };
|
||||
template <> struct SelectIfImpl<false> { template <typename T1, typename T2> struct Apply { typedef T2 Type; }; };
|
||||
template <bool C, typename T1, typename T2> struct SelectIfCond : SelectIfImpl<C>::template Apply<T1,T2> {};
|
||||
template <typename C, typename T1, typename T2> struct SelectIf : SelectIfCond<C::Value, T1, T2> {};
|
||||
template <bool C>
|
||||
struct SelectIfImpl
|
||||
{
|
||||
template <typename T1, typename T2>
|
||||
struct Apply
|
||||
{
|
||||
typedef T1 Type;
|
||||
};
|
||||
};
|
||||
template <>
|
||||
struct SelectIfImpl<false>
|
||||
{
|
||||
template <typename T1, typename T2>
|
||||
struct Apply
|
||||
{
|
||||
typedef T2 Type;
|
||||
};
|
||||
};
|
||||
template <bool C, typename T1, typename T2>
|
||||
struct SelectIfCond : SelectIfImpl<C>::template Apply<T1, T2>
|
||||
{
|
||||
};
|
||||
template <typename C, typename T1, typename T2>
|
||||
struct SelectIf : SelectIfCond<C::Value, T1, T2>
|
||||
{
|
||||
};
|
||||
|
||||
template <bool Cond1, bool Cond2> struct AndExprCond : FalseType {};
|
||||
template <> struct AndExprCond<true, true> : TrueType {};
|
||||
template <bool Cond1, bool Cond2> struct OrExprCond : TrueType {};
|
||||
template <> struct OrExprCond<false, false> : FalseType {};
|
||||
|
||||
template <typename C> struct BoolExpr : SelectIf<C,TrueType,FalseType>::Type {};
|
||||
template <typename C> struct NotExpr : SelectIf<C,FalseType,TrueType>::Type {};
|
||||
template <typename C1, typename C2> struct AndExpr : AndExprCond<C1::Value, C2::Value>::Type {};
|
||||
template <typename C1, typename C2> struct OrExpr : OrExprCond<C1::Value, C2::Value>::Type {};
|
||||
template <bool Cond1, bool Cond2>
|
||||
struct AndExprCond : FalseType
|
||||
{
|
||||
};
|
||||
template <>
|
||||
struct AndExprCond<true, true> : TrueType
|
||||
{
|
||||
};
|
||||
template <bool Cond1, bool Cond2>
|
||||
struct OrExprCond : TrueType
|
||||
{
|
||||
};
|
||||
template <>
|
||||
struct OrExprCond<false, false> : FalseType
|
||||
{
|
||||
};
|
||||
|
||||
template <typename C>
|
||||
struct BoolExpr : SelectIf<C, TrueType, FalseType>::Type
|
||||
{
|
||||
};
|
||||
template <typename C>
|
||||
struct NotExpr : SelectIf<C, FalseType, TrueType>::Type
|
||||
{
|
||||
};
|
||||
template <typename C1, typename C2>
|
||||
struct AndExpr : AndExprCond<C1::Value, C2::Value>::Type
|
||||
{
|
||||
};
|
||||
template <typename C1, typename C2>
|
||||
struct OrExpr : OrExprCond<C1::Value, C2::Value>::Type
|
||||
{
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// AddConst, MaybeAddConst, RemoveConst
|
||||
template <typename T> struct AddConst { typedef const T Type; };
|
||||
template <bool Constify, typename T> struct MaybeAddConst : SelectIfCond<Constify, const T, T> {};
|
||||
template <typename T> struct RemoveConst { typedef T Type; };
|
||||
template <typename T> struct RemoveConst<const T> { typedef T Type; };
|
||||
|
||||
template <typename T>
|
||||
struct AddConst
|
||||
{
|
||||
typedef const T Type;
|
||||
};
|
||||
template <bool Constify, typename T>
|
||||
struct MaybeAddConst : SelectIfCond<Constify, const T, T>
|
||||
{
|
||||
};
|
||||
template <typename T>
|
||||
struct RemoveConst
|
||||
{
|
||||
typedef T Type;
|
||||
};
|
||||
template <typename T>
|
||||
struct RemoveConst<const T>
|
||||
{
|
||||
typedef T Type;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// IsSame, IsConst, IsMoreConst, IsPointer
|
||||
//
|
||||
template <typename T, typename U> struct IsSame : FalseType {};
|
||||
template <typename T> struct IsSame<T, T> : TrueType {};
|
||||
template <typename T, typename U>
|
||||
struct IsSame : FalseType
|
||||
{
|
||||
};
|
||||
template <typename T>
|
||||
struct IsSame<T, T> : TrueType
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T> struct IsConst : FalseType {};
|
||||
template <typename T> struct IsConst<const T> : TrueType {};
|
||||
template <typename T>
|
||||
struct IsConst : FalseType
|
||||
{
|
||||
};
|
||||
template <typename T>
|
||||
struct IsConst<const T> : TrueType
|
||||
{
|
||||
};
|
||||
|
||||
template <typename CT, typename T>
|
||||
struct IsMoreConst
|
||||
: AndExpr<IsSame<typename RemoveConst<CT>::Type, typename RemoveConst<T>::Type>,
|
||||
BoolType<IsConst<CT>::Value >= IsConst<T>::Value> >::Type {};
|
||||
struct IsMoreConst : AndExpr<IsSame<typename RemoveConst<CT>::Type, typename RemoveConst<T>::Type>,
|
||||
BoolType<IsConst<CT>::Value >= IsConst<T>::Value>>::Type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T> struct IsPointer : FalseType {};
|
||||
template <typename T> struct IsPointer<T*> : TrueType {};
|
||||
template <typename T>
|
||||
struct IsPointer : FalseType
|
||||
{
|
||||
};
|
||||
template <typename T>
|
||||
struct IsPointer<T*> : TrueType
|
||||
{
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// IsBaseOf
|
||||
//
|
||||
#if RAPIDJSON_HAS_CXX11_TYPETRAITS
|
||||
|
||||
template <typename B, typename D> struct IsBaseOf
|
||||
: BoolType< ::std::is_base_of<B,D>::value> {};
|
||||
template <typename B, typename D>
|
||||
struct IsBaseOf : BoolType<::std::is_base_of<B, D>::value>
|
||||
{
|
||||
};
|
||||
|
||||
#else // simplified version adopted from Boost
|
||||
|
||||
template<typename B, typename D> struct IsBaseOfImpl {
|
||||
template <typename B, typename D>
|
||||
struct IsBaseOfImpl
|
||||
{
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(B) != 0);
|
||||
RAPIDJSON_STATIC_ASSERT(sizeof(D) != 0);
|
||||
|
||||
typedef char (&Yes)[1];
|
||||
typedef char (&No) [2];
|
||||
typedef char (&No)[2];
|
||||
|
||||
template <typename T>
|
||||
static Yes Check(const D*, T);
|
||||
static No Check(const B*, int);
|
||||
static No Check(const B*, int);
|
||||
|
||||
struct Host {
|
||||
struct Host
|
||||
{
|
||||
operator const B*() const;
|
||||
operator const D*();
|
||||
};
|
||||
|
||||
enum { Value = (sizeof(Check(Host(), 0)) == sizeof(Yes)) };
|
||||
enum
|
||||
{
|
||||
Value = (sizeof(Check(Host(), 0)) == sizeof(Yes))
|
||||
};
|
||||
};
|
||||
|
||||
template <typename B, typename D> struct IsBaseOf
|
||||
: OrExpr<IsSame<B, D>, BoolExpr<IsBaseOfImpl<B, D> > >::Type {};
|
||||
template <typename B, typename D>
|
||||
struct IsBaseOf : OrExpr<IsSame<B, D>, BoolExpr<IsBaseOfImpl<B, D>>>::Type
|
||||
{
|
||||
};
|
||||
|
||||
#endif // RAPIDJSON_HAS_CXX11_TYPETRAITS
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// EnableIf / DisableIf
|
||||
//
|
||||
template <bool Condition, typename T = void> struct EnableIfCond { typedef T Type; };
|
||||
template <typename T> struct EnableIfCond<false, T> { /* empty */ };
|
||||
template <bool Condition, typename T = void>
|
||||
struct EnableIfCond
|
||||
{
|
||||
typedef T Type;
|
||||
};
|
||||
template <typename T>
|
||||
struct EnableIfCond<false, T>
|
||||
{ /* empty */
|
||||
};
|
||||
|
||||
template <bool Condition, typename T = void> struct DisableIfCond { typedef T Type; };
|
||||
template <typename T> struct DisableIfCond<true, T> { /* empty */ };
|
||||
template <bool Condition, typename T = void>
|
||||
struct DisableIfCond
|
||||
{
|
||||
typedef T Type;
|
||||
};
|
||||
template <typename T>
|
||||
struct DisableIfCond<true, T>
|
||||
{ /* empty */
|
||||
};
|
||||
|
||||
template <typename Condition, typename T = void>
|
||||
struct EnableIf : EnableIfCond<Condition::Value, T> {};
|
||||
struct EnableIf : EnableIfCond<Condition::Value, T>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Condition, typename T = void>
|
||||
struct DisableIf : DisableIfCond<Condition::Value, T> {};
|
||||
struct DisableIf : DisableIfCond<Condition::Value, T>
|
||||
{
|
||||
};
|
||||
|
||||
// SFINAE helpers
|
||||
struct SfinaeTag {};
|
||||
template <typename T> struct RemoveSfinaeTag;
|
||||
template <typename T> struct RemoveSfinaeTag<SfinaeTag&(*)(T)> { typedef T Type; };
|
||||
struct SfinaeTag
|
||||
{
|
||||
};
|
||||
template <typename T>
|
||||
struct RemoveSfinaeTag;
|
||||
template <typename T>
|
||||
struct RemoveSfinaeTag<SfinaeTag& (*)(T)>
|
||||
{
|
||||
typedef T Type;
|
||||
};
|
||||
|
||||
#define RAPIDJSON_REMOVEFPTR_(type) \
|
||||
typename ::RAPIDJSON_NAMESPACE::internal::RemoveSfinaeTag \
|
||||
< ::RAPIDJSON_NAMESPACE::internal::SfinaeTag&(*) type>::Type
|
||||
#define RAPIDJSON_REMOVEFPTR_(type) \
|
||||
typename ::RAPIDJSON_NAMESPACE::internal::RemoveSfinaeTag< \
|
||||
::RAPIDJSON_NAMESPACE::internal::SfinaeTag&(*)type>::Type
|
||||
|
||||
#define RAPIDJSON_ENABLEIF(cond) \
|
||||
typename ::RAPIDJSON_NAMESPACE::internal::EnableIf \
|
||||
<RAPIDJSON_REMOVEFPTR_(cond)>::Type * = NULL
|
||||
typename ::RAPIDJSON_NAMESPACE::internal::EnableIf<RAPIDJSON_REMOVEFPTR_(cond)>::Type* = NULL
|
||||
|
||||
#define RAPIDJSON_DISABLEIF(cond) \
|
||||
typename ::RAPIDJSON_NAMESPACE::internal::DisableIf \
|
||||
<RAPIDJSON_REMOVEFPTR_(cond)>::Type * = NULL
|
||||
typename ::RAPIDJSON_NAMESPACE::internal::DisableIf<RAPIDJSON_REMOVEFPTR_(cond)>::Type* = NULL
|
||||
|
||||
#define RAPIDJSON_ENABLEIF_RETURN(cond,returntype) \
|
||||
typename ::RAPIDJSON_NAMESPACE::internal::EnableIf \
|
||||
<RAPIDJSON_REMOVEFPTR_(cond), \
|
||||
RAPIDJSON_REMOVEFPTR_(returntype)>::Type
|
||||
#define RAPIDJSON_ENABLEIF_RETURN(cond, returntype) \
|
||||
typename ::RAPIDJSON_NAMESPACE::internal::EnableIf<RAPIDJSON_REMOVEFPTR_(cond), \
|
||||
RAPIDJSON_REMOVEFPTR_(returntype)>::Type
|
||||
|
||||
#define RAPIDJSON_DISABLEIF_RETURN(cond,returntype) \
|
||||
typename ::RAPIDJSON_NAMESPACE::internal::DisableIf \
|
||||
<RAPIDJSON_REMOVEFPTR_(cond), \
|
||||
RAPIDJSON_REMOVEFPTR_(returntype)>::Type
|
||||
#define RAPIDJSON_DISABLEIF_RETURN(cond, returntype) \
|
||||
typename ::RAPIDJSON_NAMESPACE::internal::DisableIf<RAPIDJSON_REMOVEFPTR_(cond), \
|
||||
RAPIDJSON_REMOVEFPTR_(returntype)>::Type
|
||||
|
||||
} // namespace internal
|
||||
RAPIDJSON_NAMESPACE_END
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// Tencent is pleased to support the open source community by making RapidJSON available.
|
||||
//
|
||||
//
|
||||
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
|
||||
//
|
||||
// Licensed under the MIT License (the "License"); you may not use this file except
|
||||
@@ -7,9 +7,9 @@
|
||||
//
|
||||
// http://opensource.org/licenses/MIT
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations under the License.
|
||||
|
||||
#ifndef RAPIDJSON_POW10_
|
||||
@@ -25,26 +25,39 @@ namespace internal {
|
||||
\param n non-negative exponent. Must <= 308.
|
||||
\return 10.0^n
|
||||
*/
|
||||
inline double Pow10(int n) {
|
||||
static const double e[] = { // 1e-0...1e308: 309 * 8 bytes = 2472 bytes
|
||||
1e+0,
|
||||
1e+1, 1e+2, 1e+3, 1e+4, 1e+5, 1e+6, 1e+7, 1e+8, 1e+9, 1e+10, 1e+11, 1e+12, 1e+13, 1e+14, 1e+15, 1e+16, 1e+17, 1e+18, 1e+19, 1e+20,
|
||||
1e+21, 1e+22, 1e+23, 1e+24, 1e+25, 1e+26, 1e+27, 1e+28, 1e+29, 1e+30, 1e+31, 1e+32, 1e+33, 1e+34, 1e+35, 1e+36, 1e+37, 1e+38, 1e+39, 1e+40,
|
||||
1e+41, 1e+42, 1e+43, 1e+44, 1e+45, 1e+46, 1e+47, 1e+48, 1e+49, 1e+50, 1e+51, 1e+52, 1e+53, 1e+54, 1e+55, 1e+56, 1e+57, 1e+58, 1e+59, 1e+60,
|
||||
1e+61, 1e+62, 1e+63, 1e+64, 1e+65, 1e+66, 1e+67, 1e+68, 1e+69, 1e+70, 1e+71, 1e+72, 1e+73, 1e+74, 1e+75, 1e+76, 1e+77, 1e+78, 1e+79, 1e+80,
|
||||
1e+81, 1e+82, 1e+83, 1e+84, 1e+85, 1e+86, 1e+87, 1e+88, 1e+89, 1e+90, 1e+91, 1e+92, 1e+93, 1e+94, 1e+95, 1e+96, 1e+97, 1e+98, 1e+99, 1e+100,
|
||||
1e+101,1e+102,1e+103,1e+104,1e+105,1e+106,1e+107,1e+108,1e+109,1e+110,1e+111,1e+112,1e+113,1e+114,1e+115,1e+116,1e+117,1e+118,1e+119,1e+120,
|
||||
1e+121,1e+122,1e+123,1e+124,1e+125,1e+126,1e+127,1e+128,1e+129,1e+130,1e+131,1e+132,1e+133,1e+134,1e+135,1e+136,1e+137,1e+138,1e+139,1e+140,
|
||||
1e+141,1e+142,1e+143,1e+144,1e+145,1e+146,1e+147,1e+148,1e+149,1e+150,1e+151,1e+152,1e+153,1e+154,1e+155,1e+156,1e+157,1e+158,1e+159,1e+160,
|
||||
1e+161,1e+162,1e+163,1e+164,1e+165,1e+166,1e+167,1e+168,1e+169,1e+170,1e+171,1e+172,1e+173,1e+174,1e+175,1e+176,1e+177,1e+178,1e+179,1e+180,
|
||||
1e+181,1e+182,1e+183,1e+184,1e+185,1e+186,1e+187,1e+188,1e+189,1e+190,1e+191,1e+192,1e+193,1e+194,1e+195,1e+196,1e+197,1e+198,1e+199,1e+200,
|
||||
1e+201,1e+202,1e+203,1e+204,1e+205,1e+206,1e+207,1e+208,1e+209,1e+210,1e+211,1e+212,1e+213,1e+214,1e+215,1e+216,1e+217,1e+218,1e+219,1e+220,
|
||||
1e+221,1e+222,1e+223,1e+224,1e+225,1e+226,1e+227,1e+228,1e+229,1e+230,1e+231,1e+232,1e+233,1e+234,1e+235,1e+236,1e+237,1e+238,1e+239,1e+240,
|
||||
1e+241,1e+242,1e+243,1e+244,1e+245,1e+246,1e+247,1e+248,1e+249,1e+250,1e+251,1e+252,1e+253,1e+254,1e+255,1e+256,1e+257,1e+258,1e+259,1e+260,
|
||||
1e+261,1e+262,1e+263,1e+264,1e+265,1e+266,1e+267,1e+268,1e+269,1e+270,1e+271,1e+272,1e+273,1e+274,1e+275,1e+276,1e+277,1e+278,1e+279,1e+280,
|
||||
1e+281,1e+282,1e+283,1e+284,1e+285,1e+286,1e+287,1e+288,1e+289,1e+290,1e+291,1e+292,1e+293,1e+294,1e+295,1e+296,1e+297,1e+298,1e+299,1e+300,
|
||||
1e+301,1e+302,1e+303,1e+304,1e+305,1e+306,1e+307,1e+308
|
||||
};
|
||||
inline double Pow10(int n)
|
||||
{
|
||||
static const double e[] = {
|
||||
// 1e-0...1e308: 309 * 8 bytes = 2472 bytes
|
||||
1e+0, 1e+1, 1e+2, 1e+3, 1e+4, 1e+5, 1e+6, 1e+7, 1e+8, 1e+9, 1e+10,
|
||||
1e+11, 1e+12, 1e+13, 1e+14, 1e+15, 1e+16, 1e+17, 1e+18, 1e+19, 1e+20, 1e+21,
|
||||
1e+22, 1e+23, 1e+24, 1e+25, 1e+26, 1e+27, 1e+28, 1e+29, 1e+30, 1e+31, 1e+32,
|
||||
1e+33, 1e+34, 1e+35, 1e+36, 1e+37, 1e+38, 1e+39, 1e+40, 1e+41, 1e+42, 1e+43,
|
||||
1e+44, 1e+45, 1e+46, 1e+47, 1e+48, 1e+49, 1e+50, 1e+51, 1e+52, 1e+53, 1e+54,
|
||||
1e+55, 1e+56, 1e+57, 1e+58, 1e+59, 1e+60, 1e+61, 1e+62, 1e+63, 1e+64, 1e+65,
|
||||
1e+66, 1e+67, 1e+68, 1e+69, 1e+70, 1e+71, 1e+72, 1e+73, 1e+74, 1e+75, 1e+76,
|
||||
1e+77, 1e+78, 1e+79, 1e+80, 1e+81, 1e+82, 1e+83, 1e+84, 1e+85, 1e+86, 1e+87,
|
||||
1e+88, 1e+89, 1e+90, 1e+91, 1e+92, 1e+93, 1e+94, 1e+95, 1e+96, 1e+97, 1e+98,
|
||||
1e+99, 1e+100, 1e+101, 1e+102, 1e+103, 1e+104, 1e+105, 1e+106, 1e+107, 1e+108, 1e+109,
|
||||
1e+110, 1e+111, 1e+112, 1e+113, 1e+114, 1e+115, 1e+116, 1e+117, 1e+118, 1e+119, 1e+120,
|
||||
1e+121, 1e+122, 1e+123, 1e+124, 1e+125, 1e+126, 1e+127, 1e+128, 1e+129, 1e+130, 1e+131,
|
||||
1e+132, 1e+133, 1e+134, 1e+135, 1e+136, 1e+137, 1e+138, 1e+139, 1e+140, 1e+141, 1e+142,
|
||||
1e+143, 1e+144, 1e+145, 1e+146, 1e+147, 1e+148, 1e+149, 1e+150, 1e+151, 1e+152, 1e+153,
|
||||
1e+154, 1e+155, 1e+156, 1e+157, 1e+158, 1e+159, 1e+160, 1e+161, 1e+162, 1e+163, 1e+164,
|
||||
1e+165, 1e+166, 1e+167, 1e+168, 1e+169, 1e+170, 1e+171, 1e+172, 1e+173, 1e+174, 1e+175,
|
||||
1e+176, 1e+177, 1e+178, 1e+179, 1e+180, 1e+181, 1e+182, 1e+183, 1e+184, 1e+185, 1e+186,
|
||||
1e+187, 1e+188, 1e+189, 1e+190, 1e+191, 1e+192, 1e+193, 1e+194, 1e+195, 1e+196, 1e+197,
|
||||
1e+198, 1e+199, 1e+200, 1e+201, 1e+202, 1e+203, 1e+204, 1e+205, 1e+206, 1e+207, 1e+208,
|
||||
1e+209, 1e+210, 1e+211, 1e+212, 1e+213, 1e+214, 1e+215, 1e+216, 1e+217, 1e+218, 1e+219,
|
||||
1e+220, 1e+221, 1e+222, 1e+223, 1e+224, 1e+225, 1e+226, 1e+227, 1e+228, 1e+229, 1e+230,
|
||||
1e+231, 1e+232, 1e+233, 1e+234, 1e+235, 1e+236, 1e+237, 1e+238, 1e+239, 1e+240, 1e+241,
|
||||
1e+242, 1e+243, 1e+244, 1e+245, 1e+246, 1e+247, 1e+248, 1e+249, 1e+250, 1e+251, 1e+252,
|
||||
1e+253, 1e+254, 1e+255, 1e+256, 1e+257, 1e+258, 1e+259, 1e+260, 1e+261, 1e+262, 1e+263,
|
||||
1e+264, 1e+265, 1e+266, 1e+267, 1e+268, 1e+269, 1e+270, 1e+271, 1e+272, 1e+273, 1e+274,
|
||||
1e+275, 1e+276, 1e+277, 1e+278, 1e+279, 1e+280, 1e+281, 1e+282, 1e+283, 1e+284, 1e+285,
|
||||
1e+286, 1e+287, 1e+288, 1e+289, 1e+290, 1e+291, 1e+292, 1e+293, 1e+294, 1e+295, 1e+296,
|
||||
1e+297, 1e+298, 1e+299, 1e+300, 1e+301, 1e+302, 1e+303, 1e+304, 1e+305, 1e+306, 1e+307,
|
||||
1e+308};
|
||||
RAPIDJSON_ASSERT(n >= 0 && n <= 308);
|
||||
return e[n];
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,5 @@
|
||||
// Tencent is pleased to support the open source community by making RapidJSON available.
|
||||
//
|
||||
//
|
||||
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
|
||||
//
|
||||
// Licensed under the MIT License (the "License"); you may not use this file except
|
||||
@@ -7,9 +7,9 @@
|
||||
//
|
||||
// http://opensource.org/licenses/MIT
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations under the License.
|
||||
|
||||
#ifndef RAPIDJSON_INTERNAL_STACK_H_
|
||||
@@ -21,7 +21,7 @@
|
||||
|
||||
#if defined(__clang__)
|
||||
RAPIDJSON_DIAG_PUSH
|
||||
RAPIDJSON_DIAG_OFF(c++98-compat)
|
||||
RAPIDJSON_DIAG_OFF(c++ 98 - compat)
|
||||
#endif
|
||||
|
||||
RAPIDJSON_NAMESPACE_BEGIN
|
||||
@@ -32,13 +32,21 @@ namespace internal {
|
||||
|
||||
//! A type-unsafe stack for storing different types of data.
|
||||
/*! \tparam Allocator Allocator for allocating stack memory.
|
||||
*/
|
||||
*/
|
||||
template <typename Allocator>
|
||||
class Stack {
|
||||
public:
|
||||
class Stack
|
||||
{
|
||||
public:
|
||||
// Optimization note: Do not allocate memory for stack_ in constructor.
|
||||
// Do it lazily when first Push() -> Expand() -> Resize().
|
||||
Stack(Allocator* allocator, size_t stackCapacity) : allocator_(allocator), ownAllocator_(0), stack_(0), stackTop_(0), stackEnd_(0), initialCapacity_(stackCapacity) {
|
||||
Stack(Allocator* allocator, size_t stackCapacity)
|
||||
: allocator_(allocator),
|
||||
ownAllocator_(0),
|
||||
stack_(0),
|
||||
stackTop_(0),
|
||||
stackEnd_(0),
|
||||
initialCapacity_(stackCapacity)
|
||||
{
|
||||
}
|
||||
|
||||
#if RAPIDJSON_HAS_CXX11_RVALUE_REFS
|
||||
@@ -50,44 +58,44 @@ public:
|
||||
stackEnd_(rhs.stackEnd_),
|
||||
initialCapacity_(rhs.initialCapacity_)
|
||||
{
|
||||
rhs.allocator_ = 0;
|
||||
rhs.ownAllocator_ = 0;
|
||||
rhs.stack_ = 0;
|
||||
rhs.stackTop_ = 0;
|
||||
rhs.stackEnd_ = 0;
|
||||
rhs.allocator_ = 0;
|
||||
rhs.ownAllocator_ = 0;
|
||||
rhs.stack_ = 0;
|
||||
rhs.stackTop_ = 0;
|
||||
rhs.stackEnd_ = 0;
|
||||
rhs.initialCapacity_ = 0;
|
||||
}
|
||||
#endif
|
||||
|
||||
~Stack() {
|
||||
Destroy();
|
||||
}
|
||||
~Stack() { Destroy(); }
|
||||
|
||||
#if RAPIDJSON_HAS_CXX11_RVALUE_REFS
|
||||
Stack& operator=(Stack&& rhs) {
|
||||
if (&rhs != this)
|
||||
Stack& operator=(Stack&& rhs)
|
||||
{
|
||||
if(&rhs != this)
|
||||
{
|
||||
Destroy();
|
||||
|
||||
allocator_ = rhs.allocator_;
|
||||
ownAllocator_ = rhs.ownAllocator_;
|
||||
stack_ = rhs.stack_;
|
||||
stackTop_ = rhs.stackTop_;
|
||||
stackEnd_ = rhs.stackEnd_;
|
||||
allocator_ = rhs.allocator_;
|
||||
ownAllocator_ = rhs.ownAllocator_;
|
||||
stack_ = rhs.stack_;
|
||||
stackTop_ = rhs.stackTop_;
|
||||
stackEnd_ = rhs.stackEnd_;
|
||||
initialCapacity_ = rhs.initialCapacity_;
|
||||
|
||||
rhs.allocator_ = 0;
|
||||
rhs.ownAllocator_ = 0;
|
||||
rhs.stack_ = 0;
|
||||
rhs.stackTop_ = 0;
|
||||
rhs.stackEnd_ = 0;
|
||||
rhs.allocator_ = 0;
|
||||
rhs.ownAllocator_ = 0;
|
||||
rhs.stack_ = 0;
|
||||
rhs.stackTop_ = 0;
|
||||
rhs.stackEnd_ = 0;
|
||||
rhs.initialCapacity_ = 0;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
#endif
|
||||
|
||||
void Swap(Stack& rhs) RAPIDJSON_NOEXCEPT {
|
||||
void Swap(Stack& rhs) RAPIDJSON_NOEXCEPT
|
||||
{
|
||||
internal::Swap(allocator_, rhs.allocator_);
|
||||
internal::Swap(ownAllocator_, rhs.ownAllocator_);
|
||||
internal::Swap(stack_, rhs.stack_);
|
||||
@@ -98,11 +106,13 @@ public:
|
||||
|
||||
void Clear() { stackTop_ = stack_; }
|
||||
|
||||
void ShrinkToFit() {
|
||||
if (Empty()) {
|
||||
void ShrinkToFit()
|
||||
{
|
||||
if(Empty())
|
||||
{
|
||||
// If the stack is empty, completely deallocate the memory.
|
||||
Allocator::Free(stack_); // NOLINT (+clang-analyzer-unix.Malloc)
|
||||
stack_ = 0;
|
||||
stack_ = 0;
|
||||
stackTop_ = 0;
|
||||
stackEnd_ = 0;
|
||||
}
|
||||
@@ -112,21 +122,25 @@ public:
|
||||
|
||||
// Optimization note: try to minimize the size of this function for force inline.
|
||||
// Expansion is run very infrequently, so it is moved to another (probably non-inline) function.
|
||||
template<typename T>
|
||||
RAPIDJSON_FORCEINLINE void Reserve(size_t count = 1) {
|
||||
// Expand the stack if needed
|
||||
if (RAPIDJSON_UNLIKELY(static_cast<std::ptrdiff_t>(sizeof(T) * count) > (stackEnd_ - stackTop_)))
|
||||
template <typename T>
|
||||
RAPIDJSON_FORCEINLINE void Reserve(size_t count = 1)
|
||||
{
|
||||
// Expand the stack if needed
|
||||
if(RAPIDJSON_UNLIKELY(static_cast<std::ptrdiff_t>(sizeof(T) * count) >
|
||||
(stackEnd_ - stackTop_)))
|
||||
Expand<T>(count);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
RAPIDJSON_FORCEINLINE T* Push(size_t count = 1) {
|
||||
template <typename T>
|
||||
RAPIDJSON_FORCEINLINE T* Push(size_t count = 1)
|
||||
{
|
||||
Reserve<T>(count);
|
||||
return PushUnsafe<T>(count);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
RAPIDJSON_FORCEINLINE T* PushUnsafe(size_t count = 1) {
|
||||
template <typename T>
|
||||
RAPIDJSON_FORCEINLINE T* PushUnsafe(size_t count = 1)
|
||||
{
|
||||
RAPIDJSON_ASSERT(stackTop_);
|
||||
RAPIDJSON_ASSERT(static_cast<std::ptrdiff_t>(sizeof(T) * count) <= (stackEnd_ - stackTop_));
|
||||
T* ret = reinterpret_cast<T*>(stackTop_);
|
||||
@@ -134,42 +148,56 @@ public:
|
||||
return ret;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T* Pop(size_t count) {
|
||||
template <typename T>
|
||||
T* Pop(size_t count)
|
||||
{
|
||||
RAPIDJSON_ASSERT(GetSize() >= count * sizeof(T));
|
||||
stackTop_ -= count * sizeof(T);
|
||||
return reinterpret_cast<T*>(stackTop_);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T* Top() {
|
||||
template <typename T>
|
||||
T* Top()
|
||||
{
|
||||
RAPIDJSON_ASSERT(GetSize() >= sizeof(T));
|
||||
return reinterpret_cast<T*>(stackTop_ - sizeof(T));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
const T* Top() const {
|
||||
template <typename T>
|
||||
const T* Top() const
|
||||
{
|
||||
RAPIDJSON_ASSERT(GetSize() >= sizeof(T));
|
||||
return reinterpret_cast<T*>(stackTop_ - sizeof(T));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T* End() { return reinterpret_cast<T*>(stackTop_); }
|
||||
|
||||
template<typename T>
|
||||
const T* End() const { return reinterpret_cast<T*>(stackTop_); }
|
||||
|
||||
template<typename T>
|
||||
T* Bottom() { return reinterpret_cast<T*>(stack_); }
|
||||
|
||||
template<typename T>
|
||||
const T* Bottom() const { return reinterpret_cast<T*>(stack_); }
|
||||
|
||||
bool HasAllocator() const {
|
||||
return allocator_ != 0;
|
||||
template <typename T>
|
||||
T* End()
|
||||
{
|
||||
return reinterpret_cast<T*>(stackTop_);
|
||||
}
|
||||
|
||||
Allocator& GetAllocator() {
|
||||
template <typename T>
|
||||
const T* End() const
|
||||
{
|
||||
return reinterpret_cast<T*>(stackTop_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T* Bottom()
|
||||
{
|
||||
return reinterpret_cast<T*>(stack_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const T* Bottom() const
|
||||
{
|
||||
return reinterpret_cast<T*>(stack_);
|
||||
}
|
||||
|
||||
bool HasAllocator() const { return allocator_ != 0; }
|
||||
|
||||
Allocator& GetAllocator()
|
||||
{
|
||||
RAPIDJSON_ASSERT(allocator_);
|
||||
return *allocator_;
|
||||
}
|
||||
@@ -178,34 +206,41 @@ public:
|
||||
size_t GetSize() const { return static_cast<size_t>(stackTop_ - stack_); }
|
||||
size_t GetCapacity() const { return static_cast<size_t>(stackEnd_ - stack_); }
|
||||
|
||||
private:
|
||||
template<typename T>
|
||||
void Expand(size_t count) {
|
||||
// Only expand the capacity if the current stack exists. Otherwise just create a stack with initial capacity.
|
||||
private:
|
||||
template <typename T>
|
||||
void Expand(size_t count)
|
||||
{
|
||||
// Only expand the capacity if the current stack exists. Otherwise just create a stack with
|
||||
// initial capacity.
|
||||
size_t newCapacity;
|
||||
if (stack_ == 0) {
|
||||
if (!allocator_)
|
||||
if(stack_ == 0)
|
||||
{
|
||||
if(!allocator_)
|
||||
ownAllocator_ = allocator_ = RAPIDJSON_NEW(Allocator)();
|
||||
newCapacity = initialCapacity_;
|
||||
} else {
|
||||
}
|
||||
else
|
||||
{
|
||||
newCapacity = GetCapacity();
|
||||
newCapacity += (newCapacity + 1) / 2;
|
||||
}
|
||||
size_t newSize = GetSize() + sizeof(T) * count;
|
||||
if (newCapacity < newSize)
|
||||
if(newCapacity < newSize)
|
||||
newCapacity = newSize;
|
||||
|
||||
Resize(newCapacity);
|
||||
}
|
||||
|
||||
void Resize(size_t newCapacity) {
|
||||
const size_t size = GetSize(); // Backup the current size
|
||||
stack_ = static_cast<char*>(allocator_->Realloc(stack_, GetCapacity(), newCapacity));
|
||||
void Resize(size_t newCapacity)
|
||||
{
|
||||
const size_t size = GetSize(); // Backup the current size
|
||||
stack_ = static_cast<char*>(allocator_->Realloc(stack_, GetCapacity(), newCapacity));
|
||||
stackTop_ = stack_ + size;
|
||||
stackEnd_ = stack_ + newCapacity;
|
||||
}
|
||||
|
||||
void Destroy() {
|
||||
void Destroy()
|
||||
{
|
||||
Allocator::Free(stack_);
|
||||
RAPIDJSON_DELETE(ownAllocator_); // Only delete if it is owned by the stack
|
||||
}
|
||||
@@ -216,9 +251,9 @@ private:
|
||||
|
||||
Allocator* allocator_;
|
||||
Allocator* ownAllocator_;
|
||||
char *stack_;
|
||||
char *stackTop_;
|
||||
char *stackEnd_;
|
||||
char* stack_;
|
||||
char* stackTop_;
|
||||
char* stackEnd_;
|
||||
size_t initialCapacity_;
|
||||
};
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// Tencent is pleased to support the open source community by making RapidJSON available.
|
||||
//
|
||||
//
|
||||
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
|
||||
//
|
||||
// Licensed under the MIT License (the "License"); you may not use this file except
|
||||
@@ -7,9 +7,9 @@
|
||||
//
|
||||
// http://opensource.org/licenses/MIT
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations under the License.
|
||||
|
||||
#ifndef RAPIDJSON_INTERNAL_STRFUNC_H_
|
||||
@@ -24,24 +24,29 @@ namespace internal {
|
||||
//! Custom strlen() which works on different character types.
|
||||
/*! \tparam Ch Character type (e.g. char, wchar_t, short)
|
||||
\param s Null-terminated input string.
|
||||
\return Number of characters in the string.
|
||||
\note This has the same semantics as strlen(), the return value is not number of Unicode codepoints.
|
||||
\return Number of characters in the string.
|
||||
\note This has the same semantics as strlen(), the return value is not number of Unicode
|
||||
codepoints.
|
||||
*/
|
||||
template <typename Ch>
|
||||
inline SizeType StrLen(const Ch* s) {
|
||||
inline SizeType StrLen(const Ch* s)
|
||||
{
|
||||
RAPIDJSON_ASSERT(s != 0);
|
||||
const Ch* p = s;
|
||||
while (*p) ++p;
|
||||
while(*p)
|
||||
++p;
|
||||
return SizeType(p - s);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline SizeType StrLen(const char* s) {
|
||||
inline SizeType StrLen(const char* s)
|
||||
{
|
||||
return SizeType(std::strlen(s));
|
||||
}
|
||||
|
||||
template <>
|
||||
inline SizeType StrLen(const wchar_t* s) {
|
||||
inline SizeType StrLen(const wchar_t* s)
|
||||
{
|
||||
return SizeType(std::wcslen(s));
|
||||
}
|
||||
|
||||
@@ -51,25 +56,34 @@ inline SizeType StrLen(const wchar_t* s) {
|
||||
\param s2 Null-terminated input string.
|
||||
\return 0 if equal
|
||||
*/
|
||||
template<typename Ch>
|
||||
inline int StrCmp(const Ch* s1, const Ch* s2) {
|
||||
template <typename Ch>
|
||||
inline int StrCmp(const Ch* s1, const Ch* s2)
|
||||
{
|
||||
RAPIDJSON_ASSERT(s1 != 0);
|
||||
RAPIDJSON_ASSERT(s2 != 0);
|
||||
while(*s1 && (*s1 == *s2)) { s1++; s2++; }
|
||||
return static_cast<unsigned>(*s1) < static_cast<unsigned>(*s2) ? -1 : static_cast<unsigned>(*s1) > static_cast<unsigned>(*s2);
|
||||
while(*s1 && (*s1 == *s2))
|
||||
{
|
||||
s1++;
|
||||
s2++;
|
||||
}
|
||||
return static_cast<unsigned>(*s1) < static_cast<unsigned>(*s2)
|
||||
? -1
|
||||
: static_cast<unsigned>(*s1) > static_cast<unsigned>(*s2);
|
||||
}
|
||||
|
||||
//! Returns number of code points in a encoded string.
|
||||
template<typename Encoding>
|
||||
bool CountStringCodePoint(const typename Encoding::Ch* s, SizeType length, SizeType* outCount) {
|
||||
template <typename Encoding>
|
||||
bool CountStringCodePoint(const typename Encoding::Ch* s, SizeType length, SizeType* outCount)
|
||||
{
|
||||
RAPIDJSON_ASSERT(s != 0);
|
||||
RAPIDJSON_ASSERT(outCount != 0);
|
||||
GenericStringStream<Encoding> is(s);
|
||||
const typename Encoding::Ch* end = s + length;
|
||||
SizeType count = 0;
|
||||
while (is.src_ < end) {
|
||||
SizeType count = 0;
|
||||
while(is.src_ < end)
|
||||
{
|
||||
unsigned codepoint;
|
||||
if (!Encoding::Decode(is, &codepoint))
|
||||
if(!Encoding::Decode(is, &codepoint))
|
||||
return false;
|
||||
count++;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// Tencent is pleased to support the open source community by making RapidJSON available.
|
||||
//
|
||||
//
|
||||
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
|
||||
//
|
||||
// Licensed under the MIT License (the "License"); you may not use this file except
|
||||
@@ -7,9 +7,9 @@
|
||||
//
|
||||
// http://opensource.org/licenses/MIT
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations under the License.
|
||||
|
||||
#ifndef RAPIDJSON_STRTOD_
|
||||
@@ -25,17 +25,20 @@
|
||||
RAPIDJSON_NAMESPACE_BEGIN
|
||||
namespace internal {
|
||||
|
||||
inline double FastPath(double significand, int exp) {
|
||||
if (exp < -308)
|
||||
inline double FastPath(double significand, int exp)
|
||||
{
|
||||
if(exp < -308)
|
||||
return 0.0;
|
||||
else if (exp >= 0)
|
||||
else if(exp >= 0)
|
||||
return significand * internal::Pow10(exp);
|
||||
else
|
||||
return significand / internal::Pow10(-exp);
|
||||
}
|
||||
|
||||
inline double StrtodNormalPrecision(double d, int p) {
|
||||
if (p < -308) {
|
||||
inline double StrtodNormalPrecision(double d, int p)
|
||||
{
|
||||
if(p < -308)
|
||||
{
|
||||
// Prevent expSum < -308, making Pow10(p) = 0
|
||||
d = FastPath(d, -308);
|
||||
d = FastPath(d, p + 308);
|
||||
@@ -46,27 +49,33 @@ inline double StrtodNormalPrecision(double d, int p) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline T Min3(T a, T b, T c) {
|
||||
inline T Min3(T a, T b, T c)
|
||||
{
|
||||
T m = a;
|
||||
if (m > b) m = b;
|
||||
if (m > c) m = c;
|
||||
if(m > b)
|
||||
m = b;
|
||||
if(m > c)
|
||||
m = c;
|
||||
return m;
|
||||
}
|
||||
|
||||
inline int CheckWithinHalfULP(double b, const BigInteger& d, int dExp) {
|
||||
inline int CheckWithinHalfULP(double b, const BigInteger& d, int dExp)
|
||||
{
|
||||
const Double db(b);
|
||||
const uint64_t bInt = db.IntegerSignificand();
|
||||
const int bExp = db.IntegerExponent();
|
||||
const int hExp = bExp - 1;
|
||||
const int bExp = db.IntegerExponent();
|
||||
const int hExp = bExp - 1;
|
||||
|
||||
int dS_Exp2 = 0, dS_Exp5 = 0, bS_Exp2 = 0, bS_Exp5 = 0, hS_Exp2 = 0, hS_Exp5 = 0;
|
||||
|
||||
// Adjust for decimal exponent
|
||||
if (dExp >= 0) {
|
||||
if(dExp >= 0)
|
||||
{
|
||||
dS_Exp2 += dExp;
|
||||
dS_Exp5 += dExp;
|
||||
}
|
||||
else {
|
||||
else
|
||||
{
|
||||
bS_Exp2 -= dExp;
|
||||
bS_Exp5 -= dExp;
|
||||
hS_Exp2 -= dExp;
|
||||
@@ -74,17 +83,19 @@ inline int CheckWithinHalfULP(double b, const BigInteger& d, int dExp) {
|
||||
}
|
||||
|
||||
// Adjust for binary exponent
|
||||
if (bExp >= 0)
|
||||
if(bExp >= 0)
|
||||
bS_Exp2 += bExp;
|
||||
else {
|
||||
else
|
||||
{
|
||||
dS_Exp2 -= bExp;
|
||||
hS_Exp2 -= bExp;
|
||||
}
|
||||
|
||||
// Adjust for half ulp exponent
|
||||
if (hExp >= 0)
|
||||
if(hExp >= 0)
|
||||
hS_Exp2 += hExp;
|
||||
else {
|
||||
else
|
||||
{
|
||||
dS_Exp2 -= hExp;
|
||||
bS_Exp2 -= hExp;
|
||||
}
|
||||
@@ -110,16 +121,19 @@ inline int CheckWithinHalfULP(double b, const BigInteger& d, int dExp) {
|
||||
return delta.Compare(hS);
|
||||
}
|
||||
|
||||
inline bool StrtodFast(double d, int p, double* result) {
|
||||
inline bool StrtodFast(double d, int p, double* result)
|
||||
{
|
||||
// Use fast path for string-to-double conversion if possible
|
||||
// see http://www.exploringbinary.com/fast-path-decimal-to-floating-point-conversion/
|
||||
if (p > 22 && p < 22 + 16) {
|
||||
if(p > 22 && p < 22 + 16)
|
||||
{
|
||||
// Fast Path Cases In Disguise
|
||||
d *= internal::Pow10(p - 22);
|
||||
p = 22;
|
||||
}
|
||||
|
||||
if (p >= -22 && p <= 22 && d <= 9007199254740991.0) { // 2^53 - 1
|
||||
if(p >= -22 && p <= 22 && d <= 9007199254740991.0)
|
||||
{ // 2^53 - 1
|
||||
*result = FastPath(d, p);
|
||||
return true;
|
||||
}
|
||||
@@ -128,24 +142,26 @@ inline bool StrtodFast(double d, int p, double* result) {
|
||||
}
|
||||
|
||||
// Compute an approximation and see if it is within 1/2 ULP
|
||||
template<typename Ch>
|
||||
inline bool StrtodDiyFp(const Ch* decimals, int dLen, int dExp, double* result) {
|
||||
template <typename Ch>
|
||||
inline bool StrtodDiyFp(const Ch* decimals, int dLen, int dExp, double* result)
|
||||
{
|
||||
uint64_t significand = 0;
|
||||
int i = 0; // 2^64 - 1 = 18446744073709551615, 1844674407370955161 = 0x1999999999999999
|
||||
for (; i < dLen; i++) {
|
||||
if (significand > RAPIDJSON_UINT64_C2(0x19999999, 0x99999999) ||
|
||||
(significand == RAPIDJSON_UINT64_C2(0x19999999, 0x99999999) && decimals[i] >= Ch('5')))
|
||||
int i = 0; // 2^64 - 1 = 18446744073709551615, 1844674407370955161 = 0x1999999999999999
|
||||
for(; i < dLen; i++)
|
||||
{
|
||||
if(significand > RAPIDJSON_UINT64_C2(0x19999999, 0x99999999) ||
|
||||
(significand == RAPIDJSON_UINT64_C2(0x19999999, 0x99999999) && decimals[i] >= Ch('5')))
|
||||
break;
|
||||
significand = significand * 10u + static_cast<unsigned>(decimals[i] - Ch('0'));
|
||||
}
|
||||
|
||||
if (i < dLen && decimals[i] >= Ch('5')) // Rounding
|
||||
|
||||
if(i < dLen && decimals[i] >= Ch('5')) // Rounding
|
||||
significand++;
|
||||
|
||||
int remaining = dLen - i;
|
||||
int remaining = dLen - i;
|
||||
const int kUlpShift = 3;
|
||||
const int kUlp = 1 << kUlpShift;
|
||||
int64_t error = (remaining == 0) ? 0 : kUlp / 2;
|
||||
const int kUlp = 1 << kUlpShift;
|
||||
int64_t error = (remaining == 0) ? 0 : kUlp / 2;
|
||||
|
||||
DiyFp v(significand, 0);
|
||||
v = v.Normalize();
|
||||
@@ -155,20 +171,21 @@ inline bool StrtodDiyFp(const Ch* decimals, int dLen, int dExp, double* result)
|
||||
|
||||
int actualExp;
|
||||
DiyFp cachedPower = GetCachedPower10(dExp, &actualExp);
|
||||
if (actualExp != dExp) {
|
||||
if(actualExp != dExp)
|
||||
{
|
||||
static const DiyFp kPow10[] = {
|
||||
DiyFp(RAPIDJSON_UINT64_C2(0xa0000000, 0x00000000), -60), // 10^1
|
||||
DiyFp(RAPIDJSON_UINT64_C2(0xc8000000, 0x00000000), -57), // 10^2
|
||||
DiyFp(RAPIDJSON_UINT64_C2(0xfa000000, 0x00000000), -54), // 10^3
|
||||
DiyFp(RAPIDJSON_UINT64_C2(0x9c400000, 0x00000000), -50), // 10^4
|
||||
DiyFp(RAPIDJSON_UINT64_C2(0xc3500000, 0x00000000), -47), // 10^5
|
||||
DiyFp(RAPIDJSON_UINT64_C2(0xf4240000, 0x00000000), -44), // 10^6
|
||||
DiyFp(RAPIDJSON_UINT64_C2(0x98968000, 0x00000000), -40) // 10^7
|
||||
DiyFp(RAPIDJSON_UINT64_C2(0xa0000000, 0x00000000), -60), // 10^1
|
||||
DiyFp(RAPIDJSON_UINT64_C2(0xc8000000, 0x00000000), -57), // 10^2
|
||||
DiyFp(RAPIDJSON_UINT64_C2(0xfa000000, 0x00000000), -54), // 10^3
|
||||
DiyFp(RAPIDJSON_UINT64_C2(0x9c400000, 0x00000000), -50), // 10^4
|
||||
DiyFp(RAPIDJSON_UINT64_C2(0xc3500000, 0x00000000), -47), // 10^5
|
||||
DiyFp(RAPIDJSON_UINT64_C2(0xf4240000, 0x00000000), -44), // 10^6
|
||||
DiyFp(RAPIDJSON_UINT64_C2(0x98968000, 0x00000000), -40) // 10^7
|
||||
};
|
||||
int adjustment = dExp - actualExp;
|
||||
RAPIDJSON_ASSERT(adjustment >= 1 && adjustment < 8);
|
||||
v = v * kPow10[adjustment - 1];
|
||||
if (dLen + adjustment > 19) // has more digits than decimal digits in 64-bit
|
||||
if(dLen + adjustment > 19) // has more digits than decimal digits in 64-bit
|
||||
error += kUlp / 2;
|
||||
}
|
||||
|
||||
@@ -177,25 +194,28 @@ inline bool StrtodDiyFp(const Ch* decimals, int dLen, int dExp, double* result)
|
||||
error += kUlp + (error == 0 ? 0 : 1);
|
||||
|
||||
const int oldExp = v.e;
|
||||
v = v.Normalize();
|
||||
v = v.Normalize();
|
||||
error <<= oldExp - v.e;
|
||||
|
||||
const int effectiveSignificandSize = Double::EffectiveSignificandSize(64 + v.e);
|
||||
int precisionSize = 64 - effectiveSignificandSize;
|
||||
if (precisionSize + kUlpShift >= 64) {
|
||||
int precisionSize = 64 - effectiveSignificandSize;
|
||||
if(precisionSize + kUlpShift >= 64)
|
||||
{
|
||||
int scaleExp = (precisionSize + kUlpShift) - 63;
|
||||
v.f >>= scaleExp;
|
||||
v.e += scaleExp;
|
||||
v.e += scaleExp;
|
||||
error = (error >> scaleExp) + 1 + kUlp;
|
||||
precisionSize -= scaleExp;
|
||||
}
|
||||
|
||||
DiyFp rounded(v.f >> precisionSize, v.e + precisionSize);
|
||||
const uint64_t precisionBits = (v.f & ((uint64_t(1) << precisionSize) - 1)) * kUlp;
|
||||
const uint64_t halfWay = (uint64_t(1) << (precisionSize - 1)) * kUlp;
|
||||
if (precisionBits >= halfWay + static_cast<unsigned>(error)) {
|
||||
const uint64_t halfWay = (uint64_t(1) << (precisionSize - 1)) * kUlp;
|
||||
if(precisionBits >= halfWay + static_cast<unsigned>(error))
|
||||
{
|
||||
rounded.f++;
|
||||
if (rounded.f & (DiyFp::kDpHiddenBit << 1)) { // rounding overflows mantissa (issue #340)
|
||||
if(rounded.f & (DiyFp::kDpHiddenBit << 1))
|
||||
{ // rounding overflows mantissa (issue #340)
|
||||
rounded.f >>= 1;
|
||||
rounded.e++;
|
||||
}
|
||||
@@ -203,20 +223,23 @@ inline bool StrtodDiyFp(const Ch* decimals, int dLen, int dExp, double* result)
|
||||
|
||||
*result = rounded.ToDouble();
|
||||
|
||||
return halfWay - static_cast<unsigned>(error) >= precisionBits || precisionBits >= halfWay + static_cast<unsigned>(error);
|
||||
return halfWay - static_cast<unsigned>(error) >= precisionBits ||
|
||||
precisionBits >= halfWay + static_cast<unsigned>(error);
|
||||
}
|
||||
|
||||
template<typename Ch>
|
||||
inline double StrtodBigInteger(double approx, const Ch* decimals, int dLen, int dExp) {
|
||||
template <typename Ch>
|
||||
inline double StrtodBigInteger(double approx, const Ch* decimals, int dLen, int dExp)
|
||||
{
|
||||
RAPIDJSON_ASSERT(dLen >= 0);
|
||||
const BigInteger dInt(decimals, static_cast<unsigned>(dLen));
|
||||
Double a(approx);
|
||||
int cmp = CheckWithinHalfULP(a.Value(), dInt, dExp);
|
||||
if (cmp < 0)
|
||||
return a.Value(); // within half ULP
|
||||
else if (cmp == 0) {
|
||||
if(cmp < 0)
|
||||
return a.Value(); // within half ULP
|
||||
else if(cmp == 0)
|
||||
{
|
||||
// Round towards even
|
||||
if (a.Significand() & 1)
|
||||
if(a.Significand() & 1)
|
||||
return a.NextPositiveDouble();
|
||||
else
|
||||
return a.Value();
|
||||
@@ -225,13 +248,15 @@ inline double StrtodBigInteger(double approx, const Ch* decimals, int dLen, int
|
||||
return a.NextPositiveDouble();
|
||||
}
|
||||
|
||||
template<typename Ch>
|
||||
inline double StrtodFullPrecision(double d, int p, const Ch* decimals, size_t length, size_t decimalPosition, int exp) {
|
||||
template <typename Ch>
|
||||
inline double StrtodFullPrecision(
|
||||
double d, int p, const Ch* decimals, size_t length, size_t decimalPosition, int exp)
|
||||
{
|
||||
RAPIDJSON_ASSERT(d >= 0.0);
|
||||
RAPIDJSON_ASSERT(length >= 1);
|
||||
|
||||
double result = 0.0;
|
||||
if (StrtodFast(d, p, &result))
|
||||
if(StrtodFast(d, p, &result))
|
||||
return result;
|
||||
|
||||
RAPIDJSON_ASSERT(length <= INT_MAX);
|
||||
@@ -248,39 +273,43 @@ inline double StrtodFullPrecision(double d, int p, const Ch* decimals, size_t le
|
||||
RAPIDJSON_ASSERT(dExp <= INT_MAX - dLen);
|
||||
|
||||
// Trim leading zeros
|
||||
while (dLen > 0 && *decimals == '0') {
|
||||
while(dLen > 0 && *decimals == '0')
|
||||
{
|
||||
dLen--;
|
||||
decimals++;
|
||||
}
|
||||
|
||||
// Trim trailing zeros
|
||||
while (dLen > 0 && decimals[dLen - 1] == '0') {
|
||||
while(dLen > 0 && decimals[dLen - 1] == '0')
|
||||
{
|
||||
dLen--;
|
||||
dExp++;
|
||||
}
|
||||
|
||||
if (dLen == 0) { // Buffer only contains zeros.
|
||||
if(dLen == 0)
|
||||
{ // Buffer only contains zeros.
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Trim right-most digits
|
||||
const int kMaxDecimalDigit = 767 + 1;
|
||||
if (dLen > kMaxDecimalDigit) {
|
||||
if(dLen > kMaxDecimalDigit)
|
||||
{
|
||||
dExp += dLen - kMaxDecimalDigit;
|
||||
dLen = kMaxDecimalDigit;
|
||||
}
|
||||
|
||||
// If too small, underflow to zero.
|
||||
// Any x <= 10^-324 is interpreted as zero.
|
||||
if (dLen + dExp <= -324)
|
||||
if(dLen + dExp <= -324)
|
||||
return 0.0;
|
||||
|
||||
// If too large, overflow to infinity.
|
||||
// Any x >= 10^309 is interpreted as +infinity.
|
||||
if (dLen + dExp > 309)
|
||||
if(dLen + dExp > 309)
|
||||
return std::numeric_limits<double>::infinity();
|
||||
|
||||
if (StrtodDiyFp(decimals, dLen, dExp, &result))
|
||||
if(StrtodDiyFp(decimals, dLen, dExp, &result))
|
||||
return result;
|
||||
|
||||
// Use approximation from StrtodDiyFp and make adjustment with BigInteger comparison
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
|
||||
#if defined(__clang__)
|
||||
RAPIDJSON_DIAG_PUSH
|
||||
RAPIDJSON_DIAG_OFF(c++98-compat)
|
||||
RAPIDJSON_DIAG_OFF(c++ 98 - compat)
|
||||
#endif
|
||||
|
||||
RAPIDJSON_NAMESPACE_BEGIN
|
||||
@@ -30,10 +30,11 @@ namespace internal {
|
||||
\note This has the same semantics as std::swap().
|
||||
*/
|
||||
template <typename T>
|
||||
inline void Swap(T& a, T& b) RAPIDJSON_NOEXCEPT {
|
||||
inline void Swap(T& a, T& b) RAPIDJSON_NOEXCEPT
|
||||
{
|
||||
T tmp = a;
|
||||
a = b;
|
||||
b = tmp;
|
||||
a = b;
|
||||
b = tmp;
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// Tencent is pleased to support the open source community by making RapidJSON available.
|
||||
//
|
||||
//
|
||||
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
|
||||
//
|
||||
// Licensed under the MIT License (the "License"); you may not use this file except
|
||||
@@ -7,9 +7,9 @@
|
||||
//
|
||||
// http://opensource.org/licenses/MIT
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations under the License.
|
||||
|
||||
#ifndef RAPIDJSON_ISTREAMWRAPPER_H_
|
||||
@@ -44,17 +44,27 @@ RAPIDJSON_NAMESPACE_BEGIN
|
||||
|
||||
\tparam StreamType Class derived from \c std::basic_istream.
|
||||
*/
|
||||
|
||||
|
||||
template <typename StreamType>
|
||||
class BasicIStreamWrapper {
|
||||
public:
|
||||
class BasicIStreamWrapper
|
||||
{
|
||||
public:
|
||||
typedef typename StreamType::char_type Ch;
|
||||
|
||||
//! Constructor.
|
||||
/*!
|
||||
\param stream stream opened for read.
|
||||
*/
|
||||
BasicIStreamWrapper(StreamType &stream) : stream_(stream), buffer_(peekBuffer_), bufferSize_(4), bufferLast_(0), current_(buffer_), readCount_(0), count_(0), eof_(false) {
|
||||
BasicIStreamWrapper(StreamType& stream)
|
||||
: stream_(stream),
|
||||
buffer_(peekBuffer_),
|
||||
bufferSize_(4),
|
||||
bufferLast_(0),
|
||||
current_(buffer_),
|
||||
readCount_(0),
|
||||
count_(0),
|
||||
eof_(false)
|
||||
{
|
||||
Read();
|
||||
}
|
||||
|
||||
@@ -64,55 +74,78 @@ public:
|
||||
\param buffer user-supplied buffer.
|
||||
\param bufferSize size of buffer in bytes. Must >=4 bytes.
|
||||
*/
|
||||
BasicIStreamWrapper(StreamType &stream, char* buffer, size_t bufferSize) : stream_(stream), buffer_(buffer), bufferSize_(bufferSize), bufferLast_(0), current_(buffer_), readCount_(0), count_(0), eof_(false) {
|
||||
BasicIStreamWrapper(StreamType& stream, char* buffer, size_t bufferSize)
|
||||
: stream_(stream),
|
||||
buffer_(buffer),
|
||||
bufferSize_(bufferSize),
|
||||
bufferLast_(0),
|
||||
current_(buffer_),
|
||||
readCount_(0),
|
||||
count_(0),
|
||||
eof_(false)
|
||||
{
|
||||
RAPIDJSON_ASSERT(bufferSize >= 4);
|
||||
Read();
|
||||
}
|
||||
|
||||
Ch Peek() const { return *current_; }
|
||||
Ch Take() { Ch c = *current_; Read(); return c; }
|
||||
Ch Take()
|
||||
{
|
||||
Ch c = *current_;
|
||||
Read();
|
||||
return c;
|
||||
}
|
||||
size_t Tell() const { return count_ + static_cast<size_t>(current_ - buffer_); }
|
||||
|
||||
// Not implemented
|
||||
void Put(Ch) { RAPIDJSON_ASSERT(false); }
|
||||
void Flush() { RAPIDJSON_ASSERT(false); }
|
||||
Ch* PutBegin() { RAPIDJSON_ASSERT(false); return 0; }
|
||||
size_t PutEnd(Ch*) { RAPIDJSON_ASSERT(false); return 0; }
|
||||
|
||||
// For encoding detection only.
|
||||
const Ch* Peek4() const {
|
||||
return (current_ + 4 - !eof_ <= bufferLast_) ? current_ : 0;
|
||||
void Flush() { RAPIDJSON_ASSERT(false); }
|
||||
Ch* PutBegin()
|
||||
{
|
||||
RAPIDJSON_ASSERT(false);
|
||||
return 0;
|
||||
}
|
||||
size_t PutEnd(Ch*)
|
||||
{
|
||||
RAPIDJSON_ASSERT(false);
|
||||
return 0;
|
||||
}
|
||||
|
||||
private:
|
||||
// For encoding detection only.
|
||||
const Ch* Peek4() const { return (current_ + 4 - !eof_ <= bufferLast_) ? current_ : 0; }
|
||||
|
||||
private:
|
||||
BasicIStreamWrapper();
|
||||
BasicIStreamWrapper(const BasicIStreamWrapper&);
|
||||
BasicIStreamWrapper& operator=(const BasicIStreamWrapper&);
|
||||
|
||||
void Read() {
|
||||
if (current_ < bufferLast_)
|
||||
void Read()
|
||||
{
|
||||
if(current_ < bufferLast_)
|
||||
++current_;
|
||||
else if (!eof_) {
|
||||
else if(!eof_)
|
||||
{
|
||||
count_ += readCount_;
|
||||
readCount_ = bufferSize_;
|
||||
readCount_ = bufferSize_;
|
||||
bufferLast_ = buffer_ + readCount_ - 1;
|
||||
current_ = buffer_;
|
||||
current_ = buffer_;
|
||||
|
||||
if (!stream_.read(buffer_, static_cast<std::streamsize>(bufferSize_))) {
|
||||
readCount_ = static_cast<size_t>(stream_.gcount());
|
||||
if(!stream_.read(buffer_, static_cast<std::streamsize>(bufferSize_)))
|
||||
{
|
||||
readCount_ = static_cast<size_t>(stream_.gcount());
|
||||
*(bufferLast_ = buffer_ + readCount_) = '\0';
|
||||
eof_ = true;
|
||||
eof_ = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
StreamType &stream_;
|
||||
StreamType& stream_;
|
||||
Ch peekBuffer_[4], *buffer_;
|
||||
size_t bufferSize_;
|
||||
Ch *bufferLast_;
|
||||
Ch *current_;
|
||||
Ch* bufferLast_;
|
||||
Ch* current_;
|
||||
size_t readCount_;
|
||||
size_t count_; //!< Number of characters read
|
||||
size_t count_; //!< Number of characters read
|
||||
bool eof_;
|
||||
};
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// Tencent is pleased to support the open source community by making RapidJSON available.
|
||||
//
|
||||
//
|
||||
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
|
||||
//
|
||||
// Licensed under the MIT License (the "License"); you may not use this file except
|
||||
@@ -7,9 +7,9 @@
|
||||
//
|
||||
// http://opensource.org/licenses/MIT
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations under the License.
|
||||
|
||||
#ifndef RAPIDJSON_MEMORYBUFFER_H_
|
||||
@@ -27,17 +27,22 @@ RAPIDJSON_NAMESPACE_BEGIN
|
||||
It is similar to FileWriteBuffer but the destination is an in-memory buffer instead of a file.
|
||||
|
||||
Differences between MemoryBuffer and StringBuffer:
|
||||
1. StringBuffer has Encoding but MemoryBuffer is only a byte buffer.
|
||||
2. StringBuffer::GetString() returns a null-terminated string. MemoryBuffer::GetBuffer() returns a buffer without terminator.
|
||||
1. StringBuffer has Encoding but MemoryBuffer is only a byte buffer.
|
||||
2. StringBuffer::GetString() returns a null-terminated string. MemoryBuffer::GetBuffer() returns
|
||||
a buffer without terminator.
|
||||
|
||||
\tparam Allocator type for allocating memory buffer.
|
||||
\note implements Stream concept
|
||||
*/
|
||||
template <typename Allocator = CrtAllocator>
|
||||
struct GenericMemoryBuffer {
|
||||
struct GenericMemoryBuffer
|
||||
{
|
||||
typedef char Ch; // byte
|
||||
|
||||
GenericMemoryBuffer(Allocator* allocator = 0, size_t capacity = kDefaultCapacity) : stack_(allocator, capacity) {}
|
||||
GenericMemoryBuffer(Allocator* allocator = 0, size_t capacity = kDefaultCapacity)
|
||||
: stack_(allocator, capacity)
|
||||
{
|
||||
}
|
||||
|
||||
void Put(Ch c) { *stack_.template Push<Ch>() = c; }
|
||||
void Flush() {}
|
||||
@@ -47,9 +52,7 @@ struct GenericMemoryBuffer {
|
||||
Ch* Push(size_t count) { return stack_.template Push<Ch>(count); }
|
||||
void Pop(size_t count) { stack_.template Pop<Ch>(count); }
|
||||
|
||||
const Ch* GetBuffer() const {
|
||||
return stack_.template Bottom<Ch>();
|
||||
}
|
||||
const Ch* GetBuffer() const { return stack_.template Bottom<Ch>(); }
|
||||
|
||||
size_t GetSize() const { return stack_.GetSize(); }
|
||||
|
||||
@@ -60,8 +63,9 @@ struct GenericMemoryBuffer {
|
||||
typedef GenericMemoryBuffer<> MemoryBuffer;
|
||||
|
||||
//! Implement specialized version of PutN() with memset() for better performance.
|
||||
template<>
|
||||
inline void PutN(MemoryBuffer& memoryBuffer, char c, size_t n) {
|
||||
template <>
|
||||
inline void PutN(MemoryBuffer& memoryBuffer, char c, size_t n)
|
||||
{
|
||||
std::memset(memoryBuffer.stack_.Push<char>(n), c, n * sizeof(c));
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// Tencent is pleased to support the open source community by making RapidJSON available.
|
||||
//
|
||||
//
|
||||
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
|
||||
//
|
||||
// Licensed under the MIT License (the "License"); you may not use this file except
|
||||
@@ -7,9 +7,9 @@
|
||||
//
|
||||
// http://opensource.org/licenses/MIT
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations under the License.
|
||||
|
||||
#ifndef RAPIDJSON_MEMORYSTREAM_H_
|
||||
@@ -19,8 +19,8 @@
|
||||
|
||||
#ifdef __clang__
|
||||
RAPIDJSON_DIAG_PUSH
|
||||
RAPIDJSON_DIAG_OFF(unreachable-code)
|
||||
RAPIDJSON_DIAG_OFF(missing-noreturn)
|
||||
RAPIDJSON_DIAG_OFF(unreachable - code)
|
||||
RAPIDJSON_DIAG_OFF(missing - noreturn)
|
||||
#endif
|
||||
|
||||
RAPIDJSON_NAMESPACE_BEGIN
|
||||
@@ -33,33 +33,43 @@ RAPIDJSON_NAMESPACE_BEGIN
|
||||
|
||||
Differences between MemoryStream and StringStream:
|
||||
1. StringStream has encoding but MemoryStream is a byte stream.
|
||||
2. MemoryStream needs size of the source buffer and the buffer don't need to be null terminated. StringStream assume null-terminated string as source.
|
||||
3. MemoryStream supports Peek4() for encoding detection. StringStream is specified with an encoding so it should not have Peek4().
|
||||
\note implements Stream concept
|
||||
2. MemoryStream needs size of the source buffer and the buffer don't need to be null terminated.
|
||||
StringStream assume null-terminated string as source.
|
||||
3. MemoryStream supports Peek4() for encoding detection. StringStream is specified with an
|
||||
encoding so it should not have Peek4(). \note implements Stream concept
|
||||
*/
|
||||
struct MemoryStream {
|
||||
struct MemoryStream
|
||||
{
|
||||
typedef char Ch; // byte
|
||||
|
||||
MemoryStream(const Ch *src, size_t size) : src_(src), begin_(src), end_(src + size), size_(size) {}
|
||||
MemoryStream(const Ch* src, size_t size) : src_(src), begin_(src), end_(src + size), size_(size)
|
||||
{
|
||||
}
|
||||
|
||||
Ch Peek() const { return RAPIDJSON_UNLIKELY(src_ == end_) ? '\0' : *src_; }
|
||||
Ch Take() { return RAPIDJSON_UNLIKELY(src_ == end_) ? '\0' : *src_++; }
|
||||
size_t Tell() const { return static_cast<size_t>(src_ - begin_); }
|
||||
|
||||
Ch* PutBegin() { RAPIDJSON_ASSERT(false); return 0; }
|
||||
Ch* PutBegin()
|
||||
{
|
||||
RAPIDJSON_ASSERT(false);
|
||||
return 0;
|
||||
}
|
||||
void Put(Ch) { RAPIDJSON_ASSERT(false); }
|
||||
void Flush() { RAPIDJSON_ASSERT(false); }
|
||||
size_t PutEnd(Ch*) { RAPIDJSON_ASSERT(false); return 0; }
|
||||
|
||||
// For encoding detection only.
|
||||
const Ch* Peek4() const {
|
||||
return Tell() + 4 <= size_ ? src_ : 0;
|
||||
size_t PutEnd(Ch*)
|
||||
{
|
||||
RAPIDJSON_ASSERT(false);
|
||||
return 0;
|
||||
}
|
||||
|
||||
const Ch* src_; //!< Current read position.
|
||||
const Ch* begin_; //!< Original head of the string.
|
||||
const Ch* end_; //!< End of stream.
|
||||
size_t size_; //!< Size of the stream.
|
||||
// For encoding detection only.
|
||||
const Ch* Peek4() const { return Tell() + 4 <= size_ ? src_ : 0; }
|
||||
|
||||
const Ch* src_; //!< Current read position.
|
||||
const Ch* begin_; //!< Original head of the string.
|
||||
const Ch* end_; //!< End of stream.
|
||||
size_t size_; //!< Size of the stream.
|
||||
};
|
||||
|
||||
RAPIDJSON_NAMESPACE_END
|
||||
|
||||
@@ -1,37 +1,37 @@
|
||||
// ISO C9x compliant inttypes.h for Microsoft Visual Studio
|
||||
// Based on ISO/IEC 9899:TC2 Committee draft (May 6, 2005) WG14/N1124
|
||||
//
|
||||
// Based on ISO/IEC 9899:TC2 Committee draft (May 6, 2005) WG14/N1124
|
||||
//
|
||||
// Copyright (c) 2006-2013 Alexander Chemeris
|
||||
//
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are met:
|
||||
//
|
||||
//
|
||||
// 1. Redistributions of source code must retain the above copyright notice,
|
||||
// this list of conditions and the following disclaimer.
|
||||
//
|
||||
//
|
||||
// 2. Redistributions in binary form must reproduce the above copyright
|
||||
// notice, this list of conditions and the following disclaimer in the
|
||||
// documentation and/or other materials provided with the distribution.
|
||||
//
|
||||
//
|
||||
// 3. Neither the name of the product nor the names of its contributors may
|
||||
// be used to endorse or promote products derived from this software
|
||||
// without specific prior written permission.
|
||||
//
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED
|
||||
// WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
|
||||
// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
|
||||
// EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
// OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
|
||||
// OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
|
||||
// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
|
||||
// OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
|
||||
// ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
//
|
||||
//
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// The above software in this distribution may have been modified by
|
||||
// THL A29 Limited ("Tencent Modifications").
|
||||
// The above software in this distribution may have been modified by
|
||||
// THL A29 Limited ("Tencent Modifications").
|
||||
// All Tencent Modifications are Copyright (C) 2015 THL A29 Limited.
|
||||
|
||||
#ifndef _MSC_VER // [
|
||||
@@ -54,9 +54,10 @@
|
||||
|
||||
// 7.8 Format conversion of integer types
|
||||
|
||||
typedef struct {
|
||||
intmax_t quot;
|
||||
intmax_t rem;
|
||||
typedef struct
|
||||
{
|
||||
intmax_t quot;
|
||||
intmax_t rem;
|
||||
} imaxdiv_t;
|
||||
|
||||
// 7.8.1 Macros for format specifiers
|
||||
@@ -64,212 +65,212 @@ typedef struct {
|
||||
#if !defined(__cplusplus) || defined(__STDC_FORMAT_MACROS) // [ See footnote 185 at page 198
|
||||
|
||||
// The fprintf macros for signed integers are:
|
||||
#define PRId8 "d"
|
||||
#define PRIi8 "i"
|
||||
#define PRIdLEAST8 "d"
|
||||
#define PRIiLEAST8 "i"
|
||||
#define PRIdFAST8 "d"
|
||||
#define PRIiFAST8 "i"
|
||||
#define PRId8 "d"
|
||||
#define PRIi8 "i"
|
||||
#define PRIdLEAST8 "d"
|
||||
#define PRIiLEAST8 "i"
|
||||
#define PRIdFAST8 "d"
|
||||
#define PRIiFAST8 "i"
|
||||
|
||||
#define PRId16 "hd"
|
||||
#define PRIi16 "hi"
|
||||
#define PRIdLEAST16 "hd"
|
||||
#define PRIiLEAST16 "hi"
|
||||
#define PRIdFAST16 "hd"
|
||||
#define PRIiFAST16 "hi"
|
||||
#define PRId16 "hd"
|
||||
#define PRIi16 "hi"
|
||||
#define PRIdLEAST16 "hd"
|
||||
#define PRIiLEAST16 "hi"
|
||||
#define PRIdFAST16 "hd"
|
||||
#define PRIiFAST16 "hi"
|
||||
|
||||
#define PRId32 "I32d"
|
||||
#define PRIi32 "I32i"
|
||||
#define PRIdLEAST32 "I32d"
|
||||
#define PRIiLEAST32 "I32i"
|
||||
#define PRIdFAST32 "I32d"
|
||||
#define PRIiFAST32 "I32i"
|
||||
#define PRId32 "I32d"
|
||||
#define PRIi32 "I32i"
|
||||
#define PRIdLEAST32 "I32d"
|
||||
#define PRIiLEAST32 "I32i"
|
||||
#define PRIdFAST32 "I32d"
|
||||
#define PRIiFAST32 "I32i"
|
||||
|
||||
#define PRId64 "I64d"
|
||||
#define PRIi64 "I64i"
|
||||
#define PRIdLEAST64 "I64d"
|
||||
#define PRIiLEAST64 "I64i"
|
||||
#define PRIdFAST64 "I64d"
|
||||
#define PRIiFAST64 "I64i"
|
||||
#define PRId64 "I64d"
|
||||
#define PRIi64 "I64i"
|
||||
#define PRIdLEAST64 "I64d"
|
||||
#define PRIiLEAST64 "I64i"
|
||||
#define PRIdFAST64 "I64d"
|
||||
#define PRIiFAST64 "I64i"
|
||||
|
||||
#define PRIdMAX "I64d"
|
||||
#define PRIiMAX "I64i"
|
||||
#define PRIdMAX "I64d"
|
||||
#define PRIiMAX "I64i"
|
||||
|
||||
#define PRIdPTR "Id"
|
||||
#define PRIiPTR "Ii"
|
||||
#define PRIdPTR "Id"
|
||||
#define PRIiPTR "Ii"
|
||||
|
||||
// The fprintf macros for unsigned integers are:
|
||||
#define PRIo8 "o"
|
||||
#define PRIu8 "u"
|
||||
#define PRIx8 "x"
|
||||
#define PRIX8 "X"
|
||||
#define PRIoLEAST8 "o"
|
||||
#define PRIuLEAST8 "u"
|
||||
#define PRIxLEAST8 "x"
|
||||
#define PRIXLEAST8 "X"
|
||||
#define PRIoFAST8 "o"
|
||||
#define PRIuFAST8 "u"
|
||||
#define PRIxFAST8 "x"
|
||||
#define PRIXFAST8 "X"
|
||||
#define PRIo8 "o"
|
||||
#define PRIu8 "u"
|
||||
#define PRIx8 "x"
|
||||
#define PRIX8 "X"
|
||||
#define PRIoLEAST8 "o"
|
||||
#define PRIuLEAST8 "u"
|
||||
#define PRIxLEAST8 "x"
|
||||
#define PRIXLEAST8 "X"
|
||||
#define PRIoFAST8 "o"
|
||||
#define PRIuFAST8 "u"
|
||||
#define PRIxFAST8 "x"
|
||||
#define PRIXFAST8 "X"
|
||||
|
||||
#define PRIo16 "ho"
|
||||
#define PRIu16 "hu"
|
||||
#define PRIx16 "hx"
|
||||
#define PRIX16 "hX"
|
||||
#define PRIoLEAST16 "ho"
|
||||
#define PRIuLEAST16 "hu"
|
||||
#define PRIxLEAST16 "hx"
|
||||
#define PRIXLEAST16 "hX"
|
||||
#define PRIoFAST16 "ho"
|
||||
#define PRIuFAST16 "hu"
|
||||
#define PRIxFAST16 "hx"
|
||||
#define PRIXFAST16 "hX"
|
||||
#define PRIo16 "ho"
|
||||
#define PRIu16 "hu"
|
||||
#define PRIx16 "hx"
|
||||
#define PRIX16 "hX"
|
||||
#define PRIoLEAST16 "ho"
|
||||
#define PRIuLEAST16 "hu"
|
||||
#define PRIxLEAST16 "hx"
|
||||
#define PRIXLEAST16 "hX"
|
||||
#define PRIoFAST16 "ho"
|
||||
#define PRIuFAST16 "hu"
|
||||
#define PRIxFAST16 "hx"
|
||||
#define PRIXFAST16 "hX"
|
||||
|
||||
#define PRIo32 "I32o"
|
||||
#define PRIu32 "I32u"
|
||||
#define PRIx32 "I32x"
|
||||
#define PRIX32 "I32X"
|
||||
#define PRIoLEAST32 "I32o"
|
||||
#define PRIuLEAST32 "I32u"
|
||||
#define PRIxLEAST32 "I32x"
|
||||
#define PRIXLEAST32 "I32X"
|
||||
#define PRIoFAST32 "I32o"
|
||||
#define PRIuFAST32 "I32u"
|
||||
#define PRIxFAST32 "I32x"
|
||||
#define PRIXFAST32 "I32X"
|
||||
#define PRIo32 "I32o"
|
||||
#define PRIu32 "I32u"
|
||||
#define PRIx32 "I32x"
|
||||
#define PRIX32 "I32X"
|
||||
#define PRIoLEAST32 "I32o"
|
||||
#define PRIuLEAST32 "I32u"
|
||||
#define PRIxLEAST32 "I32x"
|
||||
#define PRIXLEAST32 "I32X"
|
||||
#define PRIoFAST32 "I32o"
|
||||
#define PRIuFAST32 "I32u"
|
||||
#define PRIxFAST32 "I32x"
|
||||
#define PRIXFAST32 "I32X"
|
||||
|
||||
#define PRIo64 "I64o"
|
||||
#define PRIu64 "I64u"
|
||||
#define PRIx64 "I64x"
|
||||
#define PRIX64 "I64X"
|
||||
#define PRIoLEAST64 "I64o"
|
||||
#define PRIuLEAST64 "I64u"
|
||||
#define PRIxLEAST64 "I64x"
|
||||
#define PRIXLEAST64 "I64X"
|
||||
#define PRIoFAST64 "I64o"
|
||||
#define PRIuFAST64 "I64u"
|
||||
#define PRIxFAST64 "I64x"
|
||||
#define PRIXFAST64 "I64X"
|
||||
#define PRIo64 "I64o"
|
||||
#define PRIu64 "I64u"
|
||||
#define PRIx64 "I64x"
|
||||
#define PRIX64 "I64X"
|
||||
#define PRIoLEAST64 "I64o"
|
||||
#define PRIuLEAST64 "I64u"
|
||||
#define PRIxLEAST64 "I64x"
|
||||
#define PRIXLEAST64 "I64X"
|
||||
#define PRIoFAST64 "I64o"
|
||||
#define PRIuFAST64 "I64u"
|
||||
#define PRIxFAST64 "I64x"
|
||||
#define PRIXFAST64 "I64X"
|
||||
|
||||
#define PRIoMAX "I64o"
|
||||
#define PRIuMAX "I64u"
|
||||
#define PRIxMAX "I64x"
|
||||
#define PRIXMAX "I64X"
|
||||
#define PRIoMAX "I64o"
|
||||
#define PRIuMAX "I64u"
|
||||
#define PRIxMAX "I64x"
|
||||
#define PRIXMAX "I64X"
|
||||
|
||||
#define PRIoPTR "Io"
|
||||
#define PRIuPTR "Iu"
|
||||
#define PRIxPTR "Ix"
|
||||
#define PRIXPTR "IX"
|
||||
#define PRIoPTR "Io"
|
||||
#define PRIuPTR "Iu"
|
||||
#define PRIxPTR "Ix"
|
||||
#define PRIXPTR "IX"
|
||||
|
||||
// The fscanf macros for signed integers are:
|
||||
#define SCNd8 "d"
|
||||
#define SCNi8 "i"
|
||||
#define SCNdLEAST8 "d"
|
||||
#define SCNiLEAST8 "i"
|
||||
#define SCNdFAST8 "d"
|
||||
#define SCNiFAST8 "i"
|
||||
#define SCNd8 "d"
|
||||
#define SCNi8 "i"
|
||||
#define SCNdLEAST8 "d"
|
||||
#define SCNiLEAST8 "i"
|
||||
#define SCNdFAST8 "d"
|
||||
#define SCNiFAST8 "i"
|
||||
|
||||
#define SCNd16 "hd"
|
||||
#define SCNi16 "hi"
|
||||
#define SCNdLEAST16 "hd"
|
||||
#define SCNiLEAST16 "hi"
|
||||
#define SCNdFAST16 "hd"
|
||||
#define SCNiFAST16 "hi"
|
||||
#define SCNd16 "hd"
|
||||
#define SCNi16 "hi"
|
||||
#define SCNdLEAST16 "hd"
|
||||
#define SCNiLEAST16 "hi"
|
||||
#define SCNdFAST16 "hd"
|
||||
#define SCNiFAST16 "hi"
|
||||
|
||||
#define SCNd32 "ld"
|
||||
#define SCNi32 "li"
|
||||
#define SCNdLEAST32 "ld"
|
||||
#define SCNiLEAST32 "li"
|
||||
#define SCNdFAST32 "ld"
|
||||
#define SCNiFAST32 "li"
|
||||
#define SCNd32 "ld"
|
||||
#define SCNi32 "li"
|
||||
#define SCNdLEAST32 "ld"
|
||||
#define SCNiLEAST32 "li"
|
||||
#define SCNdFAST32 "ld"
|
||||
#define SCNiFAST32 "li"
|
||||
|
||||
#define SCNd64 "I64d"
|
||||
#define SCNi64 "I64i"
|
||||
#define SCNdLEAST64 "I64d"
|
||||
#define SCNiLEAST64 "I64i"
|
||||
#define SCNdFAST64 "I64d"
|
||||
#define SCNiFAST64 "I64i"
|
||||
#define SCNd64 "I64d"
|
||||
#define SCNi64 "I64i"
|
||||
#define SCNdLEAST64 "I64d"
|
||||
#define SCNiLEAST64 "I64i"
|
||||
#define SCNdFAST64 "I64d"
|
||||
#define SCNiFAST64 "I64i"
|
||||
|
||||
#define SCNdMAX "I64d"
|
||||
#define SCNiMAX "I64i"
|
||||
#define SCNdMAX "I64d"
|
||||
#define SCNiMAX "I64i"
|
||||
|
||||
#ifdef _WIN64 // [
|
||||
# define SCNdPTR "I64d"
|
||||
# define SCNiPTR "I64i"
|
||||
#else // _WIN64 ][
|
||||
# define SCNdPTR "ld"
|
||||
# define SCNiPTR "li"
|
||||
#endif // _WIN64 ]
|
||||
#define SCNdPTR "I64d"
|
||||
#define SCNiPTR "I64i"
|
||||
#else // _WIN64 ][
|
||||
#define SCNdPTR "ld"
|
||||
#define SCNiPTR "li"
|
||||
#endif // _WIN64 ]
|
||||
|
||||
// The fscanf macros for unsigned integers are:
|
||||
#define SCNo8 "o"
|
||||
#define SCNu8 "u"
|
||||
#define SCNx8 "x"
|
||||
#define SCNX8 "X"
|
||||
#define SCNoLEAST8 "o"
|
||||
#define SCNuLEAST8 "u"
|
||||
#define SCNxLEAST8 "x"
|
||||
#define SCNXLEAST8 "X"
|
||||
#define SCNoFAST8 "o"
|
||||
#define SCNuFAST8 "u"
|
||||
#define SCNxFAST8 "x"
|
||||
#define SCNXFAST8 "X"
|
||||
#define SCNo8 "o"
|
||||
#define SCNu8 "u"
|
||||
#define SCNx8 "x"
|
||||
#define SCNX8 "X"
|
||||
#define SCNoLEAST8 "o"
|
||||
#define SCNuLEAST8 "u"
|
||||
#define SCNxLEAST8 "x"
|
||||
#define SCNXLEAST8 "X"
|
||||
#define SCNoFAST8 "o"
|
||||
#define SCNuFAST8 "u"
|
||||
#define SCNxFAST8 "x"
|
||||
#define SCNXFAST8 "X"
|
||||
|
||||
#define SCNo16 "ho"
|
||||
#define SCNu16 "hu"
|
||||
#define SCNx16 "hx"
|
||||
#define SCNX16 "hX"
|
||||
#define SCNoLEAST16 "ho"
|
||||
#define SCNuLEAST16 "hu"
|
||||
#define SCNxLEAST16 "hx"
|
||||
#define SCNXLEAST16 "hX"
|
||||
#define SCNoFAST16 "ho"
|
||||
#define SCNuFAST16 "hu"
|
||||
#define SCNxFAST16 "hx"
|
||||
#define SCNXFAST16 "hX"
|
||||
#define SCNo16 "ho"
|
||||
#define SCNu16 "hu"
|
||||
#define SCNx16 "hx"
|
||||
#define SCNX16 "hX"
|
||||
#define SCNoLEAST16 "ho"
|
||||
#define SCNuLEAST16 "hu"
|
||||
#define SCNxLEAST16 "hx"
|
||||
#define SCNXLEAST16 "hX"
|
||||
#define SCNoFAST16 "ho"
|
||||
#define SCNuFAST16 "hu"
|
||||
#define SCNxFAST16 "hx"
|
||||
#define SCNXFAST16 "hX"
|
||||
|
||||
#define SCNo32 "lo"
|
||||
#define SCNu32 "lu"
|
||||
#define SCNx32 "lx"
|
||||
#define SCNX32 "lX"
|
||||
#define SCNoLEAST32 "lo"
|
||||
#define SCNuLEAST32 "lu"
|
||||
#define SCNxLEAST32 "lx"
|
||||
#define SCNXLEAST32 "lX"
|
||||
#define SCNoFAST32 "lo"
|
||||
#define SCNuFAST32 "lu"
|
||||
#define SCNxFAST32 "lx"
|
||||
#define SCNXFAST32 "lX"
|
||||
#define SCNo32 "lo"
|
||||
#define SCNu32 "lu"
|
||||
#define SCNx32 "lx"
|
||||
#define SCNX32 "lX"
|
||||
#define SCNoLEAST32 "lo"
|
||||
#define SCNuLEAST32 "lu"
|
||||
#define SCNxLEAST32 "lx"
|
||||
#define SCNXLEAST32 "lX"
|
||||
#define SCNoFAST32 "lo"
|
||||
#define SCNuFAST32 "lu"
|
||||
#define SCNxFAST32 "lx"
|
||||
#define SCNXFAST32 "lX"
|
||||
|
||||
#define SCNo64 "I64o"
|
||||
#define SCNu64 "I64u"
|
||||
#define SCNx64 "I64x"
|
||||
#define SCNX64 "I64X"
|
||||
#define SCNoLEAST64 "I64o"
|
||||
#define SCNuLEAST64 "I64u"
|
||||
#define SCNxLEAST64 "I64x"
|
||||
#define SCNXLEAST64 "I64X"
|
||||
#define SCNoFAST64 "I64o"
|
||||
#define SCNuFAST64 "I64u"
|
||||
#define SCNxFAST64 "I64x"
|
||||
#define SCNXFAST64 "I64X"
|
||||
#define SCNo64 "I64o"
|
||||
#define SCNu64 "I64u"
|
||||
#define SCNx64 "I64x"
|
||||
#define SCNX64 "I64X"
|
||||
#define SCNoLEAST64 "I64o"
|
||||
#define SCNuLEAST64 "I64u"
|
||||
#define SCNxLEAST64 "I64x"
|
||||
#define SCNXLEAST64 "I64X"
|
||||
#define SCNoFAST64 "I64o"
|
||||
#define SCNuFAST64 "I64u"
|
||||
#define SCNxFAST64 "I64x"
|
||||
#define SCNXFAST64 "I64X"
|
||||
|
||||
#define SCNoMAX "I64o"
|
||||
#define SCNuMAX "I64u"
|
||||
#define SCNxMAX "I64x"
|
||||
#define SCNXMAX "I64X"
|
||||
#define SCNoMAX "I64o"
|
||||
#define SCNuMAX "I64u"
|
||||
#define SCNxMAX "I64x"
|
||||
#define SCNXMAX "I64X"
|
||||
|
||||
#ifdef _WIN64 // [
|
||||
# define SCNoPTR "I64o"
|
||||
# define SCNuPTR "I64u"
|
||||
# define SCNxPTR "I64x"
|
||||
# define SCNXPTR "I64X"
|
||||
#else // _WIN64 ][
|
||||
# define SCNoPTR "lo"
|
||||
# define SCNuPTR "lu"
|
||||
# define SCNxPTR "lx"
|
||||
# define SCNXPTR "lX"
|
||||
#endif // _WIN64 ]
|
||||
#define SCNoPTR "I64o"
|
||||
#define SCNuPTR "I64u"
|
||||
#define SCNxPTR "I64x"
|
||||
#define SCNXPTR "I64X"
|
||||
#else // _WIN64 ][
|
||||
#define SCNoPTR "lo"
|
||||
#define SCNuPTR "lu"
|
||||
#define SCNxPTR "lx"
|
||||
#define SCNXPTR "lX"
|
||||
#endif // _WIN64 ]
|
||||
|
||||
#endif // __STDC_FORMAT_MACROS ]
|
||||
|
||||
@@ -284,23 +285,24 @@ typedef struct {
|
||||
// in %MSVC.NET%\crt\src\div.c
|
||||
#ifdef STATIC_IMAXDIV // [
|
||||
static
|
||||
#else // STATIC_IMAXDIV ][
|
||||
#else // STATIC_IMAXDIV ][
|
||||
_inline
|
||||
#endif // STATIC_IMAXDIV ]
|
||||
imaxdiv_t __cdecl imaxdiv(intmax_t numer, intmax_t denom)
|
||||
#endif // STATIC_IMAXDIV ]
|
||||
imaxdiv_t __cdecl imaxdiv(intmax_t numer, intmax_t denom)
|
||||
{
|
||||
imaxdiv_t result;
|
||||
imaxdiv_t result;
|
||||
|
||||
result.quot = numer / denom;
|
||||
result.rem = numer % denom;
|
||||
result.quot = numer / denom;
|
||||
result.rem = numer % denom;
|
||||
|
||||
if (numer < 0 && result.rem > 0) {
|
||||
// did division wrong; must fix up
|
||||
++result.quot;
|
||||
result.rem -= denom;
|
||||
}
|
||||
if(numer < 0 && result.rem > 0)
|
||||
{
|
||||
// did division wrong; must fix up
|
||||
++result.quot;
|
||||
result.rem -= denom;
|
||||
}
|
||||
|
||||
return result;
|
||||
return result;
|
||||
}
|
||||
|
||||
// 7.8.2.3 The strtoimax and strtoumax functions
|
||||
|
||||
@@ -1,37 +1,37 @@
|
||||
// ISO C9x compliant stdint.h for Microsoft Visual Studio
|
||||
// Based on ISO/IEC 9899:TC2 Committee draft (May 6, 2005) WG14/N1124
|
||||
//
|
||||
// Based on ISO/IEC 9899:TC2 Committee draft (May 6, 2005) WG14/N1124
|
||||
//
|
||||
// Copyright (c) 2006-2013 Alexander Chemeris
|
||||
//
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are met:
|
||||
//
|
||||
//
|
||||
// 1. Redistributions of source code must retain the above copyright notice,
|
||||
// this list of conditions and the following disclaimer.
|
||||
//
|
||||
//
|
||||
// 2. Redistributions in binary form must reproduce the above copyright
|
||||
// notice, this list of conditions and the following disclaimer in the
|
||||
// documentation and/or other materials provided with the distribution.
|
||||
//
|
||||
//
|
||||
// 3. Neither the name of the product nor the names of its contributors may
|
||||
// be used to endorse or promote products derived from this software
|
||||
// without specific prior written permission.
|
||||
//
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED
|
||||
// WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
|
||||
// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
|
||||
// EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||||
// OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
|
||||
// OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
|
||||
// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
|
||||
// OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
|
||||
// ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
//
|
||||
//
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// The above software in this distribution may have been modified by
|
||||
// THL A29 Limited ("Tencent Modifications").
|
||||
// The above software in this distribution may have been modified by
|
||||
// THL A29 Limited ("Tencent Modifications").
|
||||
// All Tencent Modifications are Copyright (C) 2015 THL A29 Limited.
|
||||
|
||||
#ifndef _MSC_VER // [
|
||||
@@ -45,7 +45,8 @@
|
||||
#pragma once
|
||||
#endif
|
||||
|
||||
// miloyip: Originally Visual Studio 2010 uses its own stdint.h. However it generates warning with INT64_C(), so change to use this file for vs2010.
|
||||
// miloyip: Originally Visual Studio 2010 uses its own stdint.h. However it generates warning with
|
||||
// INT64_C(), so change to use this file for vs2010.
|
||||
#if _MSC_VER >= 1600 // [
|
||||
#include <stdint.h>
|
||||
|
||||
@@ -62,12 +63,12 @@
|
||||
|
||||
// 7.18.4.1 Macros for minimum-width integer constants
|
||||
|
||||
#define INT8_C(val) val##i8
|
||||
#define INT8_C(val) val##i8
|
||||
#define INT16_C(val) val##i16
|
||||
#define INT32_C(val) val##i32
|
||||
#define INT64_C(val) val##i64
|
||||
|
||||
#define UINT8_C(val) val##ui8
|
||||
#define UINT8_C(val) val##ui8
|
||||
#define UINT16_C(val) val##ui16
|
||||
#define UINT32_C(val) val##ui32
|
||||
#define UINT64_C(val) val##ui64
|
||||
@@ -76,10 +77,10 @@
|
||||
// These #ifndef's are needed to prevent collisions with <boost/cstdint.hpp>.
|
||||
// Check out Issue 9 for the details.
|
||||
#ifndef INTMAX_C // [
|
||||
# define INTMAX_C INT64_C
|
||||
#endif // INTMAX_C ]
|
||||
#define INTMAX_C INT64_C
|
||||
#endif // INTMAX_C ]
|
||||
#ifndef UINTMAX_C // [
|
||||
# define UINTMAX_C UINT64_C
|
||||
#define UINTMAX_C UINT64_C
|
||||
#endif // UINTMAX_C ]
|
||||
|
||||
#endif // __STDC_CONSTANT_MACROS ]
|
||||
@@ -95,20 +96,19 @@
|
||||
#if defined(__cplusplus) && !defined(_M_ARM)
|
||||
extern "C" {
|
||||
#endif
|
||||
# include <wchar.h>
|
||||
#include <wchar.h>
|
||||
#if defined(__cplusplus) && !defined(_M_ARM)
|
||||
}
|
||||
#endif
|
||||
|
||||
// Define _W64 macros to mark types changing their size, like intptr_t.
|
||||
#ifndef _W64
|
||||
# if !defined(__midl) && (defined(_X86_) || defined(_M_IX86)) && _MSC_VER >= 1300
|
||||
# define _W64 __w64
|
||||
# else
|
||||
# define _W64
|
||||
# endif
|
||||
#if !defined(__midl) && (defined(_X86_) || defined(_M_IX86)) && _MSC_VER >= 1300
|
||||
#define _W64 __w64
|
||||
#else
|
||||
#define _W64
|
||||
#endif
|
||||
#endif
|
||||
|
||||
|
||||
// 7.18.1 Integer types
|
||||
|
||||
@@ -117,168 +117,166 @@ extern "C" {
|
||||
// Visual Studio 6 and Embedded Visual C++ 4 doesn't
|
||||
// realize that, e.g. char has the same size as __int8
|
||||
// so we give up on __intX for them.
|
||||
#if (_MSC_VER < 1300)
|
||||
typedef signed char int8_t;
|
||||
typedef signed short int16_t;
|
||||
typedef signed int int32_t;
|
||||
typedef unsigned char uint8_t;
|
||||
typedef unsigned short uint16_t;
|
||||
typedef unsigned int uint32_t;
|
||||
#if(_MSC_VER < 1300)
|
||||
typedef signed char int8_t;
|
||||
typedef signed short int16_t;
|
||||
typedef signed int int32_t;
|
||||
typedef unsigned char uint8_t;
|
||||
typedef unsigned short uint16_t;
|
||||
typedef unsigned int uint32_t;
|
||||
#else
|
||||
typedef signed __int8 int8_t;
|
||||
typedef signed __int16 int16_t;
|
||||
typedef signed __int32 int32_t;
|
||||
typedef unsigned __int8 uint8_t;
|
||||
typedef unsigned __int16 uint16_t;
|
||||
typedef unsigned __int32 uint32_t;
|
||||
typedef signed __int8 int8_t;
|
||||
typedef signed __int16 int16_t;
|
||||
typedef signed __int32 int32_t;
|
||||
typedef unsigned __int8 uint8_t;
|
||||
typedef unsigned __int16 uint16_t;
|
||||
typedef unsigned __int32 uint32_t;
|
||||
#endif
|
||||
typedef signed __int64 int64_t;
|
||||
typedef unsigned __int64 uint64_t;
|
||||
|
||||
typedef signed __int64 int64_t;
|
||||
typedef unsigned __int64 uint64_t;
|
||||
|
||||
// 7.18.1.2 Minimum-width integer types
|
||||
typedef int8_t int_least8_t;
|
||||
typedef int16_t int_least16_t;
|
||||
typedef int32_t int_least32_t;
|
||||
typedef int64_t int_least64_t;
|
||||
typedef uint8_t uint_least8_t;
|
||||
typedef uint16_t uint_least16_t;
|
||||
typedef uint32_t uint_least32_t;
|
||||
typedef uint64_t uint_least64_t;
|
||||
typedef int8_t int_least8_t;
|
||||
typedef int16_t int_least16_t;
|
||||
typedef int32_t int_least32_t;
|
||||
typedef int64_t int_least64_t;
|
||||
typedef uint8_t uint_least8_t;
|
||||
typedef uint16_t uint_least16_t;
|
||||
typedef uint32_t uint_least32_t;
|
||||
typedef uint64_t uint_least64_t;
|
||||
|
||||
// 7.18.1.3 Fastest minimum-width integer types
|
||||
typedef int8_t int_fast8_t;
|
||||
typedef int16_t int_fast16_t;
|
||||
typedef int32_t int_fast32_t;
|
||||
typedef int64_t int_fast64_t;
|
||||
typedef uint8_t uint_fast8_t;
|
||||
typedef uint16_t uint_fast16_t;
|
||||
typedef uint32_t uint_fast32_t;
|
||||
typedef uint64_t uint_fast64_t;
|
||||
typedef int8_t int_fast8_t;
|
||||
typedef int16_t int_fast16_t;
|
||||
typedef int32_t int_fast32_t;
|
||||
typedef int64_t int_fast64_t;
|
||||
typedef uint8_t uint_fast8_t;
|
||||
typedef uint16_t uint_fast16_t;
|
||||
typedef uint32_t uint_fast32_t;
|
||||
typedef uint64_t uint_fast64_t;
|
||||
|
||||
// 7.18.1.4 Integer types capable of holding object pointers
|
||||
#ifdef _WIN64 // [
|
||||
typedef signed __int64 intptr_t;
|
||||
typedef unsigned __int64 uintptr_t;
|
||||
#else // _WIN64 ][
|
||||
typedef _W64 signed int intptr_t;
|
||||
typedef _W64 unsigned int uintptr_t;
|
||||
#endif // _WIN64 ]
|
||||
typedef signed __int64 intptr_t;
|
||||
typedef unsigned __int64 uintptr_t;
|
||||
#else // _WIN64 ][
|
||||
typedef _W64 signed int intptr_t;
|
||||
typedef _W64 unsigned int uintptr_t;
|
||||
#endif // _WIN64 ]
|
||||
|
||||
// 7.18.1.5 Greatest-width integer types
|
||||
typedef int64_t intmax_t;
|
||||
typedef uint64_t uintmax_t;
|
||||
|
||||
typedef int64_t intmax_t;
|
||||
typedef uint64_t uintmax_t;
|
||||
|
||||
// 7.18.2 Limits of specified-width integer types
|
||||
|
||||
#if !defined(__cplusplus) || defined(__STDC_LIMIT_MACROS) // [ See footnote 220 at page 257 and footnote 221 at page 259
|
||||
#if !defined(__cplusplus) || \
|
||||
defined(__STDC_LIMIT_MACROS) // [ See footnote 220 at page 257 and footnote 221 at page 259
|
||||
|
||||
// 7.18.2.1 Limits of exact-width integer types
|
||||
#define INT8_MIN ((int8_t)_I8_MIN)
|
||||
#define INT8_MAX _I8_MAX
|
||||
#define INT16_MIN ((int16_t)_I16_MIN)
|
||||
#define INT16_MAX _I16_MAX
|
||||
#define INT32_MIN ((int32_t)_I32_MIN)
|
||||
#define INT32_MAX _I32_MAX
|
||||
#define INT64_MIN ((int64_t)_I64_MIN)
|
||||
#define INT64_MAX _I64_MAX
|
||||
#define UINT8_MAX _UI8_MAX
|
||||
#define UINT16_MAX _UI16_MAX
|
||||
#define UINT32_MAX _UI32_MAX
|
||||
#define UINT64_MAX _UI64_MAX
|
||||
#define INT8_MIN ((int8_t)_I8_MIN)
|
||||
#define INT8_MAX _I8_MAX
|
||||
#define INT16_MIN ((int16_t)_I16_MIN)
|
||||
#define INT16_MAX _I16_MAX
|
||||
#define INT32_MIN ((int32_t)_I32_MIN)
|
||||
#define INT32_MAX _I32_MAX
|
||||
#define INT64_MIN ((int64_t)_I64_MIN)
|
||||
#define INT64_MAX _I64_MAX
|
||||
#define UINT8_MAX _UI8_MAX
|
||||
#define UINT16_MAX _UI16_MAX
|
||||
#define UINT32_MAX _UI32_MAX
|
||||
#define UINT64_MAX _UI64_MAX
|
||||
|
||||
// 7.18.2.2 Limits of minimum-width integer types
|
||||
#define INT_LEAST8_MIN INT8_MIN
|
||||
#define INT_LEAST8_MAX INT8_MAX
|
||||
#define INT_LEAST16_MIN INT16_MIN
|
||||
#define INT_LEAST16_MAX INT16_MAX
|
||||
#define INT_LEAST32_MIN INT32_MIN
|
||||
#define INT_LEAST32_MAX INT32_MAX
|
||||
#define INT_LEAST64_MIN INT64_MIN
|
||||
#define INT_LEAST64_MAX INT64_MAX
|
||||
#define UINT_LEAST8_MAX UINT8_MAX
|
||||
#define UINT_LEAST16_MAX UINT16_MAX
|
||||
#define UINT_LEAST32_MAX UINT32_MAX
|
||||
#define UINT_LEAST64_MAX UINT64_MAX
|
||||
#define INT_LEAST8_MIN INT8_MIN
|
||||
#define INT_LEAST8_MAX INT8_MAX
|
||||
#define INT_LEAST16_MIN INT16_MIN
|
||||
#define INT_LEAST16_MAX INT16_MAX
|
||||
#define INT_LEAST32_MIN INT32_MIN
|
||||
#define INT_LEAST32_MAX INT32_MAX
|
||||
#define INT_LEAST64_MIN INT64_MIN
|
||||
#define INT_LEAST64_MAX INT64_MAX
|
||||
#define UINT_LEAST8_MAX UINT8_MAX
|
||||
#define UINT_LEAST16_MAX UINT16_MAX
|
||||
#define UINT_LEAST32_MAX UINT32_MAX
|
||||
#define UINT_LEAST64_MAX UINT64_MAX
|
||||
|
||||
// 7.18.2.3 Limits of fastest minimum-width integer types
|
||||
#define INT_FAST8_MIN INT8_MIN
|
||||
#define INT_FAST8_MAX INT8_MAX
|
||||
#define INT_FAST16_MIN INT16_MIN
|
||||
#define INT_FAST16_MAX INT16_MAX
|
||||
#define INT_FAST32_MIN INT32_MIN
|
||||
#define INT_FAST32_MAX INT32_MAX
|
||||
#define INT_FAST64_MIN INT64_MIN
|
||||
#define INT_FAST64_MAX INT64_MAX
|
||||
#define UINT_FAST8_MAX UINT8_MAX
|
||||
#define UINT_FAST16_MAX UINT16_MAX
|
||||
#define UINT_FAST32_MAX UINT32_MAX
|
||||
#define UINT_FAST64_MAX UINT64_MAX
|
||||
#define INT_FAST8_MIN INT8_MIN
|
||||
#define INT_FAST8_MAX INT8_MAX
|
||||
#define INT_FAST16_MIN INT16_MIN
|
||||
#define INT_FAST16_MAX INT16_MAX
|
||||
#define INT_FAST32_MIN INT32_MIN
|
||||
#define INT_FAST32_MAX INT32_MAX
|
||||
#define INT_FAST64_MIN INT64_MIN
|
||||
#define INT_FAST64_MAX INT64_MAX
|
||||
#define UINT_FAST8_MAX UINT8_MAX
|
||||
#define UINT_FAST16_MAX UINT16_MAX
|
||||
#define UINT_FAST32_MAX UINT32_MAX
|
||||
#define UINT_FAST64_MAX UINT64_MAX
|
||||
|
||||
// 7.18.2.4 Limits of integer types capable of holding object pointers
|
||||
#ifdef _WIN64 // [
|
||||
# define INTPTR_MIN INT64_MIN
|
||||
# define INTPTR_MAX INT64_MAX
|
||||
# define UINTPTR_MAX UINT64_MAX
|
||||
#define INTPTR_MIN INT64_MIN
|
||||
#define INTPTR_MAX INT64_MAX
|
||||
#define UINTPTR_MAX UINT64_MAX
|
||||
#else // _WIN64 ][
|
||||
# define INTPTR_MIN INT32_MIN
|
||||
# define INTPTR_MAX INT32_MAX
|
||||
# define UINTPTR_MAX UINT32_MAX
|
||||
#define INTPTR_MIN INT32_MIN
|
||||
#define INTPTR_MAX INT32_MAX
|
||||
#define UINTPTR_MAX UINT32_MAX
|
||||
#endif // _WIN64 ]
|
||||
|
||||
// 7.18.2.5 Limits of greatest-width integer types
|
||||
#define INTMAX_MIN INT64_MIN
|
||||
#define INTMAX_MAX INT64_MAX
|
||||
#define UINTMAX_MAX UINT64_MAX
|
||||
#define INTMAX_MIN INT64_MIN
|
||||
#define INTMAX_MAX INT64_MAX
|
||||
#define UINTMAX_MAX UINT64_MAX
|
||||
|
||||
// 7.18.3 Limits of other integer types
|
||||
|
||||
#ifdef _WIN64 // [
|
||||
# define PTRDIFF_MIN _I64_MIN
|
||||
# define PTRDIFF_MAX _I64_MAX
|
||||
#else // _WIN64 ][
|
||||
# define PTRDIFF_MIN _I32_MIN
|
||||
# define PTRDIFF_MAX _I32_MAX
|
||||
#endif // _WIN64 ]
|
||||
#define PTRDIFF_MIN _I64_MIN
|
||||
#define PTRDIFF_MAX _I64_MAX
|
||||
#else // _WIN64 ][
|
||||
#define PTRDIFF_MIN _I32_MIN
|
||||
#define PTRDIFF_MAX _I32_MAX
|
||||
#endif // _WIN64 ]
|
||||
|
||||
#define SIG_ATOMIC_MIN INT_MIN
|
||||
#define SIG_ATOMIC_MAX INT_MAX
|
||||
#define SIG_ATOMIC_MIN INT_MIN
|
||||
#define SIG_ATOMIC_MAX INT_MAX
|
||||
|
||||
#ifndef SIZE_MAX // [
|
||||
# ifdef _WIN64 // [
|
||||
# define SIZE_MAX _UI64_MAX
|
||||
# else // _WIN64 ][
|
||||
# define SIZE_MAX _UI32_MAX
|
||||
# endif // _WIN64 ]
|
||||
#endif // SIZE_MAX ]
|
||||
#ifdef _WIN64 // [
|
||||
#define SIZE_MAX _UI64_MAX
|
||||
#else // _WIN64 ][
|
||||
#define SIZE_MAX _UI32_MAX
|
||||
#endif // _WIN64 ]
|
||||
#endif // SIZE_MAX ]
|
||||
|
||||
// WCHAR_MIN and WCHAR_MAX are also defined in <wchar.h>
|
||||
#ifndef WCHAR_MIN // [
|
||||
# define WCHAR_MIN 0
|
||||
#endif // WCHAR_MIN ]
|
||||
#define WCHAR_MIN 0
|
||||
#endif // WCHAR_MIN ]
|
||||
#ifndef WCHAR_MAX // [
|
||||
# define WCHAR_MAX _UI16_MAX
|
||||
#endif // WCHAR_MAX ]
|
||||
#define WCHAR_MAX _UI16_MAX
|
||||
#endif // WCHAR_MAX ]
|
||||
|
||||
#define WINT_MIN 0
|
||||
#define WINT_MAX _UI16_MAX
|
||||
#define WINT_MIN 0
|
||||
#define WINT_MAX _UI16_MAX
|
||||
|
||||
#endif // __STDC_LIMIT_MACROS ]
|
||||
|
||||
|
||||
// 7.18.4 Limits of other integer types
|
||||
|
||||
#if !defined(__cplusplus) || defined(__STDC_CONSTANT_MACROS) // [ See footnote 224 at page 260
|
||||
|
||||
// 7.18.4.1 Macros for minimum-width integer constants
|
||||
|
||||
#define INT8_C(val) val##i8
|
||||
#define INT8_C(val) val##i8
|
||||
#define INT16_C(val) val##i16
|
||||
#define INT32_C(val) val##i32
|
||||
#define INT64_C(val) val##i64
|
||||
|
||||
#define UINT8_C(val) val##ui8
|
||||
#define UINT8_C(val) val##ui8
|
||||
#define UINT16_C(val) val##ui16
|
||||
#define UINT32_C(val) val##ui32
|
||||
#define UINT64_C(val) val##ui64
|
||||
@@ -287,10 +285,10 @@ typedef uint64_t uintmax_t;
|
||||
// These #ifndef's are needed to prevent collisions with <boost/cstdint.hpp>.
|
||||
// Check out Issue 9 for the details.
|
||||
#ifndef INTMAX_C // [
|
||||
# define INTMAX_C INT64_C
|
||||
#endif // INTMAX_C ]
|
||||
#define INTMAX_C INT64_C
|
||||
#endif // INTMAX_C ]
|
||||
#ifndef UINTMAX_C // [
|
||||
# define UINTMAX_C UINT64_C
|
||||
#define UINTMAX_C UINT64_C
|
||||
#endif // UINTMAX_C ]
|
||||
|
||||
#endif // __STDC_CONSTANT_MACROS ]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// Tencent is pleased to support the open source community by making RapidJSON available.
|
||||
//
|
||||
//
|
||||
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
|
||||
//
|
||||
// Licensed under the MIT License (the "License"); you may not use this file except
|
||||
@@ -7,9 +7,9 @@
|
||||
//
|
||||
// http://opensource.org/licenses/MIT
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations under the License.
|
||||
|
||||
#ifndef RAPIDJSON_OSTREAMWRAPPER_H_
|
||||
@@ -40,29 +40,46 @@ RAPIDJSON_NAMESPACE_BEGIN
|
||||
|
||||
\tparam StreamType Class derived from \c std::basic_ostream.
|
||||
*/
|
||||
|
||||
|
||||
template <typename StreamType>
|
||||
class BasicOStreamWrapper {
|
||||
public:
|
||||
class BasicOStreamWrapper
|
||||
{
|
||||
public:
|
||||
typedef typename StreamType::char_type Ch;
|
||||
BasicOStreamWrapper(StreamType& stream) : stream_(stream) {}
|
||||
|
||||
void Put(Ch c) {
|
||||
stream_.put(c);
|
||||
}
|
||||
void Put(Ch c) { stream_.put(c); }
|
||||
|
||||
void Flush() {
|
||||
stream_.flush();
|
||||
}
|
||||
void Flush() { stream_.flush(); }
|
||||
|
||||
// Not implemented
|
||||
char Peek() const { RAPIDJSON_ASSERT(false); return 0; }
|
||||
char Take() { RAPIDJSON_ASSERT(false); return 0; }
|
||||
size_t Tell() const { RAPIDJSON_ASSERT(false); return 0; }
|
||||
char* PutBegin() { RAPIDJSON_ASSERT(false); return 0; }
|
||||
size_t PutEnd(char*) { RAPIDJSON_ASSERT(false); return 0; }
|
||||
char Peek() const
|
||||
{
|
||||
RAPIDJSON_ASSERT(false);
|
||||
return 0;
|
||||
}
|
||||
char Take()
|
||||
{
|
||||
RAPIDJSON_ASSERT(false);
|
||||
return 0;
|
||||
}
|
||||
size_t Tell() const
|
||||
{
|
||||
RAPIDJSON_ASSERT(false);
|
||||
return 0;
|
||||
}
|
||||
char* PutBegin()
|
||||
{
|
||||
RAPIDJSON_ASSERT(false);
|
||||
return 0;
|
||||
}
|
||||
size_t PutEnd(char*)
|
||||
{
|
||||
RAPIDJSON_ASSERT(false);
|
||||
return 0;
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
BasicOStreamWrapper(const BasicOStreamWrapper&);
|
||||
BasicOStreamWrapper& operator=(const BasicOStreamWrapper&);
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,5 @@
|
||||
// Tencent is pleased to support the open source community by making RapidJSON available.
|
||||
//
|
||||
//
|
||||
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
|
||||
//
|
||||
// Licensed under the MIT License (the "License"); you may not use this file except
|
||||
@@ -7,9 +7,9 @@
|
||||
//
|
||||
// http://opensource.org/licenses/MIT
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations under the License.
|
||||
|
||||
#ifndef RAPIDJSON_PRETTYWRITER_H_
|
||||
@@ -24,7 +24,7 @@ RAPIDJSON_DIAG_OFF(effc++)
|
||||
|
||||
#if defined(__clang__)
|
||||
RAPIDJSON_DIAG_PUSH
|
||||
RAPIDJSON_DIAG_OFF(c++98-compat)
|
||||
RAPIDJSON_DIAG_OFF(c++ 98 - compat)
|
||||
#endif
|
||||
|
||||
RAPIDJSON_NAMESPACE_BEGIN
|
||||
@@ -32,8 +32,9 @@ RAPIDJSON_NAMESPACE_BEGIN
|
||||
//! Combination of PrettyWriter format flags.
|
||||
/*! \see PrettyWriter::SetFormatOptions
|
||||
*/
|
||||
enum PrettyFormatOptions {
|
||||
kFormatDefault = 0, //!< Default pretty formatting.
|
||||
enum PrettyFormatOptions
|
||||
{
|
||||
kFormatDefault = 0, //!< Default pretty formatting.
|
||||
kFormatSingleLineArray = 1 //!< Format arrays on a single line.
|
||||
};
|
||||
|
||||
@@ -44,9 +45,15 @@ enum PrettyFormatOptions {
|
||||
\tparam TargetEncoding Encoding of output stream.
|
||||
\tparam StackAllocator Type of allocator for allocating memory of stack.
|
||||
*/
|
||||
template<typename OutputStream, typename SourceEncoding = UTF8<>, typename TargetEncoding = UTF8<>, typename StackAllocator = CrtAllocator, unsigned writeFlags = kWriteDefaultFlags>
|
||||
class PrettyWriter : public Writer<OutputStream, SourceEncoding, TargetEncoding, StackAllocator, writeFlags> {
|
||||
public:
|
||||
template <typename OutputStream,
|
||||
typename SourceEncoding = UTF8<>,
|
||||
typename TargetEncoding = UTF8<>,
|
||||
typename StackAllocator = CrtAllocator,
|
||||
unsigned writeFlags = kWriteDefaultFlags>
|
||||
class PrettyWriter
|
||||
: public Writer<OutputStream, SourceEncoding, TargetEncoding, StackAllocator, writeFlags>
|
||||
{
|
||||
public:
|
||||
typedef Writer<OutputStream, SourceEncoding, TargetEncoding, StackAllocator, writeFlags> Base;
|
||||
typedef typename Base::Ch Ch;
|
||||
|
||||
@@ -55,34 +62,54 @@ public:
|
||||
\param allocator User supplied allocator. If it is null, it will create a private one.
|
||||
\param levelDepth Initial capacity of stack.
|
||||
*/
|
||||
explicit PrettyWriter(OutputStream& os, StackAllocator* allocator = 0, size_t levelDepth = Base::kDefaultLevelDepth) :
|
||||
Base(os, allocator, levelDepth), indentChar_(' '), indentCharCount_(4), formatOptions_(kFormatDefault) {}
|
||||
explicit PrettyWriter(OutputStream& os,
|
||||
StackAllocator* allocator = 0,
|
||||
size_t levelDepth = Base::kDefaultLevelDepth)
|
||||
: Base(os, allocator, levelDepth),
|
||||
indentChar_(' '),
|
||||
indentCharCount_(4),
|
||||
formatOptions_(kFormatDefault)
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
explicit PrettyWriter(StackAllocator* allocator = 0, size_t levelDepth = Base::kDefaultLevelDepth) :
|
||||
Base(allocator, levelDepth), indentChar_(' '), indentCharCount_(4), formatOptions_(kFormatDefault) {}
|
||||
explicit PrettyWriter(StackAllocator* allocator = 0,
|
||||
size_t levelDepth = Base::kDefaultLevelDepth)
|
||||
: Base(allocator, levelDepth),
|
||||
indentChar_(' '),
|
||||
indentCharCount_(4),
|
||||
formatOptions_(kFormatDefault)
|
||||
{
|
||||
}
|
||||
|
||||
#if RAPIDJSON_HAS_CXX11_RVALUE_REFS
|
||||
PrettyWriter(PrettyWriter&& rhs) :
|
||||
Base(std::forward<PrettyWriter>(rhs)), indentChar_(rhs.indentChar_), indentCharCount_(rhs.indentCharCount_), formatOptions_(rhs.formatOptions_) {}
|
||||
PrettyWriter(PrettyWriter&& rhs)
|
||||
: Base(std::forward<PrettyWriter>(rhs)),
|
||||
indentChar_(rhs.indentChar_),
|
||||
indentCharCount_(rhs.indentCharCount_),
|
||||
formatOptions_(rhs.formatOptions_)
|
||||
{
|
||||
}
|
||||
#endif
|
||||
|
||||
//! Set custom indentation.
|
||||
/*! \param indentChar Character for indentation. Must be whitespace character (' ', '\\t', '\\n', '\\r').
|
||||
\param indentCharCount Number of indent characters for each indentation level.
|
||||
\note The default indentation is 4 spaces.
|
||||
/*! \param indentChar Character for indentation. Must be whitespace character (' ', '\\t',
|
||||
'\\n', '\\r'). \param indentCharCount Number of indent characters for each indentation
|
||||
level. \note The default indentation is 4 spaces.
|
||||
*/
|
||||
PrettyWriter& SetIndent(Ch indentChar, unsigned indentCharCount) {
|
||||
RAPIDJSON_ASSERT(indentChar == ' ' || indentChar == '\t' || indentChar == '\n' || indentChar == '\r');
|
||||
indentChar_ = indentChar;
|
||||
PrettyWriter& SetIndent(Ch indentChar, unsigned indentCharCount)
|
||||
{
|
||||
RAPIDJSON_ASSERT(indentChar == ' ' || indentChar == '\t' || indentChar == '\n' ||
|
||||
indentChar == '\r');
|
||||
indentChar_ = indentChar;
|
||||
indentCharCount_ = indentCharCount;
|
||||
return *this;
|
||||
}
|
||||
|
||||
//! Set pretty writer formatting options.
|
||||
/*! \param options Formatting options.
|
||||
*/
|
||||
PrettyWriter& SetFormatOptions(PrettyFormatOptions options) {
|
||||
*/
|
||||
PrettyWriter& SetFormatOptions(PrettyFormatOptions options)
|
||||
{
|
||||
formatOptions_ = options;
|
||||
return *this;
|
||||
}
|
||||
@@ -92,22 +119,52 @@ public:
|
||||
*/
|
||||
//@{
|
||||
|
||||
bool Null() { PrettyPrefix(kNullType); return Base::EndValue(Base::WriteNull()); }
|
||||
bool Bool(bool b) { PrettyPrefix(b ? kTrueType : kFalseType); return Base::EndValue(Base::WriteBool(b)); }
|
||||
bool Int(int i) { PrettyPrefix(kNumberType); return Base::EndValue(Base::WriteInt(i)); }
|
||||
bool Uint(unsigned u) { PrettyPrefix(kNumberType); return Base::EndValue(Base::WriteUint(u)); }
|
||||
bool Int64(int64_t i64) { PrettyPrefix(kNumberType); return Base::EndValue(Base::WriteInt64(i64)); }
|
||||
bool Uint64(uint64_t u64) { PrettyPrefix(kNumberType); return Base::EndValue(Base::WriteUint64(u64)); }
|
||||
bool Double(double d) { PrettyPrefix(kNumberType); return Base::EndValue(Base::WriteDouble(d)); }
|
||||
bool Null()
|
||||
{
|
||||
PrettyPrefix(kNullType);
|
||||
return Base::EndValue(Base::WriteNull());
|
||||
}
|
||||
bool Bool(bool b)
|
||||
{
|
||||
PrettyPrefix(b ? kTrueType : kFalseType);
|
||||
return Base::EndValue(Base::WriteBool(b));
|
||||
}
|
||||
bool Int(int i)
|
||||
{
|
||||
PrettyPrefix(kNumberType);
|
||||
return Base::EndValue(Base::WriteInt(i));
|
||||
}
|
||||
bool Uint(unsigned u)
|
||||
{
|
||||
PrettyPrefix(kNumberType);
|
||||
return Base::EndValue(Base::WriteUint(u));
|
||||
}
|
||||
bool Int64(int64_t i64)
|
||||
{
|
||||
PrettyPrefix(kNumberType);
|
||||
return Base::EndValue(Base::WriteInt64(i64));
|
||||
}
|
||||
bool Uint64(uint64_t u64)
|
||||
{
|
||||
PrettyPrefix(kNumberType);
|
||||
return Base::EndValue(Base::WriteUint64(u64));
|
||||
}
|
||||
bool Double(double d)
|
||||
{
|
||||
PrettyPrefix(kNumberType);
|
||||
return Base::EndValue(Base::WriteDouble(d));
|
||||
}
|
||||
|
||||
bool RawNumber(const Ch* str, SizeType length, bool copy = false) {
|
||||
bool RawNumber(const Ch* str, SizeType length, bool copy = false)
|
||||
{
|
||||
RAPIDJSON_ASSERT(str != 0);
|
||||
(void)copy;
|
||||
PrettyPrefix(kNumberType);
|
||||
return Base::EndValue(Base::WriteString(str, length));
|
||||
}
|
||||
|
||||
bool String(const Ch* str, SizeType length, bool copy = false) {
|
||||
bool String(const Ch* str, SizeType length, bool copy = false)
|
||||
{
|
||||
RAPIDJSON_ASSERT(str != 0);
|
||||
(void)copy;
|
||||
PrettyPrefix(kStringType);
|
||||
@@ -115,65 +172,76 @@ public:
|
||||
}
|
||||
|
||||
#if RAPIDJSON_HAS_STDSTRING
|
||||
bool String(const std::basic_string<Ch>& str) {
|
||||
bool String(const std::basic_string<Ch>& str)
|
||||
{
|
||||
return String(str.data(), SizeType(str.size()));
|
||||
}
|
||||
#endif
|
||||
|
||||
bool StartObject() {
|
||||
bool StartObject()
|
||||
{
|
||||
PrettyPrefix(kObjectType);
|
||||
new (Base::level_stack_.template Push<typename Base::Level>()) typename Base::Level(false);
|
||||
new(Base::level_stack_.template Push<typename Base::Level>()) typename Base::Level(false);
|
||||
return Base::WriteStartObject();
|
||||
}
|
||||
|
||||
bool Key(const Ch* str, SizeType length, bool copy = false) { return String(str, length, copy); }
|
||||
bool Key(const Ch* str, SizeType length, bool copy = false)
|
||||
{
|
||||
return String(str, length, copy);
|
||||
}
|
||||
|
||||
#if RAPIDJSON_HAS_STDSTRING
|
||||
bool Key(const std::basic_string<Ch>& str) {
|
||||
return Key(str.data(), SizeType(str.size()));
|
||||
}
|
||||
bool Key(const std::basic_string<Ch>& str) { return Key(str.data(), SizeType(str.size())); }
|
||||
#endif
|
||||
|
||||
bool EndObject(SizeType memberCount = 0) {
|
||||
|
||||
bool EndObject(SizeType memberCount = 0)
|
||||
{
|
||||
(void)memberCount;
|
||||
RAPIDJSON_ASSERT(Base::level_stack_.GetSize() >= sizeof(typename Base::Level)); // not inside an Object
|
||||
RAPIDJSON_ASSERT(!Base::level_stack_.template Top<typename Base::Level>()->inArray); // currently inside an Array, not Object
|
||||
RAPIDJSON_ASSERT(0 == Base::level_stack_.template Top<typename Base::Level>()->valueCount % 2); // Object has a Key without a Value
|
||||
|
||||
RAPIDJSON_ASSERT(Base::level_stack_.GetSize() >=
|
||||
sizeof(typename Base::Level)); // not inside an Object
|
||||
RAPIDJSON_ASSERT(!Base::level_stack_.template Top<typename Base::Level>()
|
||||
->inArray); // currently inside an Array, not Object
|
||||
RAPIDJSON_ASSERT(0 == Base::level_stack_.template Top<typename Base::Level>()->valueCount %
|
||||
2); // Object has a Key without a Value
|
||||
|
||||
bool empty = Base::level_stack_.template Pop<typename Base::Level>(1)->valueCount == 0;
|
||||
|
||||
if (!empty) {
|
||||
if(!empty)
|
||||
{
|
||||
Base::os_->Put('\n');
|
||||
WriteIndent();
|
||||
}
|
||||
bool ret = Base::EndValue(Base::WriteEndObject());
|
||||
(void)ret;
|
||||
RAPIDJSON_ASSERT(ret == true);
|
||||
if (Base::level_stack_.Empty()) // end of json text
|
||||
if(Base::level_stack_.Empty()) // end of json text
|
||||
Base::Flush();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool StartArray() {
|
||||
bool StartArray()
|
||||
{
|
||||
PrettyPrefix(kArrayType);
|
||||
new (Base::level_stack_.template Push<typename Base::Level>()) typename Base::Level(true);
|
||||
new(Base::level_stack_.template Push<typename Base::Level>()) typename Base::Level(true);
|
||||
return Base::WriteStartArray();
|
||||
}
|
||||
|
||||
bool EndArray(SizeType memberCount = 0) {
|
||||
bool EndArray(SizeType memberCount = 0)
|
||||
{
|
||||
(void)memberCount;
|
||||
RAPIDJSON_ASSERT(Base::level_stack_.GetSize() >= sizeof(typename Base::Level));
|
||||
RAPIDJSON_ASSERT(Base::level_stack_.template Top<typename Base::Level>()->inArray);
|
||||
bool empty = Base::level_stack_.template Pop<typename Base::Level>(1)->valueCount == 0;
|
||||
|
||||
if (!empty && !(formatOptions_ & kFormatSingleLineArray)) {
|
||||
if(!empty && !(formatOptions_ & kFormatSingleLineArray))
|
||||
{
|
||||
Base::os_->Put('\n');
|
||||
WriteIndent();
|
||||
}
|
||||
bool ret = Base::EndValue(Base::WriteEndArray());
|
||||
(void)ret;
|
||||
RAPIDJSON_ASSERT(ret == true);
|
||||
if (Base::level_stack_.Empty()) // end of json text
|
||||
if(Base::level_stack_.Empty()) // end of json text
|
||||
Base::Flush();
|
||||
return true;
|
||||
}
|
||||
@@ -193,42 +261,51 @@ public:
|
||||
/*!
|
||||
For user to write a stringified JSON as a value.
|
||||
|
||||
\param json A well-formed JSON value. It should not contain null character within [0, length - 1] range.
|
||||
\param length Length of the json.
|
||||
\param type Type of the root of json.
|
||||
\note When using PrettyWriter::RawValue(), the result json may not be indented correctly.
|
||||
\param json A well-formed JSON value. It should not contain null character within [0, length
|
||||
- 1] range. \param length Length of the json. \param type Type of the root of json. \note
|
||||
When using PrettyWriter::RawValue(), the result json may not be indented correctly.
|
||||
*/
|
||||
bool RawValue(const Ch* json, size_t length, Type type) {
|
||||
bool RawValue(const Ch* json, size_t length, Type type)
|
||||
{
|
||||
RAPIDJSON_ASSERT(json != 0);
|
||||
PrettyPrefix(type);
|
||||
return Base::EndValue(Base::WriteRawValue(json, length));
|
||||
}
|
||||
|
||||
protected:
|
||||
void PrettyPrefix(Type type) {
|
||||
protected:
|
||||
void PrettyPrefix(Type type)
|
||||
{
|
||||
(void)type;
|
||||
if (Base::level_stack_.GetSize() != 0) { // this value is not at root
|
||||
if(Base::level_stack_.GetSize() != 0)
|
||||
{ // this value is not at root
|
||||
typename Base::Level* level = Base::level_stack_.template Top<typename Base::Level>();
|
||||
|
||||
if (level->inArray) {
|
||||
if (level->valueCount > 0) {
|
||||
if(level->inArray)
|
||||
{
|
||||
if(level->valueCount > 0)
|
||||
{
|
||||
Base::os_->Put(','); // add comma if it is not the first element in array
|
||||
if (formatOptions_ & kFormatSingleLineArray)
|
||||
if(formatOptions_ & kFormatSingleLineArray)
|
||||
Base::os_->Put(' ');
|
||||
}
|
||||
|
||||
if (!(formatOptions_ & kFormatSingleLineArray)) {
|
||||
if(!(formatOptions_ & kFormatSingleLineArray))
|
||||
{
|
||||
Base::os_->Put('\n');
|
||||
WriteIndent();
|
||||
}
|
||||
}
|
||||
else { // in object
|
||||
if (level->valueCount > 0) {
|
||||
if (level->valueCount % 2 == 0) {
|
||||
else
|
||||
{ // in object
|
||||
if(level->valueCount > 0)
|
||||
{
|
||||
if(level->valueCount % 2 == 0)
|
||||
{
|
||||
Base::os_->Put(',');
|
||||
Base::os_->Put('\n');
|
||||
}
|
||||
else {
|
||||
else
|
||||
{
|
||||
Base::os_->Put(':');
|
||||
Base::os_->Put(' ');
|
||||
}
|
||||
@@ -236,21 +313,25 @@ protected:
|
||||
else
|
||||
Base::os_->Put('\n');
|
||||
|
||||
if (level->valueCount % 2 == 0)
|
||||
if(level->valueCount % 2 == 0)
|
||||
WriteIndent();
|
||||
}
|
||||
if (!level->inArray && level->valueCount % 2 == 0)
|
||||
RAPIDJSON_ASSERT(type == kStringType); // if it's in object, then even number should be a name
|
||||
if(!level->inArray && level->valueCount % 2 == 0)
|
||||
RAPIDJSON_ASSERT(
|
||||
type == kStringType); // if it's in object, then even number should be a name
|
||||
level->valueCount++;
|
||||
}
|
||||
else {
|
||||
RAPIDJSON_ASSERT(!Base::hasRoot_); // Should only has one and only one root.
|
||||
else
|
||||
{
|
||||
RAPIDJSON_ASSERT(!Base::hasRoot_); // Should only has one and only one root.
|
||||
Base::hasRoot_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
void WriteIndent() {
|
||||
size_t count = (Base::level_stack_.GetSize() / sizeof(typename Base::Level)) * indentCharCount_;
|
||||
void WriteIndent()
|
||||
{
|
||||
size_t count =
|
||||
(Base::level_stack_.GetSize() / sizeof(typename Base::Level)) * indentCharCount_;
|
||||
PutN(*Base::os_, static_cast<typename OutputStream::Ch>(indentChar_), count);
|
||||
}
|
||||
|
||||
@@ -258,7 +339,7 @@ protected:
|
||||
unsigned indentCharCount_;
|
||||
PrettyFormatOptions formatOptions_;
|
||||
|
||||
private:
|
||||
private:
|
||||
// Prohibit copy constructor & assignment operator.
|
||||
PrettyWriter(const PrettyWriter&);
|
||||
PrettyWriter& operator=(const PrettyWriter&);
|
||||
|
||||
@@ -36,8 +36,8 @@
|
||||
different translation units of a single application.
|
||||
*/
|
||||
|
||||
#include <cstdlib> // malloc(), realloc(), free(), size_t
|
||||
#include <cstring> // memset(), memcpy(), memmove(), memcmp()
|
||||
#include <cstdlib> // malloc(), realloc(), free(), size_t
|
||||
#include <cstring> // memset(), memcpy(), memmove(), memcmp()
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// RAPIDJSON_VERSION_STRING
|
||||
@@ -226,8 +226,8 @@
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// RAPIDJSON_ENDIAN
|
||||
#define RAPIDJSON_LITTLEENDIAN 0 //!< Little endian machine
|
||||
#define RAPIDJSON_BIGENDIAN 1 //!< Big endian machine
|
||||
#define RAPIDJSON_LITTLEENDIAN 0 //!< Little endian machine
|
||||
#define RAPIDJSON_BIGENDIAN 1 //!< Big endian machine
|
||||
|
||||
//! Endianness of the machine.
|
||||
/*!
|
||||
@@ -244,41 +244,46 @@
|
||||
*/
|
||||
#ifndef RAPIDJSON_ENDIAN
|
||||
// Detect with GCC 4.6's macro
|
||||
# ifdef __BYTE_ORDER__
|
||||
# if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
|
||||
# define RAPIDJSON_ENDIAN RAPIDJSON_LITTLEENDIAN
|
||||
# elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
||||
# define RAPIDJSON_ENDIAN RAPIDJSON_BIGENDIAN
|
||||
# else
|
||||
# error Unknown machine endianness detected. User needs to define RAPIDJSON_ENDIAN.
|
||||
# endif // __BYTE_ORDER__
|
||||
#ifdef __BYTE_ORDER__
|
||||
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
|
||||
#define RAPIDJSON_ENDIAN RAPIDJSON_LITTLEENDIAN
|
||||
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
||||
#define RAPIDJSON_ENDIAN RAPIDJSON_BIGENDIAN
|
||||
#else
|
||||
#error Unknown machine endianness detected. User needs to define RAPIDJSON_ENDIAN.
|
||||
#endif // __BYTE_ORDER__
|
||||
// Detect with GLIBC's endian.h
|
||||
# elif defined(__GLIBC__)
|
||||
# include <endian.h>
|
||||
# if (__BYTE_ORDER == __LITTLE_ENDIAN)
|
||||
# define RAPIDJSON_ENDIAN RAPIDJSON_LITTLEENDIAN
|
||||
# elif (__BYTE_ORDER == __BIG_ENDIAN)
|
||||
# define RAPIDJSON_ENDIAN RAPIDJSON_BIGENDIAN
|
||||
# else
|
||||
# error Unknown machine endianness detected. User needs to define RAPIDJSON_ENDIAN.
|
||||
# endif // __GLIBC__
|
||||
#elif defined(__GLIBC__)
|
||||
#include <endian.h>
|
||||
#if(__BYTE_ORDER == __LITTLE_ENDIAN)
|
||||
#define RAPIDJSON_ENDIAN RAPIDJSON_LITTLEENDIAN
|
||||
#elif(__BYTE_ORDER == __BIG_ENDIAN)
|
||||
#define RAPIDJSON_ENDIAN RAPIDJSON_BIGENDIAN
|
||||
#else
|
||||
#error Unknown machine endianness detected. User needs to define RAPIDJSON_ENDIAN.
|
||||
#endif // __GLIBC__
|
||||
// Detect with _LITTLE_ENDIAN and _BIG_ENDIAN macro
|
||||
# elif defined(_LITTLE_ENDIAN) && !defined(_BIG_ENDIAN)
|
||||
# define RAPIDJSON_ENDIAN RAPIDJSON_LITTLEENDIAN
|
||||
# elif defined(_BIG_ENDIAN) && !defined(_LITTLE_ENDIAN)
|
||||
# define RAPIDJSON_ENDIAN RAPIDJSON_BIGENDIAN
|
||||
#elif defined(_LITTLE_ENDIAN) && !defined(_BIG_ENDIAN)
|
||||
#define RAPIDJSON_ENDIAN RAPIDJSON_LITTLEENDIAN
|
||||
#elif defined(_BIG_ENDIAN) && !defined(_LITTLE_ENDIAN)
|
||||
#define RAPIDJSON_ENDIAN RAPIDJSON_BIGENDIAN
|
||||
// Detect with architecture macros
|
||||
# elif defined(__sparc) || defined(__sparc__) || defined(_POWER) || defined(__powerpc__) || defined(__ppc__) || defined(__ppc64__) || defined(__hpux) || defined(__hppa) || defined(_MIPSEB) || defined(_POWER) || defined(__s390__)
|
||||
# define RAPIDJSON_ENDIAN RAPIDJSON_BIGENDIAN
|
||||
# elif defined(__i386__) || defined(__alpha__) || defined(__ia64) || defined(__ia64__) || defined(_M_IX86) || defined(_M_IA64) || defined(_M_ALPHA) || defined(__amd64) || defined(__amd64__) || defined(_M_AMD64) || defined(__x86_64) || defined(__x86_64__) || defined(_M_X64) || defined(__bfin__)
|
||||
# define RAPIDJSON_ENDIAN RAPIDJSON_LITTLEENDIAN
|
||||
# elif defined(_MSC_VER) && (defined(_M_ARM) || defined(_M_ARM64))
|
||||
# define RAPIDJSON_ENDIAN RAPIDJSON_LITTLEENDIAN
|
||||
# elif defined(RAPIDJSON_DOXYGEN_RUNNING)
|
||||
# define RAPIDJSON_ENDIAN
|
||||
# else
|
||||
# error Unknown machine endianness detected. User needs to define RAPIDJSON_ENDIAN.
|
||||
# endif
|
||||
#elif defined(__sparc) || defined(__sparc__) || defined(_POWER) || defined(__powerpc__) || \
|
||||
defined(__ppc__) || defined(__ppc64__) || defined(__hpux) || defined(__hppa) || \
|
||||
defined(_MIPSEB) || defined(_POWER) || defined(__s390__)
|
||||
#define RAPIDJSON_ENDIAN RAPIDJSON_BIGENDIAN
|
||||
#elif defined(__i386__) || defined(__alpha__) || defined(__ia64) || defined(__ia64__) || \
|
||||
defined(_M_IX86) || defined(_M_IA64) || defined(_M_ALPHA) || defined(__amd64) || \
|
||||
defined(__amd64__) || defined(_M_AMD64) || defined(__x86_64) || defined(__x86_64__) || \
|
||||
defined(_M_X64) || defined(__bfin__)
|
||||
#define RAPIDJSON_ENDIAN RAPIDJSON_LITTLEENDIAN
|
||||
#elif defined(_MSC_VER) && (defined(_M_ARM) || defined(_M_ARM64))
|
||||
#define RAPIDJSON_ENDIAN RAPIDJSON_LITTLEENDIAN
|
||||
#elif defined(RAPIDJSON_DOXYGEN_RUNNING)
|
||||
#define RAPIDJSON_ENDIAN
|
||||
#else
|
||||
#error Unknown machine endianness detected. User needs to define RAPIDJSON_ENDIAN.
|
||||
#endif
|
||||
#endif // RAPIDJSON_ENDIAN
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -286,7 +291,8 @@
|
||||
|
||||
//! Whether using 64-bit architecture
|
||||
#ifndef RAPIDJSON_64BIT
|
||||
#if defined(__LP64__) || (defined(__x86_64__) && defined(__ILP32__)) || defined(_WIN64) || defined(__EMSCRIPTEN__)
|
||||
#if defined(__LP64__) || (defined(__x86_64__) && defined(__ILP32__)) || defined(_WIN64) || \
|
||||
defined(__EMSCRIPTEN__)
|
||||
#define RAPIDJSON_64BIT 1
|
||||
#else
|
||||
#define RAPIDJSON_64BIT 0
|
||||
@@ -317,7 +323,8 @@
|
||||
Use this macro to define 64-bit constants by a pair of 32-bit integer.
|
||||
*/
|
||||
#ifndef RAPIDJSON_UINT64_C2
|
||||
#define RAPIDJSON_UINT64_C2(high32, low32) ((static_cast<uint64_t>(high32) << 32) | static_cast<uint64_t>(low32))
|
||||
#define RAPIDJSON_UINT64_C2(high32, low32) \
|
||||
((static_cast<uint64_t>(high32) << 32) | static_cast<uint64_t>(low32))
|
||||
#endif
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -327,12 +334,13 @@
|
||||
/*!
|
||||
\ingroup RAPIDJSON_CONFIG
|
||||
|
||||
This optimization uses the fact that current X86-64 architecture only implement lower 48-bit virtual address.
|
||||
The higher 16-bit can be used for storing other data.
|
||||
\c GenericValue uses this optimization to reduce its size form 24 bytes to 16 bytes in 64-bit architecture.
|
||||
This optimization uses the fact that current X86-64 architecture only implement lower 48-bit
|
||||
virtual address. The higher 16-bit can be used for storing other data. \c GenericValue uses this
|
||||
optimization to reduce its size form 24 bytes to 16 bytes in 64-bit architecture.
|
||||
*/
|
||||
#ifndef RAPIDJSON_48BITPOINTER_OPTIMIZATION
|
||||
#if defined(__amd64__) || defined(__amd64) || defined(__x86_64__) || defined(__x86_64) || defined(_M_X64) || defined(_M_AMD64)
|
||||
#if defined(__amd64__) || defined(__amd64) || defined(__x86_64__) || defined(__x86_64) || \
|
||||
defined(_M_X64) || defined(_M_AMD64)
|
||||
#define RAPIDJSON_48BITPOINTER_OPTIMIZATION 1
|
||||
#else
|
||||
#define RAPIDJSON_48BITPOINTER_OPTIMIZATION 0
|
||||
@@ -343,8 +351,14 @@
|
||||
#if RAPIDJSON_64BIT != 1
|
||||
#error RAPIDJSON_48BITPOINTER_OPTIMIZATION can only be set to 1 when RAPIDJSON_64BIT=1
|
||||
#endif
|
||||
#define RAPIDJSON_SETPOINTER(type, p, x) (p = reinterpret_cast<type *>((reinterpret_cast<uintptr_t>(p) & static_cast<uintptr_t>(RAPIDJSON_UINT64_C2(0xFFFF0000, 0x00000000))) | reinterpret_cast<uintptr_t>(reinterpret_cast<const void*>(x))))
|
||||
#define RAPIDJSON_GETPOINTER(type, p) (reinterpret_cast<type *>(reinterpret_cast<uintptr_t>(p) & static_cast<uintptr_t>(RAPIDJSON_UINT64_C2(0x0000FFFF, 0xFFFFFFFF))))
|
||||
#define RAPIDJSON_SETPOINTER(type, p, x) \
|
||||
(p = reinterpret_cast<type*>( \
|
||||
(reinterpret_cast<uintptr_t>(p) & \
|
||||
static_cast<uintptr_t>(RAPIDJSON_UINT64_C2(0xFFFF0000, 0x00000000))) | \
|
||||
reinterpret_cast<uintptr_t>(reinterpret_cast<const void*>(x))))
|
||||
#define RAPIDJSON_GETPOINTER(type, p) \
|
||||
(reinterpret_cast<type*>(reinterpret_cast<uintptr_t>(p) & \
|
||||
static_cast<uintptr_t>(RAPIDJSON_UINT64_C2(0x0000FFFF, 0xFFFFFFFF))))
|
||||
#else
|
||||
#define RAPIDJSON_SETPOINTER(type, p, x) (p = (x))
|
||||
#define RAPIDJSON_GETPOINTER(type, p) (p)
|
||||
@@ -379,8 +393,8 @@
|
||||
If any of these symbols is defined, RapidJSON defines the macro
|
||||
\c RAPIDJSON_SIMD to indicate the availability of the optimized code.
|
||||
*/
|
||||
#if defined(RAPIDJSON_SSE2) || defined(RAPIDJSON_SSE42) \
|
||||
|| defined(RAPIDJSON_NEON) || defined(RAPIDJSON_DOXYGEN_RUNNING)
|
||||
#if defined(RAPIDJSON_SSE2) || defined(RAPIDJSON_SSE42) || defined(RAPIDJSON_NEON) || \
|
||||
defined(RAPIDJSON_DOXYGEN_RUNNING)
|
||||
#define RAPIDJSON_SIMD
|
||||
#endif
|
||||
|
||||
@@ -442,9 +456,8 @@ RAPIDJSON_NAMESPACE_END
|
||||
|
||||
// Prefer C++11 static_assert, if available
|
||||
#ifndef RAPIDJSON_STATIC_ASSERT
|
||||
#if RAPIDJSON_CPLUSPLUS >= 201103L || ( defined(_MSC_VER) && _MSC_VER >= 1800 )
|
||||
#define RAPIDJSON_STATIC_ASSERT(x) \
|
||||
static_assert(x, RAPIDJSON_STRINGIFY(x))
|
||||
#if RAPIDJSON_CPLUSPLUS >= 201103L || (defined(_MSC_VER) && _MSC_VER >= 1800)
|
||||
#define RAPIDJSON_STATIC_ASSERT(x) static_assert(x, RAPIDJSON_STRINGIFY(x))
|
||||
#endif // C++11
|
||||
#endif // RAPIDJSON_STATIC_ASSERT
|
||||
|
||||
@@ -454,15 +467,26 @@ RAPIDJSON_NAMESPACE_END
|
||||
//!@cond RAPIDJSON_HIDDEN_FROM_DOXYGEN
|
||||
#endif
|
||||
RAPIDJSON_NAMESPACE_BEGIN
|
||||
template <bool x> struct STATIC_ASSERTION_FAILURE;
|
||||
template <> struct STATIC_ASSERTION_FAILURE<true> { enum { value = 1 }; };
|
||||
template <size_t x> struct StaticAssertTest {};
|
||||
template <bool x>
|
||||
struct STATIC_ASSERTION_FAILURE;
|
||||
template <>
|
||||
struct STATIC_ASSERTION_FAILURE<true>
|
||||
{
|
||||
enum
|
||||
{
|
||||
value = 1
|
||||
};
|
||||
};
|
||||
template <size_t x>
|
||||
struct StaticAssertTest
|
||||
{
|
||||
};
|
||||
RAPIDJSON_NAMESPACE_END
|
||||
|
||||
#if defined(__GNUC__) || defined(__clang__)
|
||||
#define RAPIDJSON_STATIC_ASSERT_UNUSED_ATTRIBUTE __attribute__((unused))
|
||||
#else
|
||||
#define RAPIDJSON_STATIC_ASSERT_UNUSED_ATTRIBUTE
|
||||
#define RAPIDJSON_STATIC_ASSERT_UNUSED_ATTRIBUTE
|
||||
#endif
|
||||
#ifndef __clang__
|
||||
//!@endcond
|
||||
@@ -473,9 +497,9 @@ RAPIDJSON_NAMESPACE_END
|
||||
\param x compile-time condition
|
||||
\hideinitializer
|
||||
*/
|
||||
#define RAPIDJSON_STATIC_ASSERT(x) \
|
||||
typedef ::RAPIDJSON_NAMESPACE::StaticAssertTest< \
|
||||
sizeof(::RAPIDJSON_NAMESPACE::STATIC_ASSERTION_FAILURE<bool(x) >)> \
|
||||
#define RAPIDJSON_STATIC_ASSERT(x) \
|
||||
typedef ::RAPIDJSON_NAMESPACE::StaticAssertTest<sizeof( \
|
||||
::RAPIDJSON_NAMESPACE::STATIC_ASSERTION_FAILURE<bool(x)>)> \
|
||||
RAPIDJSON_JOIN(StaticAssertTypedef, __LINE__) RAPIDJSON_STATIC_ASSERT_UNUSED_ATTRIBUTE
|
||||
#endif // RAPIDJSON_STATIC_ASSERT
|
||||
|
||||
@@ -513,13 +537,15 @@ RAPIDJSON_NAMESPACE_END
|
||||
|
||||
//!@cond RAPIDJSON_HIDDEN_FROM_DOXYGEN
|
||||
|
||||
#define RAPIDJSON_MULTILINEMACRO_BEGIN do {
|
||||
#define RAPIDJSON_MULTILINEMACRO_BEGIN \
|
||||
do \
|
||||
{
|
||||
#define RAPIDJSON_MULTILINEMACRO_END \
|
||||
} while((void)0, 0)
|
||||
} \
|
||||
while((void)0, 0)
|
||||
|
||||
// adopted from Boost
|
||||
#define RAPIDJSON_VERSION_CODE(x,y,z) \
|
||||
(((x)*100000) + ((y)*100) + (z))
|
||||
#define RAPIDJSON_VERSION_CODE(x, y, z) (((x) * 100000) + ((y) * 100) + (z))
|
||||
|
||||
#if defined(__has_builtin)
|
||||
#define RAPIDJSON_HAS_BUILTIN(x) __has_builtin(x)
|
||||
@@ -531,24 +557,25 @@ RAPIDJSON_NAMESPACE_END
|
||||
// RAPIDJSON_DIAG_PUSH/POP, RAPIDJSON_DIAG_OFF
|
||||
|
||||
#if defined(__GNUC__)
|
||||
#define RAPIDJSON_GNUC \
|
||||
RAPIDJSON_VERSION_CODE(__GNUC__,__GNUC_MINOR__,__GNUC_PATCHLEVEL__)
|
||||
#define RAPIDJSON_GNUC RAPIDJSON_VERSION_CODE(__GNUC__, __GNUC_MINOR__, __GNUC_PATCHLEVEL__)
|
||||
#endif
|
||||
|
||||
#if defined(__clang__) || (defined(RAPIDJSON_GNUC) && RAPIDJSON_GNUC >= RAPIDJSON_VERSION_CODE(4,2,0))
|
||||
#if defined(__clang__) || \
|
||||
(defined(RAPIDJSON_GNUC) && RAPIDJSON_GNUC >= RAPIDJSON_VERSION_CODE(4, 2, 0))
|
||||
|
||||
#define RAPIDJSON_PRAGMA(x) _Pragma(RAPIDJSON_STRINGIFY(x))
|
||||
#define RAPIDJSON_DIAG_PRAGMA(x) RAPIDJSON_PRAGMA(GCC diagnostic x)
|
||||
#define RAPIDJSON_DIAG_OFF(x) \
|
||||
RAPIDJSON_DIAG_PRAGMA(ignored RAPIDJSON_STRINGIFY(RAPIDJSON_JOIN(-W,x)))
|
||||
RAPIDJSON_DIAG_PRAGMA(ignored RAPIDJSON_STRINGIFY(RAPIDJSON_JOIN(-W, x)))
|
||||
|
||||
// push/pop support in Clang and GCC>=4.6
|
||||
#if defined(__clang__) || (defined(RAPIDJSON_GNUC) && RAPIDJSON_GNUC >= RAPIDJSON_VERSION_CODE(4,6,0))
|
||||
#if defined(__clang__) || \
|
||||
(defined(RAPIDJSON_GNUC) && RAPIDJSON_GNUC >= RAPIDJSON_VERSION_CODE(4, 6, 0))
|
||||
#define RAPIDJSON_DIAG_PUSH RAPIDJSON_DIAG_PRAGMA(push)
|
||||
#define RAPIDJSON_DIAG_POP RAPIDJSON_DIAG_PRAGMA(pop)
|
||||
#else // GCC >= 4.2, < 4.6
|
||||
#define RAPIDJSON_DIAG_POP RAPIDJSON_DIAG_PRAGMA(pop)
|
||||
#else // GCC >= 4.2, < 4.6
|
||||
#define RAPIDJSON_DIAG_PUSH /* ignored */
|
||||
#define RAPIDJSON_DIAG_POP /* ignored */
|
||||
#define RAPIDJSON_DIAG_POP /* ignored */
|
||||
#endif
|
||||
|
||||
#elif defined(_MSC_VER)
|
||||
@@ -557,9 +584,9 @@ RAPIDJSON_NAMESPACE_END
|
||||
#define RAPIDJSON_PRAGMA(x) __pragma(x)
|
||||
#define RAPIDJSON_DIAG_PRAGMA(x) RAPIDJSON_PRAGMA(warning(x))
|
||||
|
||||
#define RAPIDJSON_DIAG_OFF(x) RAPIDJSON_DIAG_PRAGMA(disable: x)
|
||||
#define RAPIDJSON_DIAG_OFF(x) RAPIDJSON_DIAG_PRAGMA(disable : x)
|
||||
#define RAPIDJSON_DIAG_PUSH RAPIDJSON_DIAG_PRAGMA(push)
|
||||
#define RAPIDJSON_DIAG_POP RAPIDJSON_DIAG_PRAGMA(pop)
|
||||
#define RAPIDJSON_DIAG_POP RAPIDJSON_DIAG_PRAGMA(pop)
|
||||
|
||||
#else
|
||||
|
||||
@@ -580,15 +607,16 @@ RAPIDJSON_NAMESPACE_END
|
||||
#if RAPIDJSON_HAS_CXX11
|
||||
#define RAPIDJSON_HAS_CXX11_RVALUE_REFS 1
|
||||
#elif defined(__clang__)
|
||||
#if __has_feature(cxx_rvalue_references) && \
|
||||
(defined(_MSC_VER) || defined(_LIBCPP_VERSION) || defined(__GLIBCXX__) && __GLIBCXX__ >= 20080306)
|
||||
#if __has_feature(cxx_rvalue_references) && (defined(_MSC_VER) || defined(_LIBCPP_VERSION) || \
|
||||
defined(__GLIBCXX__) && __GLIBCXX__ >= 20080306)
|
||||
#define RAPIDJSON_HAS_CXX11_RVALUE_REFS 1
|
||||
#else
|
||||
#define RAPIDJSON_HAS_CXX11_RVALUE_REFS 0
|
||||
#endif
|
||||
#elif (defined(RAPIDJSON_GNUC) && (RAPIDJSON_GNUC >= RAPIDJSON_VERSION_CODE(4,3,0)) && defined(__GXX_EXPERIMENTAL_CXX0X__)) || \
|
||||
(defined(_MSC_VER) && _MSC_VER >= 1600) || \
|
||||
(defined(__SUNPRO_CC) && __SUNPRO_CC >= 0x5140 && defined(__GXX_EXPERIMENTAL_CXX0X__))
|
||||
#elif(defined(RAPIDJSON_GNUC) && (RAPIDJSON_GNUC >= RAPIDJSON_VERSION_CODE(4, 3, 0)) && \
|
||||
defined(__GXX_EXPERIMENTAL_CXX0X__)) || \
|
||||
(defined(_MSC_VER) && _MSC_VER >= 1600) || \
|
||||
(defined(__SUNPRO_CC) && __SUNPRO_CC >= 0x5140 && defined(__GXX_EXPERIMENTAL_CXX0X__))
|
||||
|
||||
#define RAPIDJSON_HAS_CXX11_RVALUE_REFS 1
|
||||
#else
|
||||
@@ -605,8 +633,9 @@ RAPIDJSON_NAMESPACE_END
|
||||
#define RAPIDJSON_HAS_CXX11_NOEXCEPT 1
|
||||
#elif defined(__clang__)
|
||||
#define RAPIDJSON_HAS_CXX11_NOEXCEPT __has_feature(cxx_noexcept)
|
||||
#elif (defined(RAPIDJSON_GNUC) && (RAPIDJSON_GNUC >= RAPIDJSON_VERSION_CODE(4,6,0)) && defined(__GXX_EXPERIMENTAL_CXX0X__)) || \
|
||||
(defined(_MSC_VER) && _MSC_VER >= 1900) || \
|
||||
#elif(defined(RAPIDJSON_GNUC) && (RAPIDJSON_GNUC >= RAPIDJSON_VERSION_CODE(4, 6, 0)) && \
|
||||
defined(__GXX_EXPERIMENTAL_CXX0X__)) || \
|
||||
(defined(_MSC_VER) && _MSC_VER >= 1900) || \
|
||||
(defined(__SUNPRO_CC) && __SUNPRO_CC >= 0x5140 && defined(__GXX_EXPERIMENTAL_CXX0X__))
|
||||
#define RAPIDJSON_HAS_CXX11_NOEXCEPT 1
|
||||
#else
|
||||
@@ -623,7 +652,7 @@ RAPIDJSON_NAMESPACE_END
|
||||
|
||||
// no automatic detection, yet
|
||||
#ifndef RAPIDJSON_HAS_CXX11_TYPETRAITS
|
||||
#if (defined(_MSC_VER) && _MSC_VER >= 1700)
|
||||
#if(defined(_MSC_VER) && _MSC_VER >= 1700)
|
||||
#define RAPIDJSON_HAS_CXX11_TYPETRAITS 1
|
||||
#else
|
||||
#define RAPIDJSON_HAS_CXX11_TYPETRAITS 0
|
||||
@@ -633,9 +662,10 @@ RAPIDJSON_NAMESPACE_END
|
||||
#ifndef RAPIDJSON_HAS_CXX11_RANGE_FOR
|
||||
#if defined(__clang__)
|
||||
#define RAPIDJSON_HAS_CXX11_RANGE_FOR __has_feature(cxx_range_for)
|
||||
#elif (defined(RAPIDJSON_GNUC) && (RAPIDJSON_GNUC >= RAPIDJSON_VERSION_CODE(4,6,0)) && defined(__GXX_EXPERIMENTAL_CXX0X__)) || \
|
||||
(defined(_MSC_VER) && _MSC_VER >= 1700) || \
|
||||
(defined(__SUNPRO_CC) && __SUNPRO_CC >= 0x5140 && defined(__GXX_EXPERIMENTAL_CXX0X__))
|
||||
#elif(defined(RAPIDJSON_GNUC) && (RAPIDJSON_GNUC >= RAPIDJSON_VERSION_CODE(4, 6, 0)) && \
|
||||
defined(__GXX_EXPERIMENTAL_CXX0X__)) || \
|
||||
(defined(_MSC_VER) && _MSC_VER >= 1700) || \
|
||||
(defined(__SUNPRO_CC) && __SUNPRO_CC >= 0x5140 && defined(__GXX_EXPERIMENTAL_CXX0X__))
|
||||
#define RAPIDJSON_HAS_CXX11_RANGE_FOR 1
|
||||
#else
|
||||
#define RAPIDJSON_HAS_CXX11_RANGE_FOR 0
|
||||
@@ -650,31 +680,31 @@ RAPIDJSON_NAMESPACE_END
|
||||
#endif
|
||||
|
||||
#if RAPIDJSON_HAS_CXX17
|
||||
# define RAPIDJSON_DELIBERATE_FALLTHROUGH [[fallthrough]]
|
||||
#define RAPIDJSON_DELIBERATE_FALLTHROUGH [[fallthrough]]
|
||||
#elif defined(__has_cpp_attribute)
|
||||
# if __has_cpp_attribute(clang::fallthrough)
|
||||
# define RAPIDJSON_DELIBERATE_FALLTHROUGH [[clang::fallthrough]]
|
||||
# elif __has_cpp_attribute(fallthrough)
|
||||
# define RAPIDJSON_DELIBERATE_FALLTHROUGH __attribute__((fallthrough))
|
||||
# else
|
||||
# define RAPIDJSON_DELIBERATE_FALLTHROUGH
|
||||
# endif
|
||||
#if __has_cpp_attribute(clang::fallthrough)
|
||||
#define RAPIDJSON_DELIBERATE_FALLTHROUGH [[clang::fallthrough]]
|
||||
#elif __has_cpp_attribute(fallthrough)
|
||||
#define RAPIDJSON_DELIBERATE_FALLTHROUGH __attribute__((fallthrough))
|
||||
#else
|
||||
# define RAPIDJSON_DELIBERATE_FALLTHROUGH
|
||||
#define RAPIDJSON_DELIBERATE_FALLTHROUGH
|
||||
#endif
|
||||
#else
|
||||
#define RAPIDJSON_DELIBERATE_FALLTHROUGH
|
||||
#endif
|
||||
|
||||
//!@endcond
|
||||
|
||||
//! Assertion (in non-throwing contexts).
|
||||
/*! \ingroup RAPIDJSON_CONFIG
|
||||
Some functions provide a \c noexcept guarantee, if the compiler supports it.
|
||||
In these cases, the \ref RAPIDJSON_ASSERT macro cannot be overridden to
|
||||
throw an exception. This macro adds a separate customization point for
|
||||
such cases.
|
||||
/*! \ingroup RAPIDJSON_CONFIG
|
||||
Some functions provide a \c noexcept guarantee, if the compiler supports it.
|
||||
In these cases, the \ref RAPIDJSON_ASSERT macro cannot be overridden to
|
||||
throw an exception. This macro adds a separate customization point for
|
||||
such cases.
|
||||
|
||||
Defaults to C \c assert() (as \ref RAPIDJSON_ASSERT), if \c noexcept is
|
||||
supported, and to \ref RAPIDJSON_ASSERT otherwise.
|
||||
*/
|
||||
Defaults to C \c assert() (as \ref RAPIDJSON_ASSERT), if \c noexcept is
|
||||
supported, and to \ref RAPIDJSON_ASSERT otherwise.
|
||||
*/
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// RAPIDJSON_NOEXCEPT_ASSERT
|
||||
@@ -726,14 +756,15 @@ RAPIDJSON_NAMESPACE_END
|
||||
RAPIDJSON_NAMESPACE_BEGIN
|
||||
|
||||
//! Type of JSON value
|
||||
enum Type {
|
||||
kNullType = 0, //!< null
|
||||
kFalseType = 1, //!< false
|
||||
kTrueType = 2, //!< true
|
||||
kObjectType = 3, //!< object
|
||||
kArrayType = 4, //!< array
|
||||
kStringType = 5, //!< string
|
||||
kNumberType = 6 //!< number
|
||||
enum Type
|
||||
{
|
||||
kNullType = 0, //!< null
|
||||
kFalseType = 1, //!< false
|
||||
kTrueType = 2, //!< true
|
||||
kObjectType = 3, //!< object
|
||||
kArrayType = 4, //!< array
|
||||
kStringType = 5, //!< string
|
||||
kNumberType = 6 //!< number
|
||||
};
|
||||
|
||||
RAPIDJSON_NAMESPACE_END
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -69,34 +69,41 @@ concept Stream {
|
||||
For custom stream, this type can be specialized for other configuration.
|
||||
See TEST(Reader, CustomStringStream) in readertest.cpp for example.
|
||||
*/
|
||||
template<typename Stream>
|
||||
struct StreamTraits {
|
||||
template <typename Stream>
|
||||
struct StreamTraits
|
||||
{
|
||||
//! Whether to make local copy of stream for optimization during parsing.
|
||||
/*!
|
||||
By default, for safety, streams do not use local copy optimization.
|
||||
Stream that can be copied fast should specialize this, like StreamTraits<StringStream>.
|
||||
*/
|
||||
enum { copyOptimization = 0 };
|
||||
enum
|
||||
{
|
||||
copyOptimization = 0
|
||||
};
|
||||
};
|
||||
|
||||
//! Reserve n characters for writing to a stream.
|
||||
template<typename Stream>
|
||||
inline void PutReserve(Stream& stream, size_t count) {
|
||||
template <typename Stream>
|
||||
inline void PutReserve(Stream& stream, size_t count)
|
||||
{
|
||||
(void)stream;
|
||||
(void)count;
|
||||
}
|
||||
|
||||
//! Write character to a stream, presuming buffer is reserved.
|
||||
template<typename Stream>
|
||||
inline void PutUnsafe(Stream& stream, typename Stream::Ch c) {
|
||||
template <typename Stream>
|
||||
inline void PutUnsafe(Stream& stream, typename Stream::Ch c)
|
||||
{
|
||||
stream.Put(c);
|
||||
}
|
||||
|
||||
//! Put N copies of a character to a stream.
|
||||
template<typename Stream, typename Ch>
|
||||
inline void PutN(Stream& stream, Ch c, size_t n) {
|
||||
template <typename Stream, typename Ch>
|
||||
inline void PutN(Stream& stream, Ch c, size_t n)
|
||||
{
|
||||
PutReserve(stream, n);
|
||||
for (size_t i = 0; i < n; i++)
|
||||
for(size_t i = 0; i < n; i++)
|
||||
PutUnsafe(stream, c);
|
||||
}
|
||||
|
||||
@@ -111,15 +118,16 @@ inline void PutN(Stream& stream, Ch c, size_t n) {
|
||||
|
||||
#if defined(_MSC_VER) && _MSC_VER <= 1800
|
||||
RAPIDJSON_DIAG_PUSH
|
||||
RAPIDJSON_DIAG_OFF(4702) // unreachable code
|
||||
RAPIDJSON_DIAG_OFF(4512) // assignment operator could not be generated
|
||||
RAPIDJSON_DIAG_OFF(4702) // unreachable code
|
||||
RAPIDJSON_DIAG_OFF(4512) // assignment operator could not be generated
|
||||
#endif
|
||||
|
||||
template <typename InputStream, typename Encoding = UTF8<> >
|
||||
class GenericStreamWrapper {
|
||||
public:
|
||||
template <typename InputStream, typename Encoding = UTF8<>>
|
||||
class GenericStreamWrapper
|
||||
{
|
||||
public:
|
||||
typedef typename Encoding::Ch Ch;
|
||||
GenericStreamWrapper(InputStream& is): is_(is) {}
|
||||
GenericStreamWrapper(InputStream& is) : is_(is) {}
|
||||
|
||||
Ch Peek() const { return is_.Peek(); }
|
||||
Ch Take() { return is_.Take(); }
|
||||
@@ -136,7 +144,7 @@ public:
|
||||
UTFType GetType() const { return is_.GetType(); }
|
||||
bool HasBOM() const { return is_.HasBOM(); }
|
||||
|
||||
protected:
|
||||
protected:
|
||||
InputStream& is_;
|
||||
};
|
||||
|
||||
@@ -149,33 +157,46 @@ RAPIDJSON_DIAG_POP
|
||||
|
||||
//! Read-only string stream.
|
||||
/*! \note implements Stream concept
|
||||
*/
|
||||
*/
|
||||
template <typename Encoding>
|
||||
struct GenericStringStream {
|
||||
struct GenericStringStream
|
||||
{
|
||||
typedef typename Encoding::Ch Ch;
|
||||
|
||||
GenericStringStream(const Ch *src) : src_(src), head_(src) {}
|
||||
GenericStringStream(const Ch* src) : src_(src), head_(src) {}
|
||||
|
||||
Ch Peek() const { return *src_; }
|
||||
Ch Take() { return *src_++; }
|
||||
size_t Tell() const { return static_cast<size_t>(src_ - head_); }
|
||||
|
||||
Ch* PutBegin() { RAPIDJSON_ASSERT(false); return 0; }
|
||||
Ch* PutBegin()
|
||||
{
|
||||
RAPIDJSON_ASSERT(false);
|
||||
return 0;
|
||||
}
|
||||
void Put(Ch) { RAPIDJSON_ASSERT(false); }
|
||||
void Flush() { RAPIDJSON_ASSERT(false); }
|
||||
size_t PutEnd(Ch*) { RAPIDJSON_ASSERT(false); return 0; }
|
||||
size_t PutEnd(Ch*)
|
||||
{
|
||||
RAPIDJSON_ASSERT(false);
|
||||
return 0;
|
||||
}
|
||||
|
||||
const Ch* src_; //!< Current read position.
|
||||
const Ch* head_; //!< Original head of the string.
|
||||
const Ch* src_; //!< Current read position.
|
||||
const Ch* head_; //!< Original head of the string.
|
||||
};
|
||||
|
||||
template <typename Encoding>
|
||||
struct StreamTraits<GenericStringStream<Encoding> > {
|
||||
enum { copyOptimization = 1 };
|
||||
struct StreamTraits<GenericStringStream<Encoding>>
|
||||
{
|
||||
enum
|
||||
{
|
||||
copyOptimization = 1
|
||||
};
|
||||
};
|
||||
|
||||
//! String stream with UTF8 encoding.
|
||||
typedef GenericStringStream<UTF8<> > StringStream;
|
||||
typedef GenericStringStream<UTF8<>> StringStream;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// InsituStringStream
|
||||
@@ -185,10 +206,11 @@ typedef GenericStringStream<UTF8<> > StringStream;
|
||||
\note implements Stream concept
|
||||
*/
|
||||
template <typename Encoding>
|
||||
struct GenericInsituStringStream {
|
||||
struct GenericInsituStringStream
|
||||
{
|
||||
typedef typename Encoding::Ch Ch;
|
||||
|
||||
GenericInsituStringStream(Ch *src) : src_(src), dst_(0), head_(src) {}
|
||||
GenericInsituStringStream(Ch* src) : src_(src), dst_(0), head_(src) {}
|
||||
|
||||
// Read
|
||||
Ch Peek() { return *src_; }
|
||||
@@ -196,13 +218,22 @@ struct GenericInsituStringStream {
|
||||
size_t Tell() { return static_cast<size_t>(src_ - head_); }
|
||||
|
||||
// Write
|
||||
void Put(Ch c) { RAPIDJSON_ASSERT(dst_ != 0); *dst_++ = c; }
|
||||
void Put(Ch c)
|
||||
{
|
||||
RAPIDJSON_ASSERT(dst_ != 0);
|
||||
*dst_++ = c;
|
||||
}
|
||||
|
||||
Ch* PutBegin() { return dst_ = src_; }
|
||||
size_t PutEnd(Ch* begin) { return static_cast<size_t>(dst_ - begin); }
|
||||
void Flush() {}
|
||||
|
||||
Ch* Push(size_t count) { Ch* begin = dst_; dst_ += count; return begin; }
|
||||
Ch* Push(size_t count)
|
||||
{
|
||||
Ch* begin = dst_;
|
||||
dst_ += count;
|
||||
return begin;
|
||||
}
|
||||
void Pop(size_t count) { dst_ -= count; }
|
||||
|
||||
Ch* src_;
|
||||
@@ -211,12 +242,16 @@ struct GenericInsituStringStream {
|
||||
};
|
||||
|
||||
template <typename Encoding>
|
||||
struct StreamTraits<GenericInsituStringStream<Encoding> > {
|
||||
enum { copyOptimization = 1 };
|
||||
struct StreamTraits<GenericInsituStringStream<Encoding>>
|
||||
{
|
||||
enum
|
||||
{
|
||||
copyOptimization = 1
|
||||
};
|
||||
};
|
||||
|
||||
//! Insitu string stream with UTF8 encoding.
|
||||
typedef GenericInsituStringStream<UTF8<> > InsituStringStream;
|
||||
typedef GenericInsituStringStream<UTF8<>> InsituStringStream;
|
||||
|
||||
RAPIDJSON_NAMESPACE_END
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// Tencent is pleased to support the open source community by making RapidJSON available.
|
||||
//
|
||||
//
|
||||
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
|
||||
//
|
||||
// Licensed under the MIT License (the "License"); you may not use this file except
|
||||
@@ -7,9 +7,9 @@
|
||||
//
|
||||
// http://opensource.org/licenses/MIT
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// Unless required by applicable law or agreed to in writing, software distributed
|
||||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
|
||||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations under the License.
|
||||
|
||||
#ifndef RAPIDJSON_STRINGBUFFER_H_
|
||||
@@ -26,7 +26,7 @@
|
||||
|
||||
#if defined(__clang__)
|
||||
RAPIDJSON_DIAG_PUSH
|
||||
RAPIDJSON_DIAG_OFF(c++98-compat)
|
||||
RAPIDJSON_DIAG_OFF(c++ 98 - compat)
|
||||
#endif
|
||||
|
||||
RAPIDJSON_NAMESPACE_BEGIN
|
||||
@@ -38,16 +38,21 @@ RAPIDJSON_NAMESPACE_BEGIN
|
||||
\note implements Stream concept
|
||||
*/
|
||||
template <typename Encoding, typename Allocator = CrtAllocator>
|
||||
class GenericStringBuffer {
|
||||
public:
|
||||
class GenericStringBuffer
|
||||
{
|
||||
public:
|
||||
typedef typename Encoding::Ch Ch;
|
||||
|
||||
GenericStringBuffer(Allocator* allocator = 0, size_t capacity = kDefaultCapacity) : stack_(allocator, capacity) {}
|
||||
GenericStringBuffer(Allocator* allocator = 0, size_t capacity = kDefaultCapacity)
|
||||
: stack_(allocator, capacity)
|
||||
{
|
||||
}
|
||||
|
||||
#if RAPIDJSON_HAS_CXX11_RVALUE_REFS
|
||||
GenericStringBuffer(GenericStringBuffer&& rhs) : stack_(std::move(rhs.stack_)) {}
|
||||
GenericStringBuffer& operator=(GenericStringBuffer&& rhs) {
|
||||
if (&rhs != this)
|
||||
GenericStringBuffer& operator=(GenericStringBuffer&& rhs)
|
||||
{
|
||||
if(&rhs != this)
|
||||
stack_ = std::move(rhs.stack_);
|
||||
return *this;
|
||||
}
|
||||
@@ -58,7 +63,8 @@ public:
|
||||
void Flush() {}
|
||||
|
||||
void Clear() { stack_.Clear(); }
|
||||
void ShrinkToFit() {
|
||||
void ShrinkToFit()
|
||||
{
|
||||
// Push and pop a null terminator. This is safe.
|
||||
*stack_.template Push<Ch>() = '\0';
|
||||
stack_.ShrinkToFit();
|
||||
@@ -70,7 +76,8 @@ public:
|
||||
Ch* PushUnsafe(size_t count) { return stack_.template PushUnsafe<Ch>(count); }
|
||||
void Pop(size_t count) { stack_.template Pop<Ch>(count); }
|
||||
|
||||
const Ch* GetString() const {
|
||||
const Ch* GetString() const
|
||||
{
|
||||
// Push and pop a null terminator. This is safe.
|
||||
*stack_.template Push<Ch>() = '\0';
|
||||
stack_.template Pop<Ch>(1);
|
||||
@@ -87,28 +94,31 @@ public:
|
||||
static const size_t kDefaultCapacity = 256;
|
||||
mutable internal::Stack<Allocator> stack_;
|
||||
|
||||
private:
|
||||
private:
|
||||
// Prohibit copy constructor & assignment operator.
|
||||
GenericStringBuffer(const GenericStringBuffer&);
|
||||
GenericStringBuffer& operator=(const GenericStringBuffer&);
|
||||
};
|
||||
|
||||
//! String buffer with UTF8 encoding
|
||||
typedef GenericStringBuffer<UTF8<> > StringBuffer;
|
||||
typedef GenericStringBuffer<UTF8<>> StringBuffer;
|
||||
|
||||
template<typename Encoding, typename Allocator>
|
||||
inline void PutReserve(GenericStringBuffer<Encoding, Allocator>& stream, size_t count) {
|
||||
template <typename Encoding, typename Allocator>
|
||||
inline void PutReserve(GenericStringBuffer<Encoding, Allocator>& stream, size_t count)
|
||||
{
|
||||
stream.Reserve(count);
|
||||
}
|
||||
|
||||
template<typename Encoding, typename Allocator>
|
||||
inline void PutUnsafe(GenericStringBuffer<Encoding, Allocator>& stream, typename Encoding::Ch c) {
|
||||
template <typename Encoding, typename Allocator>
|
||||
inline void PutUnsafe(GenericStringBuffer<Encoding, Allocator>& stream, typename Encoding::Ch c)
|
||||
{
|
||||
stream.PutUnsafe(c);
|
||||
}
|
||||
|
||||
//! Implement specialized version of PutN() with memset() for better performance.
|
||||
template<>
|
||||
inline void PutN(GenericStringBuffer<UTF8<> >& stream, char c, size_t n) {
|
||||
template <>
|
||||
inline void PutN(GenericStringBuffer<UTF8<>>& stream, char c, size_t n)
|
||||
{
|
||||
std::memset(stream.stack_.Push<char>(n), c, n * sizeof(c));
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
|
||||
#if defined(__clang__)
|
||||
RAPIDJSON_DIAG_PUSH
|
||||
RAPIDJSON_DIAG_OFF(c++98-compat)
|
||||
RAPIDJSON_DIAG_OFF(c++ 98 - compat)
|
||||
#elif defined(_MSC_VER)
|
||||
RAPIDJSON_DIAG_OFF(4512) // assignment operator could not be generated
|
||||
#endif
|
||||
@@ -29,66 +29,141 @@ RAPIDJSON_NAMESPACE_BEGIN
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// GenericUri
|
||||
|
||||
template <typename ValueType, typename Allocator=CrtAllocator>
|
||||
class GenericUri {
|
||||
public:
|
||||
template <typename ValueType, typename Allocator = CrtAllocator>
|
||||
class GenericUri
|
||||
{
|
||||
public:
|
||||
typedef typename ValueType::Ch Ch;
|
||||
#if RAPIDJSON_HAS_STDSTRING
|
||||
typedef std::basic_string<Ch> String;
|
||||
#endif
|
||||
|
||||
//! Constructors
|
||||
GenericUri(Allocator* allocator = 0) : uri_(), base_(), scheme_(), auth_(), path_(), query_(), frag_(), allocator_(allocator), ownAllocator_() {
|
||||
GenericUri(Allocator* allocator = 0)
|
||||
: uri_(),
|
||||
base_(),
|
||||
scheme_(),
|
||||
auth_(),
|
||||
path_(),
|
||||
query_(),
|
||||
frag_(),
|
||||
allocator_(allocator),
|
||||
ownAllocator_()
|
||||
{
|
||||
}
|
||||
|
||||
GenericUri(const Ch* uri, SizeType len, Allocator* allocator = 0) : uri_(), base_(), scheme_(), auth_(), path_(), query_(), frag_(), allocator_(allocator), ownAllocator_() {
|
||||
GenericUri(const Ch* uri, SizeType len, Allocator* allocator = 0)
|
||||
: uri_(),
|
||||
base_(),
|
||||
scheme_(),
|
||||
auth_(),
|
||||
path_(),
|
||||
query_(),
|
||||
frag_(),
|
||||
allocator_(allocator),
|
||||
ownAllocator_()
|
||||
{
|
||||
Parse(uri, len);
|
||||
}
|
||||
|
||||
GenericUri(const Ch* uri, Allocator* allocator = 0) : uri_(), base_(), scheme_(), auth_(), path_(), query_(), frag_(), allocator_(allocator), ownAllocator_() {
|
||||
GenericUri(const Ch* uri, Allocator* allocator = 0)
|
||||
: uri_(),
|
||||
base_(),
|
||||
scheme_(),
|
||||
auth_(),
|
||||
path_(),
|
||||
query_(),
|
||||
frag_(),
|
||||
allocator_(allocator),
|
||||
ownAllocator_()
|
||||
{
|
||||
Parse(uri, internal::StrLen<Ch>(uri));
|
||||
}
|
||||
|
||||
// Use with specializations of GenericValue
|
||||
template<typename T> GenericUri(const T& uri, Allocator* allocator = 0) : uri_(), base_(), scheme_(), auth_(), path_(), query_(), frag_(), allocator_(allocator), ownAllocator_() {
|
||||
template <typename T>
|
||||
GenericUri(const T& uri, Allocator* allocator = 0)
|
||||
: uri_(),
|
||||
base_(),
|
||||
scheme_(),
|
||||
auth_(),
|
||||
path_(),
|
||||
query_(),
|
||||
frag_(),
|
||||
allocator_(allocator),
|
||||
ownAllocator_()
|
||||
{
|
||||
const Ch* u = uri.template Get<const Ch*>(); // TypeHelper from document.h
|
||||
Parse(u, internal::StrLen<Ch>(u));
|
||||
}
|
||||
|
||||
#if RAPIDJSON_HAS_STDSTRING
|
||||
GenericUri(const String& uri, Allocator* allocator = 0) : uri_(), base_(), scheme_(), auth_(), path_(), query_(), frag_(), allocator_(allocator), ownAllocator_() {
|
||||
GenericUri(const String& uri, Allocator* allocator = 0)
|
||||
: uri_(),
|
||||
base_(),
|
||||
scheme_(),
|
||||
auth_(),
|
||||
path_(),
|
||||
query_(),
|
||||
frag_(),
|
||||
allocator_(allocator),
|
||||
ownAllocator_()
|
||||
{
|
||||
Parse(uri.c_str(), internal::StrLen<Ch>(uri.c_str()));
|
||||
}
|
||||
#endif
|
||||
|
||||
//! Copy constructor
|
||||
GenericUri(const GenericUri& rhs) : uri_(), base_(), scheme_(), auth_(), path_(), query_(), frag_(), allocator_(), ownAllocator_() {
|
||||
GenericUri(const GenericUri& rhs)
|
||||
: uri_(),
|
||||
base_(),
|
||||
scheme_(),
|
||||
auth_(),
|
||||
path_(),
|
||||
query_(),
|
||||
frag_(),
|
||||
allocator_(),
|
||||
ownAllocator_()
|
||||
{
|
||||
*this = rhs;
|
||||
}
|
||||
|
||||
//! Copy constructor
|
||||
GenericUri(const GenericUri& rhs, Allocator* allocator) : uri_(), base_(), scheme_(), auth_(), path_(), query_(), frag_(), allocator_(allocator), ownAllocator_() {
|
||||
GenericUri(const GenericUri& rhs, Allocator* allocator)
|
||||
: uri_(),
|
||||
base_(),
|
||||
scheme_(),
|
||||
auth_(),
|
||||
path_(),
|
||||
query_(),
|
||||
frag_(),
|
||||
allocator_(allocator),
|
||||
ownAllocator_()
|
||||
{
|
||||
*this = rhs;
|
||||
}
|
||||
|
||||
//! Destructor.
|
||||
~GenericUri() {
|
||||
~GenericUri()
|
||||
{
|
||||
Free();
|
||||
RAPIDJSON_DELETE(ownAllocator_);
|
||||
}
|
||||
|
||||
//! Assignment operator
|
||||
GenericUri& operator=(const GenericUri& rhs) {
|
||||
if (this != &rhs) {
|
||||
GenericUri& operator=(const GenericUri& rhs)
|
||||
{
|
||||
if(this != &rhs)
|
||||
{
|
||||
// Do not delete ownAllocator
|
||||
Free();
|
||||
Allocate(rhs.GetStringLength());
|
||||
auth_ = CopyPart(scheme_, rhs.scheme_, rhs.GetSchemeStringLength());
|
||||
path_ = CopyPart(auth_, rhs.auth_, rhs.GetAuthStringLength());
|
||||
auth_ = CopyPart(scheme_, rhs.scheme_, rhs.GetSchemeStringLength());
|
||||
path_ = CopyPart(auth_, rhs.auth_, rhs.GetAuthStringLength());
|
||||
query_ = CopyPart(path_, rhs.path_, rhs.GetPathStringLength());
|
||||
frag_ = CopyPart(query_, rhs.query_, rhs.GetQueryStringLength());
|
||||
base_ = CopyPart(frag_, rhs.frag_, rhs.GetFragStringLength());
|
||||
uri_ = CopyPart(base_, rhs.base_, rhs.GetBaseStringLength());
|
||||
frag_ = CopyPart(query_, rhs.query_, rhs.GetQueryStringLength());
|
||||
base_ = CopyPart(frag_, rhs.frag_, rhs.GetFragStringLength());
|
||||
uri_ = CopyPart(base_, rhs.base_, rhs.GetBaseStringLength());
|
||||
CopyPart(uri_, rhs.uri_, rhs.GetStringLength());
|
||||
}
|
||||
return *this;
|
||||
@@ -96,7 +171,9 @@ public:
|
||||
|
||||
//! Getters
|
||||
// Use with specializations of GenericValue
|
||||
template<typename T> void Get(T& uri, Allocator& allocator) {
|
||||
template <typename T>
|
||||
void Get(T& uri, Allocator& allocator)
|
||||
{
|
||||
uri.template Set<const Ch*>(this->GetString(), allocator); // TypeHelper from document.h
|
||||
}
|
||||
|
||||
@@ -105,7 +182,10 @@ public:
|
||||
const Ch* GetBaseString() const { return base_; }
|
||||
SizeType GetBaseStringLength() const { return base_ == 0 ? 0 : internal::StrLen<Ch>(base_); }
|
||||
const Ch* GetSchemeString() const { return scheme_; }
|
||||
SizeType GetSchemeStringLength() const { return scheme_ == 0 ? 0 : internal::StrLen<Ch>(scheme_); }
|
||||
SizeType GetSchemeStringLength() const
|
||||
{
|
||||
return scheme_ == 0 ? 0 : internal::StrLen<Ch>(scheme_);
|
||||
}
|
||||
const Ch* GetAuthString() const { return auth_; }
|
||||
SizeType GetAuthStringLength() const { return auth_ == 0 ? 0 : internal::StrLen<Ch>(auth_); }
|
||||
const Ch* GetPathString() const { return path_; }
|
||||
@@ -116,36 +196,59 @@ public:
|
||||
SizeType GetFragStringLength() const { return frag_ == 0 ? 0 : internal::StrLen<Ch>(frag_); }
|
||||
|
||||
#if RAPIDJSON_HAS_STDSTRING
|
||||
static String Get(const GenericUri& uri) { return String(uri.GetString(), uri.GetStringLength()); }
|
||||
static String GetBase(const GenericUri& uri) { return String(uri.GetBaseString(), uri.GetBaseStringLength()); }
|
||||
static String GetScheme(const GenericUri& uri) { return String(uri.GetSchemeString(), uri.GetSchemeStringLength()); }
|
||||
static String GetAuth(const GenericUri& uri) { return String(uri.GetAuthString(), uri.GetAuthStringLength()); }
|
||||
static String GetPath(const GenericUri& uri) { return String(uri.GetPathString(), uri.GetPathStringLength()); }
|
||||
static String GetQuery(const GenericUri& uri) { return String(uri.GetQueryString(), uri.GetQueryStringLength()); }
|
||||
static String GetFrag(const GenericUri& uri) { return String(uri.GetFragString(), uri.GetFragStringLength()); }
|
||||
static String Get(const GenericUri& uri)
|
||||
{
|
||||
return String(uri.GetString(), uri.GetStringLength());
|
||||
}
|
||||
static String GetBase(const GenericUri& uri)
|
||||
{
|
||||
return String(uri.GetBaseString(), uri.GetBaseStringLength());
|
||||
}
|
||||
static String GetScheme(const GenericUri& uri)
|
||||
{
|
||||
return String(uri.GetSchemeString(), uri.GetSchemeStringLength());
|
||||
}
|
||||
static String GetAuth(const GenericUri& uri)
|
||||
{
|
||||
return String(uri.GetAuthString(), uri.GetAuthStringLength());
|
||||
}
|
||||
static String GetPath(const GenericUri& uri)
|
||||
{
|
||||
return String(uri.GetPathString(), uri.GetPathStringLength());
|
||||
}
|
||||
static String GetQuery(const GenericUri& uri)
|
||||
{
|
||||
return String(uri.GetQueryString(), uri.GetQueryStringLength());
|
||||
}
|
||||
static String GetFrag(const GenericUri& uri)
|
||||
{
|
||||
return String(uri.GetFragString(), uri.GetFragStringLength());
|
||||
}
|
||||
#endif
|
||||
|
||||
//! Equality operators
|
||||
bool operator==(const GenericUri& rhs) const {
|
||||
return Match(rhs, true);
|
||||
}
|
||||
bool operator==(const GenericUri& rhs) const { return Match(rhs, true); }
|
||||
|
||||
bool operator!=(const GenericUri& rhs) const {
|
||||
return !Match(rhs, true);
|
||||
}
|
||||
bool operator!=(const GenericUri& rhs) const { return !Match(rhs, true); }
|
||||
|
||||
bool Match(const GenericUri& uri, bool full = true) const {
|
||||
bool Match(const GenericUri& uri, bool full = true) const
|
||||
{
|
||||
Ch* s1;
|
||||
Ch* s2;
|
||||
if (full) {
|
||||
if(full)
|
||||
{
|
||||
s1 = uri_;
|
||||
s2 = uri.uri_;
|
||||
} else {
|
||||
}
|
||||
else
|
||||
{
|
||||
s1 = base_;
|
||||
s2 = uri.base_;
|
||||
}
|
||||
if (s1 == s2) return true;
|
||||
if (s1 == 0 || s2 == 0) return false;
|
||||
if(s1 == s2)
|
||||
return true;
|
||||
if(s1 == 0 || s2 == 0)
|
||||
return false;
|
||||
return internal::StrCmp<Ch>(s1, s2) == 0;
|
||||
}
|
||||
|
||||
@@ -153,56 +256,80 @@ public:
|
||||
// See https://tools.ietf.org/html/rfc3986
|
||||
// Use for resolving an id or $ref with an in-scope id.
|
||||
// Returns a new GenericUri for the resolved URI.
|
||||
GenericUri Resolve(const GenericUri& baseuri, Allocator* allocator = 0) {
|
||||
GenericUri Resolve(const GenericUri& baseuri, Allocator* allocator = 0)
|
||||
{
|
||||
GenericUri resuri;
|
||||
resuri.allocator_ = allocator;
|
||||
// Ensure enough space for combining paths
|
||||
resuri.Allocate(GetStringLength() + baseuri.GetStringLength() + 1); // + 1 for joining slash
|
||||
|
||||
if (!(GetSchemeStringLength() == 0)) {
|
||||
if(!(GetSchemeStringLength() == 0))
|
||||
{
|
||||
// Use all of this URI
|
||||
resuri.auth_ = CopyPart(resuri.scheme_, scheme_, GetSchemeStringLength());
|
||||
resuri.path_ = CopyPart(resuri.auth_, auth_, GetAuthStringLength());
|
||||
resuri.auth_ = CopyPart(resuri.scheme_, scheme_, GetSchemeStringLength());
|
||||
resuri.path_ = CopyPart(resuri.auth_, auth_, GetAuthStringLength());
|
||||
resuri.query_ = CopyPart(resuri.path_, path_, GetPathStringLength());
|
||||
resuri.frag_ = CopyPart(resuri.query_, query_, GetQueryStringLength());
|
||||
resuri.frag_ = CopyPart(resuri.query_, query_, GetQueryStringLength());
|
||||
resuri.RemoveDotSegments();
|
||||
} else {
|
||||
}
|
||||
else
|
||||
{
|
||||
// Use the base scheme
|
||||
resuri.auth_ = CopyPart(resuri.scheme_, baseuri.scheme_, baseuri.GetSchemeStringLength());
|
||||
if (!(GetAuthStringLength() == 0)) {
|
||||
resuri.auth_ =
|
||||
CopyPart(resuri.scheme_, baseuri.scheme_, baseuri.GetSchemeStringLength());
|
||||
if(!(GetAuthStringLength() == 0))
|
||||
{
|
||||
// Use this auth, path, query
|
||||
resuri.path_ = CopyPart(resuri.auth_, auth_, GetAuthStringLength());
|
||||
resuri.path_ = CopyPart(resuri.auth_, auth_, GetAuthStringLength());
|
||||
resuri.query_ = CopyPart(resuri.path_, path_, GetPathStringLength());
|
||||
resuri.frag_ = CopyPart(resuri.query_, query_, GetQueryStringLength());
|
||||
resuri.frag_ = CopyPart(resuri.query_, query_, GetQueryStringLength());
|
||||
resuri.RemoveDotSegments();
|
||||
} else {
|
||||
}
|
||||
else
|
||||
{
|
||||
// Use the base auth
|
||||
resuri.path_ = CopyPart(resuri.auth_, baseuri.auth_, baseuri.GetAuthStringLength());
|
||||
if (GetPathStringLength() == 0) {
|
||||
if(GetPathStringLength() == 0)
|
||||
{
|
||||
// Use the base path
|
||||
resuri.query_ = CopyPart(resuri.path_, baseuri.path_, baseuri.GetPathStringLength());
|
||||
if (GetQueryStringLength() == 0) {
|
||||
resuri.query_ =
|
||||
CopyPart(resuri.path_, baseuri.path_, baseuri.GetPathStringLength());
|
||||
if(GetQueryStringLength() == 0)
|
||||
{
|
||||
// Use the base query
|
||||
resuri.frag_ = CopyPart(resuri.query_, baseuri.query_, baseuri.GetQueryStringLength());
|
||||
} else {
|
||||
resuri.frag_ =
|
||||
CopyPart(resuri.query_, baseuri.query_, baseuri.GetQueryStringLength());
|
||||
}
|
||||
else
|
||||
{
|
||||
// Use this query
|
||||
resuri.frag_ = CopyPart(resuri.query_, query_, GetQueryStringLength());
|
||||
}
|
||||
} else {
|
||||
if (path_[0] == '/') {
|
||||
}
|
||||
else
|
||||
{
|
||||
if(path_[0] == '/')
|
||||
{
|
||||
// Absolute path - use all of this path
|
||||
resuri.query_ = CopyPart(resuri.path_, path_, GetPathStringLength());
|
||||
resuri.RemoveDotSegments();
|
||||
} else {
|
||||
// Relative path - append this path to base path after base path's last slash
|
||||
}
|
||||
else
|
||||
{
|
||||
// Relative path - append this path to base path after base path's last
|
||||
// slash
|
||||
size_t pos = 0;
|
||||
if (!(baseuri.GetAuthStringLength() == 0) && baseuri.GetPathStringLength() == 0) {
|
||||
if(!(baseuri.GetAuthStringLength() == 0) &&
|
||||
baseuri.GetPathStringLength() == 0)
|
||||
{
|
||||
resuri.path_[pos] = '/';
|
||||
pos++;
|
||||
}
|
||||
size_t lastslashpos = baseuri.GetPathStringLength();
|
||||
while (lastslashpos > 0) {
|
||||
if (baseuri.path_[lastslashpos - 1] == '/') break;
|
||||
while(lastslashpos > 0)
|
||||
{
|
||||
if(baseuri.path_[lastslashpos - 1] == '/')
|
||||
break;
|
||||
lastslashpos--;
|
||||
}
|
||||
std::memcpy(&resuri.path_[pos], baseuri.path_, lastslashpos * sizeof(Ch));
|
||||
@@ -228,74 +355,87 @@ public:
|
||||
//! Get the allocator of this GenericUri.
|
||||
Allocator& GetAllocator() { return *allocator_; }
|
||||
|
||||
private:
|
||||
private:
|
||||
// Allocate memory for a URI
|
||||
// Returns total amount allocated
|
||||
std::size_t Allocate(std::size_t len) {
|
||||
std::size_t Allocate(std::size_t len)
|
||||
{
|
||||
// Create own allocator if user did not supply.
|
||||
if (!allocator_)
|
||||
ownAllocator_ = allocator_ = RAPIDJSON_NEW(Allocator)();
|
||||
if(!allocator_)
|
||||
ownAllocator_ = allocator_ = RAPIDJSON_NEW(Allocator)();
|
||||
|
||||
// Allocate one block containing each part of the URI (5) plus base plus full URI, all null terminated.
|
||||
// Order: scheme, auth, path, query, frag, base, uri
|
||||
// Note need to set, increment, assign in 3 stages to avoid compiler warning bug.
|
||||
// Allocate one block containing each part of the URI (5) plus base plus full URI, all null
|
||||
// terminated. Order: scheme, auth, path, query, frag, base, uri Note need to set,
|
||||
// increment, assign in 3 stages to avoid compiler warning bug.
|
||||
size_t total = (3 * len + 7) * sizeof(Ch);
|
||||
scheme_ = static_cast<Ch*>(allocator_->Malloc(total));
|
||||
*scheme_ = '\0';
|
||||
auth_ = scheme_;
|
||||
scheme_ = static_cast<Ch*>(allocator_->Malloc(total));
|
||||
*scheme_ = '\0';
|
||||
auth_ = scheme_;
|
||||
auth_++;
|
||||
*auth_ = '\0';
|
||||
path_ = auth_;
|
||||
path_ = auth_;
|
||||
path_++;
|
||||
*path_ = '\0';
|
||||
query_ = path_;
|
||||
query_++;
|
||||
*query_ = '\0';
|
||||
frag_ = query_;
|
||||
frag_ = query_;
|
||||
frag_++;
|
||||
*frag_ = '\0';
|
||||
base_ = frag_;
|
||||
base_ = frag_;
|
||||
base_++;
|
||||
*base_ = '\0';
|
||||
uri_ = base_;
|
||||
uri_ = base_;
|
||||
uri_++;
|
||||
*uri_ = '\0';
|
||||
return total;
|
||||
}
|
||||
|
||||
// Free memory for a URI
|
||||
void Free() {
|
||||
if (scheme_) {
|
||||
void Free()
|
||||
{
|
||||
if(scheme_)
|
||||
{
|
||||
Allocator::Free(scheme_);
|
||||
scheme_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Parse a URI into constituent scheme, authority, path, query, & fragment parts
|
||||
// Supports URIs that match regex ^(([^:/?#]+):)?(//([^/?#]*))?([^?#]*)(\?([^#]*))?(#(.*))? as per
|
||||
// https://tools.ietf.org/html/rfc3986
|
||||
void Parse(const Ch* uri, std::size_t len) {
|
||||
// Supports URIs that match regex ^(([^:/?#]+):)?(//([^/?#]*))?([^?#]*)(\?([^#]*))?(#(.*))? as
|
||||
// per https://tools.ietf.org/html/rfc3986
|
||||
void Parse(const Ch* uri, std::size_t len)
|
||||
{
|
||||
std::size_t start = 0, pos1 = 0, pos2 = 0;
|
||||
Allocate(len);
|
||||
|
||||
// Look for scheme ([^:/?#]+):)?
|
||||
if (start < len) {
|
||||
while (pos1 < len) {
|
||||
if (uri[pos1] == ':') break;
|
||||
if(start < len)
|
||||
{
|
||||
while(pos1 < len)
|
||||
{
|
||||
if(uri[pos1] == ':')
|
||||
break;
|
||||
pos1++;
|
||||
}
|
||||
if (pos1 != len) {
|
||||
while (pos2 < len) {
|
||||
if (uri[pos2] == '/') break;
|
||||
if (uri[pos2] == '?') break;
|
||||
if (uri[pos2] == '#') break;
|
||||
if(pos1 != len)
|
||||
{
|
||||
while(pos2 < len)
|
||||
{
|
||||
if(uri[pos2] == '/')
|
||||
break;
|
||||
if(uri[pos2] == '?')
|
||||
break;
|
||||
if(uri[pos2] == '#')
|
||||
break;
|
||||
pos2++;
|
||||
}
|
||||
if (pos1 < pos2) {
|
||||
if(pos1 < pos2)
|
||||
{
|
||||
pos1++;
|
||||
std::memcpy(scheme_, &uri[start], pos1 * sizeof(Ch));
|
||||
scheme_[pos1] = '\0';
|
||||
start = pos1;
|
||||
start = pos1;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -304,35 +444,45 @@ private:
|
||||
auth_ = scheme_ + GetSchemeStringLength();
|
||||
auth_++;
|
||||
*auth_ = '\0';
|
||||
if (start < len - 1 && uri[start] == '/' && uri[start + 1] == '/') {
|
||||
if(start < len - 1 && uri[start] == '/' && uri[start + 1] == '/')
|
||||
{
|
||||
pos2 = start + 2;
|
||||
while (pos2 < len) {
|
||||
if (uri[pos2] == '/') break;
|
||||
if (uri[pos2] == '?') break;
|
||||
if (uri[pos2] == '#') break;
|
||||
while(pos2 < len)
|
||||
{
|
||||
if(uri[pos2] == '/')
|
||||
break;
|
||||
if(uri[pos2] == '?')
|
||||
break;
|
||||
if(uri[pos2] == '#')
|
||||
break;
|
||||
pos2++;
|
||||
}
|
||||
std::memcpy(auth_, &uri[start], (pos2 - start) * sizeof(Ch));
|
||||
auth_[pos2 - start] = '\0';
|
||||
start = pos2;
|
||||
start = pos2;
|
||||
}
|
||||
// Look for path ([^?#]*)
|
||||
// Note need to set, increment, assign in 3 stages to avoid compiler warning bug.
|
||||
path_ = auth_ + GetAuthStringLength();
|
||||
path_++;
|
||||
*path_ = '\0';
|
||||
if (start < len) {
|
||||
if(start < len)
|
||||
{
|
||||
pos2 = start;
|
||||
while (pos2 < len) {
|
||||
if (uri[pos2] == '?') break;
|
||||
if (uri[pos2] == '#') break;
|
||||
while(pos2 < len)
|
||||
{
|
||||
if(uri[pos2] == '?')
|
||||
break;
|
||||
if(uri[pos2] == '#')
|
||||
break;
|
||||
pos2++;
|
||||
}
|
||||
if (start != pos2) {
|
||||
if(start != pos2)
|
||||
{
|
||||
std::memcpy(path_, &uri[start], (pos2 - start) * sizeof(Ch));
|
||||
path_[pos2 - start] = '\0';
|
||||
if (path_[0] == '/')
|
||||
RemoveDotSegments(); // absolute path - normalize
|
||||
if(path_[0] == '/')
|
||||
RemoveDotSegments(); // absolute path - normalize
|
||||
start = pos2;
|
||||
}
|
||||
}
|
||||
@@ -341,16 +491,20 @@ private:
|
||||
query_ = path_ + GetPathStringLength();
|
||||
query_++;
|
||||
*query_ = '\0';
|
||||
if (start < len && uri[start] == '?') {
|
||||
if(start < len && uri[start] == '?')
|
||||
{
|
||||
pos2 = start + 1;
|
||||
while (pos2 < len) {
|
||||
if (uri[pos2] == '#') break;
|
||||
while(pos2 < len)
|
||||
{
|
||||
if(uri[pos2] == '#')
|
||||
break;
|
||||
pos2++;
|
||||
}
|
||||
if (start != pos2) {
|
||||
if(start != pos2)
|
||||
{
|
||||
std::memcpy(query_, &uri[start], (pos2 - start) * sizeof(Ch));
|
||||
query_[pos2 - start] = '\0';
|
||||
start = pos2;
|
||||
start = pos2;
|
||||
}
|
||||
}
|
||||
// Look for fragment (#(.*))?
|
||||
@@ -358,7 +512,8 @@ private:
|
||||
frag_ = query_ + GetQueryStringLength();
|
||||
frag_++;
|
||||
*frag_ = '\0';
|
||||
if (start < len && uri[start] == '#') {
|
||||
if(start < len && uri[start] == '#')
|
||||
{
|
||||
std::memcpy(frag_, &uri[start], (len - start) * sizeof(Ch));
|
||||
frag_[len - start] = '\0';
|
||||
}
|
||||
@@ -371,36 +526,39 @@ private:
|
||||
}
|
||||
|
||||
// Reconstitute base
|
||||
void SetBase() {
|
||||
void SetBase()
|
||||
{
|
||||
Ch* next = base_;
|
||||
std::memcpy(next, scheme_, GetSchemeStringLength() * sizeof(Ch));
|
||||
next+= GetSchemeStringLength();
|
||||
next += GetSchemeStringLength();
|
||||
std::memcpy(next, auth_, GetAuthStringLength() * sizeof(Ch));
|
||||
next+= GetAuthStringLength();
|
||||
next += GetAuthStringLength();
|
||||
std::memcpy(next, path_, GetPathStringLength() * sizeof(Ch));
|
||||
next+= GetPathStringLength();
|
||||
next += GetPathStringLength();
|
||||
std::memcpy(next, query_, GetQueryStringLength() * sizeof(Ch));
|
||||
next+= GetQueryStringLength();
|
||||
next += GetQueryStringLength();
|
||||
*next = '\0';
|
||||
}
|
||||
|
||||
// Reconstitute uri
|
||||
void SetUri() {
|
||||
void SetUri()
|
||||
{
|
||||
Ch* next = uri_;
|
||||
std::memcpy(next, base_, GetBaseStringLength() * sizeof(Ch));
|
||||
next+= GetBaseStringLength();
|
||||
next += GetBaseStringLength();
|
||||
std::memcpy(next, frag_, GetFragStringLength() * sizeof(Ch));
|
||||
next+= GetFragStringLength();
|
||||
next += GetFragStringLength();
|
||||
*next = '\0';
|
||||
}
|
||||
|
||||
// Copy a part from one GenericUri to another
|
||||
// Return the pointer to the next part to be copied to
|
||||
Ch* CopyPart(Ch* to, Ch* from, std::size_t len) {
|
||||
Ch* CopyPart(Ch* to, Ch* from, std::size_t len)
|
||||
{
|
||||
RAPIDJSON_ASSERT(to != 0);
|
||||
RAPIDJSON_ASSERT(from != 0);
|
||||
std::memcpy(to, from, len * sizeof(Ch));
|
||||
to[len] = '\0';
|
||||
to[len] = '\0';
|
||||
Ch* next = to + len + 1;
|
||||
return next;
|
||||
}
|
||||
@@ -408,45 +566,58 @@ private:
|
||||
// Remove . and .. segments from the path_ member.
|
||||
// https://tools.ietf.org/html/rfc3986
|
||||
// This is done in place as we are only removing segments.
|
||||
void RemoveDotSegments() {
|
||||
void RemoveDotSegments()
|
||||
{
|
||||
std::size_t pathlen = GetPathStringLength();
|
||||
std::size_t pathpos = 0; // Position in path_
|
||||
std::size_t newpos = 0; // Position in new path_
|
||||
std::size_t pathpos = 0; // Position in path_
|
||||
std::size_t newpos = 0; // Position in new path_
|
||||
|
||||
// Loop through each segment in original path_
|
||||
while (pathpos < pathlen) {
|
||||
while(pathpos < pathlen)
|
||||
{
|
||||
// Get next segment, bounded by '/' or end
|
||||
size_t slashpos = 0;
|
||||
while ((pathpos + slashpos) < pathlen) {
|
||||
if (path_[pathpos + slashpos] == '/') break;
|
||||
while((pathpos + slashpos) < pathlen)
|
||||
{
|
||||
if(path_[pathpos + slashpos] == '/')
|
||||
break;
|
||||
slashpos++;
|
||||
}
|
||||
// Check for .. and . segments
|
||||
if (slashpos == 2 && path_[pathpos] == '.' && path_[pathpos + 1] == '.') {
|
||||
if(slashpos == 2 && path_[pathpos] == '.' && path_[pathpos + 1] == '.')
|
||||
{
|
||||
// Backup a .. segment in the new path_
|
||||
// We expect to find a previously added slash at the end or nothing
|
||||
RAPIDJSON_ASSERT(newpos == 0 || path_[newpos - 1] == '/');
|
||||
size_t lastslashpos = newpos;
|
||||
// Make sure we don't go beyond the start segment
|
||||
if (lastslashpos > 1) {
|
||||
if(lastslashpos > 1)
|
||||
{
|
||||
// Find the next to last slash and back up to it
|
||||
lastslashpos--;
|
||||
while (lastslashpos > 0) {
|
||||
if (path_[lastslashpos - 1] == '/') break;
|
||||
while(lastslashpos > 0)
|
||||
{
|
||||
if(path_[lastslashpos - 1] == '/')
|
||||
break;
|
||||
lastslashpos--;
|
||||
}
|
||||
// Set the new path_ position
|
||||
newpos = lastslashpos;
|
||||
}
|
||||
} else if (slashpos == 1 && path_[pathpos] == '.') {
|
||||
}
|
||||
else if(slashpos == 1 && path_[pathpos] == '.')
|
||||
{
|
||||
// Discard . segment, leaves new path_ unchanged
|
||||
} else {
|
||||
}
|
||||
else
|
||||
{
|
||||
// Move any other kind of segment to the new path_
|
||||
RAPIDJSON_ASSERT(newpos <= pathpos);
|
||||
std::memmove(&path_[newpos], &path_[pathpos], slashpos * sizeof(Ch));
|
||||
newpos += slashpos;
|
||||
// Add slash if not at end
|
||||
if ((pathpos + slashpos) < pathlen) {
|
||||
if((pathpos + slashpos) < pathlen)
|
||||
{
|
||||
path_[newpos] = '/';
|
||||
newpos++;
|
||||
}
|
||||
@@ -465,8 +636,9 @@ private:
|
||||
Ch* query_; // Includes the ?
|
||||
Ch* frag_; // Includes the #
|
||||
|
||||
Allocator* allocator_; //!< The current allocator. It is either user-supplied or equal to ownAllocator_.
|
||||
Allocator* ownAllocator_; //!< Allocator owned by this Uri.
|
||||
Allocator* allocator_; //!< The current allocator. It is either user-supplied or equal to
|
||||
//!< ownAllocator_.
|
||||
Allocator* ownAllocator_; //!< Allocator owned by this Uri.
|
||||
};
|
||||
|
||||
//! GenericUri for Value (UTF-8, default allocator).
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -8,12 +8,12 @@ def __version__():
|
||||
hash = subprocess.check_output("git rev-parse HEAD", shell=True, text=True)[
|
||||
:hash_width
|
||||
]
|
||||
except:
|
||||
except Exception:
|
||||
hash = "0" * hash_width
|
||||
try:
|
||||
change_count = subprocess.check_output(
|
||||
f"git rev-list rocm-{rocm_version}..HEAD --count", shell=True, text=True
|
||||
).strip()
|
||||
except:
|
||||
except Exception:
|
||||
change_count = "0"
|
||||
return f"{rocm_version}.dev{change_count}+g{hash}"
|
||||
|
||||
@@ -14,43 +14,69 @@ Features:
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import os
|
||||
|
||||
|
||||
def run_dependency_parser(args):
|
||||
from src.enhanced_ninja_parser import main as ninja_main
|
||||
|
||||
sys.argv = ["enhanced_ninja_parser.py"] + args
|
||||
ninja_main()
|
||||
|
||||
|
||||
def run_selective_test_filter(args):
|
||||
from src.selective_test_filter import main as filter_main
|
||||
|
||||
sys.argv = ["selective_test_filter.py"] + args
|
||||
filter_main()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Unified Ninja Dependency & Selective Testing Tool")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Unified Ninja Dependency & Selective Testing Tool"
|
||||
)
|
||||
subparsers = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
# Dependency parsing
|
||||
parser_parse = subparsers.add_parser("parse", help="Parse build.ninja and generate dependency mapping")
|
||||
parser_parse = subparsers.add_parser(
|
||||
"parse", help="Parse build.ninja and generate dependency mapping"
|
||||
)
|
||||
parser_parse.add_argument("build_ninja", help="Path to build.ninja")
|
||||
parser_parse.add_argument("--ninja", help="Path to ninja executable", default="ninja")
|
||||
parser_parse.add_argument("--workspace-root", help="Path to workspace root", default=None)
|
||||
parser_parse.add_argument(
|
||||
"--ninja", help="Path to ninja executable", default="ninja"
|
||||
)
|
||||
parser_parse.add_argument(
|
||||
"--workspace-root", help="Path to workspace root", default=None
|
||||
)
|
||||
|
||||
# Selective testing
|
||||
parser_test = subparsers.add_parser("select", help="Selective test filtering between git refs")
|
||||
parser_test = subparsers.add_parser(
|
||||
"select", help="Selective test filtering between git refs"
|
||||
)
|
||||
parser_test.add_argument("depmap_json", help="Path to dependency mapping JSON")
|
||||
parser_test.add_argument("ref1", help="Source git ref")
|
||||
parser_test.add_argument("ref2", help="Target git ref")
|
||||
parser_test.add_argument("--all", action="store_true", help="Include all executables")
|
||||
parser_test.add_argument("--test-prefix", action="store_true", help="Only include executables starting with 'test_'")
|
||||
parser_test.add_argument("--output", help="Output JSON file", default="tests_to_run.json")
|
||||
parser_test.add_argument(
|
||||
"--all", action="store_true", help="Include all executables"
|
||||
)
|
||||
parser_test.add_argument(
|
||||
"--test-prefix",
|
||||
action="store_true",
|
||||
help="Only include executables starting with 'test_'",
|
||||
)
|
||||
parser_test.add_argument(
|
||||
"--output", help="Output JSON file", default="tests_to_run.json"
|
||||
)
|
||||
|
||||
# Code auditing
|
||||
parser_audit = subparsers.add_parser("audit", help="List all files and their dependent executables")
|
||||
parser_audit = subparsers.add_parser(
|
||||
"audit", help="List all files and their dependent executables"
|
||||
)
|
||||
parser_audit.add_argument("depmap_json", help="Path to dependency mapping JSON")
|
||||
|
||||
# Build optimization
|
||||
parser_opt = subparsers.add_parser("optimize", help="List affected executables for changed files")
|
||||
parser_opt = subparsers.add_parser(
|
||||
"optimize", help="List affected executables for changed files"
|
||||
)
|
||||
parser_opt.add_argument("depmap_json", help="Path to dependency mapping JSON")
|
||||
parser_opt.add_argument("changed_files", nargs="+", help="List of changed files")
|
||||
|
||||
@@ -73,9 +99,12 @@ def main():
|
||||
elif args.command == "audit":
|
||||
run_selective_test_filter([args.depmap_json, "--audit"])
|
||||
elif args.command == "optimize":
|
||||
run_selective_test_filter([args.depmap_json, "--optimize-build"] + args.changed_files)
|
||||
run_selective_test_filter(
|
||||
[args.depmap_json, "--optimize-build"] + args.changed_files
|
||||
)
|
||||
else:
|
||||
parser.print_help()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -14,96 +14,100 @@ import re
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
import json
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import threading
|
||||
|
||||
|
||||
class EnhancedNinjaDependencyParser:
|
||||
def __init__(self, build_file_path, ninja_executable="ninja"):
|
||||
self.build_file_path = build_file_path
|
||||
self.build_dir = os.path.dirname(build_file_path)
|
||||
self.ninja_executable = ninja_executable
|
||||
|
||||
|
||||
# Core data structures
|
||||
self.executable_to_objects = {} # exe -> [object_files]
|
||||
self.object_to_source = {} # object -> primary_source
|
||||
self.object_to_all_deps = {} # object -> [all_dependencies]
|
||||
self.object_to_source = {} # object -> primary_source
|
||||
self.object_to_all_deps = {} # object -> [all_dependencies]
|
||||
self.file_to_executables = defaultdict(set) # file -> {executables}
|
||||
|
||||
|
||||
# Thread safety
|
||||
self.lock = threading.Lock()
|
||||
|
||||
|
||||
def parse_dependencies(self):
|
||||
"""Main method to parse all dependencies."""
|
||||
print(f"Parsing ninja dependencies from: {self.build_file_path}")
|
||||
|
||||
|
||||
# Step 1: Parse build file for executable -> object mappings
|
||||
self._parse_build_file()
|
||||
|
||||
|
||||
# Step 2: Get all object files and their dependencies
|
||||
print(f"Found {len(self.object_to_source)} object files")
|
||||
print("Extracting detailed dependencies for all object files...")
|
||||
self._extract_object_dependencies()
|
||||
|
||||
|
||||
# Step 3: Build the final file -> executables mapping
|
||||
self._build_file_to_executable_mapping()
|
||||
|
||||
|
||||
def _parse_build_file(self):
|
||||
"""Parse the ninja build file to extract executable -> object mappings."""
|
||||
print("Parsing ninja build file...")
|
||||
|
||||
with open(self.build_file_path, 'r') as f:
|
||||
|
||||
with open(self.build_file_path, "r") as f:
|
||||
content = f.read()
|
||||
# Parse executable build rules
|
||||
exe_pattern = r'^build (bin/[^:]+):\s+\S+\s+([^|]+)'
|
||||
obj_pattern = r'^build ([^:]+\.(?:cpp|cu|hip)\.o):\s+\S+\s+([^\s|]+)'
|
||||
|
||||
lines = content.split('\n')
|
||||
|
||||
# Parse executable build rules
|
||||
exe_pattern = r"^build (bin/[^:]+):\s+\S+\s+([^|]+)"
|
||||
obj_pattern = r"^build ([^:]+\.(?:cpp|cu|hip)\.o):\s+\S+\s+([^\s|]+)"
|
||||
|
||||
lines = content.split("\n")
|
||||
|
||||
for line in lines:
|
||||
# Match executable rules
|
||||
exe_match = re.match(exe_pattern, line)
|
||||
if exe_match and ('EXECUTABLE' in line or 'test_' in exe_match.group(1) or 'example_' in exe_match.group(1)):
|
||||
if exe_match and (
|
||||
"EXECUTABLE" in line
|
||||
or "test_" in exe_match.group(1)
|
||||
or "example_" in exe_match.group(1)
|
||||
):
|
||||
exe = exe_match.group(1)
|
||||
deps_part = exe_match.group(2).strip()
|
||||
|
||||
|
||||
object_files = []
|
||||
for dep in deps_part.split():
|
||||
if dep.endswith('.o') and not dep.startswith('/'):
|
||||
if dep.endswith(".o") and not dep.startswith("/"):
|
||||
object_files.append(dep)
|
||||
|
||||
|
||||
self.executable_to_objects[exe] = object_files
|
||||
continue
|
||||
|
||||
|
||||
# Match object compilation rules
|
||||
obj_match = re.match(obj_pattern, line)
|
||||
if obj_match:
|
||||
object_file = obj_match.group(1)
|
||||
source_file = obj_match.group(2)
|
||||
self.object_to_source[object_file] = source_file
|
||||
|
||||
|
||||
print(f"Found {len(self.executable_to_objects)} executables")
|
||||
print(f"Found {len(self.object_to_source)} object-to-source mappings")
|
||||
|
||||
|
||||
def _extract_object_dependencies(self):
|
||||
"""Extract detailed dependencies for all object files using ninja -t deps."""
|
||||
object_files = list(self.object_to_source.keys())
|
||||
# Process object files in parallel for better performance
|
||||
# Process object files in parallel for better performance
|
||||
if not object_files:
|
||||
print("No object files found - skipping dependency extraction")
|
||||
return
|
||||
|
||||
|
||||
max_workers = min(16, len(object_files)) # Limit concurrent processes
|
||||
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit all object files for processing
|
||||
future_to_obj = {
|
||||
executor.submit(self._get_object_dependencies, obj): obj
|
||||
executor.submit(self._get_object_dependencies, obj): obj
|
||||
for obj in object_files
|
||||
}
|
||||
# Process completed futures
|
||||
# Process completed futures
|
||||
completed = 0
|
||||
for future in as_completed(future_to_obj):
|
||||
obj_file = future_to_obj[future]
|
||||
@@ -113,52 +117,52 @@ class EnhancedNinjaDependencyParser:
|
||||
self.object_to_all_deps[obj_file] = dependencies
|
||||
completed += 1
|
||||
if completed % 100 == 0:
|
||||
print(f"Processed {completed}/{len(object_files)} object files...")
|
||||
print(
|
||||
f"Processed {completed}/{len(object_files)} object files..."
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error processing {obj_file}: {e}")
|
||||
|
||||
print(f"Completed dependency extraction for {len(self.object_to_all_deps)} object files")
|
||||
|
||||
|
||||
print(
|
||||
f"Completed dependency extraction for {len(self.object_to_all_deps)} object files"
|
||||
)
|
||||
|
||||
def _get_object_dependencies(self, object_file):
|
||||
"""Get all dependencies for a single object file using ninja -t deps."""
|
||||
try:
|
||||
# Run ninja -t deps for this object file
|
||||
cmd = [self.ninja_executable, "-t", "deps", object_file]
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
cwd=self.build_dir,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30
|
||||
cmd, cwd=self.build_dir, capture_output=True, text=True, timeout=30
|
||||
)
|
||||
|
||||
|
||||
if result.returncode != 0:
|
||||
return []
|
||||
|
||||
|
||||
dependencies = []
|
||||
lines = result.stdout.strip().split('\n')
|
||||
|
||||
lines = result.stdout.strip().split("\n")
|
||||
|
||||
for line in lines[1:]: # Skip first line with metadata
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#'):
|
||||
if line and not line.startswith("#"):
|
||||
# Convert absolute paths to relative paths from workspace root
|
||||
dep_file = line
|
||||
ws_root = getattr(self, "workspace_root", "..")
|
||||
ws_prefix = ws_root.rstrip("/") + "/"
|
||||
if dep_file.startswith(ws_prefix):
|
||||
dep_file = dep_file[len(ws_prefix):]
|
||||
dep_file = dep_file[len(ws_prefix) :]
|
||||
dependencies.append(dep_file)
|
||||
|
||||
|
||||
return dependencies
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error getting dependencies for {object_file}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def _build_file_to_executable_mapping(self):
|
||||
"""Build the final mapping from files to executables."""
|
||||
print("Building file-to-executable mapping...")
|
||||
|
||||
|
||||
for exe, object_files in self.executable_to_objects.items():
|
||||
for obj_file in object_files:
|
||||
# Add all dependencies of this object file
|
||||
@@ -167,106 +171,135 @@ class EnhancedNinjaDependencyParser:
|
||||
# Filter out system files and focus on project files
|
||||
if self._is_project_file(dep_file):
|
||||
self.file_to_executables[dep_file].add(exe)
|
||||
|
||||
|
||||
print(f"Built mapping for {len(self.file_to_executables)} files")
|
||||
|
||||
|
||||
# Show statistics
|
||||
multi_exe_files = {f: exes for f, exes in self.file_to_executables.items() if len(exes) > 1}
|
||||
multi_exe_files = {
|
||||
f: exes for f, exes in self.file_to_executables.items() if len(exes) > 1
|
||||
}
|
||||
print(f"Files used by multiple executables: {len(multi_exe_files)}")
|
||||
|
||||
|
||||
if multi_exe_files:
|
||||
print("Sample files with multiple dependencies:")
|
||||
for f, exes in sorted(multi_exe_files.items())[:5]:
|
||||
print(f" {f}: {len(exes)} executables")
|
||||
|
||||
|
||||
def _is_project_file(self, file_path):
|
||||
"""Determine if a file is part of the project (not system files)."""
|
||||
# Include files that are clearly part of the project
|
||||
if any(file_path.startswith(prefix) for prefix in [
|
||||
'include/', 'library/', 'test/', 'example/', 'src/', 'profiler/',
|
||||
'build/include/', 'build/_deps/gtest', 'client_example', 'codegen', 'tile_engine'
|
||||
]):
|
||||
if any(
|
||||
file_path.startswith(prefix)
|
||||
for prefix in [
|
||||
"include/",
|
||||
"library/",
|
||||
"test/",
|
||||
"example/",
|
||||
"src/",
|
||||
"profiler/",
|
||||
"build/include/",
|
||||
"build/_deps/gtest",
|
||||
"client_example",
|
||||
"codegen",
|
||||
"tile_engine",
|
||||
]
|
||||
):
|
||||
return True
|
||||
|
||||
|
||||
# Exclude system files
|
||||
if any(file_path.startswith(prefix) for prefix in [
|
||||
'/usr/', '/opt/rocm', '/lib/', '/system/', '/local/'
|
||||
]):
|
||||
if any(
|
||||
file_path.startswith(prefix)
|
||||
for prefix in ["/usr/", "/opt/rocm", "/lib/", "/system/", "/local/"]
|
||||
):
|
||||
return False
|
||||
|
||||
|
||||
# Include files with common source/header extensions
|
||||
if file_path.endswith(('.cpp', '.hpp', '.h', '.c', '.cc', '.cxx', '.cu', '.hip', '.inc')):
|
||||
if file_path.endswith(
|
||||
(".cpp", ".hpp", ".h", ".c", ".cc", ".cxx", ".cu", ".hip", ".inc")
|
||||
):
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def export_to_csv(self, output_file):
|
||||
"""Export the file-to-executable mapping to CSV with proper comma separation."""
|
||||
print(f"Exporting mapping to {output_file}")
|
||||
|
||||
with open(output_file, 'w') as f:
|
||||
|
||||
with open(output_file, "w") as f:
|
||||
f.write("source_file,executables\n")
|
||||
for file_path in sorted(self.file_to_executables.keys()):
|
||||
executables = sorted(self.file_to_executables[file_path])
|
||||
# Use semicolon to separate multiple executables within the field
|
||||
exe_list = ';'.join(executables)
|
||||
exe_list = ";".join(executables)
|
||||
f.write(f'"{file_path}","{exe_list}"\n')
|
||||
|
||||
|
||||
def export_to_json(self, output_file):
|
||||
"""Export the complete mapping to JSON."""
|
||||
print(f"Exporting complete mapping to {output_file}")
|
||||
|
||||
|
||||
# Build reverse mapping (executable -> files)
|
||||
exe_to_files = defaultdict(set)
|
||||
for file_path, exes in self.file_to_executables.items():
|
||||
for exe in exes:
|
||||
exe_to_files[exe].add(file_path)
|
||||
|
||||
|
||||
mapping_data = {
|
||||
'file_to_executables': {
|
||||
file_path: list(exes) for file_path, exes in self.file_to_executables.items()
|
||||
"file_to_executables": {
|
||||
file_path: list(exes)
|
||||
for file_path, exes in self.file_to_executables.items()
|
||||
},
|
||||
'executable_to_files': {
|
||||
"executable_to_files": {
|
||||
exe: sorted(files) for exe, files in exe_to_files.items()
|
||||
},
|
||||
'statistics': {
|
||||
'total_files': len(self.file_to_executables),
|
||||
'total_executables': len(self.executable_to_objects),
|
||||
'total_object_files': len(self.object_to_source),
|
||||
'files_with_multiple_executables': len([f for f, exes in self.file_to_executables.items() if len(exes) > 1])
|
||||
}
|
||||
"statistics": {
|
||||
"total_files": len(self.file_to_executables),
|
||||
"total_executables": len(self.executable_to_objects),
|
||||
"total_object_files": len(self.object_to_source),
|
||||
"files_with_multiple_executables": len(
|
||||
[f for f, exes in self.file_to_executables.items() if len(exes) > 1]
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
with open(output_file, 'w') as f:
|
||||
|
||||
with open(output_file, "w") as f:
|
||||
json.dump(mapping_data, f, indent=2)
|
||||
|
||||
|
||||
def print_summary(self):
|
||||
"""Print a summary of the parsed dependencies."""
|
||||
"""Print a summary of the parsed dependencies."""
|
||||
print("\n=== Enhanced Dependency Mapping Summary ===")
|
||||
print(f"Total executables: {len(self.executable_to_objects)}")
|
||||
print(f"Total files mapped: {len(self.file_to_executables)}")
|
||||
print(f"Total object files processed: {len(self.object_to_all_deps)}")
|
||||
|
||||
|
||||
# Files by type
|
||||
cpp_files = sum(1 for f in self.file_to_executables.keys() if f.endswith('.cpp'))
|
||||
hpp_files = sum(1 for f in self.file_to_executables.keys() if f.endswith('.hpp'))
|
||||
h_files = sum(1 for f in self.file_to_executables.keys() if f.endswith('.h'))
|
||||
|
||||
print(f"\nFile types:")
|
||||
cpp_files = sum(
|
||||
1 for f in self.file_to_executables.keys() if f.endswith(".cpp")
|
||||
)
|
||||
hpp_files = sum(
|
||||
1 for f in self.file_to_executables.keys() if f.endswith(".hpp")
|
||||
)
|
||||
h_files = sum(1 for f in self.file_to_executables.keys() if f.endswith(".h"))
|
||||
|
||||
print("\nFile types:")
|
||||
print(f" .cpp files: {cpp_files}")
|
||||
print(f" .hpp files: {hpp_files}")
|
||||
print(f" .h files: {h_files}")
|
||||
|
||||
|
||||
# Multi-executable files
|
||||
multi_exe_files = {f: exes for f, exes in self.file_to_executables.items() if len(exes) > 1}
|
||||
multi_exe_files = {
|
||||
f: exes for f, exes in self.file_to_executables.items() if len(exes) > 1
|
||||
}
|
||||
print(f"\nFiles used by multiple executables: {len(multi_exe_files)}")
|
||||
|
||||
|
||||
if multi_exe_files:
|
||||
print("\nTop files with most dependencies:")
|
||||
sorted_multi = sorted(multi_exe_files.items(), key=lambda x: len(x[1]), reverse=True)
|
||||
sorted_multi = sorted(
|
||||
multi_exe_files.items(), key=lambda x: len(x[1]), reverse=True
|
||||
)
|
||||
for file_path, exes in sorted_multi[:10]:
|
||||
print(f" {file_path}: {len(exes)} executables")
|
||||
|
||||
|
||||
def main():
|
||||
# Accept: build_file, ninja_path, workspace_root
|
||||
default_workspace_root = ".."
|
||||
@@ -304,15 +337,16 @@ def main():
|
||||
|
||||
# Export results
|
||||
output_dir = os.path.dirname(build_file)
|
||||
csv_file = os.path.join(output_dir, 'enhanced_file_executable_mapping.csv')
|
||||
json_file = os.path.join(output_dir, 'enhanced_dependency_mapping.json')
|
||||
csv_file = os.path.join(output_dir, "enhanced_file_executable_mapping.csv")
|
||||
json_file = os.path.join(output_dir, "enhanced_dependency_mapping.json")
|
||||
|
||||
parser.export_to_csv(csv_file)
|
||||
parser.export_to_json(json_file)
|
||||
|
||||
print(f"\nResults exported to:")
|
||||
print("\nResults exported to:")
|
||||
print(f" CSV: {csv_file}")
|
||||
print(f" JSON: {json_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -30,12 +30,15 @@ import subprocess
|
||||
import json
|
||||
import os
|
||||
|
||||
|
||||
def get_changed_files(ref1, ref2):
|
||||
"""Return a set of files changed between two git refs."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "diff", "--name-only", ref1, ref2],
|
||||
capture_output=True, text=True, check=True
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
files = set(line.strip() for line in result.stdout.splitlines() if line.strip())
|
||||
return files
|
||||
@@ -43,6 +46,7 @@ def get_changed_files(ref1, ref2):
|
||||
print(f"Error running git diff: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def load_depmap(depmap_json):
|
||||
"""Load the dependency mapping JSON."""
|
||||
with open(depmap_json, "r") as f:
|
||||
@@ -52,6 +56,7 @@ def load_depmap(depmap_json):
|
||||
return data["file_to_executables"]
|
||||
return data
|
||||
|
||||
|
||||
def select_tests(file_to_executables, changed_files, filter_mode):
|
||||
"""Return a set of test executables affected by changed files."""
|
||||
affected = set()
|
||||
@@ -64,6 +69,7 @@ def select_tests(file_to_executables, changed_files, filter_mode):
|
||||
affected.add(exe)
|
||||
return sorted(affected)
|
||||
|
||||
|
||||
def main():
|
||||
if "--audit" in sys.argv:
|
||||
if len(sys.argv) < 2:
|
||||
@@ -81,7 +87,9 @@ def main():
|
||||
|
||||
if "--optimize-build" in sys.argv:
|
||||
if len(sys.argv) < 3:
|
||||
print("Usage: python selective_test_filter.py <depmap_json> --optimize-build <changed_file1> [<changed_file2> ...]")
|
||||
print(
|
||||
"Usage: python selective_test_filter.py <depmap_json> --optimize-build <changed_file1> [<changed_file2> ...]"
|
||||
)
|
||||
sys.exit(1)
|
||||
depmap_json = sys.argv[1]
|
||||
changed_files = set(sys.argv[sys.argv.index("--optimize-build") + 1 :])
|
||||
@@ -100,7 +108,9 @@ def main():
|
||||
sys.exit(0)
|
||||
|
||||
if len(sys.argv) < 4:
|
||||
print("Usage: python selective_test_filter.py <depmap_json> <ref1> <ref2> [--all | --test-prefix] [--output <output_json>]")
|
||||
print(
|
||||
"Usage: python selective_test_filter.py <depmap_json> <ref1> <ref2> [--all | --test-prefix] [--output <output_json>]"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
depmap_json = sys.argv[1]
|
||||
@@ -131,9 +141,12 @@ def main():
|
||||
tests = select_tests(file_to_executables, changed_files, filter_mode)
|
||||
|
||||
with open(output_json, "w") as f:
|
||||
json.dump({"tests_to_run": tests, "changed_files": sorted(changed_files)}, f, indent=2)
|
||||
json.dump(
|
||||
{"tests_to_run": tests, "changed_files": sorted(changed_files)}, f, indent=2
|
||||
)
|
||||
|
||||
print(f"Exported {len(tests)} tests to run to {output_json}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -12,38 +12,38 @@ import os
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Iterator
|
||||
from typing import Dict, List, Optional, Iterator
|
||||
|
||||
|
||||
class BuildTarget:
|
||||
"""Represents a single build target with timing information."""
|
||||
|
||||
|
||||
def __init__(self, start_time: int, end_time: int, output_name: str, cmd_hash: str):
|
||||
self.start_time = int(start_time)
|
||||
self.end_time = int(end_time)
|
||||
self.cmd_hash = cmd_hash
|
||||
self.duration = self.end_time - self.start_time
|
||||
self.targets = [output_name] # List of target names for this command hash
|
||||
|
||||
|
||||
@property
|
||||
def category(self) -> str:
|
||||
"""Categorize the build target based on file extension."""
|
||||
# Use the first target for categorization
|
||||
primary_target = self.targets[0] if self.targets else ""
|
||||
ext = Path(primary_target).suffix.lower()
|
||||
if ext in ['.o', '.obj']:
|
||||
return 'compile'
|
||||
elif ext in ['.a', '.lib']:
|
||||
return 'archive'
|
||||
elif ext in ['.so', '.dll', '.dylib']:
|
||||
return 'link_shared'
|
||||
elif ext in ['.exe', '.out']:
|
||||
return 'link_executable'
|
||||
elif 'test' in primary_target.lower():
|
||||
return 'test'
|
||||
if ext in [".o", ".obj"]:
|
||||
return "compile"
|
||||
elif ext in [".a", ".lib"]:
|
||||
return "archive"
|
||||
elif ext in [".so", ".dll", ".dylib"]:
|
||||
return "link_shared"
|
||||
elif ext in [".exe", ".out"]:
|
||||
return "link_executable"
|
||||
elif "test" in primary_target.lower():
|
||||
return "test"
|
||||
else:
|
||||
return 'other'
|
||||
|
||||
return "other"
|
||||
|
||||
@property
|
||||
def output_name(self) -> str:
|
||||
"""Get the primary output name (for backward compatibility)."""
|
||||
@@ -52,11 +52,11 @@ class BuildTarget:
|
||||
|
||||
class ThreadScheduler:
|
||||
"""Simulates thread allocation for parallelism analysis."""
|
||||
|
||||
|
||||
def __init__(self, legacy_mode: bool = False):
|
||||
self.workers: List[int] = []
|
||||
self.legacy_mode = legacy_mode
|
||||
|
||||
|
||||
def allocate_thread(self, target: BuildTarget) -> int:
|
||||
"""Allocate a thread for the given target."""
|
||||
if self.legacy_mode:
|
||||
@@ -73,7 +73,7 @@ class ThreadScheduler:
|
||||
if worker_end_time <= target.start_time:
|
||||
self.workers[i] = target.end_time
|
||||
return i
|
||||
|
||||
|
||||
# No available worker, create a new one
|
||||
self.workers.append(target.end_time)
|
||||
return len(self.workers) - 1
|
||||
@@ -81,62 +81,67 @@ class ThreadScheduler:
|
||||
|
||||
class NinjaLogParser:
|
||||
"""Parser for ninja build log files."""
|
||||
|
||||
|
||||
def __init__(self, show_all_builds: bool = False):
|
||||
self.show_all_builds = show_all_builds
|
||||
|
||||
|
||||
def parse_log_file(self, log_path: str) -> List[BuildTarget]:
|
||||
"""Parse the ninja log file and return build targets."""
|
||||
if not os.path.exists(log_path):
|
||||
raise FileNotFoundError(f"Ninja log file not found: {log_path}")
|
||||
|
||||
with open(log_path, 'r', encoding='utf-8') as file:
|
||||
|
||||
with open(log_path, "r", encoding="utf-8") as file:
|
||||
lines = file.readlines()
|
||||
|
||||
|
||||
if not lines:
|
||||
raise ValueError("Empty ninja log file")
|
||||
|
||||
|
||||
# Parse and validate header
|
||||
header = lines[0].strip()
|
||||
version_match = re.match(r'^# ninja log v(\d+)$', header)
|
||||
version_match = re.match(r"^# ninja log v(\d+)$", header)
|
||||
if not version_match:
|
||||
raise ValueError(f"Invalid ninja log header: {header}")
|
||||
|
||||
|
||||
version = int(version_match.group(1))
|
||||
if version < 5:
|
||||
raise ValueError(f"Unsupported ninja log version: {version}")
|
||||
|
||||
|
||||
# Skip additional header line for version 6
|
||||
start_line = 2 if version > 5 else 1
|
||||
|
||||
|
||||
targets: Dict[str, BuildTarget] = {}
|
||||
last_end_time = 0
|
||||
|
||||
|
||||
for line_num, line in enumerate(lines[start_line:], start=start_line + 1):
|
||||
line = line.strip()
|
||||
|
||||
|
||||
# Skip empty lines and comments
|
||||
if not line or line.startswith('#'):
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
|
||||
parts = line.split('\t')
|
||||
|
||||
parts = line.split("\t")
|
||||
if len(parts) < 5:
|
||||
print(f"Warning: Skipping malformed line {line_num}: {line}", file=sys.stderr)
|
||||
print(
|
||||
f"Warning: Skipping malformed line {line_num}: {line}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
continue
|
||||
|
||||
|
||||
try:
|
||||
start_time, end_time, _, output_name, cmd_hash = parts[:5]
|
||||
start_time, end_time = int(start_time), int(end_time)
|
||||
|
||||
|
||||
# Handle incremental builds
|
||||
if not self.show_all_builds and end_time < last_end_time:
|
||||
targets.clear()
|
||||
|
||||
|
||||
last_end_time = end_time
|
||||
|
||||
|
||||
# Group targets by command hash
|
||||
if cmd_hash not in targets:
|
||||
targets[cmd_hash] = BuildTarget(start_time, end_time, output_name, cmd_hash)
|
||||
targets[cmd_hash] = BuildTarget(
|
||||
start_time, end_time, output_name, cmd_hash
|
||||
)
|
||||
else:
|
||||
# Update with the latest timing and add output
|
||||
existing = targets[cmd_hash]
|
||||
@@ -144,223 +149,260 @@ class NinjaLogParser:
|
||||
existing.end_time = max(existing.end_time, end_time)
|
||||
existing.duration = existing.end_time - existing.start_time
|
||||
existing.targets.append(output_name)
|
||||
|
||||
|
||||
except (ValueError, IndexError) as e:
|
||||
print(f"Warning: Error parsing line {line_num}: {e}", file=sys.stderr)
|
||||
continue
|
||||
|
||||
|
||||
return sorted(targets.values(), key=lambda t: t.end_time, reverse=True)
|
||||
|
||||
|
||||
class FTimeTraceReader:
|
||||
"""Reads and processes Clang -ftime-trace JSON files."""
|
||||
|
||||
|
||||
def __init__(self, granularity_us: int = 50000):
|
||||
self.granularity_us = granularity_us
|
||||
|
||||
|
||||
def read_trace_file(self, trace_path: str) -> Optional[Dict]:
|
||||
"""Read and parse a Clang time trace file."""
|
||||
try:
|
||||
with open(trace_path, 'r', encoding='utf-8') as f:
|
||||
with open(trace_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except (FileNotFoundError, json.JSONDecodeError, IOError):
|
||||
return None
|
||||
|
||||
|
||||
def filter_events(self, trace_data: Dict) -> List[Dict]:
|
||||
"""Filter trace events based on criteria."""
|
||||
if 'traceEvents' not in trace_data:
|
||||
if "traceEvents" not in trace_data:
|
||||
return []
|
||||
|
||||
|
||||
filtered_events = []
|
||||
for event in trace_data['traceEvents']:
|
||||
for event in trace_data["traceEvents"]:
|
||||
# Only include complete events (ph=X) that meet duration threshold
|
||||
if (event.get('ph') == 'X' and
|
||||
event.get('dur', 0) >= self.granularity_us and
|
||||
not event.get('name', '').startswith('Total')):
|
||||
if (
|
||||
event.get("ph") == "X"
|
||||
and event.get("dur", 0) >= self.granularity_us
|
||||
and not event.get("name", "").startswith("Total")
|
||||
):
|
||||
filtered_events.append(event)
|
||||
|
||||
|
||||
return filtered_events
|
||||
|
||||
def adjust_event_timing(self, event: Dict, target: BuildTarget, pid: int, tid: int) -> Dict:
|
||||
|
||||
def adjust_event_timing(
|
||||
self, event: Dict, target: BuildTarget, pid: int, tid: int
|
||||
) -> Dict:
|
||||
"""Adjust event timing to align with ninja build timing."""
|
||||
ninja_duration_us = target.duration * 1000
|
||||
|
||||
|
||||
# Validate event duration against ninja timing
|
||||
if event.get('dur', 0) > ninja_duration_us:
|
||||
print(f"Warning: Clang trace event duration ({event['dur']}μs) exceeds "
|
||||
f"ninja duration ({ninja_duration_us}μs) for {target.output_name}",
|
||||
file=sys.stderr)
|
||||
if event.get("dur", 0) > ninja_duration_us:
|
||||
print(
|
||||
f"Warning: Clang trace event duration ({event['dur']}μs) exceeds "
|
||||
f"ninja duration ({ninja_duration_us}μs) for {target.output_name}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
# Adjust event timing
|
||||
adjusted_event = event.copy()
|
||||
adjusted_event['pid'] = pid
|
||||
adjusted_event['tid'] = tid
|
||||
adjusted_event['ts'] += target.start_time * 1000 # Offset by ninja start time
|
||||
|
||||
adjusted_event["pid"] = pid
|
||||
adjusted_event["tid"] = tid
|
||||
adjusted_event["ts"] += target.start_time * 1000 # Offset by ninja start time
|
||||
|
||||
return adjusted_event
|
||||
|
||||
|
||||
class ChromeTraceGenerator:
|
||||
"""Generates Chrome tracing format from build targets."""
|
||||
|
||||
def __init__(self, process_id: int = 1, embed_ftime_traces: bool = False,
|
||||
granularity_us: int = 50000, ninja_log_dir: Optional[str] = None,
|
||||
legacy_format: bool = False):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
process_id: int = 1,
|
||||
embed_ftime_traces: bool = False,
|
||||
granularity_us: int = 50000,
|
||||
ninja_log_dir: Optional[str] = None,
|
||||
legacy_format: bool = False,
|
||||
):
|
||||
self.process_id = process_id
|
||||
self.scheduler = ThreadScheduler(legacy_mode=legacy_format)
|
||||
self.embed_ftime_traces = embed_ftime_traces
|
||||
self.ninja_log_dir = ninja_log_dir
|
||||
self.ftime_reader = FTimeTraceReader(granularity_us) if embed_ftime_traces else None
|
||||
self.ftime_reader = (
|
||||
FTimeTraceReader(granularity_us) if embed_ftime_traces else None
|
||||
)
|
||||
self.legacy_format = legacy_format
|
||||
|
||||
|
||||
def find_ftime_trace_files(self, target: BuildTarget) -> List[str]:
|
||||
"""Find Clang -ftime-trace files for a build target."""
|
||||
if not self.ninja_log_dir:
|
||||
return []
|
||||
|
||||
|
||||
trace_files = []
|
||||
|
||||
|
||||
# Look for .json files adjacent to object files
|
||||
obj_path = Path(self.ninja_log_dir) / target.output_name
|
||||
json_path = obj_path.with_suffix('.json')
|
||||
|
||||
json_path = obj_path.with_suffix(".json")
|
||||
|
||||
if json_path.exists():
|
||||
trace_files.append(str(json_path))
|
||||
|
||||
|
||||
return trace_files
|
||||
|
||||
|
||||
def generate_ftime_events(self, target: BuildTarget, tid: int) -> Iterator[Dict]:
|
||||
"""Generate Clang -ftime-trace events for a target."""
|
||||
if not self.embed_ftime_traces or not self.ftime_reader:
|
||||
return
|
||||
|
||||
|
||||
trace_files = self.find_ftime_trace_files(target)
|
||||
|
||||
|
||||
for trace_file in trace_files:
|
||||
trace_data = self.ftime_reader.read_trace_file(trace_file)
|
||||
if not trace_data:
|
||||
continue
|
||||
|
||||
|
||||
filtered_events = self.ftime_reader.filter_events(trace_data)
|
||||
|
||||
|
||||
for event in filtered_events:
|
||||
adjusted_event = self.ftime_reader.adjust_event_timing(
|
||||
event, target, self.process_id, tid
|
||||
)
|
||||
if adjusted_event:
|
||||
yield adjusted_event
|
||||
|
||||
|
||||
def generate_trace_events(self, targets: List[BuildTarget]) -> List[Dict]:
|
||||
"""Generate Chrome trace events from build targets."""
|
||||
events = []
|
||||
|
||||
|
||||
for target in targets:
|
||||
thread_id = self.scheduler.allocate_thread(target)
|
||||
|
||||
|
||||
# Add main ninja build event
|
||||
if self.legacy_format:
|
||||
# Legacy format: join multiple targets with commas, use "targets" category, empty args
|
||||
target_name = ', '.join(target.targets) if len(target.targets) > 1 else target.output_name
|
||||
target_name = (
|
||||
", ".join(target.targets)
|
||||
if len(target.targets) > 1
|
||||
else target.output_name
|
||||
)
|
||||
ninja_event = {
|
||||
'name': target_name,
|
||||
'cat': 'targets',
|
||||
'ph': 'X', # Complete event
|
||||
'ts': target.start_time * 1000, # Convert to microseconds
|
||||
'dur': target.duration * 1000, # Convert to microseconds
|
||||
'pid': self.process_id,
|
||||
'tid': thread_id,
|
||||
'args': {}
|
||||
"name": target_name,
|
||||
"cat": "targets",
|
||||
"ph": "X", # Complete event
|
||||
"ts": target.start_time * 1000, # Convert to microseconds
|
||||
"dur": target.duration * 1000, # Convert to microseconds
|
||||
"pid": self.process_id,
|
||||
"tid": thread_id,
|
||||
"args": {},
|
||||
}
|
||||
else:
|
||||
# New format: smart categorization, detailed args
|
||||
ninja_event = {
|
||||
'name': target.output_name,
|
||||
'cat': target.category,
|
||||
'ph': 'X', # Complete event
|
||||
'ts': target.start_time * 1000, # Convert to microseconds
|
||||
'dur': target.duration * 1000, # Convert to microseconds
|
||||
'pid': self.process_id,
|
||||
'tid': thread_id,
|
||||
'args': {
|
||||
'output': target.output_name,
|
||||
'duration_ms': target.duration,
|
||||
'cmd_hash': target.cmd_hash
|
||||
}
|
||||
"name": target.output_name,
|
||||
"cat": target.category,
|
||||
"ph": "X", # Complete event
|
||||
"ts": target.start_time * 1000, # Convert to microseconds
|
||||
"dur": target.duration * 1000, # Convert to microseconds
|
||||
"pid": self.process_id,
|
||||
"tid": thread_id,
|
||||
"args": {
|
||||
"output": target.output_name,
|
||||
"duration_ms": target.duration,
|
||||
"cmd_hash": target.cmd_hash,
|
||||
},
|
||||
}
|
||||
events.append(ninja_event)
|
||||
|
||||
|
||||
# Add embedded Clang -ftime-trace events
|
||||
if self.embed_ftime_traces:
|
||||
ftime_events = list(self.generate_ftime_events(target, thread_id))
|
||||
events.extend(ftime_events)
|
||||
|
||||
|
||||
if ftime_events:
|
||||
print(f"Embedded {len(ftime_events)} -ftime-trace events for {target.output_name}",
|
||||
file=sys.stderr)
|
||||
|
||||
print(
|
||||
f"Embedded {len(ftime_events)} -ftime-trace events for {target.output_name}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
return events
|
||||
|
||||
|
||||
class BuildAnalyzer:
|
||||
"""Analyzes build performance and provides statistics."""
|
||||
|
||||
|
||||
def __init__(self, targets: List[BuildTarget]):
|
||||
self.targets = targets
|
||||
|
||||
|
||||
def get_build_summary(self) -> Dict:
|
||||
"""Generate build performance summary."""
|
||||
if not self.targets:
|
||||
return {}
|
||||
|
||||
|
||||
total_duration = sum(t.duration for t in self.targets)
|
||||
total_targets = len(self.targets)
|
||||
|
||||
|
||||
# Category statistics
|
||||
category_stats = {}
|
||||
for target in self.targets:
|
||||
cat = target.category
|
||||
if cat not in category_stats:
|
||||
category_stats[cat] = {'count': 0, 'total_time': 0}
|
||||
category_stats[cat]['count'] += 1
|
||||
category_stats[cat]['total_time'] += target.duration
|
||||
|
||||
category_stats[cat] = {"count": 0, "total_time": 0}
|
||||
category_stats[cat]["count"] += 1
|
||||
category_stats[cat]["total_time"] += target.duration
|
||||
|
||||
# Top slowest targets
|
||||
slowest_targets = sorted(self.targets, key=lambda t: t.duration, reverse=True)[:10]
|
||||
|
||||
slowest_targets = sorted(self.targets, key=lambda t: t.duration, reverse=True)[
|
||||
:10
|
||||
]
|
||||
|
||||
return {
|
||||
'total_targets': total_targets,
|
||||
'total_duration_ms': total_duration,
|
||||
'total_duration_sec': total_duration / 1000,
|
||||
'average_duration_ms': total_duration / total_targets if total_targets > 0 else 0,
|
||||
'category_stats': category_stats,
|
||||
'slowest_targets': [
|
||||
{'name': t.output_name, 'duration_ms': t.duration, 'category': t.category}
|
||||
"total_targets": total_targets,
|
||||
"total_duration_ms": total_duration,
|
||||
"total_duration_sec": total_duration / 1000,
|
||||
"average_duration_ms": total_duration / total_targets
|
||||
if total_targets > 0
|
||||
else 0,
|
||||
"category_stats": category_stats,
|
||||
"slowest_targets": [
|
||||
{
|
||||
"name": t.output_name,
|
||||
"duration_ms": t.duration,
|
||||
"category": t.category,
|
||||
}
|
||||
for t in slowest_targets
|
||||
]
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def print_summary(self):
|
||||
"""Print build summary to stderr."""
|
||||
summary = self.get_build_summary()
|
||||
if not summary:
|
||||
print("No build data available", file=sys.stderr)
|
||||
return
|
||||
|
||||
print(f"\n=== Build Summary ===", file=sys.stderr)
|
||||
|
||||
print("\n=== Build Summary ===", file=sys.stderr)
|
||||
print(f"Total targets: {summary['total_targets']}", file=sys.stderr)
|
||||
print(f"Total time: {summary['total_duration_sec']:.2f}s", file=sys.stderr)
|
||||
print(f"Average time per target: {summary['average_duration_ms']:.2f}ms", file=sys.stderr)
|
||||
|
||||
print(f"\nBy category:", file=sys.stderr)
|
||||
for category, stats in summary['category_stats'].items():
|
||||
avg_time = stats['total_time'] / stats['count'] if stats['count'] > 0 else 0
|
||||
print(f" {category:15} {stats['count']:6} targets "
|
||||
f"{stats['total_time']/1000:8.2f}s "
|
||||
f"(avg: {avg_time/1000:.3f}s)", file=sys.stderr)
|
||||
|
||||
print(f"\nSlowest targets:", file=sys.stderr)
|
||||
for i, target in enumerate(summary['slowest_targets'][:5], 1):
|
||||
print(f" {i:2}. {target['name']} ({target['duration_ms']}ms, {target['category']})", file=sys.stderr)
|
||||
print(
|
||||
f"Average time per target: {summary['average_duration_ms']:.2f}ms",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
print("\nBy category:", file=sys.stderr)
|
||||
for category, stats in summary["category_stats"].items():
|
||||
avg_time = stats["total_time"] / stats["count"] if stats["count"] > 0 else 0
|
||||
print(
|
||||
f" {category:15} {stats['count']:6} targets "
|
||||
f"{stats['total_time'] / 1000:8.2f}s "
|
||||
f"(avg: {avg_time / 1000:.3f}s)",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
print("\nSlowest targets:", file=sys.stderr)
|
||||
for i, target in enumerate(summary["slowest_targets"][:5], 1):
|
||||
print(
|
||||
f" {i:2}. {target['name']} ({target['duration_ms']}ms, {target['category']})",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
|
||||
def create_argument_parser() -> argparse.ArgumentParser:
|
||||
@@ -376,57 +418,48 @@ Examples:
|
||||
%(prog)s build/.ninja_log --show-all # Include all builds
|
||||
%(prog)s build/.ninja_log --embed-ftime-trace # Include Clang timing data
|
||||
%(prog)s build/.ninja_log --granularity 10000 # Custom granularity threshold
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
'ninja_logs',
|
||||
nargs='+', # Accept one or more ninja log files
|
||||
help='Path(s) to the .ninja_log file(s)'
|
||||
"ninja_logs",
|
||||
nargs="+", # Accept one or more ninja log files
|
||||
help="Path(s) to the .ninja_log file(s)",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument("-o", "--output", help="Output file (default: stdout)")
|
||||
|
||||
parser.add_argument(
|
||||
'-o', '--output',
|
||||
help='Output file (default: stdout)'
|
||||
"--show-all", action="store_true", help="Show all builds, not just the last one"
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
'--show-all',
|
||||
action='store_true',
|
||||
help='Show all builds, not just the last one'
|
||||
"--summary", action="store_true", help="Print build summary to stderr"
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
'--summary',
|
||||
action='store_true',
|
||||
help='Print build summary to stderr'
|
||||
"--pretty", action="store_true", help="Pretty-print JSON output"
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
'--pretty',
|
||||
action='store_true',
|
||||
help='Pretty-print JSON output'
|
||||
"--embed-ftime-trace",
|
||||
action="store_true",
|
||||
help="Embed Clang -ftime-trace JSON files found adjacent to targets",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
'--embed-ftime-trace',
|
||||
action='store_true',
|
||||
help='Embed Clang -ftime-trace JSON files found adjacent to targets'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--granularity',
|
||||
"--granularity",
|
||||
type=int,
|
||||
default=50000,
|
||||
help='Minimum duration for -ftime-trace events in microseconds (default: 50000)'
|
||||
help="Minimum duration for -ftime-trace events in microseconds (default: 50000)",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
'--legacy-format',
|
||||
action='store_true',
|
||||
help='Output in legacy format compatible with old ninjatracer (simple JSON array, all categories as "targets", empty args)'
|
||||
"--legacy-format",
|
||||
action="store_true",
|
||||
help='Output in legacy format compatible with old ninjatracer (simple JSON array, all categories as "targets", empty args)',
|
||||
)
|
||||
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@@ -434,75 +467,79 @@ def main():
|
||||
"""Main entry point."""
|
||||
parser = create_argument_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
try:
|
||||
# Process multiple ninja log files
|
||||
all_events = []
|
||||
|
||||
|
||||
for pid, ninja_log_path in enumerate(args.ninja_logs):
|
||||
# Parse ninja log
|
||||
log_parser = NinjaLogParser(show_all_builds=args.show_all)
|
||||
targets = log_parser.parse_log_file(ninja_log_path)
|
||||
|
||||
|
||||
if not targets:
|
||||
print(f"No build targets found in ninja log: {ninja_log_path}", file=sys.stderr)
|
||||
print(
|
||||
f"No build targets found in ninja log: {ninja_log_path}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
continue
|
||||
|
||||
|
||||
# Determine ninja log directory for -ftime-trace files
|
||||
ninja_log_dir = os.path.dirname(os.path.abspath(ninja_log_path)) if args.embed_ftime_trace else None
|
||||
|
||||
ninja_log_dir = (
|
||||
os.path.dirname(os.path.abspath(ninja_log_path))
|
||||
if args.embed_ftime_trace
|
||||
else None
|
||||
)
|
||||
|
||||
# Generate trace events for this log file
|
||||
trace_generator = ChromeTraceGenerator(
|
||||
process_id=pid, # Use different PID for each log file
|
||||
embed_ftime_traces=args.embed_ftime_trace,
|
||||
granularity_us=args.granularity,
|
||||
ninja_log_dir=ninja_log_dir,
|
||||
legacy_format=args.legacy_format
|
||||
legacy_format=args.legacy_format,
|
||||
)
|
||||
events = trace_generator.generate_trace_events(targets)
|
||||
all_events.extend(events)
|
||||
|
||||
|
||||
# Print summary if requested (for each log file)
|
||||
if args.summary:
|
||||
print(f"\n=== Summary for {ninja_log_path} ===", file=sys.stderr)
|
||||
analyzer = BuildAnalyzer(targets)
|
||||
analyzer.print_summary()
|
||||
|
||||
|
||||
if not all_events:
|
||||
print("No build targets found in any ninja log files", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
|
||||
# Output format logic
|
||||
if args.legacy_format:
|
||||
# Legacy format: always output simple JSON array
|
||||
json_kwargs = {'indent': 2} if args.pretty else {}
|
||||
json_kwargs = {"indent": 2} if args.pretty else {}
|
||||
json_output = json.dumps(all_events, **json_kwargs)
|
||||
elif args.output or args.pretty:
|
||||
# Enhanced format with metadata (when saving to file or pretty printing)
|
||||
trace_data = {
|
||||
'traceEvents': all_events,
|
||||
'displayTimeUnit': 'ms',
|
||||
'systemTraceEvents': 'SystemTraceData',
|
||||
'otherData': {
|
||||
'version': '1.0',
|
||||
'generator': 'ninja_json_converter.py'
|
||||
}
|
||||
"traceEvents": all_events,
|
||||
"displayTimeUnit": "ms",
|
||||
"systemTraceEvents": "SystemTraceData",
|
||||
"otherData": {"version": "1.0", "generator": "ninja_json_converter.py"},
|
||||
}
|
||||
json_kwargs = {'indent': 2} if args.pretty else {}
|
||||
json_kwargs = {"indent": 2} if args.pretty else {}
|
||||
json_output = json.dumps(trace_data, **json_kwargs)
|
||||
else:
|
||||
# Original format (simple JSON array to stdout)
|
||||
json_output = json.dumps(all_events)
|
||||
|
||||
|
||||
if args.output:
|
||||
with open(args.output, 'w') as f:
|
||||
with open(args.output, "w") as f:
|
||||
f.write(json_output)
|
||||
print(f"Trace written to {args.output}", file=sys.stderr)
|
||||
else:
|
||||
print(json_output)
|
||||
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
#!/usr/bin/env python3
|
||||
import os, io, argparse, datetime
|
||||
#import numpy as np
|
||||
import os
|
||||
import io
|
||||
import argparse
|
||||
import datetime
|
||||
|
||||
# import numpy as np
|
||||
import sqlalchemy
|
||||
from sqlalchemy.types import NVARCHAR, Float, Integer
|
||||
from sqlalchemy import text
|
||||
import pymysql
|
||||
import pandas as pd
|
||||
from sshtunnel import SSHTunnelForwarder
|
||||
|
||||
|
||||
def print_to_string(*args, **kwargs):
|
||||
output = io.StringIO()
|
||||
print(*args, file=output, **kwargs)
|
||||
@@ -15,15 +18,18 @@ def print_to_string(*args, **kwargs):
|
||||
output.close()
|
||||
return contents
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Parse results from tf benchmark runs')
|
||||
parser.add_argument('filename', type=str, help='Log file to prase or directory containing log files')
|
||||
parser = argparse.ArgumentParser(description="Parse results from tf benchmark runs")
|
||||
parser.add_argument(
|
||||
"filename", type=str, help="Log file to prase or directory containing log files"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
files = []
|
||||
if os.path.isdir(args.filename):
|
||||
all_files = os.listdir(args.filename)
|
||||
for name in all_files:
|
||||
if not 'log' in name:
|
||||
if "log" not in name:
|
||||
continue
|
||||
files.append(os.path.join(args.filename, name))
|
||||
else:
|
||||
@@ -31,62 +37,76 @@ def parse_args():
|
||||
args.files = files
|
||||
return args
|
||||
|
||||
|
||||
def get_log_params(logfile):
|
||||
print("logfile=",logfile)
|
||||
branch_name=' '
|
||||
node_id=' '
|
||||
gpu_arch=' '
|
||||
hip_vers=' '
|
||||
compute_units=0
|
||||
environment=' '
|
||||
rocm_vers=' '
|
||||
print("logfile=", logfile)
|
||||
branch_name = " "
|
||||
node_id = " "
|
||||
gpu_arch = " "
|
||||
hip_vers = " "
|
||||
compute_units = 0
|
||||
environment = " "
|
||||
rocm_vers = " "
|
||||
for line in open(logfile):
|
||||
if 'Branch name' in line:
|
||||
lst=line.split()
|
||||
branch_name=lst[2]
|
||||
if 'On branch' in line:
|
||||
lst=line.split()
|
||||
branch_name=lst[2]
|
||||
if 'Node name' in line:
|
||||
lst=line.split()
|
||||
node_id=lst[2]
|
||||
if 'GPU_arch' in line:
|
||||
lst=line.split()
|
||||
gpu_arch=lst[2]
|
||||
if 'HIP version' in line:
|
||||
lst=line.split()
|
||||
hip_vers=lst[2]
|
||||
if 'Compute Unit' in line:
|
||||
lst=line.split()
|
||||
compute_units=lst[2]
|
||||
if 'Environment type' in line:
|
||||
lst=line.split()
|
||||
environment=lst[2]
|
||||
if 'InstalledDir' in line:
|
||||
lst=line.split()
|
||||
rocm_vers=lst[1][lst[1].find('/opt/rocm-')+len('/opt/rocm-'):lst[1].rfind('/llvm/bin')]
|
||||
return branch_name, node_id, gpu_arch, compute_units, rocm_vers, hip_vers, environment
|
||||
if "Branch name" in line:
|
||||
lst = line.split()
|
||||
branch_name = lst[2]
|
||||
if "On branch" in line:
|
||||
lst = line.split()
|
||||
branch_name = lst[2]
|
||||
if "Node name" in line:
|
||||
lst = line.split()
|
||||
node_id = lst[2]
|
||||
if "GPU_arch" in line:
|
||||
lst = line.split()
|
||||
gpu_arch = lst[2]
|
||||
if "HIP version" in line:
|
||||
lst = line.split()
|
||||
hip_vers = lst[2]
|
||||
if "Compute Unit" in line:
|
||||
lst = line.split()
|
||||
compute_units = lst[2]
|
||||
if "Environment type" in line:
|
||||
lst = line.split()
|
||||
environment = lst[2]
|
||||
if "InstalledDir" in line:
|
||||
lst = line.split()
|
||||
rocm_vers = lst[1][
|
||||
lst[1].find("/opt/rocm-") + len("/opt/rocm-") : lst[1].rfind(
|
||||
"/llvm/bin"
|
||||
)
|
||||
]
|
||||
return (
|
||||
branch_name,
|
||||
node_id,
|
||||
gpu_arch,
|
||||
compute_units,
|
||||
rocm_vers,
|
||||
hip_vers,
|
||||
environment,
|
||||
)
|
||||
|
||||
|
||||
def parse_logfile(logfile):
|
||||
glue=''
|
||||
res=[]
|
||||
tests=[]
|
||||
kernels=[]
|
||||
tflops=[]
|
||||
dtype=[]
|
||||
alayout=[]
|
||||
blayout=[]
|
||||
M=[]
|
||||
N=[]
|
||||
K=[]
|
||||
StrideA=[]
|
||||
StrideB=[]
|
||||
StrideC=[]
|
||||
if 'perf_gemm' in logfile and 'gemm_bilinear' not in logfile:
|
||||
glue = ""
|
||||
res = []
|
||||
tests = []
|
||||
kernels = []
|
||||
tflops = []
|
||||
dtype = []
|
||||
alayout = []
|
||||
blayout = []
|
||||
M = []
|
||||
N = []
|
||||
K = []
|
||||
StrideA = []
|
||||
StrideB = []
|
||||
StrideC = []
|
||||
if "perf_gemm" in logfile and "gemm_bilinear" not in logfile:
|
||||
for line in open(logfile):
|
||||
if 'Best Perf' in line:
|
||||
lst=line.split()
|
||||
if len(lst)>=37: #the line is complete
|
||||
if "Best Perf" in line:
|
||||
lst = line.split()
|
||||
if len(lst) >= 37: # the line is complete
|
||||
tests.append(glue.join(lst[5:30]))
|
||||
kernels.append(glue.join(lst[37:]))
|
||||
tflops.append(lst[33])
|
||||
@@ -99,7 +119,7 @@ def parse_logfile(logfile):
|
||||
StrideA.append(lst[23])
|
||||
StrideB.append(lst[26])
|
||||
StrideC.append(lst[29])
|
||||
elif len(lst)<37 and len(lst)>=33: #the tflops are available
|
||||
elif len(lst) < 37 and len(lst) >= 33: # the tflops are available
|
||||
tests.append(glue.join(lst[5:30]))
|
||||
kernels.append("N/A")
|
||||
tflops.append(lst[33])
|
||||
@@ -112,87 +132,141 @@ def parse_logfile(logfile):
|
||||
StrideA.append(lst[23])
|
||||
StrideB.append(lst[26])
|
||||
StrideC.append(lst[29])
|
||||
print("warning: incomplete line:",lst)
|
||||
elif len(lst)<33: #even the tflops are not available
|
||||
print("warning: incomplete line:", lst)
|
||||
elif len(lst) < 33: # even the tflops are not available
|
||||
print("Error in ckProfiler output!")
|
||||
print("warning: incomplete line=",lst)
|
||||
#sort results
|
||||
#sorted_tests = sorted(tests)
|
||||
res = [x for _,x in sorted(zip(tests,tflops))]
|
||||
#sorted_kernels = [x for _,x in sorted(zip(tests,kernels))]
|
||||
test_list=list(range(1,len(tests)+1))
|
||||
#parse conv_fwd and conv_bwd performance tests:
|
||||
elif 'conv_fwd' in logfile or 'conv_bwd' in logfile:
|
||||
print("warning: incomplete line=", lst)
|
||||
# sort results
|
||||
# sorted_tests = sorted(tests)
|
||||
res = [x for _, x in sorted(zip(tests, tflops))]
|
||||
# sorted_kernels = [x for _,x in sorted(zip(tests,kernels))]
|
||||
# test_list = list(range(1, len(tests) + 1))
|
||||
# parse conv_fwd and conv_bwd performance tests:
|
||||
elif "conv_fwd" in logfile or "conv_bwd" in logfile:
|
||||
for line in open(logfile):
|
||||
if 'tflops:' in line:
|
||||
lst=line.split()
|
||||
if "tflops:" in line:
|
||||
lst = line.split()
|
||||
res.append(lst[1])
|
||||
#parse all other performance tests:
|
||||
elif 'resnet50' in logfile or 'batched_gemm' in logfile or 'grouped_gemm' in logfile or 'gemm_bilinear' in logfile or 'reduction' in logfile:
|
||||
# parse all other performance tests:
|
||||
elif (
|
||||
"resnet50" in logfile
|
||||
or "batched_gemm" in logfile
|
||||
or "grouped_gemm" in logfile
|
||||
or "gemm_bilinear" in logfile
|
||||
or "reduction" in logfile
|
||||
):
|
||||
for line in open(logfile):
|
||||
if 'Best Perf' in line:
|
||||
lst=line.split()
|
||||
if "Best Perf" in line:
|
||||
lst = line.split()
|
||||
res.append(lst[4])
|
||||
elif 'onnx_gemm' in logfile:
|
||||
elif "onnx_gemm" in logfile:
|
||||
for line in open(logfile):
|
||||
if 'Best Perf' in line:
|
||||
lst=line.split()
|
||||
if "Best Perf" in line:
|
||||
lst = line.split()
|
||||
res.append(lst[33])
|
||||
elif 'splitK_gemm' in logfile or 'mixed_gemm' in logfile:
|
||||
elif "splitK_gemm" in logfile or "mixed_gemm" in logfile:
|
||||
for line in open(logfile):
|
||||
if 'Best Perf' in line:
|
||||
lst=line.split()
|
||||
if "Best Perf" in line:
|
||||
lst = line.split()
|
||||
res.append(lst[36])
|
||||
elif 'perf_fmha' in logfile:
|
||||
elif "perf_fmha" in logfile:
|
||||
for line in open(logfile):
|
||||
if 'TFlops' in line:
|
||||
lst=line.split()
|
||||
line_dict=dict(zip(lst[1:],lst))
|
||||
res.append(line_dict['TFlops,'])
|
||||
elif 'perf_tile_gemm_basic' in logfile or 'perf_tile_gemm_mem_pipeline' in logfile:
|
||||
if "TFlops" in line:
|
||||
lst = line.split()
|
||||
line_dict = dict(zip(lst[1:], lst))
|
||||
res.append(line_dict["TFlops,"])
|
||||
elif "perf_tile_gemm_basic" in logfile or "perf_tile_gemm_mem_pipeline" in logfile:
|
||||
for line in open(logfile):
|
||||
if 'TFlops' in line:
|
||||
lst=line.split()
|
||||
line_dict=dict(zip(lst[1:],lst))
|
||||
res.append(line_dict['TFlops,'])
|
||||
if "TFlops" in line:
|
||||
lst = line.split()
|
||||
line_dict = dict(zip(lst[1:], lst))
|
||||
res.append(line_dict["TFlops,"])
|
||||
return res
|
||||
|
||||
|
||||
def get_baseline(table, connection):
|
||||
query = text('''SELECT * from '''+table+''' WHERE Datetime = (SELECT MAX(Datetime) FROM '''+table+''' where Branch_ID='develop' );''')
|
||||
query = text(
|
||||
"""SELECT * from """
|
||||
+ table
|
||||
+ """ WHERE Datetime = (SELECT MAX(Datetime) FROM """
|
||||
+ table
|
||||
+ """ where Branch_ID='develop' );"""
|
||||
)
|
||||
return pd.read_sql(query, connection)
|
||||
|
||||
def store_new_test_result(table_name, test_results, testlist, branch_name, node_id, gpu_arch, compute_units, rocm_vers, hip_vers, environment, connection):
|
||||
params=[str(branch_name),str(node_id),str(gpu_arch),compute_units,str(rocm_vers),str(hip_vers),str(environment),str(datetime.datetime.now())]
|
||||
df=pd.DataFrame(data=[params],columns=['Branch_ID','Node_ID','GPU_arch','Compute Units','ROCM_version','HIP_version','Environment','Datetime'])
|
||||
df_add=pd.DataFrame(data=[test_results],columns=testlist)
|
||||
df=pd.concat([df,df_add],axis=1)
|
||||
#print("new test results dataframe:",df)
|
||||
df.to_sql(table_name,connection,if_exists='append',index=False)
|
||||
|
||||
def store_new_test_result(
|
||||
table_name,
|
||||
test_results,
|
||||
testlist,
|
||||
branch_name,
|
||||
node_id,
|
||||
gpu_arch,
|
||||
compute_units,
|
||||
rocm_vers,
|
||||
hip_vers,
|
||||
environment,
|
||||
connection,
|
||||
):
|
||||
params = [
|
||||
str(branch_name),
|
||||
str(node_id),
|
||||
str(gpu_arch),
|
||||
compute_units,
|
||||
str(rocm_vers),
|
||||
str(hip_vers),
|
||||
str(environment),
|
||||
str(datetime.datetime.now()),
|
||||
]
|
||||
df = pd.DataFrame(
|
||||
data=[params],
|
||||
columns=[
|
||||
"Branch_ID",
|
||||
"Node_ID",
|
||||
"GPU_arch",
|
||||
"Compute Units",
|
||||
"ROCM_version",
|
||||
"HIP_version",
|
||||
"Environment",
|
||||
"Datetime",
|
||||
],
|
||||
)
|
||||
df_add = pd.DataFrame(data=[test_results], columns=testlist)
|
||||
df = pd.concat([df, df_add], axis=1)
|
||||
# print("new test results dataframe:",df)
|
||||
df.to_sql(table_name, connection, if_exists="append", index=False)
|
||||
return 0
|
||||
|
||||
def compare_test_to_baseline(baseline,test,testlist):
|
||||
regression=0
|
||||
|
||||
def compare_test_to_baseline(baseline, test, testlist):
|
||||
regression = 0
|
||||
if not baseline.empty:
|
||||
base=baseline[testlist].to_numpy(dtype='float')
|
||||
base_list=base[0]
|
||||
ave_perf=0
|
||||
base = baseline[testlist].to_numpy(dtype="float")
|
||||
base_list = base[0]
|
||||
ave_perf = 0
|
||||
for i in range(len(base_list)):
|
||||
# success criterion:
|
||||
if base_list[i]>1.01*float(test[i]):
|
||||
print("test # ",i,"shows regression by {:.3f}%".format(
|
||||
(float(test[i])-base_list[i])/base_list[i]*100))
|
||||
regression=1
|
||||
if base_list[i]>0: ave_perf=ave_perf+float(test[i])/base_list[i]
|
||||
if regression==0:
|
||||
if base_list[i] > 1.01 * float(test[i]):
|
||||
print(
|
||||
"test # ",
|
||||
i,
|
||||
"shows regression by {:.3f}%".format(
|
||||
(float(test[i]) - base_list[i]) / base_list[i] * 100
|
||||
),
|
||||
)
|
||||
regression = 1
|
||||
if base_list[i] > 0:
|
||||
ave_perf = ave_perf + float(test[i]) / base_list[i]
|
||||
if regression == 0:
|
||||
print("no regressions found")
|
||||
ave_perf=ave_perf/len(base_list)
|
||||
print("average performance relative to baseline:",ave_perf)
|
||||
ave_perf = ave_perf / len(base_list)
|
||||
print("average performance relative to baseline:", ave_perf)
|
||||
else:
|
||||
print("could not find a baseline")
|
||||
return regression
|
||||
|
||||
'''
|
||||
|
||||
"""
|
||||
def post_test_params(tlist,connection):
|
||||
sorted_dtypes = [x for _,x in sorted(zip(tests,dtype))]
|
||||
sorted_alayout = [x for _,x in sorted(zip(tests,alayout))]
|
||||
@@ -223,29 +297,38 @@ def post_test_params(tlist,connection):
|
||||
'StrideC': Integer()
|
||||
}
|
||||
df.to_sql("ck_gemm_test_params",connection,if_exists='replace',index=False, dtype=dtypes)
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
results=[]
|
||||
tflops_base=[]
|
||||
testlist=[]
|
||||
#parse the test parameters from the logfile
|
||||
results = []
|
||||
tflops_base = []
|
||||
testlist = []
|
||||
# parse the test parameters from the logfile
|
||||
for filename in args.files:
|
||||
branch_name, node_id, gpu_arch, compute_units, rocm_vers, hip_vers, environment = get_log_params(filename)
|
||||
(
|
||||
branch_name,
|
||||
node_id,
|
||||
gpu_arch,
|
||||
compute_units,
|
||||
rocm_vers,
|
||||
hip_vers,
|
||||
environment,
|
||||
) = get_log_params(filename)
|
||||
|
||||
print("Branch name:",branch_name)
|
||||
print("Node name:",node_id)
|
||||
print("GPU_arch:",gpu_arch)
|
||||
print("Compute units:",compute_units)
|
||||
print("ROCM_version:",rocm_vers)
|
||||
print("HIP_version:",hip_vers)
|
||||
print("Environment:",environment)
|
||||
#parse results, get the Tflops value for "Best Perf" kernels
|
||||
results=parse_logfile(filename)
|
||||
print("Branch name:", branch_name)
|
||||
print("Node name:", node_id)
|
||||
print("GPU_arch:", gpu_arch)
|
||||
print("Compute units:", compute_units)
|
||||
print("ROCM_version:", rocm_vers)
|
||||
print("HIP_version:", hip_vers)
|
||||
print("Environment:", environment)
|
||||
# parse results, get the Tflops value for "Best Perf" kernels
|
||||
results = parse_logfile(filename)
|
||||
|
||||
print("Number of tests:",len(results))
|
||||
sql_hostname = '127.0.0.1'
|
||||
print("Number of tests:", len(results))
|
||||
sql_hostname = "127.0.0.1"
|
||||
sql_username = os.environ["dbuser"]
|
||||
sql_password = os.environ["dbpassword"]
|
||||
sql_main_database = os.environ["ck_perf_db"]
|
||||
@@ -256,127 +339,147 @@ def main():
|
||||
ssh_pass = os.environ["dbsshpassword"]
|
||||
|
||||
with SSHTunnelForwarder(
|
||||
(ssh_host, ssh_port),
|
||||
ssh_username=ssh_user,
|
||||
ssh_password=ssh_pass,
|
||||
remote_bind_address=(sql_hostname, sql_port)) as tunnel:
|
||||
|
||||
sqlEngine = sqlalchemy.create_engine('mysql+pymysql://{0}:{1}@{2}:{3}/{4}'.
|
||||
format(sql_username, sql_password, sql_hostname, tunnel.local_bind_port, sql_main_database))
|
||||
(ssh_host, ssh_port),
|
||||
ssh_username=ssh_user,
|
||||
ssh_password=ssh_pass,
|
||||
remote_bind_address=(sql_hostname, sql_port),
|
||||
) as tunnel:
|
||||
sqlEngine = sqlalchemy.create_engine(
|
||||
"mysql+pymysql://{0}:{1}@{2}:{3}/{4}".format(
|
||||
sql_username,
|
||||
sql_password,
|
||||
sql_hostname,
|
||||
tunnel.local_bind_port,
|
||||
sql_main_database,
|
||||
)
|
||||
)
|
||||
conn = sqlEngine.connect()
|
||||
|
||||
#save gemm performance tests:
|
||||
if 'perf_gemm' in filename and 'gemm_bilinear' not in filename:
|
||||
#write the ck_gemm_test_params table only needed once the test set changes
|
||||
#post_test_params(test_list,conn)
|
||||
for i in range(1,len(results)+1):
|
||||
testlist.append("Test%i"%i)
|
||||
table_name="ck_gemm_tflops"
|
||||
if 'batched_gemm' in filename:
|
||||
for i in range(1,len(results)+1):
|
||||
testlist.append("Test%i"%i)
|
||||
table_name="ck_batched_gemm_tflops"
|
||||
if 'grouped_gemm' in filename:
|
||||
for i in range(1,len(results)+1):
|
||||
testlist.append("Test%i"%i)
|
||||
table_name="ck_grouped_gemm_tflops"
|
||||
if 'perf_conv_fwd' in filename:
|
||||
for i in range(1,len(results)+1):
|
||||
testlist.append("Test%i"%i)
|
||||
table_name="ck_conv_fwd_tflops"
|
||||
if 'perf_conv_bwd_data' in filename:
|
||||
for i in range(1,len(results)+1):
|
||||
testlist.append("Test%i"%i)
|
||||
table_name="ck_conv_bwd_data_tflops"
|
||||
if 'grouped_conv_fwd' in filename:
|
||||
for i in range(1,len(results)+1):
|
||||
testlist.append("Test%i"%i)
|
||||
table_name="ck_grouped_conv_fwd_tflops"
|
||||
if 'grouped_conv_bwd_data' in filename:
|
||||
for i in range(1,len(results)+1):
|
||||
testlist.append("Test%i"%i)
|
||||
table_name="ck_grouped_conv_bwd_data_tflops"
|
||||
if 'grouped_conv_bwd_weight' in filename:
|
||||
for i in range(1,len(results)+1):
|
||||
testlist.append("Test%i"%i)
|
||||
table_name="ck_grouped_conv_bwd_weight_tflops"
|
||||
if 'gemm_bilinear' in filename:
|
||||
for i in range(1,len(results)+1):
|
||||
testlist.append("Test%i"%i)
|
||||
table_name="ck_gemm_bilinear_tflops"
|
||||
if 'reduction' in filename:
|
||||
for i in range(1,len(results)+1):
|
||||
testlist.append("Test%i"%i)
|
||||
table_name="ck_reduction_GBps"
|
||||
if 'resnet50_N4' in filename:
|
||||
for i in range(1,50):
|
||||
testlist.append("Layer%i"%i)
|
||||
table_name="ck_resnet50_N4_tflops"
|
||||
if 'resnet50_N256' in filename:
|
||||
for i in range(1,50):
|
||||
testlist.append("Layer%i"%i)
|
||||
table_name="ck_resnet50_N256_tflops"
|
||||
if 'onnx_gemm' in filename:
|
||||
for i in range(1,len(results)+1):
|
||||
testlist.append("Test%i"%i)
|
||||
table_name="ck_onnx_gemm_tflops"
|
||||
if 'splitK_gemm' in filename:
|
||||
for i in range(1,len(results)+1):
|
||||
testlist.append("Test%i"%i)
|
||||
table_name="ck_splitK_gemm_tflops"
|
||||
if 'mixed_gemm' in filename:
|
||||
for i in range(1,len(results)+1):
|
||||
testlist.append("Test%i"%i)
|
||||
table_name="ck_mixed_gemm_tflops"
|
||||
if 'fmha_fwd' in filename:
|
||||
for i in range(1,len(results)+1):
|
||||
testlist.append("Test%i"%i)
|
||||
table_name="ck_fmha_fwd_tflops"
|
||||
if 'fmha_bwd' in filename:
|
||||
for i in range(1,len(results)+1):
|
||||
testlist.append("Test%i"%i)
|
||||
table_name="ck_fmha_bwd_tflops"
|
||||
if 'gemm_basic_fp16' in filename:
|
||||
for i in range(1, len(results)+1):
|
||||
testlist.append("Test%i"%i)
|
||||
table_name="ck_tile_gemm_basic_fp16_tflops"
|
||||
if 'gemm_mem_pipeline_fp16' in filename:
|
||||
for i in range(1, len(results)+1):
|
||||
testlist.append("Test%i"%i)
|
||||
table_name="ck_tile_gemm_mem_pipeline_fp16_tflops"
|
||||
if 'gemm_basic_bf16' in filename:
|
||||
for i in range(1, len(results)+1):
|
||||
testlist.append("Test%i"%i)
|
||||
table_name="ck_tile_gemm_basic_bf16_tflops"
|
||||
if 'gemm_mem_pipeline_bf16' in filename:
|
||||
for i in range(1, len(results)+1):
|
||||
testlist.append("Test%i"%i)
|
||||
table_name="ck_tile_gemm_mem_pipeline_bf16_tflops"
|
||||
if 'gemm_basic_fp8' in filename:
|
||||
for i in range(1, len(results)+1):
|
||||
testlist.append("Test%i"%i)
|
||||
table_name="ck_tile_gemm_basic_fp8_tflops"
|
||||
if 'gemm_mem_pipeline_fp8' in filename:
|
||||
for i in range(1, len(results)+1):
|
||||
testlist.append("Test%i"%i)
|
||||
table_name="ck_tile_gemm_mem_pipeline_fp8_tflops"
|
||||
if 'gemm_basic_bf8' in filename:
|
||||
for i in range(1, len(results)+1):
|
||||
testlist.append("Test%i"%i)
|
||||
table_name="ck_tile_gemm_basic_bf8_tflops"
|
||||
if 'gemm_mem_pipeline_bf8' in filename:
|
||||
for i in range(1, len(results)+1):
|
||||
testlist.append("Test%i"%i)
|
||||
table_name="ck_tile_gemm_mem_pipeline_bf8_tflops"
|
||||
# save gemm performance tests:
|
||||
if "perf_gemm" in filename and "gemm_bilinear" not in filename:
|
||||
# write the ck_gemm_test_params table only needed once the test set changes
|
||||
# post_test_params(test_list,conn)
|
||||
for i in range(1, len(results) + 1):
|
||||
testlist.append("Test%i" % i)
|
||||
table_name = "ck_gemm_tflops"
|
||||
if "batched_gemm" in filename:
|
||||
for i in range(1, len(results) + 1):
|
||||
testlist.append("Test%i" % i)
|
||||
table_name = "ck_batched_gemm_tflops"
|
||||
if "grouped_gemm" in filename:
|
||||
for i in range(1, len(results) + 1):
|
||||
testlist.append("Test%i" % i)
|
||||
table_name = "ck_grouped_gemm_tflops"
|
||||
if "perf_conv_fwd" in filename:
|
||||
for i in range(1, len(results) + 1):
|
||||
testlist.append("Test%i" % i)
|
||||
table_name = "ck_conv_fwd_tflops"
|
||||
if "perf_conv_bwd_data" in filename:
|
||||
for i in range(1, len(results) + 1):
|
||||
testlist.append("Test%i" % i)
|
||||
table_name = "ck_conv_bwd_data_tflops"
|
||||
if "grouped_conv_fwd" in filename:
|
||||
for i in range(1, len(results) + 1):
|
||||
testlist.append("Test%i" % i)
|
||||
table_name = "ck_grouped_conv_fwd_tflops"
|
||||
if "grouped_conv_bwd_data" in filename:
|
||||
for i in range(1, len(results) + 1):
|
||||
testlist.append("Test%i" % i)
|
||||
table_name = "ck_grouped_conv_bwd_data_tflops"
|
||||
if "grouped_conv_bwd_weight" in filename:
|
||||
for i in range(1, len(results) + 1):
|
||||
testlist.append("Test%i" % i)
|
||||
table_name = "ck_grouped_conv_bwd_weight_tflops"
|
||||
if "gemm_bilinear" in filename:
|
||||
for i in range(1, len(results) + 1):
|
||||
testlist.append("Test%i" % i)
|
||||
table_name = "ck_gemm_bilinear_tflops"
|
||||
if "reduction" in filename:
|
||||
for i in range(1, len(results) + 1):
|
||||
testlist.append("Test%i" % i)
|
||||
table_name = "ck_reduction_GBps"
|
||||
if "resnet50_N4" in filename:
|
||||
for i in range(1, 50):
|
||||
testlist.append("Layer%i" % i)
|
||||
table_name = "ck_resnet50_N4_tflops"
|
||||
if "resnet50_N256" in filename:
|
||||
for i in range(1, 50):
|
||||
testlist.append("Layer%i" % i)
|
||||
table_name = "ck_resnet50_N256_tflops"
|
||||
if "onnx_gemm" in filename:
|
||||
for i in range(1, len(results) + 1):
|
||||
testlist.append("Test%i" % i)
|
||||
table_name = "ck_onnx_gemm_tflops"
|
||||
if "splitK_gemm" in filename:
|
||||
for i in range(1, len(results) + 1):
|
||||
testlist.append("Test%i" % i)
|
||||
table_name = "ck_splitK_gemm_tflops"
|
||||
if "mixed_gemm" in filename:
|
||||
for i in range(1, len(results) + 1):
|
||||
testlist.append("Test%i" % i)
|
||||
table_name = "ck_mixed_gemm_tflops"
|
||||
if "fmha_fwd" in filename:
|
||||
for i in range(1, len(results) + 1):
|
||||
testlist.append("Test%i" % i)
|
||||
table_name = "ck_fmha_fwd_tflops"
|
||||
if "fmha_bwd" in filename:
|
||||
for i in range(1, len(results) + 1):
|
||||
testlist.append("Test%i" % i)
|
||||
table_name = "ck_fmha_bwd_tflops"
|
||||
if "gemm_basic_fp16" in filename:
|
||||
for i in range(1, len(results) + 1):
|
||||
testlist.append("Test%i" % i)
|
||||
table_name = "ck_tile_gemm_basic_fp16_tflops"
|
||||
if "gemm_mem_pipeline_fp16" in filename:
|
||||
for i in range(1, len(results) + 1):
|
||||
testlist.append("Test%i" % i)
|
||||
table_name = "ck_tile_gemm_mem_pipeline_fp16_tflops"
|
||||
if "gemm_basic_bf16" in filename:
|
||||
for i in range(1, len(results) + 1):
|
||||
testlist.append("Test%i" % i)
|
||||
table_name = "ck_tile_gemm_basic_bf16_tflops"
|
||||
if "gemm_mem_pipeline_bf16" in filename:
|
||||
for i in range(1, len(results) + 1):
|
||||
testlist.append("Test%i" % i)
|
||||
table_name = "ck_tile_gemm_mem_pipeline_bf16_tflops"
|
||||
if "gemm_basic_fp8" in filename:
|
||||
for i in range(1, len(results) + 1):
|
||||
testlist.append("Test%i" % i)
|
||||
table_name = "ck_tile_gemm_basic_fp8_tflops"
|
||||
if "gemm_mem_pipeline_fp8" in filename:
|
||||
for i in range(1, len(results) + 1):
|
||||
testlist.append("Test%i" % i)
|
||||
table_name = "ck_tile_gemm_mem_pipeline_fp8_tflops"
|
||||
if "gemm_basic_bf8" in filename:
|
||||
for i in range(1, len(results) + 1):
|
||||
testlist.append("Test%i" % i)
|
||||
table_name = "ck_tile_gemm_basic_bf8_tflops"
|
||||
if "gemm_mem_pipeline_bf8" in filename:
|
||||
for i in range(1, len(results) + 1):
|
||||
testlist.append("Test%i" % i)
|
||||
table_name = "ck_tile_gemm_mem_pipeline_bf8_tflops"
|
||||
|
||||
tflops_base = get_baseline(table_name,conn)
|
||||
store_new_test_result(table_name, results, testlist, branch_name, node_id, gpu_arch, compute_units, rocm_vers, hip_vers, environment, sqlEngine)
|
||||
tflops_base = get_baseline(table_name, conn)
|
||||
store_new_test_result(
|
||||
table_name,
|
||||
results,
|
||||
testlist,
|
||||
branch_name,
|
||||
node_id,
|
||||
gpu_arch,
|
||||
compute_units,
|
||||
rocm_vers,
|
||||
hip_vers,
|
||||
environment,
|
||||
sqlEngine,
|
||||
)
|
||||
conn.close()
|
||||
|
||||
#compare the results to the baseline if baseline exists
|
||||
regression=0
|
||||
regression=compare_test_to_baseline(tflops_base,results,testlist)
|
||||
# compare the results to the baseline if baseline exists
|
||||
regression = 0
|
||||
regression = compare_test_to_baseline(tflops_base, results, testlist)
|
||||
return regression
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -2,18 +2,6 @@
|
||||
# Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# Get list of staged files
|
||||
STAGED_FILES=$(git diff --cached --name-only)
|
||||
|
||||
# Check if any staged file is under include/ck_tile/ or example/ck_tile/
|
||||
if echo "$STAGED_FILES" | grep -qE '^(include/ck_tile/|example/ck_tile/)'; then
|
||||
echo "Detected changes in ck_tile-related files. Running remod.py..."
|
||||
|
||||
# Run remod.py in both required locations
|
||||
(cd include/ck_tile/ && python3 remod.py)
|
||||
(cd example/ck_tile/ && python3 remod.py)
|
||||
|
||||
echo "remod.py completed."
|
||||
else
|
||||
echo "No changes in ck_tile-related files. Skipping remod.py."
|
||||
fi
|
||||
# Run remod.py in both required locations
|
||||
(cd include/ck_tile/ && python3 remod.py)
|
||||
(cd example/ck_tile/ && python3 remod.py)
|
||||
|
||||
@@ -71,7 +71,7 @@ def tuples(filename):
|
||||
try:
|
||||
m, n, k = map(int, line)
|
||||
lines.append((m, n, k))
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
return lines
|
||||
|
||||
@@ -163,19 +163,19 @@ def run_shape(shape, profiler_bin, op_name, dtype, layout):
|
||||
m, n, k = shape
|
||||
try:
|
||||
op = OPs[op_name]
|
||||
except:
|
||||
except KeyError:
|
||||
raise AssertionError(f"Invalid operator {op_name}")
|
||||
name_arg = op.name
|
||||
op_wrapper = op.value()
|
||||
|
||||
try:
|
||||
dtype_arg = str(op_wrapper.dtype[dtype].value)
|
||||
except:
|
||||
except KeyError:
|
||||
raise AssertionError(f"Invalid dtype for {op_name}: {dtype}")
|
||||
|
||||
try:
|
||||
layout_wrapper = op_wrapper.layout[layout]
|
||||
except:
|
||||
except KeyError:
|
||||
raise AssertionError(f"Invalid layout for {op_name}: {layout}")
|
||||
layout_arg = str(layout_wrapper.value)
|
||||
# verification: no, initialization: decimal, print tensor: no, time kernel: yes
|
||||
@@ -286,7 +286,9 @@ def main():
|
||||
try:
|
||||
from tqdm import tqdm as iterate
|
||||
except ImportError:
|
||||
iterate = lambda x: x
|
||||
|
||||
def iterate(x):
|
||||
return x
|
||||
|
||||
for s in iterate(shapes):
|
||||
run_shape_stdout_lines = run_shape(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -9,7 +9,7 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/pool.hpp"
|
||||
#include "ck_tile/ops/pooling.hpp"
|
||||
#include "ck_tile/host/reference/reference_pool.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -10,28 +10,37 @@ and saves them as CSV files that can be read by the shell script.
|
||||
"""
|
||||
|
||||
import csv
|
||||
import itertools
|
||||
import argparse
|
||||
|
||||
def generate_2d_configs(mode='full'):
|
||||
|
||||
def generate_2d_configs(mode="full"):
|
||||
"""Generate all 2D model configuration combinations
|
||||
|
||||
|
||||
Args:
|
||||
mode: 'small' for minimal set (~50 configs), 'half' for reduced set (~250 configs), 'full' for comprehensive set (~500 configs)
|
||||
"""
|
||||
|
||||
|
||||
# Define parameter ranges
|
||||
models_2d = [
|
||||
'resnet18', 'resnet34', 'resnet50',
|
||||
'mobilenet_v2', 'mobilenet_v3_large', 'mobilenet_v3_small',
|
||||
'vgg11', 'vgg16', 'vgg19',
|
||||
'alexnet', 'googlenet',
|
||||
'densenet121', 'densenet161',
|
||||
'squeezenet1_0', 'squeezenet1_1',
|
||||
'shufflenet_v2_x1_0'
|
||||
"resnet18",
|
||||
"resnet34",
|
||||
"resnet50",
|
||||
"mobilenet_v2",
|
||||
"mobilenet_v3_large",
|
||||
"mobilenet_v3_small",
|
||||
"vgg11",
|
||||
"vgg16",
|
||||
"vgg19",
|
||||
"alexnet",
|
||||
"googlenet",
|
||||
"densenet121",
|
||||
"densenet161",
|
||||
"squeezenet1_0",
|
||||
"squeezenet1_1",
|
||||
"shufflenet_v2_x1_0",
|
||||
]
|
||||
|
||||
if mode == 'small':
|
||||
|
||||
if mode == "small":
|
||||
# Minimal set for quick testing
|
||||
batch_sizes = [1, 8] # Just two batch sizes
|
||||
# Very limited input dimensions - only 2 key sizes
|
||||
@@ -41,12 +50,12 @@ def generate_2d_configs(mode='full'):
|
||||
]
|
||||
# Use only first 3 models for minimal testing
|
||||
models_2d = models_2d[:3] # Only resnet18, resnet34, resnet50
|
||||
elif mode == 'half':
|
||||
elif mode == "half":
|
||||
# Reduced set for faster testing
|
||||
batch_sizes = [1, 8, 32] # Small, medium, large
|
||||
# Reduced input dimensions - 5 key sizes
|
||||
input_dims = [
|
||||
(64, 64), # Small
|
||||
(64, 64), # Small
|
||||
(224, 224), # Standard (most common)
|
||||
(512, 512), # Large
|
||||
(224, 320), # Rectangular
|
||||
@@ -57,18 +66,23 @@ def generate_2d_configs(mode='full'):
|
||||
batch_sizes = [1, 4, 8, 16, 32]
|
||||
# More dimensions but skip some redundant ones
|
||||
input_dims = [
|
||||
(64, 64), (128, 128), (224, 224), (256, 256), (512, 512), # Square
|
||||
(224, 320), (320, 224), # Rectangular (reduced from 4)
|
||||
(64, 64),
|
||||
(128, 128),
|
||||
(224, 224),
|
||||
(256, 256),
|
||||
(512, 512), # Square
|
||||
(224, 320),
|
||||
(320, 224), # Rectangular (reduced from 4)
|
||||
(227, 227), # AlexNet preferred
|
||||
(299, 299) # Inception preferred
|
||||
(299, 299), # Inception preferred
|
||||
]
|
||||
|
||||
precisions = ['fp32'] #, 'fp16', 'bf16']
|
||||
|
||||
precisions = ["fp32"] # , 'fp16', 'bf16']
|
||||
channels = [3] # Most models expect RGB
|
||||
|
||||
|
||||
configs = []
|
||||
config_id = 1
|
||||
|
||||
|
||||
# Generate all combinations (but limit to reasonable subset)
|
||||
for model in models_2d:
|
||||
for batch_size in batch_sizes:
|
||||
@@ -77,36 +91,37 @@ def generate_2d_configs(mode='full'):
|
||||
# Skip some combinations to keep dataset manageable
|
||||
if batch_size > 16 and height > 256:
|
||||
continue # Skip large batch + large image combinations
|
||||
if precision != 'fp32' and batch_size < 8:
|
||||
if precision != "fp32" and batch_size < 8:
|
||||
continue # Skip mixed precision with tiny batches
|
||||
|
||||
|
||||
config_name = f"{model}_b{batch_size}_{height}x{width}_{precision}"
|
||||
|
||||
|
||||
config = {
|
||||
'config_name': config_name,
|
||||
'model': model,
|
||||
'batch_size': batch_size,
|
||||
'channels': channels[0],
|
||||
'height': height,
|
||||
'width': width,
|
||||
'precision': precision
|
||||
"config_name": config_name,
|
||||
"model": model,
|
||||
"batch_size": batch_size,
|
||||
"channels": channels[0],
|
||||
"height": height,
|
||||
"width": width,
|
||||
"precision": precision,
|
||||
}
|
||||
|
||||
|
||||
configs.append(config)
|
||||
config_id += 1
|
||||
|
||||
|
||||
return configs
|
||||
|
||||
def generate_3d_configs(mode='full'):
|
||||
|
||||
def generate_3d_configs(mode="full"):
|
||||
"""Generate all 3D model configuration combinations
|
||||
|
||||
|
||||
Args:
|
||||
mode: 'small' for minimal set (~10 configs), 'half' for reduced set (~50 configs), 'full' for comprehensive set (~100 configs)
|
||||
"""
|
||||
|
||||
models_3d = ['r3d_18', 'mc3_18', 'r2plus1d_18']
|
||||
|
||||
if mode == 'small':
|
||||
|
||||
models_3d = ["r3d_18", "mc3_18", "r2plus1d_18"]
|
||||
|
||||
if mode == "small":
|
||||
# Minimal set for quick testing
|
||||
batch_sizes = [1, 4] # Just two batch sizes
|
||||
temporal_sizes = [8] # Only smallest temporal size
|
||||
@@ -116,7 +131,7 @@ def generate_3d_configs(mode='full'):
|
||||
]
|
||||
# Use only first model for minimal testing
|
||||
models_3d = models_3d[:1] # Only r3d_18
|
||||
elif mode == 'half':
|
||||
elif mode == "half":
|
||||
# Reduced set for faster testing
|
||||
batch_sizes = [1, 4, 8] # Skip batch_size=2
|
||||
temporal_sizes = [8, 16] # Skip 32 (most expensive)
|
||||
@@ -124,7 +139,7 @@ def generate_3d_configs(mode='full'):
|
||||
input_dims = [
|
||||
(112, 112), # Small (common for video)
|
||||
(224, 224), # Standard
|
||||
(224, 320) # Rectangular
|
||||
(224, 320), # Rectangular
|
||||
]
|
||||
else: # full mode
|
||||
# More comprehensive but still reasonable
|
||||
@@ -132,15 +147,18 @@ def generate_3d_configs(mode='full'):
|
||||
temporal_sizes = [8, 16, 32]
|
||||
# More dimensions
|
||||
input_dims = [
|
||||
(112, 112), (224, 224), (256, 256), # Standard sizes
|
||||
(224, 320), (320, 224) # Rectangular
|
||||
(112, 112),
|
||||
(224, 224),
|
||||
(256, 256), # Standard sizes
|
||||
(224, 320),
|
||||
(320, 224), # Rectangular
|
||||
]
|
||||
|
||||
precisions = ['fp32'] #, 'fp16'] # Skip bf16 for 3D to reduce combinations
|
||||
|
||||
precisions = ["fp32"] # , 'fp16'] # Skip bf16 for 3D to reduce combinations
|
||||
channels = [3]
|
||||
|
||||
|
||||
configs = []
|
||||
|
||||
|
||||
for model in models_3d:
|
||||
for batch_size in batch_sizes:
|
||||
for temporal_size in temporal_sizes:
|
||||
@@ -151,75 +169,97 @@ def generate_3d_configs(mode='full'):
|
||||
continue
|
||||
if batch_size > 2 and height > 224:
|
||||
continue
|
||||
|
||||
|
||||
config_name = f"{model}_b{batch_size}_t{temporal_size}_{height}x{width}_{precision}"
|
||||
|
||||
|
||||
config = {
|
||||
'config_name': config_name,
|
||||
'model': model,
|
||||
'batch_size': batch_size,
|
||||
'channels': channels[0],
|
||||
'temporal_size': temporal_size,
|
||||
'height': height,
|
||||
'width': width,
|
||||
'precision': precision
|
||||
}
|
||||
|
||||
"config_name": config_name,
|
||||
"model": model,
|
||||
"batch_size": batch_size,
|
||||
"channels": channels[0],
|
||||
"temporal_size": temporal_size,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"precision": precision,
|
||||
}
|
||||
|
||||
configs.append(config)
|
||||
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
def save_configs_to_csv(configs, filename, config_type):
|
||||
"""Save configurations to CSV file"""
|
||||
|
||||
|
||||
if not configs:
|
||||
print(f"No {config_type} configurations generated")
|
||||
return
|
||||
|
||||
|
||||
fieldnames = list(configs[0].keys())
|
||||
|
||||
with open(filename, 'w', newline='\n', encoding='utf-8') as csvfile:
|
||||
|
||||
with open(filename, "w", newline="\n", encoding="utf-8") as csvfile:
|
||||
csvfile.write(f"# {config_type} Model Configurations\n")
|
||||
csvfile.write(f"# Generated {len(configs)} configurations\n")
|
||||
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames, lineterminator='\n')
|
||||
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames, lineterminator="\n")
|
||||
writer.writeheader()
|
||||
|
||||
|
||||
for config in configs:
|
||||
writer.writerow(config)
|
||||
|
||||
|
||||
print(f"Generated {len(configs)} {config_type} configurations → {filename}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Generate model configuration combinations')
|
||||
parser.add_argument('--output-2d', type=str, default='model_configs_2d.csv',
|
||||
help='Output file for 2D configurations')
|
||||
parser.add_argument('--output-3d', type=str, default='model_configs_3d.csv',
|
||||
help='Output file for 3D configurations')
|
||||
parser.add_argument('--mode', choices=['small', 'half', 'full'], default='full',
|
||||
help='Configuration mode: small (~60 total), half (~300 total) or full (~600 total) (default: half)')
|
||||
parser.add_argument('--limit', type=int,
|
||||
help='Limit number of configurations per type (for testing)')
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate model configuration combinations"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-2d",
|
||||
type=str,
|
||||
default="model_configs_2d.csv",
|
||||
help="Output file for 2D configurations",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-3d",
|
||||
type=str,
|
||||
default="model_configs_3d.csv",
|
||||
help="Output file for 3D configurations",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
choices=["small", "half", "full"],
|
||||
default="full",
|
||||
help="Configuration mode: small (~60 total), half (~300 total) or full (~600 total) (default: half)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--limit",
|
||||
type=int,
|
||||
help="Limit number of configurations per type (for testing)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
print(f"Generating {args.mode} model configurations...")
|
||||
|
||||
|
||||
print("Generating 2D model configurations...")
|
||||
configs_2d = generate_2d_configs(mode=args.mode)
|
||||
if args.limit:
|
||||
configs_2d = configs_2d[:args.limit]
|
||||
configs_2d = configs_2d[: args.limit]
|
||||
save_configs_to_csv(configs_2d, args.output_2d, "2D")
|
||||
|
||||
|
||||
print("Generating 3D model configurations...")
|
||||
configs_3d = generate_3d_configs(mode=args.mode)
|
||||
if args.limit:
|
||||
configs_3d = configs_3d[:args.limit]
|
||||
configs_3d = configs_3d[: args.limit]
|
||||
save_configs_to_csv(configs_3d, args.output_3d, "3D")
|
||||
|
||||
print(f"\nTotal configurations: {len(configs_2d)} 2D + {len(configs_3d)} 3D = {len(configs_2d) + len(configs_3d)}")
|
||||
|
||||
print(
|
||||
f"\nTotal configurations: {len(configs_2d)} 2D + {len(configs_3d)} 3D = {len(configs_2d) + len(configs_3d)}"
|
||||
)
|
||||
print("\nTo use these configurations:")
|
||||
print(" Update generate_test_dataset.sh to read from these CSV files")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -18,301 +18,428 @@ import csv
|
||||
import re
|
||||
import os
|
||||
|
||||
|
||||
def parse_miopen_command(command_line):
|
||||
"""
|
||||
Parse MIOpen driver command line into parameter dictionary
|
||||
|
||||
|
||||
Example input:
|
||||
./bin/MIOpenDriver conv -n 4 -c 3 -H 224 -W 224 -k 64 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -m conv -g 1 -F 1 -t 1
|
||||
|
||||
|
||||
Returns dict with parsed parameters or None if parsing fails
|
||||
"""
|
||||
if not command_line.strip().startswith('./bin/MIOpenDriver conv'):
|
||||
if not command_line.strip().startswith("./bin/MIOpenDriver conv"):
|
||||
return None
|
||||
|
||||
|
||||
# Extract parameters using regex
|
||||
params = {}
|
||||
|
||||
|
||||
# Parameter mapping: flag -> description
|
||||
# Support both short (-D) and long (--in_d) parameter formats
|
||||
param_patterns = {
|
||||
'n': r'-n\s+(\d+)', # batch size
|
||||
'c': r'-c\s+(\d+)', # input channels
|
||||
'k': r'-k\s+(\d+)', # output channels
|
||||
'H': r'-H\s+(\d+)', # input height
|
||||
'W': r'-W\s+(\d+)', # input width
|
||||
'D': r'(?:-D|--in_d)\s+(\d+)', # input depth (3D only) - supports both -D and --in_d
|
||||
'y': r'-y\s+(\d+)', # kernel height
|
||||
'x': r'-x\s+(\d+)', # kernel width
|
||||
'z': r'(?:-z|--fil_d)\s+(\d+)', # kernel depth (3D only) - supports both -z and --fil_d
|
||||
'u': r'-u\s+(\d+)', # stride height
|
||||
'v': r'-v\s+(\d+)', # stride width
|
||||
'w': r'(?:-w|--conv_stride_d)\s+(\d+)', # stride depth (3D only) - supports both -w and --conv_stride_d
|
||||
'p': r'-p\s+(\d+)', # pad height
|
||||
'q': r'-q\s+(\d+)', # pad width
|
||||
's': r'(?:-s|--pad_d)\s+(\d+)', # pad depth (3D only) - supports both -s and --pad_d
|
||||
'l': r'-l\s+(\d+)', # dilation height
|
||||
'j': r'-j\s+(\d+)', # dilation width
|
||||
'r': r'(?:-r|--dilation_d)\s+(\d+)', # dilation depth (3D only) - supports both -r and --dilation_d
|
||||
'g': r'-g\s+(\d+)', # groups
|
||||
'F': r'-F\s+(\d+)', # direction (1=fwd, 2=bwd_weight, 4=bwd_data)
|
||||
"n": r"-n\s+(\d+)", # batch size
|
||||
"c": r"-c\s+(\d+)", # input channels
|
||||
"k": r"-k\s+(\d+)", # output channels
|
||||
"H": r"-H\s+(\d+)", # input height
|
||||
"W": r"-W\s+(\d+)", # input width
|
||||
"D": r"(?:-D|--in_d)\s+(\d+)", # input depth (3D only) - supports both -D and --in_d
|
||||
"y": r"-y\s+(\d+)", # kernel height
|
||||
"x": r"-x\s+(\d+)", # kernel width
|
||||
"z": r"(?:-z|--fil_d)\s+(\d+)", # kernel depth (3D only) - supports both -z and --fil_d
|
||||
"u": r"-u\s+(\d+)", # stride height
|
||||
"v": r"-v\s+(\d+)", # stride width
|
||||
"w": r"(?:-w|--conv_stride_d)\s+(\d+)", # stride depth (3D only) - supports both -w and --conv_stride_d
|
||||
"p": r"-p\s+(\d+)", # pad height
|
||||
"q": r"-q\s+(\d+)", # pad width
|
||||
"s": r"(?:-s|--pad_d)\s+(\d+)", # pad depth (3D only) - supports both -s and --pad_d
|
||||
"l": r"-l\s+(\d+)", # dilation height
|
||||
"j": r"-j\s+(\d+)", # dilation width
|
||||
"r": r"(?:-r|--dilation_d)\s+(\d+)", # dilation depth (3D only) - supports both -r and --dilation_d
|
||||
"g": r"-g\s+(\d+)", # groups
|
||||
"F": r"-F\s+(\d+)", # direction (1=fwd, 2=bwd_weight, 4=bwd_data)
|
||||
}
|
||||
|
||||
|
||||
for param, pattern in param_patterns.items():
|
||||
match = re.search(pattern, command_line)
|
||||
if match:
|
||||
params[param] = int(match.group(1))
|
||||
|
||||
|
||||
return params if params else None
|
||||
|
||||
|
||||
def miopen_to_conv_param(miopen_params):
|
||||
"""
|
||||
Convert MIOpen parameters to CK ConvParam format
|
||||
|
||||
|
||||
Returns dictionary in CSV format or None if conversion fails
|
||||
"""
|
||||
if not miopen_params:
|
||||
return None
|
||||
|
||||
|
||||
# Determine if 2D or 3D convolution
|
||||
is_3d = 'D' in miopen_params or 'z' in miopen_params or 'w' in miopen_params or 'r' in miopen_params or 's' in miopen_params
|
||||
|
||||
is_3d = (
|
||||
"D" in miopen_params
|
||||
or "z" in miopen_params
|
||||
or "w" in miopen_params
|
||||
or "r" in miopen_params
|
||||
or "s" in miopen_params
|
||||
)
|
||||
|
||||
# Extract basic parameters with defaults
|
||||
ndim = 3 if is_3d else 2
|
||||
groups = miopen_params.get('g', 1)
|
||||
batch_size = miopen_params.get('n', 1)
|
||||
groups = miopen_params.get("g", 1)
|
||||
batch_size = miopen_params.get("n", 1)
|
||||
# MIOpen uses total channels (C*G), CK uses channels per group
|
||||
out_channels_total = miopen_params.get('k', 64)
|
||||
in_channels_total = miopen_params.get('c', 3)
|
||||
out_channels_total = miopen_params.get("k", 64)
|
||||
in_channels_total = miopen_params.get("c", 3)
|
||||
out_channels = out_channels_total // groups # CK format: channels per group
|
||||
in_channels = in_channels_total // groups # CK format: channels per group
|
||||
|
||||
in_channels = in_channels_total // groups # CK format: channels per group
|
||||
|
||||
if is_3d:
|
||||
# 3D convolution
|
||||
kernel_d = miopen_params.get('z', 3)
|
||||
kernel_h = miopen_params.get('y', 3)
|
||||
kernel_w = miopen_params.get('x', 3)
|
||||
|
||||
input_d = miopen_params.get('D', 16)
|
||||
input_h = miopen_params.get('H', 32)
|
||||
input_w = miopen_params.get('W', 32)
|
||||
|
||||
stride_d = miopen_params.get('w', 1)
|
||||
stride_h = miopen_params.get('u', 1)
|
||||
stride_w = miopen_params.get('v', 1)
|
||||
|
||||
dilation_d = miopen_params.get('r', 1)
|
||||
dilation_h = miopen_params.get('l', 1)
|
||||
dilation_w = miopen_params.get('j', 1)
|
||||
|
||||
pad_d = miopen_params.get('s', 0)
|
||||
pad_h = miopen_params.get('p', 0)
|
||||
pad_w = miopen_params.get('q', 0)
|
||||
|
||||
kernel_d = miopen_params.get("z", 3)
|
||||
kernel_h = miopen_params.get("y", 3)
|
||||
kernel_w = miopen_params.get("x", 3)
|
||||
|
||||
input_d = miopen_params.get("D", 16)
|
||||
input_h = miopen_params.get("H", 32)
|
||||
input_w = miopen_params.get("W", 32)
|
||||
|
||||
stride_d = miopen_params.get("w", 1)
|
||||
stride_h = miopen_params.get("u", 1)
|
||||
stride_w = miopen_params.get("v", 1)
|
||||
|
||||
dilation_d = miopen_params.get("r", 1)
|
||||
dilation_h = miopen_params.get("l", 1)
|
||||
dilation_w = miopen_params.get("j", 1)
|
||||
|
||||
pad_d = miopen_params.get("s", 0)
|
||||
pad_h = miopen_params.get("p", 0)
|
||||
pad_w = miopen_params.get("q", 0)
|
||||
|
||||
# Calculate output dimensions
|
||||
output_d = (input_d + 2 * pad_d - dilation_d * (kernel_d - 1) - 1) // stride_d + 1
|
||||
output_h = (input_h + 2 * pad_h - dilation_h * (kernel_h - 1) - 1) // stride_h + 1
|
||||
output_w = (input_w + 2 * pad_w - dilation_w * (kernel_w - 1) - 1) // stride_w + 1
|
||||
|
||||
output_d = (
|
||||
input_d + 2 * pad_d - dilation_d * (kernel_d - 1) - 1
|
||||
) // stride_d + 1
|
||||
output_h = (
|
||||
input_h + 2 * pad_h - dilation_h * (kernel_h - 1) - 1
|
||||
) // stride_h + 1
|
||||
output_w = (
|
||||
input_w + 2 * pad_w - dilation_w * (kernel_w - 1) - 1
|
||||
) // stride_w + 1
|
||||
|
||||
# Skip invalid configurations
|
||||
if output_d <= 0 or output_h <= 0 or output_w <= 0:
|
||||
return None
|
||||
|
||||
direction = miopen_params.get('F', 1) # 1=fwd, 2=bwd_weight, 4=bwd_data
|
||||
direction_name = {1: 'fwd', 2: 'bwd_weight', 4: 'bwd_data'}.get(direction, 'fwd')
|
||||
|
||||
|
||||
direction = miopen_params.get("F", 1) # 1=fwd, 2=bwd_weight, 4=bwd_data
|
||||
direction_name = {1: "fwd", 2: "bwd_weight", 4: "bwd_data"}.get(
|
||||
direction, "fwd"
|
||||
)
|
||||
|
||||
return {
|
||||
'NDim': ndim,
|
||||
'Groups': groups,
|
||||
'BatchSize': batch_size,
|
||||
'OutChannels': out_channels,
|
||||
'InChannels': in_channels,
|
||||
'KernelD': kernel_d, 'KernelH': kernel_h, 'KernelW': kernel_w,
|
||||
'InputD': input_d, 'InputH': input_h, 'InputW': input_w,
|
||||
'OutputD': output_d, 'OutputH': output_h, 'OutputW': output_w,
|
||||
'StrideD': stride_d, 'StrideH': stride_h, 'StrideW': stride_w,
|
||||
'DilationD': dilation_d, 'DilationH': dilation_h, 'DilationW': dilation_w,
|
||||
'LeftPadD': pad_d, 'LeftPadH': pad_h, 'LeftPadW': pad_w,
|
||||
'RightPadD': pad_d, 'RightPadH': pad_h, 'RightPadW': pad_w,
|
||||
'TestName': f'MIOpen_3D_{direction_name}'
|
||||
"NDim": ndim,
|
||||
"Groups": groups,
|
||||
"BatchSize": batch_size,
|
||||
"OutChannels": out_channels,
|
||||
"InChannels": in_channels,
|
||||
"KernelD": kernel_d,
|
||||
"KernelH": kernel_h,
|
||||
"KernelW": kernel_w,
|
||||
"InputD": input_d,
|
||||
"InputH": input_h,
|
||||
"InputW": input_w,
|
||||
"OutputD": output_d,
|
||||
"OutputH": output_h,
|
||||
"OutputW": output_w,
|
||||
"StrideD": stride_d,
|
||||
"StrideH": stride_h,
|
||||
"StrideW": stride_w,
|
||||
"DilationD": dilation_d,
|
||||
"DilationH": dilation_h,
|
||||
"DilationW": dilation_w,
|
||||
"LeftPadD": pad_d,
|
||||
"LeftPadH": pad_h,
|
||||
"LeftPadW": pad_w,
|
||||
"RightPadD": pad_d,
|
||||
"RightPadH": pad_h,
|
||||
"RightPadW": pad_w,
|
||||
"TestName": f"MIOpen_3D_{direction_name}",
|
||||
}
|
||||
|
||||
|
||||
else:
|
||||
# 2D convolution
|
||||
kernel_h = miopen_params.get('y', 3)
|
||||
kernel_w = miopen_params.get('x', 3)
|
||||
|
||||
input_h = miopen_params.get('H', 32)
|
||||
input_w = miopen_params.get('W', 32)
|
||||
|
||||
stride_h = miopen_params.get('u', 1)
|
||||
stride_w = miopen_params.get('v', 1)
|
||||
|
||||
dilation_h = miopen_params.get('l', 1)
|
||||
dilation_w = miopen_params.get('j', 1)
|
||||
|
||||
pad_h = miopen_params.get('p', 0)
|
||||
pad_w = miopen_params.get('q', 0)
|
||||
|
||||
kernel_h = miopen_params.get("y", 3)
|
||||
kernel_w = miopen_params.get("x", 3)
|
||||
|
||||
input_h = miopen_params.get("H", 32)
|
||||
input_w = miopen_params.get("W", 32)
|
||||
|
||||
stride_h = miopen_params.get("u", 1)
|
||||
stride_w = miopen_params.get("v", 1)
|
||||
|
||||
dilation_h = miopen_params.get("l", 1)
|
||||
dilation_w = miopen_params.get("j", 1)
|
||||
|
||||
pad_h = miopen_params.get("p", 0)
|
||||
pad_w = miopen_params.get("q", 0)
|
||||
|
||||
# Calculate output dimensions
|
||||
output_h = (input_h + 2 * pad_h - dilation_h * (kernel_h - 1) - 1) // stride_h + 1
|
||||
output_w = (input_w + 2 * pad_w - dilation_w * (kernel_w - 1) - 1) // stride_w + 1
|
||||
|
||||
output_h = (
|
||||
input_h + 2 * pad_h - dilation_h * (kernel_h - 1) - 1
|
||||
) // stride_h + 1
|
||||
output_w = (
|
||||
input_w + 2 * pad_w - dilation_w * (kernel_w - 1) - 1
|
||||
) // stride_w + 1
|
||||
|
||||
# Skip invalid configurations
|
||||
if output_h <= 0 or output_w <= 0:
|
||||
return None
|
||||
|
||||
direction = miopen_params.get('F', 1)
|
||||
direction_name = {1: 'fwd', 2: 'bwd_weight', 4: 'bwd_data'}.get(direction, 'fwd')
|
||||
|
||||
|
||||
direction = miopen_params.get("F", 1)
|
||||
direction_name = {1: "fwd", 2: "bwd_weight", 4: "bwd_data"}.get(
|
||||
direction, "fwd"
|
||||
)
|
||||
|
||||
return {
|
||||
'NDim': ndim,
|
||||
'Groups': groups,
|
||||
'BatchSize': batch_size,
|
||||
'OutChannels': out_channels,
|
||||
'InChannels': in_channels,
|
||||
'KernelH': kernel_h, 'KernelW': kernel_w,
|
||||
'InputH': input_h, 'InputW': input_w,
|
||||
'OutputH': output_h, 'OutputW': output_w,
|
||||
'StrideH': stride_h, 'StrideW': stride_w,
|
||||
'DilationH': dilation_h, 'DilationW': dilation_w,
|
||||
'LeftPadH': pad_h, 'LeftPadW': pad_w,
|
||||
'RightPadH': pad_h, 'RightPadW': pad_w,
|
||||
'TestName': f'MIOpen_2D_{direction_name}'
|
||||
"NDim": ndim,
|
||||
"Groups": groups,
|
||||
"BatchSize": batch_size,
|
||||
"OutChannels": out_channels,
|
||||
"InChannels": in_channels,
|
||||
"KernelH": kernel_h,
|
||||
"KernelW": kernel_w,
|
||||
"InputH": input_h,
|
||||
"InputW": input_w,
|
||||
"OutputH": output_h,
|
||||
"OutputW": output_w,
|
||||
"StrideH": stride_h,
|
||||
"StrideW": stride_w,
|
||||
"DilationH": dilation_h,
|
||||
"DilationW": dilation_w,
|
||||
"LeftPadH": pad_h,
|
||||
"LeftPadW": pad_w,
|
||||
"RightPadH": pad_h,
|
||||
"RightPadW": pad_w,
|
||||
"TestName": f"MIOpen_2D_{direction_name}",
|
||||
}
|
||||
|
||||
|
||||
def write_csv_cases(test_cases, output_file, ndim):
|
||||
"""Write test cases to CSV file"""
|
||||
if not test_cases:
|
||||
print(f"No {ndim}D test cases to write")
|
||||
return
|
||||
|
||||
|
||||
print(f"Writing {len(test_cases)} {ndim}D test cases to {output_file}")
|
||||
|
||||
|
||||
# Define CSV headers based on dimension
|
||||
if ndim == 2:
|
||||
headers = ['NDim', 'Groups', 'BatchSize', 'OutChannels', 'InChannels',
|
||||
'KernelH', 'KernelW', 'InputH', 'InputW', 'OutputH', 'OutputW',
|
||||
'StrideH', 'StrideW', 'DilationH', 'DilationW',
|
||||
'LeftPadH', 'LeftPadW', 'RightPadH', 'RightPadW', 'TestName']
|
||||
headers = [
|
||||
"NDim",
|
||||
"Groups",
|
||||
"BatchSize",
|
||||
"OutChannels",
|
||||
"InChannels",
|
||||
"KernelH",
|
||||
"KernelW",
|
||||
"InputH",
|
||||
"InputW",
|
||||
"OutputH",
|
||||
"OutputW",
|
||||
"StrideH",
|
||||
"StrideW",
|
||||
"DilationH",
|
||||
"DilationW",
|
||||
"LeftPadH",
|
||||
"LeftPadW",
|
||||
"RightPadH",
|
||||
"RightPadW",
|
||||
"TestName",
|
||||
]
|
||||
else: # 3D
|
||||
headers = ['NDim', 'Groups', 'BatchSize', 'OutChannels', 'InChannels',
|
||||
'KernelD', 'KernelH', 'KernelW', 'InputD', 'InputH', 'InputW',
|
||||
'OutputD', 'OutputH', 'OutputW', 'StrideD', 'StrideH', 'StrideW',
|
||||
'DilationD', 'DilationH', 'DilationW',
|
||||
'LeftPadD', 'LeftPadH', 'LeftPadW', 'RightPadD', 'RightPadH', 'RightPadW', 'TestName']
|
||||
|
||||
with open(output_file, 'w', newline='') as csvfile:
|
||||
headers = [
|
||||
"NDim",
|
||||
"Groups",
|
||||
"BatchSize",
|
||||
"OutChannels",
|
||||
"InChannels",
|
||||
"KernelD",
|
||||
"KernelH",
|
||||
"KernelW",
|
||||
"InputD",
|
||||
"InputH",
|
||||
"InputW",
|
||||
"OutputD",
|
||||
"OutputH",
|
||||
"OutputW",
|
||||
"StrideD",
|
||||
"StrideH",
|
||||
"StrideW",
|
||||
"DilationD",
|
||||
"DilationH",
|
||||
"DilationW",
|
||||
"LeftPadD",
|
||||
"LeftPadH",
|
||||
"LeftPadW",
|
||||
"RightPadD",
|
||||
"RightPadH",
|
||||
"RightPadW",
|
||||
"TestName",
|
||||
]
|
||||
|
||||
with open(output_file, "w", newline="") as csvfile:
|
||||
# Write header comment
|
||||
csvfile.write(f"# {ndim}D Convolution Test Cases from MIOpen Commands\n")
|
||||
csvfile.write(f"# Generated {len(test_cases)} test cases\n")
|
||||
|
||||
|
||||
writer = csv.DictWriter(csvfile, fieldnames=headers)
|
||||
writer.writeheader()
|
||||
|
||||
|
||||
for test_case in test_cases:
|
||||
# Only write fields that exist in headers
|
||||
filtered_case = {k: v for k, v in test_case.items() if k in headers}
|
||||
writer.writerow(filtered_case)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Convert MIOpen commands to CSV test cases')
|
||||
|
||||
parser.add_argument('--input', type=str, required=True,
|
||||
help='Input file with MIOpen driver commands')
|
||||
parser.add_argument('--output', type=str,
|
||||
help='Output CSV file (for mixed 2D/3D cases)')
|
||||
parser.add_argument('--output-2d', type=str, default='miopen_conv_2d.csv',
|
||||
help='Output CSV file for 2D cases')
|
||||
parser.add_argument('--output-3d', type=str, default='miopen_conv_3d.csv',
|
||||
help='Output CSV file for 3D cases')
|
||||
parser.add_argument('--filter-duplicates', action='store_true',
|
||||
help='Remove duplicate test cases')
|
||||
parser.add_argument('--model-name', type=str, default='MIOpen',
|
||||
help='Model name to use in test case names (default: MIOpen)')
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert MIOpen commands to CSV test cases"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--input",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Input file with MIOpen driver commands",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output", type=str, help="Output CSV file (for mixed 2D/3D cases)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-2d",
|
||||
type=str,
|
||||
default="miopen_conv_2d.csv",
|
||||
help="Output CSV file for 2D cases",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-3d",
|
||||
type=str,
|
||||
default="miopen_conv_3d.csv",
|
||||
help="Output CSV file for 3D cases",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--filter-duplicates", action="store_true", help="Remove duplicate test cases"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
default="MIOpen",
|
||||
help="Model name to use in test case names (default: MIOpen)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
if not os.path.exists(args.input):
|
||||
print(f"ERROR: Input file not found: {args.input}")
|
||||
return 1
|
||||
|
||||
|
||||
print(f"Parsing MIOpen commands from {args.input}...")
|
||||
|
||||
|
||||
test_cases_2d = []
|
||||
test_cases_3d = []
|
||||
total_lines = 0
|
||||
parsed_lines = 0
|
||||
|
||||
with open(args.input, 'r') as f:
|
||||
|
||||
with open(args.input, "r") as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
total_lines += 1
|
||||
line = line.strip()
|
||||
|
||||
|
||||
# Skip empty lines and non-MIOpen commands
|
||||
# Handle both direct commands and logged commands with MIOpen prefix
|
||||
if not line:
|
||||
continue
|
||||
|
||||
|
||||
# Extract the actual MIOpenDriver command from logged format
|
||||
if 'MIOpenDriver conv' in line:
|
||||
if "MIOpenDriver conv" in line:
|
||||
# Extract command after finding MIOpenDriver
|
||||
command_start = line.find('./bin/MIOpenDriver conv')
|
||||
command_start = line.find("./bin/MIOpenDriver conv")
|
||||
if command_start != -1:
|
||||
line = line[command_start:]
|
||||
else:
|
||||
# Handle cases where path might be different - create standard format
|
||||
driver_start = line.find('MIOpenDriver conv')
|
||||
driver_start = line.find("MIOpenDriver conv")
|
||||
if driver_start != -1:
|
||||
line = './bin/' + line[driver_start:]
|
||||
line = "./bin/" + line[driver_start:]
|
||||
else:
|
||||
continue
|
||||
elif not line.startswith('./bin/MIOpenDriver conv'):
|
||||
elif not line.startswith("./bin/MIOpenDriver conv"):
|
||||
continue
|
||||
|
||||
|
||||
try:
|
||||
# Parse MIOpen command
|
||||
miopen_params = parse_miopen_command(line)
|
||||
if not miopen_params:
|
||||
continue
|
||||
|
||||
|
||||
# Convert to ConvParam format
|
||||
conv_param = miopen_to_conv_param(miopen_params)
|
||||
if not conv_param:
|
||||
continue
|
||||
|
||||
|
||||
# Add model name to test name
|
||||
conv_param['TestName'] = f"{args.model_name}_{conv_param['NDim']}D_fwd"
|
||||
|
||||
conv_param["TestName"] = f"{args.model_name}_{conv_param['NDim']}D_fwd"
|
||||
|
||||
# Separate 2D and 3D cases
|
||||
if conv_param['NDim'] == 2:
|
||||
if conv_param["NDim"] == 2:
|
||||
test_cases_2d.append(conv_param)
|
||||
else:
|
||||
test_cases_3d.append(conv_param)
|
||||
|
||||
|
||||
parsed_lines += 1
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"WARNING: Failed to parse line {line_num}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
print(f"Processed {total_lines} lines, parsed {parsed_lines} commands")
|
||||
print(f"Found {len(test_cases_2d)} 2D cases, {len(test_cases_3d)} 3D cases")
|
||||
|
||||
|
||||
# Remove duplicates if requested
|
||||
if args.filter_duplicates:
|
||||
# Simple duplicate removal based on key parameters
|
||||
def make_key(case):
|
||||
if case['NDim'] == 2:
|
||||
return (case['Groups'], case['BatchSize'], case['OutChannels'], case['InChannels'],
|
||||
case['KernelH'], case['KernelW'], case['InputH'], case['InputW'],
|
||||
case['StrideH'], case['StrideW'])
|
||||
if case["NDim"] == 2:
|
||||
return (
|
||||
case["Groups"],
|
||||
case["BatchSize"],
|
||||
case["OutChannels"],
|
||||
case["InChannels"],
|
||||
case["KernelH"],
|
||||
case["KernelW"],
|
||||
case["InputH"],
|
||||
case["InputW"],
|
||||
case["StrideH"],
|
||||
case["StrideW"],
|
||||
)
|
||||
else:
|
||||
return (case['Groups'], case['BatchSize'], case['OutChannels'], case['InChannels'],
|
||||
case['KernelD'], case['KernelH'], case['KernelW'],
|
||||
case['InputD'], case['InputH'], case['InputW'],
|
||||
case['StrideD'], case['StrideH'], case['StrideW'])
|
||||
|
||||
return (
|
||||
case["Groups"],
|
||||
case["BatchSize"],
|
||||
case["OutChannels"],
|
||||
case["InChannels"],
|
||||
case["KernelD"],
|
||||
case["KernelH"],
|
||||
case["KernelW"],
|
||||
case["InputD"],
|
||||
case["InputH"],
|
||||
case["InputW"],
|
||||
case["StrideD"],
|
||||
case["StrideH"],
|
||||
case["StrideW"],
|
||||
)
|
||||
|
||||
seen_2d = set()
|
||||
unique_2d = []
|
||||
for case in test_cases_2d:
|
||||
@@ -320,7 +447,7 @@ def main():
|
||||
if key not in seen_2d:
|
||||
seen_2d.add(key)
|
||||
unique_2d.append(case)
|
||||
|
||||
|
||||
seen_3d = set()
|
||||
unique_3d = []
|
||||
for case in test_cases_3d:
|
||||
@@ -328,11 +455,13 @@ def main():
|
||||
if key not in seen_3d:
|
||||
seen_3d.add(key)
|
||||
unique_3d.append(case)
|
||||
|
||||
print(f"After deduplication: {len(unique_2d)} 2D cases, {len(unique_3d)} 3D cases")
|
||||
|
||||
print(
|
||||
f"After deduplication: {len(unique_2d)} 2D cases, {len(unique_3d)} 3D cases"
|
||||
)
|
||||
test_cases_2d = unique_2d
|
||||
test_cases_3d = unique_3d
|
||||
|
||||
|
||||
# Write output files
|
||||
if args.output:
|
||||
# Write mixed cases to single file
|
||||
@@ -340,14 +469,36 @@ def main():
|
||||
if all_cases:
|
||||
print(f"Writing {len(all_cases)} total cases to {args.output}")
|
||||
# Use 2D headers for mixed file, extend as needed
|
||||
mixed_headers = ['NDim', 'Groups', 'BatchSize', 'OutChannels', 'InChannels',
|
||||
'KernelH', 'KernelW', 'InputH', 'InputW', 'OutputH', 'OutputW',
|
||||
'StrideH', 'StrideW', 'DilationH', 'DilationW',
|
||||
'LeftPadH', 'LeftPadW', 'RightPadH', 'RightPadW', 'TestName']
|
||||
|
||||
with open(args.output, 'w', newline='') as csvfile:
|
||||
csvfile.write(f"# Mixed 2D/3D Convolution Test Cases from MIOpen Commands\n")
|
||||
writer = csv.DictWriter(csvfile, fieldnames=mixed_headers, extrasaction='ignore')
|
||||
mixed_headers = [
|
||||
"NDim",
|
||||
"Groups",
|
||||
"BatchSize",
|
||||
"OutChannels",
|
||||
"InChannels",
|
||||
"KernelH",
|
||||
"KernelW",
|
||||
"InputH",
|
||||
"InputW",
|
||||
"OutputH",
|
||||
"OutputW",
|
||||
"StrideH",
|
||||
"StrideW",
|
||||
"DilationH",
|
||||
"DilationW",
|
||||
"LeftPadH",
|
||||
"LeftPadW",
|
||||
"RightPadH",
|
||||
"RightPadW",
|
||||
"TestName",
|
||||
]
|
||||
|
||||
with open(args.output, "w", newline="") as csvfile:
|
||||
csvfile.write(
|
||||
"# Mixed 2D/3D Convolution Test Cases from MIOpen Commands\n"
|
||||
)
|
||||
writer = csv.DictWriter(
|
||||
csvfile, fieldnames=mixed_headers, extrasaction="ignore"
|
||||
)
|
||||
writer.writeheader()
|
||||
for case in all_cases:
|
||||
writer.writerow(case)
|
||||
@@ -355,12 +506,13 @@ def main():
|
||||
# Write separate files for 2D and 3D
|
||||
if test_cases_2d:
|
||||
write_csv_cases(test_cases_2d, args.output_2d, 2)
|
||||
|
||||
|
||||
if test_cases_3d:
|
||||
write_csv_cases(test_cases_3d, args.output_3d, 3)
|
||||
|
||||
|
||||
print("Conversion completed!")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
|
||||
@@ -7,13 +7,12 @@ PyTorch Model Runner with MIOpen Command Logging using torchvision models
|
||||
|
||||
Usage:
|
||||
MIOPEN_ENABLE_LOGGING_CMD=1 python3 run_model_with_miopen.py --model resnet18 2> miopen_commands.txt
|
||||
|
||||
|
||||
Available 2D models: alexnet, vgg11, vgg16, resnet18, resnet50, mobilenet_v2, etc.
|
||||
Available 3D models: r3d_18, mc3_18, r2plus1d_18
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.models as models
|
||||
import torchvision.models.video as video_models
|
||||
import argparse
|
||||
@@ -21,94 +20,145 @@ import os
|
||||
|
||||
# Define available models
|
||||
MODELS_2D = [
|
||||
'alexnet', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn',
|
||||
'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
|
||||
'resnext50_32x4d', 'resnext101_32x8d', 'resnext101_64x4d',
|
||||
'wide_resnet50_2', 'wide_resnet101_2',
|
||||
'densenet121', 'densenet161', 'densenet169', 'densenet201',
|
||||
'inception_v3', 'googlenet',
|
||||
'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0',
|
||||
'mobilenet_v2', 'mobilenet_v3_large', 'mobilenet_v3_small',
|
||||
'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3',
|
||||
'squeezenet1_0', 'squeezenet1_1'
|
||||
"alexnet",
|
||||
"vgg11",
|
||||
"vgg11_bn",
|
||||
"vgg13",
|
||||
"vgg13_bn",
|
||||
"vgg16",
|
||||
"vgg16_bn",
|
||||
"vgg19",
|
||||
"vgg19_bn",
|
||||
"resnet18",
|
||||
"resnet34",
|
||||
"resnet50",
|
||||
"resnet101",
|
||||
"resnet152",
|
||||
"resnext50_32x4d",
|
||||
"resnext101_32x8d",
|
||||
"resnext101_64x4d",
|
||||
"wide_resnet50_2",
|
||||
"wide_resnet101_2",
|
||||
"densenet121",
|
||||
"densenet161",
|
||||
"densenet169",
|
||||
"densenet201",
|
||||
"inception_v3",
|
||||
"googlenet",
|
||||
"shufflenet_v2_x0_5",
|
||||
"shufflenet_v2_x1_0",
|
||||
"shufflenet_v2_x1_5",
|
||||
"shufflenet_v2_x2_0",
|
||||
"mobilenet_v2",
|
||||
"mobilenet_v3_large",
|
||||
"mobilenet_v3_small",
|
||||
"mnasnet0_5",
|
||||
"mnasnet0_75",
|
||||
"mnasnet1_0",
|
||||
"mnasnet1_3",
|
||||
"squeezenet1_0",
|
||||
"squeezenet1_1",
|
||||
]
|
||||
|
||||
MODELS_3D = [
|
||||
'r3d_18', 'mc3_18', 'r2plus1d_18'
|
||||
]
|
||||
MODELS_3D = ["r3d_18", "mc3_18", "r2plus1d_18"]
|
||||
|
||||
ALL_MODELS = MODELS_2D + MODELS_3D
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='PyTorch Model Runner with MIOpen Command Logging')
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="PyTorch Model Runner with MIOpen Command Logging"
|
||||
)
|
||||
|
||||
# Model selection
|
||||
parser.add_argument('--model', choices=ALL_MODELS, default='resnet18',
|
||||
help='Model to run')
|
||||
|
||||
parser.add_argument(
|
||||
"--model", choices=ALL_MODELS, default="resnet18", help="Model to run"
|
||||
)
|
||||
|
||||
# Input tensor dimensions
|
||||
parser.add_argument('--batch-size', type=int, default=4,
|
||||
help='Batch size')
|
||||
parser.add_argument('--channels', type=int, default=3,
|
||||
help='Input channels (e.g., 3 for RGB, 1 for grayscale)')
|
||||
parser.add_argument('--height', type=int, default=224,
|
||||
help='Input height')
|
||||
parser.add_argument('--width', type=int, default=224,
|
||||
help='Input width')
|
||||
parser.add_argument('--input-size', type=int,
|
||||
help='Input size (sets both height and width to same value)')
|
||||
parser.add_argument('--temporal-size', type=int, default=16,
|
||||
help='Temporal dimension for 3D models')
|
||||
|
||||
parser.add_argument("--batch-size", type=int, default=4, help="Batch size")
|
||||
parser.add_argument(
|
||||
"--channels",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Input channels (e.g., 3 for RGB, 1 for grayscale)",
|
||||
)
|
||||
parser.add_argument("--height", type=int, default=224, help="Input height")
|
||||
parser.add_argument("--width", type=int, default=224, help="Input width")
|
||||
parser.add_argument(
|
||||
"--input-size",
|
||||
type=int,
|
||||
help="Input size (sets both height and width to same value)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temporal-size", type=int, default=16, help="Temporal dimension for 3D models"
|
||||
)
|
||||
|
||||
# Device and precision
|
||||
parser.add_argument('--device', choices=['cuda', 'cpu', 'auto'], default='auto',
|
||||
help='Device to run on')
|
||||
parser.add_argument('--precision', choices=['fp32', 'fp16', 'bf16'], default='fp32',
|
||||
help='Floating point precision')
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
choices=["cuda", "cpu", "auto"],
|
||||
default="auto",
|
||||
help="Device to run on",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
choices=["fp32", "fp16", "bf16"],
|
||||
default="fp32",
|
||||
help="Floating point precision",
|
||||
)
|
||||
|
||||
# Output control
|
||||
parser.add_argument('--quiet', action='store_true',
|
||||
help='Suppress output except errors')
|
||||
parser.add_argument('--verbose', action='store_true',
|
||||
help='Verbose output')
|
||||
|
||||
parser.add_argument(
|
||||
"--quiet", action="store_true", help="Suppress output except errors"
|
||||
)
|
||||
parser.add_argument("--verbose", action="store_true", help="Verbose output")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# Handle input-size override
|
||||
if args.input_size:
|
||||
args.height = args.input_size
|
||||
args.width = args.input_size
|
||||
|
||||
|
||||
# Check MIOpen logging
|
||||
if not os.environ.get('MIOPEN_ENABLE_LOGGING_CMD') and not args.quiet:
|
||||
if not os.environ.get("MIOPEN_ENABLE_LOGGING_CMD") and not args.quiet:
|
||||
print("WARNING: Set MIOPEN_ENABLE_LOGGING_CMD=1 to capture commands")
|
||||
|
||||
|
||||
# Device selection
|
||||
if args.device == 'auto':
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
if args.device == "auto":
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
else:
|
||||
device = torch.device(args.device)
|
||||
|
||||
|
||||
# Check if actually running on GPU
|
||||
if device.type == 'cpu':
|
||||
if device.type == "cpu":
|
||||
import sys
|
||||
print(f"WARNING: Running on CPU, MIOpen commands will not be generated!", file=sys.stderr)
|
||||
|
||||
print(
|
||||
"WARNING: Running on CPU, MIOpen commands will not be generated!",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(f"CUDA/ROCm available: {torch.cuda.is_available()}", file=sys.stderr)
|
||||
if torch.cuda.is_available():
|
||||
print(f"GPU device count: {torch.cuda.device_count()}", file=sys.stderr)
|
||||
print(f"GPU name: {torch.cuda.get_device_name(0) if torch.cuda.device_count() > 0 else 'N/A'}", file=sys.stderr)
|
||||
print(
|
||||
f"GPU name: {torch.cuda.get_device_name(0) if torch.cuda.device_count() > 0 else 'N/A'}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
# Continue anyway for testing purposes
|
||||
|
||||
|
||||
if not args.quiet:
|
||||
print(f"Using device: {device}")
|
||||
|
||||
|
||||
# Create model using torchvision
|
||||
if args.model in MODELS_3D:
|
||||
# 3D Video models
|
||||
model = getattr(video_models, args.model)(weights=None)
|
||||
# 3D input: (batch, channels, temporal, height, width)
|
||||
input_tensor = torch.randn(args.batch_size, args.channels, args.temporal_size, args.height, args.width)
|
||||
input_tensor = torch.randn(
|
||||
args.batch_size, args.channels, args.temporal_size, args.height, args.width
|
||||
)
|
||||
if not args.quiet:
|
||||
print(f"3D model: {args.model}")
|
||||
print(f"Input shape: {input_tensor.shape} (B, C, T, H, W)")
|
||||
@@ -116,34 +166,37 @@ def main():
|
||||
# 2D Image models
|
||||
model = getattr(models, args.model)(weights=None)
|
||||
# 2D input: (batch, channels, height, width)
|
||||
input_tensor = torch.randn(args.batch_size, args.channels, args.height, args.width)
|
||||
input_tensor = torch.randn(
|
||||
args.batch_size, args.channels, args.height, args.width
|
||||
)
|
||||
if not args.quiet:
|
||||
print(f"2D model: {args.model}")
|
||||
print(f"Input shape: {input_tensor.shape} (B, C, H, W)")
|
||||
|
||||
|
||||
# Set precision
|
||||
if args.precision == 'fp16':
|
||||
if args.precision == "fp16":
|
||||
model = model.half()
|
||||
input_tensor = input_tensor.half()
|
||||
elif args.precision == 'bf16':
|
||||
elif args.precision == "bf16":
|
||||
model = model.bfloat16()
|
||||
input_tensor = input_tensor.bfloat16()
|
||||
|
||||
|
||||
model = model.to(device)
|
||||
input_tensor = input_tensor.to(device)
|
||||
|
||||
|
||||
if not args.quiet:
|
||||
print(f"Running {args.model} model...")
|
||||
|
||||
|
||||
# Run inference
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
output = model(input_tensor)
|
||||
if not args.quiet:
|
||||
print(f"Output shape: {output.shape}")
|
||||
|
||||
|
||||
if not args.quiet:
|
||||
print("Done! MIOpen commands logged to stderr")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -170,11 +170,11 @@ warp_tile_supported_combinations = {
|
||||
[16, 16, 128],
|
||||
[32, 32, 64],
|
||||
],
|
||||
"fp8_bf8_fp16": [
|
||||
"fp8_bf8_fp16": [
|
||||
[16, 16, 128],
|
||||
[32, 32, 64],
|
||||
],
|
||||
"bf8_fp8_fp16": [
|
||||
"bf8_fp8_fp16": [
|
||||
[16, 16, 128],
|
||||
[32, 32, 64],
|
||||
],
|
||||
|
||||
@@ -107,32 +107,32 @@ WARP_TILE_SUPPORTED_COMBINATIONS = {
|
||||
"fp16_fp16_fp16": [
|
||||
[16, 16, 16],
|
||||
],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# Supported warp tile combinations for different GPU architectures and data types
|
||||
WARP_SUPPORTED_COMBINATIONS = {
|
||||
"gfx90a": [
|
||||
[1, 4, 1],
|
||||
[2, 2, 1],
|
||||
[1, 4, 1],
|
||||
[2, 2, 1],
|
||||
[4, 1, 1],
|
||||
],
|
||||
"gfx942": [
|
||||
[1, 4, 1],
|
||||
[2, 2, 1],
|
||||
[1, 4, 1],
|
||||
[2, 2, 1],
|
||||
[4, 1, 1],
|
||||
],
|
||||
"gfx950": [
|
||||
[1, 4, 1],
|
||||
[2, 2, 1],
|
||||
[1, 4, 1],
|
||||
[2, 2, 1],
|
||||
[4, 1, 1],
|
||||
],
|
||||
"gfx1201": [
|
||||
[2, 4, 1],
|
||||
[1, 8, 1],
|
||||
[8, 1, 1],
|
||||
[2, 4, 1],
|
||||
[1, 8, 1],
|
||||
[8, 1, 1],
|
||||
[4, 2, 1],
|
||||
],
|
||||
],
|
||||
}
|
||||
|
||||
# Unsupported trait combinations
|
||||
@@ -186,14 +186,14 @@ def is_trait_combination_valid(pipeline: str, epilogue: str, scheduler: str) ->
|
||||
|
||||
|
||||
def validate_warp_configuration(
|
||||
warp_m: int,
|
||||
warp_n: int,
|
||||
warp_m: int,
|
||||
warp_n: int,
|
||||
warp_k: int,
|
||||
gpu_name: str = None,
|
||||
) -> bool:
|
||||
"""Validate warp configuration."""
|
||||
if gpu_name is None:
|
||||
gpu_name = get_gpu_name_by_id(0)
|
||||
gpu_name = get_gpu_name_by_id(0)
|
||||
|
||||
current_combination = [warp_m, warp_n, warp_k]
|
||||
|
||||
@@ -205,11 +205,8 @@ def validate_warp_configuration(
|
||||
|
||||
# Check if current combination is in the allowed list
|
||||
if current_combination not in allowed_combinations:
|
||||
error_msg = (
|
||||
f"Invalid warp tile combination: {current_combination} not in allowed list. "
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user