Update pre-commit to fixed versions, run remod for ck_tile (#2895)

* Fix ruff linter errors

* Fix remod dos2unix command

* Clang format

* Ignore utility in remod

* Run remod

* Specify clang-format version in pre-commit

* Specify ruff version

* Include PoolKernelArgs in reference_pool

* Add calculate_total_elements to reference batched contraction

* Fix calculate_total_elements declaration

* Refactor remod pre-commit hook

* Fix Aquant tests

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>

[ROCm/composable_kernel commit: d40b50b9d5]
This commit is contained in:
Johannes Graner
2025-10-17 00:29:17 +02:00
committed by GitHub
parent 1d9320c8f3
commit 8af66c65d0
77 changed files with 21671 additions and 9858 deletions

View File

@@ -6,6 +6,7 @@ import subprocess
import sys
from typing import Iterable, Optional, Mapping
def gha_set_output(vars: Mapping[str, str | Path]):
"""Sets values in a step's output parameters.
@@ -25,6 +26,7 @@ def gha_set_output(vars: Mapping[str, str | Path]):
with open(step_output_file, "a") as f:
f.writelines(f"{k}={str(v)}" + "\n" for k, v in vars.items())
def get_modified_paths(base_ref: str) -> Optional[Iterable[str]]:
"""Returns the paths of modified files relative to the base reference."""
try:
@@ -42,11 +44,13 @@ def get_modified_paths(base_ref: str) -> Optional[Iterable[str]]:
file=sys.stderr,
)
return None
GITHUB_WORKFLOWS_CI_PATTERNS = [
"therock*",
]
def is_path_workflow_file_related_to_ci(path: str) -> bool:
return any(
fnmatch.fnmatch(path, ".github/workflows/" + pattern)
@@ -56,11 +60,13 @@ def is_path_workflow_file_related_to_ci(path: str) -> bool:
for pattern in GITHUB_WORKFLOWS_CI_PATTERNS
)
def check_for_workflow_file_related_to_ci(paths: Optional[Iterable[str]]) -> bool:
if paths is None:
return False
return any(is_path_workflow_file_related_to_ci(p) for p in paths)
# Paths matching any of these patterns are considered to have no influence over
# build or test workflows so any related jobs can be skipped if all paths
# modified by a commit/PR match a pattern in this list.
@@ -70,23 +76,26 @@ SKIPPABLE_PATH_PATTERNS = [
"*.md",
"*.pre-commit-config.*",
"*LICENSE",
'Jenkinsfile',
'.github/ISSUE_TEMPLATE/*',
'.github/CODEOWNERS',
'.github/*.md',
'.github/dependabot.yml',
"Jenkinsfile",
".github/ISSUE_TEMPLATE/*",
".github/CODEOWNERS",
".github/*.md",
".github/dependabot.yml",
]
def is_path_skippable(path: str) -> bool:
"""Determines if a given relative path to a file matches any skippable patterns."""
return any(fnmatch.fnmatch(path, pattern) for pattern in SKIPPABLE_PATH_PATTERNS)
def check_for_non_skippable_path(paths: Optional[Iterable[str]]) -> bool:
"""Returns true if at least one path is not in the skippable set."""
if paths is None:
return False
return any(not is_path_skippable(p) for p in paths)
def should_ci_run_given_modified_paths(paths: Optional[Iterable[str]]) -> bool:
"""Returns true if CI workflows should run given a list of modified paths."""
@@ -118,16 +127,16 @@ def should_ci_run_given_modified_paths(paths: Optional[Iterable[str]]) -> bool:
)
return False
def main(args):
base_ref = args.get("base_ref")
modified_paths = get_modified_paths(base_ref)
print("modified_paths (max 200):", modified_paths[:200])
enable_jobs = should_ci_run_given_modified_paths(modified_paths)
output = {
'enable_therock_ci': json.dumps(enable_jobs)
}
output = {"enable_therock_ci": json.dumps(enable_jobs)}
gha_set_output(output)
if __name__ == "__main__":
args = {}
args["base_ref"] = os.environ.get("BASE_REF", "HEAD^1")

View File

@@ -1,11 +1,25 @@
repos:
- repo: local
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.3
hooks:
- id: clang-format
name: clang-format
entry: clang-format-18 -i --style=file
language: system
types_or: [c++, inc]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.0
hooks:
- id: ruff-check
args: [ --fix ]
exclude: |
(?x)^(
docs/conf.py
)$
- id: ruff-format
exclude: |
(?x)^(
docs/conf.py
)$
- repo: local
hooks:
# - id: copyright-year-checker
# name: copyright-year-checker
# entry: script/check_copyright_year.sh
@@ -18,21 +32,9 @@ repos:
language: script
types_or: [c++, text]
verbose: true
- id: ruff-check
name: Ruff Linter
entry: ruff check --fix
language: python
types: [python]
additional_dependencies: [ruff]
- id: ruff-format
name: Ruff Formatter
entry: ruff format
language: python
types: [python]
additional_dependencies: [ruff]
- id: run-remod-if-ck-tile-changed
name: Run remod.py if ck_tile files changed
entry: script/remod_for_ck_tile.sh
language: script
always_run: true
files: '^(include|example)/ck_tile/.*$'
pass_filenames: false

View File

@@ -2,4 +2,4 @@
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
GEN_DIR = "" # in Cmake, have to generate files in same folder
GEN_DIR = "" # in Cmake, have to generate files in same folder

View File

@@ -3,38 +3,35 @@
# generate kernel instances to speed up compilation
FWD_DTYPE_MAP = {
"fp32" : "FmhaFwdFp32",
"fp16" : "FmhaFwdFp16",
"bf16" : "FmhaFwdBf16",
"fp8" : "FmhaFwdFp8",
"fp32": "FmhaFwdFp32",
"fp16": "FmhaFwdFp16",
"bf16": "FmhaFwdBf16",
"fp8": "FmhaFwdFp8",
"fp8fp16": "FmhaFwdFp8Fp16",
"fp8bf16": "FmhaFwdFp8Bf16",
"fp8fp32": "FmhaFwdFp8Fp32"
"fp8fp32": "FmhaFwdFp8Fp32",
}
BWD_DTYPE_MAP = {
"fp32": "FmhaBwdFp32",
"fp16": "FmhaBwdFp16",
"bf16": "FmhaBwdBf16"
}
BWD_DTYPE_MAP = {"fp32": "FmhaBwdFp32", "fp16": "FmhaBwdFp16", "bf16": "FmhaBwdBf16"}
MASK_IMPL = {
"generic" : "ck_tile::GenericAttentionMask",
"simplified" : "ck_tile::SimplifiedGenericAttentionMask"
"generic": "ck_tile::GenericAttentionMask",
"simplified": "ck_tile::SimplifiedGenericAttentionMask",
}
_MASK_SIMPLIFIED_MAP = {
"s_no" : "ck_tile::SimplifiedGenericAttentionMask<false>",
"s_mask" : "ck_tile::SimplifiedGenericAttentionMask<true>",
"s_no": "ck_tile::SimplifiedGenericAttentionMask<false>",
"s_mask": "ck_tile::SimplifiedGenericAttentionMask<true>",
}
_MASK_MAP = {
"no" : "FmhaMasks::NoMask",
"causal" : "FmhaMasks::CausalMask",
"generic" : "FmhaMasks::GenericMask"
"no": "FmhaMasks::NoMask",
"causal": "FmhaMasks::CausalMask",
"generic": "FmhaMasks::GenericMask",
}
def get_mask_map(mask : str):
def get_mask_map(mask: str):
if mask == "generic":
return _MASK_MAP
elif mask == "simplified":
@@ -43,18 +40,20 @@ def get_mask_map(mask : str):
assert False
return None
_MASK_CHECK_MAP = {
"no" : "t.mask_type == mask_enum::no_mask",
"causal" : "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
"generic" : "t.mask_type == mask_enum::window_generic",
"no": "t.mask_type == mask_enum::no_mask",
"causal": "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
"generic": "t.mask_type == mask_enum::window_generic",
}
_MASK_SIMPLIFIED_CHECK_MAP = {
"s_no" : "t.mask_type == mask_enum::no_mask",
"s_mask" : "t.mask_type != mask_enum::no_mask",
"s_no": "t.mask_type == mask_enum::no_mask",
"s_mask": "t.mask_type != mask_enum::no_mask",
}
def get_mask_check_map(mask : str):
def get_mask_check_map(mask: str):
if mask == "generic":
return _MASK_CHECK_MAP
elif mask == "simplified":
@@ -63,76 +62,71 @@ def get_mask_check_map(mask : str):
assert False
return None
BIAS_MAP = {
"no" : "ck_tile::BlockAttentionBiasEnum::NO_BIAS",
"bias" : "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS",
"alibi" : "ck_tile::BlockAttentionBiasEnum::ALIBI"
"no": "ck_tile::BlockAttentionBiasEnum::NO_BIAS",
"bias": "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS",
"alibi": "ck_tile::BlockAttentionBiasEnum::ALIBI",
}
# TODO: this is ugly
BIAS_CHECK_MAP = {
"no" : "bias_enum::no_bias",
"bias" : "bias_enum::elementwise_bias",
"alibi" : "bias_enum::alibi"
"no": "bias_enum::no_bias",
"bias": "bias_enum::elementwise_bias",
"alibi": "bias_enum::alibi",
}
DROPOUT_MAP = {
"no" : "ck_tile::BlockDropoutBwd<false, true, false>",
"dropout_wg32" : "ck_tile::BlockDropoutBwd<true, true, false>",
"dropout_wg32_storerandval" : "ck_tile::BlockDropoutBwd<true, true, true >",
"dropout_wg16" : "ck_tile::BlockDropoutBwd<true, false, false>",
"dropout_wg16_storerandval" : "ck_tile::BlockDropoutBwd<true, false, true >"
"no": "ck_tile::BlockDropoutBwd<false, true, false>",
"dropout_wg32": "ck_tile::BlockDropoutBwd<true, true, false>",
"dropout_wg32_storerandval": "ck_tile::BlockDropoutBwd<true, true, true >",
"dropout_wg16": "ck_tile::BlockDropoutBwd<true, false, false>",
"dropout_wg16_storerandval": "ck_tile::BlockDropoutBwd<true, false, true >",
}
DROPOUT_CHECK_MAP = {
"no" : "t.has_dropout == false",
"dropout_wg32" : "t.has_dropout == true && t.is_store_randval == false",
"dropout_wg32_storerandval" : "t.has_dropout == true && t.is_store_randval == true",
"dropout_wg16" : "t.has_dropout == true && t.is_store_randval == false",
"dropout_wg16_storerandval" : "t.has_dropout == true && t.is_store_randval == true",
"no": "t.has_dropout == false",
"dropout_wg32": "t.has_dropout == true && t.is_store_randval == false",
"dropout_wg32_storerandval": "t.has_dropout == true && t.is_store_randval == true",
"dropout_wg16": "t.has_dropout == true && t.is_store_randval == false",
"dropout_wg16_storerandval": "t.has_dropout == true && t.is_store_randval == true",
}
ROPE_MAP = {
"no" : "ck_tile::RotaryEmbeddingEnum::NONE",
"inter" : "ck_tile::RotaryEmbeddingEnum::INTERLEAVED",
"half" : "ck_tile::RotaryEmbeddingEnum::HALF_ROTATED"
"no": "ck_tile::RotaryEmbeddingEnum::NONE",
"inter": "ck_tile::RotaryEmbeddingEnum::INTERLEAVED",
"half": "ck_tile::RotaryEmbeddingEnum::HALF_ROTATED",
}
ROPE_CHECK_MAP = {
"no" : "rope_enum::none",
"inter" : "rope_enum::interleaved",
"half" : "rope_enum::half_rotated"
"no": "rope_enum::none",
"inter": "rope_enum::interleaved",
"half": "rope_enum::half_rotated",
}
MODE_MAP = {
"batch" : "false",
"group" : "true"
}
MODE_MAP = {"batch": "false", "group": "true"}
LAYOUT_MAP = {
"row" : "true",
"col" : "false"
}
LAYOUT_MAP = {"row": "true", "col": "false"}
PIPELINE_MAP = {
"qr" : "ck_tile::BlockFmhaPipelineQRKSVS",
"qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync",
"qs" : "ck_tile::BlockFmhaPipelineQSKSVS",
"qr_async_trload" : "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload",
"qr": "ck_tile::BlockFmhaPipelineQRKSVS",
"qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsync",
"qs": "ck_tile::BlockFmhaPipelineQSKSVS",
"qr_async_trload": "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload",
}
PIPELINE_ENUM_MAP = {
"qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
"qr_nwarp_sshuffle" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qs" : "ck_tile::BlockFmhaPipelineEnum::QSKSVS",
"qr_pagedkv" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qr_async_trload" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD",
"qr": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qr_async": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
"qr_nwarp_sshuffle": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qs": "ck_tile::BlockFmhaPipelineEnum::QSKSVS",
"qr_pagedkv": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qr_async_trload": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD",
}
BOOL_MAP = {
"t" : "true",
"f" : "false",
True : "true",
False : "false",
"t": "true",
"f": "false",
True: "true",
False: "false",
}

View File

@@ -9,28 +9,26 @@ import itertools
from pathlib import Path
from typing import List, Optional, Tuple
from codegen.cmake_config import *
from codegen.cpp_symbol_map import *
from codegen.cmake_config import GEN_DIR
from codegen.cpp_symbol_map import (
MODE_MAP,
LAYOUT_MAP,
BIAS_CHECK_MAP,
get_mask_check_map,
get_mask_map,
BIAS_MAP,
FWD_DTYPE_MAP,
BOOL_MAP,
PIPELINE_ENUM_MAP,
)
DTYPE_BITS = {
"fp32": 32,
"fp16": 16,
"bf16": 16,
"fp8" : 8,
"bf8" : 8
}
DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8}
K0_MAX_SUBMAX_MAP = {
32 : 32,
64 : 64,
96 : 128,
128: 128,
256: 256
}
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}
FMHA_BATCH_PREFILL_PIPELINE_MAP = {
"qr_async" : "ck_tile::BlockFmhaBatchPrefillPipelineQRKSVSAsync",
"qr_async": "ck_tile::BlockFmhaBatchPrefillPipelineQRKSVSAsync",
}
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
@@ -40,7 +38,7 @@ FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
#include "fmha_fwd.hpp"
"""
FMHA_FWD_KERNEL_BODY="""
FMHA_FWD_KERNEL_BODY = """
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
@@ -116,8 +114,8 @@ float fmha_batch_prefill_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_b
}}
"""
FMHA_FWD_API_FILENAME="fmha_batch_prefill_api.cpp"
FMHA_FWD_API="""
FMHA_FWD_API_FILENAME = "fmha_batch_prefill_api.cpp"
FMHA_FWD_API = """
#include <cstdio>
namespace {{
@@ -167,173 +165,223 @@ float fmha_batch_prefill(fmha_batch_prefill_traits t, fmha_batch_prefill_args a,
}}
"""
FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
{F_hdim_case}
}}
"""
FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
{F_inner_dispatch}
}}
"""
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>;
return fmha_batch_prefill_<trait_>(s, a);
}}
"""
@dataclass
class CppConstraint:
bool_expr: str = None
def __str__(self):
if self.bool_expr is None:
return 'true'
return "true"
else:
return f'{self.bool_expr}'
return f"{self.bool_expr}"
def __and__(self, other):
return CppConstraint(f'({str(self)}) && ({str(other)})')
return CppConstraint(f"({str(self)}) && ({str(other)})")
@dataclass
class FmhaFwdApiTrait:
pipeline_tag : str
pipeline_tag: str
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim : str
dtype : str # data type
mode : str # value from MODE_MAP
bm0 : int # tile size along q seqlen (block size)
bn0 : int # tile size along qk seqlen
bk0 : int # tile size along qk gemm unroll
bn1 : int # tile size along v head_dim
bk1 : int # tile size along kv gemm unroll
bk0max : int
vlayout : str
logits : str
mask : str
bias : str #
lse : str #
dropout : str
squant : str #
spad : str
skpad : str
dpad : str
dvpad : str
constraint : CppConstraint
hdim: str
dtype: str # data type
mode: str # value from MODE_MAP
bm0: int # tile size along q seqlen (block size)
bn0: int # tile size along qk seqlen
bk0: int # tile size along qk gemm unroll
bn1: int # tile size along v head_dim
bk1: int # tile size along kv gemm unroll
bk0max: int
vlayout: str
logits: str
mask: str
bias: str #
lse: str #
dropout: str
squant: str #
spad: str
skpad: str
dpad: str
dvpad: str
constraint: CppConstraint
@property
def name(self) -> str:
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
return (
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}"
)
@property
def scheck(self) -> str:
if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true
if self.pipeline_tag == 'qr_async':
if self.spad == 't' : return 'true' # always support
else : return 'true'
elif self.pipeline_tag in ['qr']:
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_q % {self.bm0} == 0'
else: assert False
if self.mode == "group":
return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true
if self.pipeline_tag == "qr_async":
if self.spad == "t":
return "true" # always support
else:
return "true"
elif self.pipeline_tag in ["qr"]:
if self.spad == "t":
return f"true /*a.seqlen_q % {self.bm0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
return f"a.seqlen_q % {self.bm0} == 0"
else:
assert False
@property
def skcheck(self) -> str:
if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true
if self.pipeline_tag == 'qr_async':
if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0'
else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0'
elif self.pipeline_tag in ['qr', 'qr_fp8']:
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_k % {self.bn0} == 0'
else: assert False
if self.mode == "group":
return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true
if self.pipeline_tag == "qr_async":
if self.skpad == "t":
return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0"
else:
return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0"
elif self.pipeline_tag in ["qr", "qr_fp8"]:
if self.skpad == "t":
return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
return f"a.seqlen_k % {self.bn0} == 0"
else:
assert False
@property
def dcheck(self) -> str:
if self.pipeline_tag == 'qr_async':
if self.pipeline_tag == "qr_async":
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr']:
if self.dpad == "t":
return f"a.hdim_q % {vec} == 0"
else:
assert False
elif self.pipeline_tag in ["qr"]:
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_q % {bk0submax} == 0'
else: assert False
if self.dpad == "t":
return f"true /*a.hdim_q % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
return f"a.hdim_q % {bk0submax} == 0"
else:
assert False
@property
def dvcheck(self) -> str:
if self.pipeline_tag == 'qr_async':
if self.pipeline_tag == "qr_async":
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr']:
if self.dvpad == "t":
return f"a.hdim_v % {vec} == 0"
else:
assert False
elif self.pipeline_tag in ["qr"]:
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_v % {bk0submax} == 0'
else: assert False
if self.dvpad == "t":
return f"true /*a.hdim_v % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
return f"a.hdim_v % {bk0submax} == 0"
else:
assert False
@dataclass
class FmhaFwdPipeline:
tag : str
tag: str
F_vlayout : str # row/col
F_spad : str # true/false
F_skpad : str #
F_dpad : str #
F_dvpad : str #
F_logits : str # t/f
F_bias : str # true/false
F_lse : str #
F_dropout : str #
F_squant : str #
F_mask : str # value from MASK_MAP
F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint())
F_vlayout: str # row/col
F_spad: str # true/false
F_skpad: str #
F_dpad: str #
F_dvpad: str #
F_logits: str # t/f
F_bias: str # true/false
F_lse: str #
F_dropout: str #
F_squant: str #
F_mask: str # value from MASK_MAP
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
@property
def name(self) -> str:
def pad_name() -> str:
n = ''
if self.F_spad == 't': n += 's'
if self.F_skpad == 't' : n += 'sk'
if self.F_dpad == 't' : n += 'd'
if self.F_dvpad == 't' : n += 'dv'
if n != '' : n = 'p' + n
n = ""
if self.F_spad == "t":
n += "s"
if self.F_skpad == "t":
n += "sk"
if self.F_dpad == "t":
n += "d"
if self.F_dvpad == "t":
n += "dv"
if n != "":
n = "p" + n
return n
pn = pad_name()
n = f'{self.tag}_v{self.F_vlayout[0]}'
if pn != '' : n += f'_{pn}'
else: n += '_npad'
if self.F_logits == 't' : n += '_logits'
else: n += '_nlogits'
if self.F_bias != 'no' : n += f'_{self.F_bias}'
else: n += '_nbias'
if self.F_mask[0:2] == 's_':
if self.F_mask == 's_mask': n += f'_mask'
else: n += '_nmask'
n = f"{self.tag}_v{self.F_vlayout[0]}"
if pn != "":
n += f"_{pn}"
else:
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
else: n += '_nmask'
n += "_npad"
if self.F_lse == 't' : n += '_lse'
else: n += '_nlse'
if self.F_logits == "t":
n += "_logits"
else:
n += "_nlogits"
if self.F_dropout == 't' : n += '_dropout'
else: n += '_ndropout'
if self.F_bias != "no":
n += f"_{self.F_bias}"
else:
n += "_nbias"
if self.F_squant == 't' : n += '_squant'
else: n += '_nsquant'
if self.F_mask[0:2] == "s_":
if self.F_mask == "s_mask":
n += "_mask"
else:
n += "_nmask"
else:
if self.F_mask != "no":
n += f"_m{self.F_mask[0]}"
else:
n += "_nmask"
if self.F_lse == "t":
n += "_lse"
else:
n += "_nlse"
if self.F_dropout == "t":
n += "_dropout"
else:
n += "_ndropout"
if self.F_squant == "t":
n += "_squant"
else:
n += "_nsquant"
return n
class FmhaFwdApiPool:
def __init__(self, mask_impl):
self.pool = dict()
self.mask_impl = mask_impl
def register_traits(self, trait : FmhaFwdApiTrait) -> None:
def register_traits(self, trait: FmhaFwdApiTrait) -> None:
# TODO: do we need to check duplication?
if trait.dtype not in self.pool.keys():
self.pool[trait.dtype] = dict()
@@ -344,118 +392,152 @@ class FmhaFwdApiPool:
@property
def api(self) -> str:
per_dtypes=str()
per_dtypes = str()
for i, dtype in enumerate(self.pool.keys()):
per_hdim_case=str()
per_hdim_case = str()
for j, hdim in enumerate(self.pool[dtype].keys()):
traits=self.pool[dtype][hdim]
inners=str()
traits = self.pool[dtype][hdim]
inners = str()
for k, trait in enumerate(traits):
if_k = 'if' if k == 0 else 'else if'
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_squant=BOOL_MAP[trait.squant],
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_constraint=trait.constraint,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
if_k = "if" if k == 0 else "else if"
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(
F_if=if_k,
F_mode=MODE_MAP[trait.mode],
F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
F_logits=BOOL_MAP[trait.logits],
F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
F_bias_check=BIAS_CHECK_MAP[trait.bias],
F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse],
F_dropout=BOOL_MAP[trait.dropout],
F_squant=BOOL_MAP[trait.squant],
F_scheck=trait.scheck,
F_skcheck=trait.skcheck,
F_dcheck=trait.dcheck,
F_dvcheck=trait.dvcheck,
F_constraint=trait.constraint,
F_spad=BOOL_MAP[trait.spad],
F_skpad=BOOL_MAP[trait.skpad],
F_dpad=BOOL_MAP[trait.dpad],
F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0,
F_bn0=trait.bn0,
F_bk0=trait.bk0,
F_bn1=trait.bn1,
F_bk1=trait.bk1,
F_bk0max=trait.bk0max,
F_hdim=hdim,
F_dtype=FWD_DTYPE_MAP[dtype],
)
if_j = "if" if j == 0 else "else if"
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners
)
if_i = "if" if i == 0 else "else if"
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
)
if not per_dtypes:
# empty string we add some ignore to suppress warning in api
per_dtypes += ' (void)t ; (void)s ; (void)a;'
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes)
per_dtypes += " (void)t ; (void)s ; (void)a;"
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_dtypes)
@dataclass
class FmhaFwdTileSize:
F_bm0 : int # tile size along q seqlen (block size)
F_bn0 : int # tile size along k seqlen
F_bk0 : int # tile size along qk gemm unroll
F_bn1 : int # tile size along v head_dim
F_bk1 : int # tile size along kv gemm unroll
F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_rm0 : int # number of warps for gemm0 along q seqlen
F_rn0 : int # number of warps for gemm0 along k seqlen
F_rk0 : int # number of warps for gemm0 along head dim q (not used)
F_rm1 : int # number of warps for gemm1 along q seqlen
F_rn1 : int # number of warps for gemm1 along head dim v
F_rk1 : int # number of warps for gemm1 along k seqlen (not used)
F_wm0 : int # gemm0 warp size along m
F_wn0 : int # gemm0 warp size along n
F_wk0 : int # gemm0 warp size along k
F_wm1 : int # gemm1 warp size along m
F_wn1 : int # gemm1 warp size along n
F_wk1 : int # gemm1 warp size along k
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint())
F_bm0: int # tile size along q seqlen (block size)
F_bn0: int # tile size along k seqlen
F_bk0: int # tile size along qk gemm unroll
F_bn1: int # tile size along v head_dim
F_bk1: int # tile size along kv gemm unroll
F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_rm0: int # number of warps for gemm0 along q seqlen
F_rn0: int # number of warps for gemm0 along k seqlen
F_rk0: int # number of warps for gemm0 along head dim q (not used)
F_rm1: int # number of warps for gemm1 along q seqlen
F_rn1: int # number of warps for gemm1 along head dim v
F_rk1: int # number of warps for gemm1 along k seqlen (not used)
F_wm0: int # gemm0 warp size along m
F_wn0: int # gemm0 warp size along n
F_wk0: int # gemm0 warp size along k
F_wm1: int # gemm1 warp size along m
F_wn1: int # gemm1 warp size along n
F_wk1: int # gemm1 warp size along k
F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
@property
def name(self) -> str:
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\
f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\
f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\
("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
return (
f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}"
+ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}"
+ f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}"
+ ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
)
@dataclass
class FmhaFwdKernel:
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
F_dtype : str # data type
F_mode : str # value from MODE_MAP
F_tile : FmhaFwdTileSize
F_pipeline : FmhaFwdPipeline
mask_impl : str
F_idx: int # this is not a tunable, but a counter to differentiate symbol
F_hdim: int # hdim
F_dtype: str # data type
F_mode: str # value from MODE_MAP
F_tile: FmhaFwdTileSize
F_pipeline: FmhaFwdPipeline
mask_impl: str
@property
def template(self) -> str:
kernel_body = str()
return FMHA_FWD_KERNEL_HEADER + \
FMHA_FWD_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = FWD_DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0,
F_bn0 = self.F_tile.F_bn0,
F_bk0 = self.F_tile.F_bk0,
F_bn1 = self.F_tile.F_bn1,
F_bk1 = self.F_tile.F_bk1,
F_bk0max = self.F_tile.F_bk0max,
F_rm0 = self.F_tile.F_rm0,
F_rn0 = self.F_tile.F_rn0,
F_rk0 = self.F_tile.F_rk0,
F_rm1 = self.F_tile.F_rm1,
F_rn1 = self.F_tile.F_rn1,
F_rk1 = self.F_tile.F_rk1,
F_wm0 = self.F_tile.F_wm0,
F_wn0 = self.F_tile.F_wn0,
F_wk0 = self.F_tile.F_wk0,
F_wm1 = self.F_tile.F_wm1,
F_wn1 = self.F_tile.F_wn1,
F_wk1 = self.F_tile.F_wk1,
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_logits = BOOL_MAP[self.F_pipeline.F_logits],
F_bias = BIAS_MAP[self.F_pipeline.F_bias],
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_dropout = BOOL_MAP[self.F_pipeline.F_dropout],
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
F_occupancy = self.F_tile.F_occupancy,
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode = MODE_MAP[self.F_mode],
F_pipeline = FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag])
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format(
F_idx=self.F_idx,
F_hdim=self.F_hdim,
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
F_bm0=self.F_tile.F_bm0,
F_bn0=self.F_tile.F_bn0,
F_bk0=self.F_tile.F_bk0,
F_bn1=self.F_tile.F_bn1,
F_bk1=self.F_tile.F_bk1,
F_bk0max=self.F_tile.F_bk0max,
F_rm0=self.F_tile.F_rm0,
F_rn0=self.F_tile.F_rn0,
F_rk0=self.F_tile.F_rk0,
F_rm1=self.F_tile.F_rm1,
F_rn1=self.F_tile.F_rn1,
F_rk1=self.F_tile.F_rk1,
F_wm0=self.F_tile.F_wm0,
F_wn0=self.F_tile.F_wn0,
F_wk0=self.F_tile.F_wk0,
F_wm1=self.F_tile.F_wm1,
F_wn1=self.F_tile.F_wn1,
F_wk1=self.F_tile.F_wk1,
F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad=BOOL_MAP[self.F_pipeline.F_spad],
F_skpad=BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad=BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad],
F_logits=BOOL_MAP[self.F_pipeline.F_logits],
F_bias=BIAS_MAP[self.F_pipeline.F_bias],
F_lse=BOOL_MAP[self.F_pipeline.F_lse],
F_dropout=BOOL_MAP[self.F_pipeline.F_dropout],
F_squant=BOOL_MAP[self.F_pipeline.F_squant],
F_occupancy=self.F_tile.F_occupancy,
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode=MODE_MAP[self.F_mode],
F_pipeline=FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag],
)
@property
def name(self) -> str:
# TODO: we don't encode idx here
return f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \
self.F_tile.name + '_' + self.F_pipeline.name
return (
f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_"
+ self.F_tile.name
+ "_"
+ self.F_pipeline.name
)
@property
def filename(self) -> str:
@@ -463,35 +545,59 @@ class FmhaFwdKernel:
def api_trait(self) -> FmhaFwdApiTrait:
return FmhaFwdApiTrait(
pipeline_tag=self.F_pipeline.tag,
hdim=str(self.F_hdim),
dtype=self.F_dtype,
mode=self.F_mode,
bm0=self.F_tile.F_bm0,
bn0=self.F_tile.F_bn0,
bk0=self.F_tile.F_bk0,
bn1=self.F_tile.F_bn1,
bk1=self.F_tile.F_bk1,
bk0max=self.F_tile.F_bk0max,
vlayout=self.F_pipeline.F_vlayout,
mask=self.F_pipeline.F_mask,
logits=self.F_pipeline.F_logits,
bias=self.F_pipeline.F_bias,
lse=self.F_pipeline.F_lse,
dropout=self.F_pipeline.F_dropout,
squant=self.F_pipeline.F_squant,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint)
pipeline_tag=self.F_pipeline.tag,
hdim=str(self.F_hdim),
dtype=self.F_dtype,
mode=self.F_mode,
bm0=self.F_tile.F_bm0,
bn0=self.F_tile.F_bn0,
bk0=self.F_tile.F_bk0,
bn1=self.F_tile.F_bn1,
bk1=self.F_tile.F_bk1,
bk0max=self.F_tile.F_bk0max,
vlayout=self.F_pipeline.F_vlayout,
mask=self.F_pipeline.F_mask,
logits=self.F_pipeline.F_logits,
bias=self.F_pipeline.F_bias,
lse=self.F_pipeline.F_lse,
dropout=self.F_pipeline.F_dropout,
squant=self.F_pipeline.F_squant,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
)
class KernelComponentFactory:
@staticmethod
def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
if dtype == "fp16" or dtype == "bf16":
return {
128 : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
128: [
FmhaFwdTileSize(
128,
128,
32,
128,
32,
128,
4,
1,
1,
4,
1,
1,
32,
32,
16,
32,
32,
16,
-1,
)
],
}
else:
return None
@@ -502,28 +608,94 @@ class KernelComponentFactory:
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let 't' padding to appear later!!
# TODO: how to design this more generic?
squant = 't' if dtype == 'fp8' else 'f'
squant = "t" if dtype == "fp8" else "f"
pipelines = []
if dtype in ['fp16', 'bf16']:
for logits, mask, bias, lse, dropout in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
# pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
# pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
if dtype in ["fp16", "bf16"]:
for logits, mask, bias, lse, dropout in itertools.product(
["t", "f"],
get_mask_map(mask_impl).keys(),
BIAS_MAP.keys(),
["t", "f"],
["t", "f"],
):
pipelines.append(
FmhaFwdPipeline(
"qr_async",
"row",
"t",
"f",
"t",
"t",
logits,
bias,
lse,
dropout,
squant,
mask,
)
)
pipelines.append(
FmhaFwdPipeline(
"qr_async",
"row",
"t",
"t",
"t",
"t",
logits,
bias,
lse,
dropout,
squant,
mask,
)
)
# pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
# pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
else:
assert False
return pipelines
class CustomFactory(KernelComponentFactory):
@staticmethod
def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]:
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
if dtype == 'fp16' or dtype == 'bf16':
if dtype == "fp16" or dtype == "bf16":
if 128 in result.keys():
result[128].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint('get_num_blocks(128) < num_cus * min_cu_util_rate')))
result[128].insert(
0,
FmhaFwdTileSize(
64,
128,
64,
128,
64,
128,
4,
1,
1,
4,
1,
1,
16,
16,
16,
16,
16,
16,
-1,
CppConstraint(
"get_num_blocks(128) < num_cus * min_cu_util_rate"
),
),
)
return result
def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
def get_fwd_blobs(
kernel_filter: Optional[str], receipt, optdim_list, mask_impl
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
@@ -532,30 +704,41 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
for dtype in FWD_DTYPE_MAP.keys():
d = CustomFactory.get_hdim_tile_size_dict(dtype)
if d == None:
if d is None:
continue
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
# for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for (hdim, tiles), mode in itertools.product(d.items(), MODE_MAP.keys()):
for tile, pipeline in itertools.product(tiles, CustomFactory.get_pipelines(dtype, hdim, receipt, mask_impl)):
for tile, pipeline in itertools.product(
tiles, CustomFactory.get_pipelines(dtype, hdim, receipt, mask_impl)
):
if mode == "group":
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
if pipeline.F_spad != "t" or pipeline.F_skpad != "t":
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue
if hdim == 192 and tile.F_bn1 == 128:
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't':
if (
pipeline.F_bias != "no"
or pipeline.F_lse == "t"
or pipeline.F_dropout == "t"
):
continue
# logits_soft_cap is only allowed if no bias
if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'):
if not (
(pipeline.F_logits == "t" and pipeline.F_bias == "no")
or pipeline.F_logits == "f"
):
continue
k = FmhaFwdKernel(F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl)
if kernel_filter != '':
k = FmhaFwdKernel(
F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl,
)
if kernel_filter != "":
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if optdim_list != [-1]:
@@ -563,48 +746,48 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
continue
# 2 - Flash attention integration
if receipt in (2, 3):
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_bias in ['no', 'alibi']
cond &= pipeline.F_squant == 'f'
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "alibi"]
cond &= pipeline.F_squant == "f"
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_bias in ['no', 'bias']
cond &= pipeline.F_squant == 'f'
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "bias"]
cond &= pipeline.F_squant == "f"
if not cond:
continue
# Aiter(mha_fwd) integration
elif receipt == 100:
cond = dtype in ['fp16', 'bf16']
cond &= mode == 'batch'
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_squant == 'f'
cond = dtype in ["fp16", "bf16"]
cond &= mode == "batch"
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_squant == "f"
if not cond:
continue
# Aiter(mha_batch_prefill) integration
elif receipt == 200:
cond = dtype in ['fp16', 'bf16']
cond &= mode == 'group'
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_squant == 'f'
cond = dtype in ["fp16", "bf16"]
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_squant == "f"
if not cond:
continue
# aiter::mha_batch_prefill C++ api integration
elif receipt == 600:
cond = dtype in ['fp16', 'bf16']
cond &= mode == 'group'
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_squant == 'f'
cond = dtype in ["fp16", "bf16"]
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_squant == "f"
if not cond:
continue
# fp32 only
if receipt == 800 or receipt == 801:
cond = dtype == 'fp32'
cond = dtype == "fp32"
if not cond:
continue
@@ -613,20 +796,28 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
return (api_pool, gen)
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None:
def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None:
(autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api)
def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None:
def write_blobs(
output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl
) -> None:
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
for kernel in kernels:
write_single_fwd_kernel(kernel, output_dir)
write_fwd_api(api_pool, output_dir)
def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None:
with file_path.open('a') as f:
def list_blobs(
file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl
) -> None:
with file_path.open("a") as f:
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -5,23 +5,27 @@
import copy
from dataclasses import dataclass
import fnmatch
import itertools
from pathlib import Path
from typing import List, Optional, Tuple
from codegen.cmake_config import *
from codegen.cpp_symbol_map import *
from codegen.cmake_config import GEN_DIR
from codegen.cpp_symbol_map import (
FWD_DTYPE_MAP,
BOOL_MAP,
ROPE_MAP,
LAYOUT_MAP,
ROPE_CHECK_MAP,
)
from codegen.ops.fmha_fwd import (
FmhaFwdApiTrait,
DTYPE_BITS,
FMHA_FWD_KERNEL_HEADER,
FMHA_FWD_API_PER_DTYPE,
FMHA_FWD_API_PER_HDIM_CASE,
)
FMHA_FWD_APPENDKV_KERNEL_BODY="""
FMHA_FWD_APPENDKV_KERNEL_BODY = """
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdAppendKVTraits<{F_spad},
@@ -66,8 +70,8 @@ float fmha_fwd_appendkv_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fw
}}
"""
FMHA_FWD_APPENDKV_API_FILENAME="fmha_fwd_appendkv_api.cpp"
FMHA_FWD_APPENDKV_API="""
FMHA_FWD_APPENDKV_API_FILENAME = "fmha_fwd_appendkv_api.cpp"
FMHA_FWD_APPENDKV_API = """
float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, const ck_tile::stream_config& s){{
float r = -1;
{F_dispatch}
@@ -75,7 +79,7 @@ float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, co
}}
"""
FMHA_FWD_APPENDKV_API_INNER_DISPATCH=""" {F_if}((t.is_v_rowmajor == {F_vlayout}) &&
FMHA_FWD_APPENDKV_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.rope_type == {F_rope_check}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv})) {{
using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
@@ -83,81 +87,101 @@ FMHA_FWD_APPENDKV_API_INNER_DISPATCH=""" {F_if}((t.is_v_rowmajor == {
}}
"""
@dataclass
class FmhaFwdAppendKVApiTrait:
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim : str
dtype : str # data type
bs : int # tile size along q seqlen
bsk : int # tile size along k seqlen
bd : int # tile size along qk gemm unroll
bdv : int # tile size along kv gemm unroll
vlayout : str
spad : str
skpad : str
dpad : str
dvpad : str
rope : str # key from ROPE_MAP
pagedkv : str
hdim: str
dtype: str # data type
bs: int # tile size along q seqlen
bsk: int # tile size along k seqlen
bd: int # tile size along qk gemm unroll
bdv: int # tile size along kv gemm unroll
vlayout: str
spad: str
skpad: str
dpad: str
dvpad: str
rope: str # key from ROPE_MAP
pagedkv: str
@property
def name(self) -> str:
return f'{self.hdim}-{self.dtype}-{self.bs}-{self.bsk}-{self.bd}-{self.bdv}-{self.vlayout}-'+\
f'{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.rope}-{self.pagedkv}'
return (
f"{self.hdim}-{self.dtype}-{self.bs}-{self.bsk}-{self.bd}-{self.bdv}-{self.vlayout}-"
+ f"{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.rope}-{self.pagedkv}"
)
@property
def scheck(self) -> str:
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bs} != 0*/'
else : return f'a.seqlen_q % {self.bs} == 0'
if self.spad == "t":
return f"true /*a.seqlen_q % {self.bs} != 0*/"
else:
return f"a.seqlen_q % {self.bs} == 0"
@property
def skcheck(self) -> str:
# we do not check all the values in a.seqlen_k_ptr
return 'true'
return "true"
@property
def dcheck(self) -> str:
if self.dpad == 't': return f'true /*a.hdim_q % {self.bd} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_q % {self.bd} == 0'
if self.dpad == "t":
return f"true /*a.hdim_q % {self.bd} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
return f"a.hdim_q % {self.bd} == 0"
@property
def dvcheck(self) -> str:
if self.dvpad == 't': return f'true /*a.hdim_v % {self.bdv} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_v % {self.bdv} == 0'
if self.dvpad == "t":
return f"true /*a.hdim_v % {self.bdv} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
return f"a.hdim_v % {self.bdv} == 0"
@dataclass
class FmhaFwdAppendKVPipeline:
F_vlayout : str # row/col
F_spad : str # true/false
F_skpad : str #
F_dpad : str #
F_dvpad : str #
F_rope : str # key from ROPE_MAP
F_pagedkv : str # t/f
F_vlayout: str # row/col
F_spad: str # true/false
F_skpad: str #
F_dpad: str #
F_dvpad: str #
F_rope: str # key from ROPE_MAP
F_pagedkv: str # t/f
@property
def name(self) -> str:
def pad_name() -> str:
n = ''
if self.F_spad == 't': n += 's'
if self.F_skpad == 't' : n += 'sk'
if self.F_dpad == 't' : n += 'd'
if self.F_dvpad == 't' : n += 'dv'
if n != '' : n = 'p' + n
n = ""
if self.F_spad == "t":
n += "s"
if self.F_skpad == "t":
n += "sk"
if self.F_dpad == "t":
n += "d"
if self.F_dvpad == "t":
n += "dv"
if n != "":
n = "p" + n
return n
pn = pad_name()
n = f'v{self.F_vlayout[0]}'
if pn != '' : n += f'_{pn}'
if self.F_rope != 'no': n += f'_{self.F_rope}'
if self.F_pagedkv == 't': n += '_pagedkv'
n = f"v{self.F_vlayout[0]}"
if pn != "":
n += f"_{pn}"
if self.F_rope != "no":
n += f"_{self.F_rope}"
if self.F_pagedkv == "t":
n += "_pagedkv"
return n
class FmhaFwdAppendKVApiPool:
def __init__(self, mask_impl):
self.pool = dict()
self.mask_impl = mask_impl
def register_traits(self, trait : FmhaFwdApiTrait) -> None:
def register_traits(self, trait: FmhaFwdApiTrait) -> None:
# TODO: do we need to check duplication?
if trait.dtype not in self.pool.keys():
self.pool[trait.dtype] = dict()
@@ -168,74 +192,104 @@ class FmhaFwdAppendKVApiPool:
@property
def api(self) -> str:
per_dtypes=str()
per_dtypes = str()
for i, dtype in enumerate(self.pool.keys()):
per_hdim_case=str()
per_hdim_case = str()
for j, hdim in enumerate(self.pool[dtype].keys()):
traits=self.pool[dtype][hdim]
inners=str()
traits = self.pool[dtype][hdim]
inners = str()
for k, trait in enumerate(traits):
if_k = 'if' if k == 0 else 'else if'
inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(F_if=if_k, F_vlayout=LAYOUT_MAP[trait.vlayout],
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_rope_check=ROPE_CHECK_MAP[trait.rope],
F_pagedkv=BOOL_MAP[trait.pagedkv], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
if_k = "if" if k == 0 else "else if"
inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(
F_if=if_k,
F_vlayout=LAYOUT_MAP[trait.vlayout],
F_scheck=trait.scheck,
F_skcheck=trait.skcheck,
F_dcheck=trait.dcheck,
F_dvcheck=trait.dvcheck,
F_rope_check=ROPE_CHECK_MAP[trait.rope],
F_pagedkv=BOOL_MAP[trait.pagedkv],
F_spad=BOOL_MAP[trait.spad],
F_skpad=BOOL_MAP[trait.skpad],
F_dpad=BOOL_MAP[trait.dpad],
F_dvpad=BOOL_MAP[trait.dvpad],
F_rope=ROPE_MAP[trait.rope],
F_bs=trait.bs,
F_bsk=trait.bsk,
F_bd=trait.bd,
F_bdv=trait.bdv,
F_hdim=hdim,
F_dtype=FWD_DTYPE_MAP[dtype],
)
if_j = "if" if j == 0 else "else if"
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners
)
if_i = "if" if i == 0 else "else if"
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
)
if not per_dtypes:
# empty string we add some ignore to suppress warning in api
per_dtypes += ' (void)t ; (void)s ; (void)a;'
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_API.format(F_dispatch = per_dtypes)
per_dtypes += " (void)t ; (void)s ; (void)a;"
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_API.format(
F_dispatch=per_dtypes
)
@dataclass
class FmhaFwdAppendKVTileSize:
F_bs : int # tile size along q seqlen
F_bsk : int # tile size along k seqlen
F_bd : int # tile size along qk gemm unroll
F_bdv : int # tile size along kv gemm unroll
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
F_bs: int # tile size along q seqlen
F_bsk: int # tile size along k seqlen
F_bd: int # tile size along qk gemm unroll
F_bdv: int # tile size along kv gemm unroll
F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
@property
def name(self) -> str:
return f"b{self.F_bs}x{self.F_bsk}x{self.F_bd}x{self.F_bdv}" +\
("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
return f"b{self.F_bs}x{self.F_bsk}x{self.F_bd}x{self.F_bdv}" + (
"" if self.F_occupancy == -1 else f"_o{self.F_occupancy}"
)
@dataclass
class FmhaFwdAppendKVKernel:
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
F_dtype : str # data type
F_tile : FmhaFwdAppendKVTileSize
F_pipeline : FmhaFwdAppendKVPipeline
mask_impl : str
F_idx: int # this is not a tunable, but a counter to differentiate symbol
F_hdim: int # hdim
F_dtype: str # data type
F_tile: FmhaFwdAppendKVTileSize
F_pipeline: FmhaFwdAppendKVPipeline
mask_impl: str
@property
def template(self) -> str:
kernel_body = str()
return FMHA_FWD_KERNEL_HEADER + \
FMHA_FWD_APPENDKV_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = FWD_DTYPE_MAP[self.F_dtype],
F_bs = self.F_tile.F_bs,
F_bsk = self.F_tile.F_bsk,
F_bd = self.F_tile.F_bd,
F_bdv = self.F_tile.F_bdv,
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_rope = ROPE_MAP[self.F_pipeline.F_rope],
F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv],
F_occupancy = self.F_tile.F_occupancy)
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_KERNEL_BODY.format(
F_idx=self.F_idx,
F_hdim=self.F_hdim,
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
F_bs=self.F_tile.F_bs,
F_bsk=self.F_tile.F_bsk,
F_bd=self.F_tile.F_bd,
F_bdv=self.F_tile.F_bdv,
F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad=BOOL_MAP[self.F_pipeline.F_spad],
F_skpad=BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad=BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad],
F_rope=ROPE_MAP[self.F_pipeline.F_rope],
F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv],
F_occupancy=self.F_tile.F_occupancy,
)
@property
def name(self) -> str:
# TODO: we don't encode idx here
return f"fmha_fwd_appendkv_d{self.F_hdim}_{self.F_dtype}_" + \
self.F_tile.name + '_' + self.F_pipeline.name
return (
f"fmha_fwd_appendkv_d{self.F_hdim}_{self.F_dtype}_"
+ self.F_tile.name
+ "_"
+ self.F_pipeline.name
)
@property
def filename(self) -> str:
@@ -243,40 +297,45 @@ class FmhaFwdAppendKVKernel:
def api_trait(self) -> FmhaFwdAppendKVApiTrait:
return FmhaFwdAppendKVApiTrait(
hdim=str(self.F_hdim),
dtype=self.F_dtype,
bs=self.F_tile.F_bs,
bsk=self.F_tile.F_bsk,
bd=self.F_tile.F_bd,
bdv=self.F_tile.F_bdv,
vlayout=self.F_pipeline.F_vlayout,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
rope=self.F_pipeline.F_rope,
pagedkv=self.F_pipeline.F_pagedkv)
hdim=str(self.F_hdim),
dtype=self.F_dtype,
bs=self.F_tile.F_bs,
bsk=self.F_tile.F_bsk,
bd=self.F_tile.F_bd,
bdv=self.F_tile.F_bdv,
vlayout=self.F_pipeline.F_vlayout,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
rope=self.F_pipeline.F_rope,
pagedkv=self.F_pipeline.F_pagedkv,
)
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
def get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
def get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype: str) -> Optional[dict]:
if dtype == "fp16" or dtype == "bf16":
return {
'32' : FmhaFwdAppendKVTileSize(64, 64, 32, 32, -1),
'64' : FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1),
'128' : FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1),
'256' : FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1),
"32": FmhaFwdAppendKVTileSize(64, 64, 32, 32, -1),
"64": FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1),
"128": FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1),
"256": FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1),
}
elif dtype == 'fp8' or dtype == 'bf8':
elif dtype == "fp8" or dtype == "bf8":
return {
'64' : FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1),
'128' : FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1),
'256' : FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1)
"64": FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1),
"128": FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1),
"256": FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1),
}
else:
return None
def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, optdim_list) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]:
def get_fwd_appendkv_blobs(
kernel_filter: Optional[str], receipt, mask_impl, optdim_list
) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]:
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
def get_pipelines(dtype, hdim) -> List[FmhaFwdAppendKVPipeline]:
@@ -284,25 +343,50 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, op
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let 't' padding to appear later!!
# TODO: how to design this more generic?
squant = 't' if dtype == 'fp8' else 'f'
pipelines = []
if dtype in ['fp16', 'bf16']:
if dtype in ["fp16", "bf16"]:
# NOTICE: it will be very complicated if we consider all the hdim_q padding cases while
# applying rotary embedding, so I just use 't' in inter/half pipelines
for vlayout in ['row', 'col']:
for vlayout in ["row", "col"]:
for pagedkv in ["t", "f"]:
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 'f', 'f', 'no', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'no', pagedkv))
pipelines.append(
FmhaFwdAppendKVPipeline(
vlayout, "f", "t", "f", "f", "no", pagedkv
)
)
pipelines.append(
FmhaFwdAppendKVPipeline(
vlayout, "t", "t", "t", "t", "no", pagedkv
)
)
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 't', 'f', 'inter', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'inter', pagedkv))
pipelines.append(
FmhaFwdAppendKVPipeline(
vlayout, "f", "t", "t", "f", "inter", pagedkv
)
)
pipelines.append(
FmhaFwdAppendKVPipeline(
vlayout, "t", "t", "t", "t", "inter", pagedkv
)
)
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 't', 'f', 'half', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'half', pagedkv))
elif dtype in ['fp8', 'bf8']:
pipelines.append(
FmhaFwdAppendKVPipeline(
vlayout, "f", "t", "t", "f", "half", pagedkv
)
)
pipelines.append(
FmhaFwdAppendKVPipeline(
vlayout, "t", "t", "t", "t", "half", pagedkv
)
)
elif dtype in ["fp8", "bf8"]:
# rope/paged-kv is not supported
pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't', 'no', 'f'))
elif dtype in ['fp8fp16', 'fp8bf16']:
pipelines.append(
FmhaFwdAppendKVPipeline("col", "t", "t", "t", "t", "no", "f")
)
elif dtype in ["fp8fp16", "fp8bf16"]:
# TODO
None
else:
@@ -314,19 +398,21 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, op
for dtype in FWD_DTYPE_MAP.keys():
d = get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype)
if d == None:
if d is None:
continue
for hdim_str in d.keys():
tile = d[hdim_str]
hdim = int(hdim_str)
for pipeline in get_pipelines(dtype, hdim):
k = FmhaFwdAppendKVKernel(F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl)
if kernel_filter != '':
k = FmhaFwdAppendKVKernel(
F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl,
)
if kernel_filter != "":
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if optdim_list != [-1]:
@@ -334,20 +420,20 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, op
continue
# 2 - Flash attention integration
if receipt == 2:
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
# fp32 only
if receipt == 800 or receipt == 801:
cond = dtype == 'fp32'
cond = dtype == "fp32"
if not cond:
continue
@@ -356,21 +442,33 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, op
return (api_pool, gen)
def write_single_kernel(kernel: FmhaFwdAppendKVKernel, autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
def write_fwd_appendkv_api(api_pool : FmhaFwdAppendKVApiPool, autogen_dir: Path) -> None:
def write_fwd_appendkv_api(api_pool: FmhaFwdAppendKVApiPool, autogen_dir: Path) -> None:
(autogen_dir / FMHA_FWD_APPENDKV_API_FILENAME).write_text(api_pool.api)
def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> None:
api_pool, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl, optdim_list)
def write_blobs(
output_dir: Path, kernel_filter: Optional[str], receipt, optdim_list, mask_impl
) -> None:
api_pool, kernels = get_fwd_appendkv_blobs(
kernel_filter, receipt, mask_impl, optdim_list
)
for kernel in kernels:
write_single_kernel(kernel, output_dir)
write_fwd_appendkv_api(api_pool, output_dir)
def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> None:
with file_path.open('a') as f:
_, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl, optdim_list)
def list_blobs(
file_path: Path, kernel_filter: Optional[str], receipt, optdim_list, mask_impl
) -> None:
with file_path.open("a") as f:
_, kernels = get_fwd_appendkv_blobs(
kernel_filter, receipt, mask_impl, optdim_list
)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME) + "\n")

File diff suppressed because it is too large Load Diff

View File

@@ -9,28 +9,26 @@ import itertools
from pathlib import Path
from typing import List, Optional, Tuple
from codegen.cmake_config import *
from codegen.cpp_symbol_map import *
from codegen.cmake_config import GEN_DIR
from codegen.cpp_symbol_map import (
LAYOUT_MAP,
BIAS_CHECK_MAP,
get_mask_check_map,
MODE_MAP,
get_mask_map,
BIAS_MAP,
FWD_DTYPE_MAP,
BOOL_MAP,
PIPELINE_ENUM_MAP,
)
DTYPE_BITS = {
"fp32": 32,
"fp16": 16,
"bf16": 16,
"fp8" : 8,
"bf8" : 8
}
DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8}
K0_MAX_SUBMAX_MAP = {
32 : 32,
64 : 64,
96 : 128,
128: 128,
256: 256
}
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}
FMHA_FWD_PAGEDKV_PIPELINE_MAP = {
"qr_pagedkv" : "ck_tile::BlockFmhaFwdPagedKVPipelineQRKSVS"
"qr_pagedkv": "ck_tile::BlockFmhaFwdPagedKVPipelineQRKSVS"
}
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
@@ -40,7 +38,7 @@ FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
#include "fmha_fwd.hpp"
"""
FMHA_FWD_KERNEL_BODY="""
FMHA_FWD_KERNEL_BODY = """
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
@@ -115,8 +113,8 @@ float fmha_fwd_pagedkv_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd
}}
"""
FMHA_FWD_API_FILENAME="fmha_fwd_pagedkv_api.cpp"
FMHA_FWD_API="""
FMHA_FWD_API_FILENAME = "fmha_fwd_pagedkv_api.cpp"
FMHA_FWD_API = """
float fmha_fwd_pagedkv(fmha_fwd_pagedkv_traits& t, fmha_fwd_pagedkv_args& a, const ck_tile::stream_config& s){{
float r = -1;
{F_dispatch}
@@ -124,164 +122,215 @@ float fmha_fwd_pagedkv(fmha_fwd_pagedkv_traits& t, fmha_fwd_pagedkv_args& a, con
}}
"""
FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
{F_hdim_case}
}}
"""
FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
{F_inner_dispatch}
}}
"""
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
return fmha_fwd_pagedkv_<trait_>(s, a);
}}
"""
@dataclass
class FmhaFwdApiTrait:
pipeline_tag : str
pipeline_tag: str
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim : str
dtype : str # data type
mode : str # value from MODE_MAP
bm0 : int # tile size along q seqlen (block size)
bn0 : int # tile size along qk seqlen
bk0 : int # tile size along qk gemm unroll
bn1 : int # tile size along v head_dim
bk1 : int # tile size along kv gemm unroll
bk0max : int
vlayout : str
logits : str
mask : str
bias : str #
lse : str #
pagedkv : str
squant : str #
spad : str
skpad : str
dpad : str
dvpad : str
skip : str
hdim: str
dtype: str # data type
mode: str # value from MODE_MAP
bm0: int # tile size along q seqlen (block size)
bn0: int # tile size along qk seqlen
bk0: int # tile size along qk gemm unroll
bn1: int # tile size along v head_dim
bk1: int # tile size along kv gemm unroll
bk0max: int
vlayout: str
logits: str
mask: str
bias: str #
lse: str #
pagedkv: str
squant: str #
spad: str
skpad: str
dpad: str
dvpad: str
skip: str
@property
def name(self) -> str:
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}'
return (
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}"
)
@property
def scheck(self) -> str:
if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true
if self.pipeline_tag == 'qr_async':
if self.spad == 't' : return 'true' # always support
else : return 'true'
elif self.pipeline_tag in ['qr_pagedkv', 'qs']:
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_q % {self.bm0} == 0'
else: assert False
if self.mode == "group":
return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true
if self.pipeline_tag == "qr_async":
if self.spad == "t":
return "true" # always support
else:
return "true"
elif self.pipeline_tag in ["qr_pagedkv", "qs"]:
if self.spad == "t":
return f"true /*a.seqlen_q % {self.bm0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
return f"a.seqlen_q % {self.bm0} == 0"
else:
assert False
@property
def skcheck(self) -> str:
if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true
if self.pipeline_tag == 'qr_async':
if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0'
else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0'
elif self.pipeline_tag in ['qr_pagedkv', 'qs']:
if self.skpad == 't' : return f'true /*a.seqlen_k_ptr != nullptr || a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_k_ptr == nullptr && a.seqlen_k % {self.bn0} == 0'
else: assert False
if self.mode == "group":
return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true
if self.pipeline_tag == "qr_async":
if self.skpad == "t":
return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0"
else:
return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0"
elif self.pipeline_tag in ["qr_pagedkv", "qs"]:
if self.skpad == "t":
return f"true /*a.seqlen_k_ptr != nullptr || a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
return f"a.seqlen_k_ptr == nullptr && a.seqlen_k % {self.bn0} == 0"
else:
assert False
@property
def dcheck(self) -> str:
if self.pipeline_tag == 'qr_async':
if self.pipeline_tag == "qr_async":
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr_pagedkv', 'qs']:
if self.dpad == "t":
return f"a.hdim_q % {vec} == 0"
else:
assert False
elif self.pipeline_tag in ["qr_pagedkv", "qs"]:
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_q % {bk0submax} == 0'
else: assert False
if self.dpad == "t":
return f"true /*a.hdim_q % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
return f"a.hdim_q % {bk0submax} == 0"
else:
assert False
@property
def dvcheck(self) -> str:
if self.pipeline_tag == 'qr_async':
if self.pipeline_tag == "qr_async":
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr_pagedkv', 'qs']:
if self.dvpad == "t":
return f"a.hdim_v % {vec} == 0"
else:
assert False
elif self.pipeline_tag in ["qr_pagedkv", "qs"]:
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_v % {bk0submax} == 0'
else: assert False
if self.dvpad == "t":
return f"true /*a.hdim_v % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
return f"a.hdim_v % {bk0submax} == 0"
else:
assert False
@dataclass
class FmhaFwdPipeline:
tag : str
tag: str
F_vlayout : str # row/col
F_spad : str # true/false
F_skpad : str #
F_dpad : str #
F_dvpad : str #
F_logits : str # t/f
F_bias : str # true/false
F_lse : str #
F_pagedkv : str #
F_squant : str #
F_mask : str # value from MASK_MAP
F_skip : str # true/false
F_vlayout: str # row/col
F_spad: str # true/false
F_skpad: str #
F_dpad: str #
F_dvpad: str #
F_logits: str # t/f
F_bias: str # true/false
F_lse: str #
F_pagedkv: str #
F_squant: str #
F_mask: str # value from MASK_MAP
F_skip: str # true/false
@property
def name(self) -> str:
def pad_name() -> str:
n = ''
if self.F_spad == 't': n += 's'
if self.F_skpad == 't' : n += 'sk'
if self.F_dpad == 't' : n += 'd'
if self.F_dvpad == 't' : n += 'dv'
if n != '' : n = 'p' + n
n = ""
if self.F_spad == "t":
n += "s"
if self.F_skpad == "t":
n += "sk"
if self.F_dpad == "t":
n += "d"
if self.F_dvpad == "t":
n += "dv"
if n != "":
n = "p" + n
return n
pn = pad_name()
n = f'{self.tag}_v{self.F_vlayout[0]}'
if pn != '' : n += f'_{pn}'
else: n += '_npad'
if self.F_logits == 't' : n += '_logits'
else: n += '_nlogits'
if self.F_bias != 'no' : n += f'_{self.F_bias}'
else: n += '_nbias'
if self.F_mask[0:2] == 's_':
if self.F_mask == 's_mask': n += f'_mask'
else: n += '_nmask'
n = f"{self.tag}_v{self.F_vlayout[0]}"
if pn != "":
n += f"_{pn}"
else:
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
else: n += '_nmask'
n += "_npad"
if self.F_lse == 't' : n += '_lse'
else: n += '_nlse'
if self.F_logits == "t":
n += "_logits"
else:
n += "_nlogits"
if self.F_skip == 't' : n += '_skip'
else: n += '_nskip'
if self.F_bias != "no":
n += f"_{self.F_bias}"
else:
n += "_nbias"
if self.F_squant == 't' : n += '_squant'
else: n += '_nsquant'
if self.F_mask[0:2] == "s_":
if self.F_mask == "s_mask":
n += "_mask"
else:
n += "_nmask"
else:
if self.F_mask != "no":
n += f"_m{self.F_mask[0]}"
else:
n += "_nmask"
if self.F_pagedkv == 't' : n += '_pagedkv'
else: n += '_npagedkv'
if self.F_lse == "t":
n += "_lse"
else:
n += "_nlse"
if self.F_skip == "t":
n += "_skip"
else:
n += "_nskip"
if self.F_squant == "t":
n += "_squant"
else:
n += "_nsquant"
if self.F_pagedkv == "t":
n += "_pagedkv"
else:
n += "_npagedkv"
return n
class FmhaFwdApiPool:
def __init__(self, mask_impl):
self.pool = dict()
self.mask_impl = mask_impl
def register_traits(self, trait : FmhaFwdApiTrait) -> None:
def register_traits(self, trait: FmhaFwdApiTrait) -> None:
# TODO: do we need to check duplication?
if trait.dtype not in self.pool.keys():
self.pool[trait.dtype] = dict()
@@ -292,117 +341,152 @@ class FmhaFwdApiPool:
@property
def api(self) -> str:
per_dtypes=str()
per_dtypes = str()
for i, dtype in enumerate(self.pool.keys()):
per_hdim_case=str()
per_hdim_case = str()
for j, hdim in enumerate(self.pool[dtype].keys()):
traits=self.pool[dtype][hdim]
inners=str()
traits = self.pool[dtype][hdim]
inners = str()
for k, trait in enumerate(traits):
if_k = 'if' if k == 0 else 'else if'
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse], F_pagedkv=BOOL_MAP[trait.pagedkv], F_skip=BOOL_MAP[trait.skip],
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
if_k = "if" if k == 0 else "else if"
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(
F_if=if_k,
F_mode=MODE_MAP[trait.mode],
F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
F_logits=BOOL_MAP[trait.logits],
F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
F_bias_check=BIAS_CHECK_MAP[trait.bias],
F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse],
F_pagedkv=BOOL_MAP[trait.pagedkv],
F_skip=BOOL_MAP[trait.skip],
F_squant=BOOL_MAP[trait.squant],
F_scheck=trait.scheck,
F_skcheck=trait.skcheck,
F_dcheck=trait.dcheck,
F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad],
F_skpad=BOOL_MAP[trait.skpad],
F_dpad=BOOL_MAP[trait.dpad],
F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0,
F_bn0=trait.bn0,
F_bk0=trait.bk0,
F_bn1=trait.bn1,
F_bk1=trait.bk1,
F_bk0max=trait.bk0max,
F_hdim=hdim,
F_dtype=FWD_DTYPE_MAP[dtype],
)
if_j = "if" if j == 0 else "else if"
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners
)
if_i = "if" if i == 0 else "else if"
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
)
if not per_dtypes:
# empty string we add some ignore to suppress warning in api
per_dtypes += ' (void)t ; (void)s ; (void)a;'
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes)
per_dtypes += " (void)t ; (void)s ; (void)a;"
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_dtypes)
@dataclass
class FmhaFwdTileSize:
F_bm0 : int # tile size along q seqlen (block size)
F_bn0 : int # tile size along k seqlen
F_bk0 : int # tile size along qk gemm unroll
F_bn1 : int # tile size along v head_dim
F_bk1 : int # tile size along kv gemm unroll
F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_rm0 : int # number of warps for gemm0 along q seqlen
F_rn0 : int # number of warps for gemm0 along k seqlen
F_rk0 : int # number of warps for gemm0 along head dim q (not used)
F_rm1 : int # number of warps for gemm1 along q seqlen
F_rn1 : int # number of warps for gemm1 along head dim v
F_rk1 : int # number of warps for gemm1 along k seqlen (not used)
F_wm0 : int # gemm0 warp size along m
F_wn0 : int # gemm0 warp size along n
F_wk0 : int # gemm0 warp size along k
F_wm1 : int # gemm1 warp size along m
F_wn1 : int # gemm1 warp size along n
F_wk1 : int # gemm1 warp size along k
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
F_bm0: int # tile size along q seqlen (block size)
F_bn0: int # tile size along k seqlen
F_bk0: int # tile size along qk gemm unroll
F_bn1: int # tile size along v head_dim
F_bk1: int # tile size along kv gemm unroll
F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_rm0: int # number of warps for gemm0 along q seqlen
F_rn0: int # number of warps for gemm0 along k seqlen
F_rk0: int # number of warps for gemm0 along head dim q (not used)
F_rm1: int # number of warps for gemm1 along q seqlen
F_rn1: int # number of warps for gemm1 along head dim v
F_rk1: int # number of warps for gemm1 along k seqlen (not used)
F_wm0: int # gemm0 warp size along m
F_wn0: int # gemm0 warp size along n
F_wk0: int # gemm0 warp size along k
F_wm1: int # gemm1 warp size along m
F_wn1: int # gemm1 warp size along n
F_wk1: int # gemm1 warp size along k
F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
@property
def name(self) -> str:
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\
f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\
f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\
("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
return (
f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}"
+ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}"
+ f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}"
+ ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
)
@dataclass
class FmhaFwdKernel:
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
F_dtype : str # data type
F_mode : str # value from MODE_MAP
F_tile : FmhaFwdTileSize
F_pipeline : FmhaFwdPipeline
mask_impl : str
F_idx: int # this is not a tunable, but a counter to differentiate symbol
F_hdim: int # hdim
F_dtype: str # data type
F_mode: str # value from MODE_MAP
F_tile: FmhaFwdTileSize
F_pipeline: FmhaFwdPipeline
mask_impl: str
@property
def template(self) -> str:
kernel_body = str()
return FMHA_FWD_KERNEL_HEADER + \
FMHA_FWD_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = FWD_DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0,
F_bn0 = self.F_tile.F_bn0,
F_bk0 = self.F_tile.F_bk0,
F_bn1 = self.F_tile.F_bn1,
F_bk1 = self.F_tile.F_bk1,
F_bk0max = self.F_tile.F_bk0max,
F_rm0 = self.F_tile.F_rm0,
F_rn0 = self.F_tile.F_rn0,
F_rk0 = self.F_tile.F_rk0,
F_rm1 = self.F_tile.F_rm1,
F_rn1 = self.F_tile.F_rn1,
F_rk1 = self.F_tile.F_rk1,
F_wm0 = self.F_tile.F_wm0,
F_wn0 = self.F_tile.F_wn0,
F_wk0 = self.F_tile.F_wk0,
F_wm1 = self.F_tile.F_wm1,
F_wn1 = self.F_tile.F_wn1,
F_wk1 = self.F_tile.F_wk1,
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_logits = BOOL_MAP[self.F_pipeline.F_logits],
F_bias = BIAS_MAP[self.F_pipeline.F_bias],
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv],
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
F_skip = BOOL_MAP[self.F_pipeline.F_skip],
F_occupancy = self.F_tile.F_occupancy,
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode = MODE_MAP[self.F_mode],
F_pipeline = FMHA_FWD_PAGEDKV_PIPELINE_MAP[self.F_pipeline.tag])
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format(
F_idx=self.F_idx,
F_hdim=self.F_hdim,
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
F_bm0=self.F_tile.F_bm0,
F_bn0=self.F_tile.F_bn0,
F_bk0=self.F_tile.F_bk0,
F_bn1=self.F_tile.F_bn1,
F_bk1=self.F_tile.F_bk1,
F_bk0max=self.F_tile.F_bk0max,
F_rm0=self.F_tile.F_rm0,
F_rn0=self.F_tile.F_rn0,
F_rk0=self.F_tile.F_rk0,
F_rm1=self.F_tile.F_rm1,
F_rn1=self.F_tile.F_rn1,
F_rk1=self.F_tile.F_rk1,
F_wm0=self.F_tile.F_wm0,
F_wn0=self.F_tile.F_wn0,
F_wk0=self.F_tile.F_wk0,
F_wm1=self.F_tile.F_wm1,
F_wn1=self.F_tile.F_wn1,
F_wk1=self.F_tile.F_wk1,
F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad=BOOL_MAP[self.F_pipeline.F_spad],
F_skpad=BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad=BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad],
F_logits=BOOL_MAP[self.F_pipeline.F_logits],
F_bias=BIAS_MAP[self.F_pipeline.F_bias],
F_lse=BOOL_MAP[self.F_pipeline.F_lse],
F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv],
F_squant=BOOL_MAP[self.F_pipeline.F_squant],
F_skip=BOOL_MAP[self.F_pipeline.F_skip],
F_occupancy=self.F_tile.F_occupancy,
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode=MODE_MAP[self.F_mode],
F_pipeline=FMHA_FWD_PAGEDKV_PIPELINE_MAP[self.F_pipeline.tag],
)
@property
def name(self) -> str:
# TODO: we don't encode idx here
return f"fmha_fwd_pagedkv_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \
self.F_tile.name + '_' + self.F_pipeline.name
return (
f"fmha_fwd_pagedkv_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_"
+ self.F_tile.name
+ "_"
+ self.F_pipeline.name
)
@property
def filename(self) -> str:
@@ -410,51 +494,64 @@ class FmhaFwdKernel:
def api_trait(self) -> FmhaFwdApiTrait:
return FmhaFwdApiTrait(
pipeline_tag=self.F_pipeline.tag,
hdim=str(self.F_hdim),
dtype=self.F_dtype,
mode=self.F_mode,
bm0=self.F_tile.F_bm0,
bn0=self.F_tile.F_bn0,
bk0=self.F_tile.F_bk0,
bn1=self.F_tile.F_bn1,
bk1=self.F_tile.F_bk1,
bk0max=self.F_tile.F_bk0max,
vlayout=self.F_pipeline.F_vlayout,
mask=self.F_pipeline.F_mask,
logits=self.F_pipeline.F_logits,
bias=self.F_pipeline.F_bias,
lse=self.F_pipeline.F_lse,
pagedkv=self.F_pipeline.F_pagedkv,
squant=self.F_pipeline.F_squant,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
skip=self.F_pipeline.F_skip)
pipeline_tag=self.F_pipeline.tag,
hdim=str(self.F_hdim),
dtype=self.F_dtype,
mode=self.F_mode,
bm0=self.F_tile.F_bm0,
bn0=self.F_tile.F_bn0,
bk0=self.F_tile.F_bk0,
bn1=self.F_tile.F_bn1,
bk1=self.F_tile.F_bk1,
bk0max=self.F_tile.F_bk0max,
vlayout=self.F_pipeline.F_vlayout,
mask=self.F_pipeline.F_mask,
logits=self.F_pipeline.F_logits,
bias=self.F_pipeline.F_bias,
lse=self.F_pipeline.F_lse,
pagedkv=self.F_pipeline.F_pagedkv,
squant=self.F_pipeline.F_squant,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
skip=self.F_pipeline.F_skip,
)
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
def get_fmha_fwd_tile_dict_from_dtype(dtype: str) -> Optional[dict]:
if dtype == "fp16" or dtype == "bf16":
return {
# '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
"128": FmhaFwdTileSize(
128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1
),
# '192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
}
elif dtype == 'fp8' or dtype == 'bf8':
elif dtype == "fp8" or dtype == "bf8":
return {
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
"64": FmhaFwdTileSize(
128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1
),
"128": FmhaFwdTileSize(
128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1
),
"256": FmhaFwdTileSize(
128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1
),
}
else:
return None
def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
def get_fwd_blobs(
kernel_filter: Optional[str], receipt, optdim_list, mask_impl
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
@@ -462,18 +559,90 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr_pagedkv pipeline, let 't' padding to appear later!!
# TODO: how to design this more generic?
squant = 't' if dtype == 'fp8' else 'f'
squant = "t" if dtype == "fp8" else "f"
pipelines = []
if dtype in ['fp16', 'bf16']:
for logits, mask, bias, pagedkv, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t"], ["f"]):
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 'f', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 't', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip))
elif dtype in ['fp8', 'bf8']:
if dtype in ["fp16", "bf16"]:
for logits, mask, bias, pagedkv, skip in itertools.product(
["t", "f"],
get_mask_map(mask_impl).keys(),
BIAS_MAP.keys(),
["t"],
["f"],
):
pipelines.append(
FmhaFwdPipeline(
"qr_pagedkv",
"row",
"t",
"f",
"f",
"f",
logits,
bias,
"f",
pagedkv,
squant,
mask,
skip,
)
)
pipelines.append(
FmhaFwdPipeline(
"qr_pagedkv",
"row",
"t",
"t",
"f",
"f",
logits,
bias,
"f",
pagedkv,
squant,
mask,
skip,
)
)
elif dtype in ["fp8", "bf8"]:
# no need lse/dropout kernels
for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 't', squant, mask, 'f'))
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 't', squant, mask, 'f'))
elif dtype in ['fp8fp16', 'fp8bf16']:
for logits, mask, bias in itertools.product(
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
):
pipelines.append(
FmhaFwdPipeline(
"qr_pagedkv",
"row",
"f",
"f",
"f",
"f",
logits,
bias,
"f",
"t",
squant,
mask,
"f",
)
)
pipelines.append(
FmhaFwdPipeline(
"qr_pagedkv",
"row",
"t",
"t",
"f",
"f",
logits,
bias,
"f",
"t",
squant,
mask,
"f",
)
)
elif dtype in ["fp8fp16", "fp8bf16"]:
# TODO
None
else:
@@ -485,9 +654,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
for dtype in FWD_DTYPE_MAP.keys():
d = get_fmha_fwd_tile_dict_from_dtype(dtype)
if d == None:
if d is None:
continue
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
# for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
tile = d[hdim_str]
hdim = int(hdim_str)
@@ -495,24 +664,29 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
# if pipeline.F_pagedkv == 'f':
# continue
if mode == "group":
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
if pipeline.F_spad != "t" or pipeline.F_skpad != "t":
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue
if hdim == 192 and tile.F_bn1 == 128:
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
if pipeline.F_bias != 'no' or pipeline.F_lse == 't' :
if pipeline.F_bias != "no" or pipeline.F_lse == "t":
continue
# logits_soft_cap is only allowed if no bias
if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'):
if not (
(pipeline.F_logits == "t" and pipeline.F_bias == "no")
or pipeline.F_logits == "f"
):
continue
k = FmhaFwdKernel(F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl)
if kernel_filter != '':
k = FmhaFwdKernel(
F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl,
)
if kernel_filter != "":
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if optdim_list != [-1]:
@@ -520,49 +694,49 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
continue
# 2 - Flash attention integration
if receipt in (2, 3):
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_bias in ['no', 'alibi']
cond &= pipeline.F_squant == 'f'
cond &= pipeline.F_skip == 'f'
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "alibi"]
cond &= pipeline.F_squant == "f"
cond &= pipeline.F_skip == "f"
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_bias in ['no', 'bias']
cond &= pipeline.F_squant == 'f'
cond &= pipeline.F_skip == 'f'
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "bias"]
cond &= pipeline.F_squant == "f"
cond &= pipeline.F_skip == "f"
if not cond:
continue
# Aiter(mha_fwd) integration
elif receipt == 100:
cond = dtype in ['fp16', 'bf16']
cond &= mode == 'batch'
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_squant == 'f'
cond = dtype in ["fp16", "bf16"]
cond &= mode == "batch"
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_squant == "f"
if not cond:
continue
# Aiter(mha_varlen_fwd) integration
elif receipt == 200:
cond = dtype in ['fp16', 'bf16']
cond &= mode == 'group'
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_squant == 'f'
cond = dtype in ["fp16", "bf16"]
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_squant == "f"
if not cond:
continue
# aiter::mha_fwd C++ api integration
elif receipt == 600:
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_squant == 'f'
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_squant == "f"
if not cond:
continue
# fp32 only
if receipt == 800 or receipt == 801:
cond = dtype == 'fp32'
cond = dtype == "fp32"
if not cond:
continue
@@ -571,20 +745,28 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
return (api_pool, gen)
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None:
def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None:
(autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api)
def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None:
def write_blobs(
output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl
) -> None:
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
for kernel in kernels:
write_single_fwd_kernel(kernel, output_dir)
write_fwd_api(api_pool, output_dir)
def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None:
with file_path.open('a') as f:
def list_blobs(
file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl
) -> None:
with file_path.open("a") as f:
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")

View File

@@ -6,30 +6,45 @@ import argparse
from enum import IntEnum
from pathlib import Path
import pkgutil
import sys
from typing import List, Optional
import codegen.ops
from codegen.cmake_config import *
from codegen.cmake_config import GEN_DIR
class HandlerId(IntEnum):
LIST_BLOBS = 0
WRITE_BLOBS = 1
# inspect all modules under 'codegen.ops' and register API handlers
ops = []
for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__):
full_module_name = '%s.%s' % (codegen.ops.__name__, module_name)
full_module_name = "%s.%s" % (codegen.ops.__name__, module_name)
ops.append(importer.find_spec(module_name).loader.load_module(module_name))
unwanted_prefix = 'fmha_'
unwanted_prefix = "fmha_"
handlers = dict(
[(op.__name__[len(unwanted_prefix):] if op.__name__.startswith(unwanted_prefix) else op.__name__,
(op.list_blobs, op.write_blobs)) for op in ops]
[
(
op.__name__[len(unwanted_prefix) :]
if op.__name__.startswith(unwanted_prefix)
else op.__name__,
(op.list_blobs, op.write_blobs),
)
for op in ops
]
)
assert 0 < len(handlers)
def write_blobs(output_dir: Optional[str], api_list : List[str], filters_list : List[str], optdim_list : List[int], receipt, mask_impl) -> None:
def write_blobs(
output_dir: Optional[str],
api_list: List[str],
filters_list: List[str],
optdim_list: List[int],
receipt,
mask_impl,
) -> None:
if output_dir is None:
output_dir = Path(__file__).parent
else:
@@ -41,8 +56,16 @@ def write_blobs(output_dir: Optional[str], api_list : List[str], filters_list :
handler = handlers[api][HandlerId.WRITE_BLOBS]
handler(output_dir, kernel_filter, receipt, optdim_list, mask_impl)
# list all the files that will be generated
def list_blobs(output_file : Optional[str], api_list : List[str], filters_list : List[str], optdim_list : List[int], receipt, mask_impl) -> None:
def list_blobs(
output_file: Optional[str],
api_list: List[str],
filters_list: List[str],
optdim_list: List[int],
receipt,
mask_impl,
) -> None:
assert output_file is not None
file_path = Path(output_file)
@@ -53,6 +76,7 @@ def list_blobs(output_file : Optional[str], api_list : List[str], filters_list :
handler = handlers[api][HandlerId.LIST_BLOBS]
handler(file_path, kernel_filter, receipt, optdim_list, mask_impl)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="generate",
@@ -60,32 +84,29 @@ if __name__ == "__main__":
)
parser.add_argument(
"-d",
"--direction", # we keep 'direction' option for backward compatibility
"--direction", # we keep 'direction' option for backward compatibility
"-a",
"--api",
default='fwd',
default="fwd",
required=False,
help="supply API(s) to generate (default: fwd). separated by comma."
help="supply API(s) to generate (default: fwd). separated by comma.",
)
parser.add_argument(
"-o",
"--output_dir",
required=False,
help="write all the blobs into a directory"
help="write all the blobs into a directory",
)
parser.add_argument(
"-l",
"--list_blobs",
required=False,
help="list all the kernels to a file"
"-l", "--list_blobs", required=False, help="list all the kernels to a file"
)
# TODO: if using filter, must apply same value to output_dir and list_blobs
parser.add_argument(
"-f",
"--filter",
default='',
default="",
required=False,
help="filter out kernels that need to generate, using fnmatch module"
help="filter out kernels that need to generate, using fnmatch module",
)
parser.add_argument(
@@ -93,7 +114,7 @@ if __name__ == "__main__":
"--mask",
default="simplified",
required=False,
help="mask implementation, simplified/generic"
help="mask implementation, simplified/generic",
)
parser.add_argument(
@@ -101,32 +122,46 @@ if __name__ == "__main__":
"--receipt",
default=0,
required=False,
help="codegen receipt. 0: generate only 8xhdim coverage\n" + \
" 1: generate more instance to cover all hdim\n" + \
" 2: Only generate instance for Flash attention integration\n" + \
" 4: Only generate instance for PyTorch integration\n" + \
" 100-199: Only generate instance for Aiter(mha_fwd) integration\n" + \
" 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n" + \
" 300-399: Only generate instance for Aiter(mha_bwd) integration\n" + \
" 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration\n" + \
" 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration"
help="codegen receipt. 0: generate only 8xhdim coverage\n"
+ " 1: generate more instance to cover all hdim\n"
+ " 2: Only generate instance for Flash attention integration\n"
+ " 4: Only generate instance for PyTorch integration\n"
+ " 100-199: Only generate instance for Aiter(mha_fwd) integration\n"
+ " 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n"
+ " 300-399: Only generate instance for Aiter(mha_bwd) integration\n"
+ " 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration\n"
+ " 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration",
)
parser.add_argument(
"--optdim",
default='-1',
default="-1",
required=False,
help="only optimize the hdim in the list. separated by comma. -1 is the default choice" + \
"eg. --optdim=32,64,128,256"
help="only optimize the hdim in the list. separated by comma. -1 is the default choice"
+ "eg. --optdim=32,64,128,256",
)
args = parser.parse_args()
api_list = args.direction.split(',')
filter_list = args.filter.split(',')
filter_list.extend([''] * (len(api_list) - len(filter_list)))
optdim_list = [int(hdim) for hdim in args.optdim.split(',')]
api_list = args.direction.split(",")
filter_list = args.filter.split(",")
filter_list.extend([""] * (len(api_list) - len(filter_list)))
optdim_list = [int(hdim) for hdim in args.optdim.split(",")]
if args.list_blobs is not None:
list_blobs(args.list_blobs, api_list, filter_list, optdim_list, int(args.receipt), mask_impl=args.mask)
list_blobs(
args.list_blobs,
api_list,
filter_list,
optdim_list,
int(args.receipt),
mask_impl=args.mask,
)
else:
write_blobs(args.output_dir, api_list, filter_list, optdim_list, int(args.receipt), mask_impl=args.mask)
write_blobs(
args.output_dir,
api_list,
filter_list,
optdim_list,
int(args.receipt),
mask_impl=args.mask,
)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -2,7 +2,7 @@
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/host.hpp"
#include "ck_tile/ops/pool.hpp"
#include "ck_tile/ops/pooling.hpp"
#include "ck_tile/host/reference/reference_pool.hpp"
#include <cstring>

View File

@@ -1,21 +1,19 @@
import pathlib
from pathlib import Path
import subprocess
import os
import copy
all_files = []
for p in sorted(Path("./").rglob("*")):
if p.suffix in ['.hpp', '.cpp']:
if p.suffix in [".hpp", ".cpp"]:
all_files.append(pathlib.PurePath(p))
# formatting
for x in all_files:
subprocess.Popen(f'dos2unix {str(x)}', shell=True)
cmd = f'clang-format-18 -style=file -i {str(x)}'
#for xp in x.parents:
#print(get_file_base(x))
subprocess.Popen(f"dos2unix -n {str(x)}", shell=True)
cmd = f"clang-format-18 -style=file -i {str(x)}"
# for xp in x.parents:
# print(get_file_base(x))
subprocess.Popen(cmd, shell=True)
#print(all_files)
# print(all_files)

View File

@@ -18,6 +18,7 @@
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/permute_pk_int4.hpp"
#include "ck_tile/host/ranges.hpp"
#include "ck_tile/host/reference/reference_batched_contraction.hpp"
#include "ck_tile/host/reference/reference_batched_dropout.hpp"
#include "ck_tile/host/reference/reference_batched_dropout_randval.hpp"
#include "ck_tile/host/reference/reference_batched_elementwise.hpp"
@@ -36,6 +37,7 @@
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
#include "ck_tile/host/reference/reference_moe_sorting.hpp"
#include "ck_tile/host/reference/reference_permute.hpp"
#include "ck_tile/host/reference/reference_pool.hpp"
#include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp"
#include "ck_tile/host/reference/reference_rowwise_quantization2d.hpp"

View File

@@ -4,6 +4,8 @@
#pragma once
#include <cstdlib>
#include <functional>
#include <numeric>
#include <thread>
#include "ck_tile/core.hpp"
@@ -155,6 +157,10 @@ void calculate_reference_multi_dimensional(
b_idx.reserve(B_dims.size());
e_idx.reserve(E_dims.size());
auto calculate_total_elements = [](const std::vector<ck_tile::index_t>& dims) {
return std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<ck_tile::index_t>());
};
for(ck_tile::index_t g_flat = 0; g_flat < calculate_total_elements(G_dims); ++g_flat)
{
ck_tile::index_t temp = g_flat;

View File

@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/ops/pooling/kernel/pool_kernel.hpp"
#include <thread>
namespace ck_tile {

View File

@@ -5,5 +5,9 @@
#include "ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp"
#include "ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp"
#include "ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -9,9 +9,9 @@
#include "ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp"
#include "ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp"

View File

@@ -10,6 +10,7 @@
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile {

View File

@@ -1,11 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/pooling/kernel/pool_kernel.hpp"
#include "ck_tile/ops/pooling/pipeline/pool_default_policy.hpp"
#include "ck_tile/ops/pooling/pipeline/pool_problem.hpp"
#include "ck_tile/ops/pooling/pipeline/pool_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -5,39 +5,43 @@ import subprocess
import os
import copy
NS = 'ck_tile'
OPS = 'ops'
REF = 'ref'
OPS_COMMON = 'common' #common header will be duplicated into ops/* other module
NS = "ck_tile"
OPS = "ops"
OPS_COMMON = "common" # common header will be duplicated into ops/* other module
IGNORED_DIRS = ["utility", "ref"]
HEADER_COMMON = f"""// SPDX-License-Identifier: MIT
// Copyright (c) 2018-{datetime.now().year}, Advanced Micro Devices, Inc. All rights reserved.\n
"""
# aa/bb/cc/file.hpp -> (aa, bb, cc, file.hpp)
def get_module(f, level = 0):
def get_module(f, level=0):
all_parts = f.parts
return str(all_parts[level])
all_files = []
for p in sorted(Path("./").rglob("*")):
if p.suffix == '.hpp':
if p.suffix == ".hpp":
all_files.append(pathlib.PurePath(p))
class submodule_t:
def __init__(self):
self.m = dict()
def push(self, f):
if len(f.parents) != 1: # ignore ./xxx.hpp
if len(f.parents) != 1: # ignore ./xxx.hpp
mod = get_module(f)
# ref is supposed to include one header on demand
if mod == REF:
# Should only be included by demand
if mod in IGNORED_DIRS:
return
if mod == OPS:
if mod not in self.m.keys():
self.m[mod] = dict()
mod2 = get_module(f, 1)
if Path(mod2).suffix != '.hpp':
if Path(mod2).suffix != ".hpp":
# ignore ops/xxx.hpp
if mod2 not in self.m[mod].keys():
self.m[mod][mod2] = list()
@@ -52,14 +56,15 @@ class submodule_t:
# print(hpath)
if os.path.exists(str(hpath)):
os.remove(str(hpath))
with hpath.open('w') as f:
with hpath.open("w") as f:
f.write(HEADER_COMMON)
f.write('#pragma once\n')
f.write('\n')
f.write("#pragma once\n")
f.write("\n")
for individual_header in include_list:
header_path = NS + '/' + str(individual_header)
f.write(f'#include \"{header_path}\"\n')
header_path = NS + "/" + str(individual_header)
f.write(f'#include "{header_path}"\n')
# f.write('\n') # otherwise clang-format will complain
# print(self.m)
# restructure common
for k, v in self.m.items():
@@ -73,21 +78,21 @@ class submodule_t:
for k, v in self.m.items():
if k == OPS:
for km, kv in v.items():
gen_header(Path(k) / (f'{km}.hpp'), kv)
gen_header(Path(k) / (f"{km}.hpp"), kv)
else:
gen_header(Path(f'{k}.hpp'), v)
gen_header(Path(f"{k}.hpp"), v)
submodule = submodule_t()
# formatting
for x in all_files:
subprocess.Popen(f'dos2unix {str(x)}', shell=True)
cmd = f'clang-format-18 -style=file -i {str(x)}'
#for xp in x.parents:
#print(get_file_base(x))
subprocess.Popen(f"dos2unix -n {str(x)}", shell=True)
cmd = f"clang-format-18 -style=file -i {str(x)}"
# for xp in x.parents:
# print(get_file_base(x))
subprocess.Popen(cmd, shell=True)
submodule.push(x)
submodule.gen()
#print(all_files)
# print(all_files)

View File

@@ -1,5 +1,5 @@
// Tencent is pleased to support the open source community by making RapidJSON available.
//
//
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
//
// Licensed under the MIT License (the "License"); you may not use this file except
@@ -7,9 +7,9 @@
//
// http://opensource.org/licenses/MIT
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
#ifndef RAPIDJSON_ALLOCATORS_H_
@@ -32,10 +32,10 @@ RAPIDJSON_NAMESPACE_BEGIN
/*! \class rapidjson::Allocator
\brief Concept for allocating, resizing and freeing memory block.
Note that Malloc() and Realloc() are non-static but Free() is static.
So if an allocator need to support Free(), it needs to put its pointer in
So if an allocator need to support Free(), it needs to put its pointer in
the header of memory block.
\code
@@ -49,7 +49,8 @@ concept Allocator {
// Resize a memory block.
// \param originalPtr The pointer to current memory block. Null pointer is permitted.
// \param originalSize The current size in bytes. (Design issue: since some allocator may not book-keep this, explicitly pass to it can save memory.)
// \param originalSize The current size in bytes. (Design issue: since some allocator may not
book-keep this, explicitly pass to it can save memory.)
// \param newSize the new size in bytes.
void* Realloc(void* originalPtr, size_t originalSize, size_t newSize);
@@ -60,7 +61,6 @@ concept Allocator {
\endcode
*/
/*! \def RAPIDJSON_ALLOCATOR_DEFAULT_CHUNK_CAPACITY
\ingroup RAPIDJSON_CONFIG
\brief User-defined kDefaultChunkCapacity definition.
@@ -72,7 +72,6 @@ concept Allocator {
#define RAPIDJSON_ALLOCATOR_DEFAULT_CHUNK_CAPACITY (64 * 1024)
#endif
///////////////////////////////////////////////////////////////////////////////
// CrtAllocator
@@ -80,38 +79,38 @@ concept Allocator {
/*! This class is just wrapper for standard C library memory routines.
\note implements Allocator concept
*/
class CrtAllocator {
public:
class CrtAllocator
{
public:
static const bool kNeedFree = true;
void* Malloc(size_t size) {
if (size) // behavior of malloc(0) is implementation defined.
void* Malloc(size_t size)
{
if(size) // behavior of malloc(0) is implementation defined.
return RAPIDJSON_MALLOC(size);
else
return NULL; // standardize to returning NULL.
}
void* Realloc(void* originalPtr, size_t originalSize, size_t newSize) {
void* Realloc(void* originalPtr, size_t originalSize, size_t newSize)
{
(void)originalSize;
if (newSize == 0) {
if(newSize == 0)
{
RAPIDJSON_FREE(originalPtr);
return NULL;
}
return RAPIDJSON_REALLOC(originalPtr, newSize);
}
static void Free(void *ptr) RAPIDJSON_NOEXCEPT { RAPIDJSON_FREE(ptr); }
static void Free(void* ptr) RAPIDJSON_NOEXCEPT { RAPIDJSON_FREE(ptr); }
bool operator==(const CrtAllocator&) const RAPIDJSON_NOEXCEPT {
return true;
}
bool operator!=(const CrtAllocator&) const RAPIDJSON_NOEXCEPT {
return false;
}
bool operator==(const CrtAllocator&) const RAPIDJSON_NOEXCEPT { return true; }
bool operator!=(const CrtAllocator&) const RAPIDJSON_NOEXCEPT { return false; }
};
///////////////////////////////////////////////////////////////////////////////
// MemoryPoolAllocator
//! Default memory allocator used by the parser and DOM.
/*! This allocator allocate memory blocks from pre-allocated memory chunks.
/*! This allocator allocate memory blocks from pre-allocated memory chunks.
It does not free memory blocks. And Realloc() only allocate new memory.
@@ -127,69 +126,82 @@ public:
\note implements Allocator concept
*/
template <typename BaseAllocator = CrtAllocator>
class MemoryPoolAllocator {
class MemoryPoolAllocator
{
//! Chunk header for perpending to each chunk.
/*! Chunks are stored as a singly linked list.
*/
struct ChunkHeader {
size_t capacity; //!< Capacity of the chunk in bytes (excluding the header itself).
size_t size; //!< Current size of allocated memory in bytes.
ChunkHeader *next; //!< Next chunk in the linked list.
*/
struct ChunkHeader
{
size_t capacity; //!< Capacity of the chunk in bytes (excluding the header itself).
size_t size; //!< Current size of allocated memory in bytes.
ChunkHeader* next; //!< Next chunk in the linked list.
};
struct SharedData {
ChunkHeader *chunkHead; //!< Head of the chunk linked-list. Only the head chunk serves allocation.
struct SharedData
{
ChunkHeader*
chunkHead; //!< Head of the chunk linked-list. Only the head chunk serves allocation.
BaseAllocator* ownBaseAllocator; //!< base allocator created by this object.
size_t refcount;
bool ownBuffer;
};
static const size_t SIZEOF_SHARED_DATA = RAPIDJSON_ALIGN(sizeof(SharedData));
static const size_t SIZEOF_SHARED_DATA = RAPIDJSON_ALIGN(sizeof(SharedData));
static const size_t SIZEOF_CHUNK_HEADER = RAPIDJSON_ALIGN(sizeof(ChunkHeader));
static inline ChunkHeader *GetChunkHead(SharedData *shared)
static inline ChunkHeader* GetChunkHead(SharedData* shared)
{
return reinterpret_cast<ChunkHeader*>(reinterpret_cast<uint8_t*>(shared) + SIZEOF_SHARED_DATA);
return reinterpret_cast<ChunkHeader*>(reinterpret_cast<uint8_t*>(shared) +
SIZEOF_SHARED_DATA);
}
static inline uint8_t *GetChunkBuffer(SharedData *shared)
static inline uint8_t* GetChunkBuffer(SharedData* shared)
{
return reinterpret_cast<uint8_t*>(shared->chunkHead) + SIZEOF_CHUNK_HEADER;
}
static const size_t kDefaultChunkCapacity = RAPIDJSON_ALLOCATOR_DEFAULT_CHUNK_CAPACITY; //!< Default chunk capacity.
static const size_t kDefaultChunkCapacity =
RAPIDJSON_ALLOCATOR_DEFAULT_CHUNK_CAPACITY; //!< Default chunk capacity.
public:
static const bool kNeedFree = false; //!< Tell users that no need to call Free() with this allocator. (concept Allocator)
static const bool kRefCounted = true; //!< Tell users that this allocator is reference counted on copy
public:
static const bool kNeedFree =
false; //!< Tell users that no need to call Free() with this allocator. (concept Allocator)
static const bool kRefCounted =
true; //!< Tell users that this allocator is reference counted on copy
//! Constructor with chunkSize.
/*! \param chunkSize The size of memory chunk. The default is kDefaultChunkSize.
\param baseAllocator The allocator for allocating memory chunks.
*/
explicit
MemoryPoolAllocator(size_t chunkSize = kDefaultChunkCapacity, BaseAllocator* baseAllocator = 0) :
chunk_capacity_(chunkSize),
baseAllocator_(baseAllocator ? baseAllocator : RAPIDJSON_NEW(BaseAllocator)()),
shared_(static_cast<SharedData*>(baseAllocator_ ? baseAllocator_->Malloc(SIZEOF_SHARED_DATA + SIZEOF_CHUNK_HEADER) : 0))
explicit MemoryPoolAllocator(size_t chunkSize = kDefaultChunkCapacity,
BaseAllocator* baseAllocator = 0)
: chunk_capacity_(chunkSize),
baseAllocator_(baseAllocator ? baseAllocator : RAPIDJSON_NEW(BaseAllocator)()),
shared_(static_cast<SharedData*>(
baseAllocator_ ? baseAllocator_->Malloc(SIZEOF_SHARED_DATA + SIZEOF_CHUNK_HEADER)
: 0))
{
RAPIDJSON_ASSERT(baseAllocator_ != 0);
RAPIDJSON_ASSERT(shared_ != 0);
if (baseAllocator) {
if(baseAllocator)
{
shared_->ownBaseAllocator = 0;
}
else {
else
{
shared_->ownBaseAllocator = baseAllocator_;
}
shared_->chunkHead = GetChunkHead(shared_);
shared_->chunkHead = GetChunkHead(shared_);
shared_->chunkHead->capacity = 0;
shared_->chunkHead->size = 0;
shared_->chunkHead->next = 0;
shared_->ownBuffer = true;
shared_->refcount = 1;
shared_->chunkHead->size = 0;
shared_->chunkHead->next = 0;
shared_->ownBuffer = true;
shared_->refcount = 1;
}
//! Constructor with user-supplied buffer.
/*! The user buffer will be used firstly. When it is full, memory pool allocates new chunk with chunk size.
/*! The user buffer will be used firstly. When it is full, memory pool allocates new chunk with
chunk size.
The user buffer will not be deallocated when this allocator is destructed.
@@ -198,25 +210,28 @@ public:
\param chunkSize The size of memory chunk. The default is kDefaultChunkSize.
\param baseAllocator The allocator for allocating memory chunks.
*/
MemoryPoolAllocator(void *buffer, size_t size, size_t chunkSize = kDefaultChunkCapacity, BaseAllocator* baseAllocator = 0) :
chunk_capacity_(chunkSize),
baseAllocator_(baseAllocator),
shared_(static_cast<SharedData*>(AlignBuffer(buffer, size)))
MemoryPoolAllocator(void* buffer,
size_t size,
size_t chunkSize = kDefaultChunkCapacity,
BaseAllocator* baseAllocator = 0)
: chunk_capacity_(chunkSize),
baseAllocator_(baseAllocator),
shared_(static_cast<SharedData*>(AlignBuffer(buffer, size)))
{
RAPIDJSON_ASSERT(size >= SIZEOF_SHARED_DATA + SIZEOF_CHUNK_HEADER);
shared_->chunkHead = GetChunkHead(shared_);
shared_->chunkHead = GetChunkHead(shared_);
shared_->chunkHead->capacity = size - SIZEOF_SHARED_DATA - SIZEOF_CHUNK_HEADER;
shared_->chunkHead->size = 0;
shared_->chunkHead->next = 0;
shared_->ownBaseAllocator = 0;
shared_->ownBuffer = false;
shared_->refcount = 1;
shared_->chunkHead->size = 0;
shared_->chunkHead->next = 0;
shared_->ownBaseAllocator = 0;
shared_->ownBuffer = false;
shared_->refcount = 1;
}
MemoryPoolAllocator(const MemoryPoolAllocator& rhs) RAPIDJSON_NOEXCEPT :
chunk_capacity_(rhs.chunk_capacity_),
baseAllocator_(rhs.baseAllocator_),
shared_(rhs.shared_)
MemoryPoolAllocator(const MemoryPoolAllocator& rhs) RAPIDJSON_NOEXCEPT
: chunk_capacity_(rhs.chunk_capacity_),
baseAllocator_(rhs.baseAllocator_),
shared_(rhs.shared_)
{
RAPIDJSON_NOEXCEPT_ASSERT(shared_->refcount > 0);
++shared_->refcount;
@@ -226,17 +241,17 @@ public:
RAPIDJSON_NOEXCEPT_ASSERT(rhs.shared_->refcount > 0);
++rhs.shared_->refcount;
this->~MemoryPoolAllocator();
baseAllocator_ = rhs.baseAllocator_;
baseAllocator_ = rhs.baseAllocator_;
chunk_capacity_ = rhs.chunk_capacity_;
shared_ = rhs.shared_;
shared_ = rhs.shared_;
return *this;
}
#if RAPIDJSON_HAS_CXX11_RVALUE_REFS
MemoryPoolAllocator(MemoryPoolAllocator&& rhs) RAPIDJSON_NOEXCEPT :
chunk_capacity_(rhs.chunk_capacity_),
baseAllocator_(rhs.baseAllocator_),
shared_(rhs.shared_)
MemoryPoolAllocator(MemoryPoolAllocator&& rhs) RAPIDJSON_NOEXCEPT
: chunk_capacity_(rhs.chunk_capacity_),
baseAllocator_(rhs.baseAllocator_),
shared_(rhs.shared_)
{
RAPIDJSON_NOEXCEPT_ASSERT(rhs.shared_->refcount > 0);
rhs.shared_ = 0;
@@ -245,40 +260,47 @@ public:
{
RAPIDJSON_NOEXCEPT_ASSERT(rhs.shared_->refcount > 0);
this->~MemoryPoolAllocator();
baseAllocator_ = rhs.baseAllocator_;
baseAllocator_ = rhs.baseAllocator_;
chunk_capacity_ = rhs.chunk_capacity_;
shared_ = rhs.shared_;
rhs.shared_ = 0;
shared_ = rhs.shared_;
rhs.shared_ = 0;
return *this;
}
#endif
//! Destructor.
/*! This deallocates all memory chunks, excluding the user-supplied buffer.
*/
~MemoryPoolAllocator() RAPIDJSON_NOEXCEPT {
if (!shared_) {
*/
~MemoryPoolAllocator() RAPIDJSON_NOEXCEPT
{
if(!shared_)
{
// do nothing if moved
return;
}
if (shared_->refcount > 1) {
if(shared_->refcount > 1)
{
--shared_->refcount;
return;
}
Clear();
BaseAllocator *a = shared_->ownBaseAllocator;
if (shared_->ownBuffer) {
BaseAllocator* a = shared_->ownBaseAllocator;
if(shared_->ownBuffer)
{
baseAllocator_->Free(shared_);
}
RAPIDJSON_DELETE(a);
}
//! Deallocates all memory chunks, excluding the first/user one.
void Clear() RAPIDJSON_NOEXCEPT {
void Clear() RAPIDJSON_NOEXCEPT
{
RAPIDJSON_NOEXCEPT_ASSERT(shared_->refcount > 0);
for (;;) {
for(;;)
{
ChunkHeader* c = shared_->chunkHead;
if (!c->next) {
if(!c->next)
{
break;
}
shared_->chunkHead = c->next;
@@ -289,78 +311,86 @@ public:
//! Computes the total capacity of allocated memory chunks.
/*! \return total capacity in bytes.
*/
size_t Capacity() const RAPIDJSON_NOEXCEPT {
*/
size_t Capacity() const RAPIDJSON_NOEXCEPT
{
RAPIDJSON_NOEXCEPT_ASSERT(shared_->refcount > 0);
size_t capacity = 0;
for (ChunkHeader* c = shared_->chunkHead; c != 0; c = c->next)
for(ChunkHeader* c = shared_->chunkHead; c != 0; c = c->next)
capacity += c->capacity;
return capacity;
}
//! Computes the memory blocks allocated.
/*! \return total used bytes.
*/
size_t Size() const RAPIDJSON_NOEXCEPT {
*/
size_t Size() const RAPIDJSON_NOEXCEPT
{
RAPIDJSON_NOEXCEPT_ASSERT(shared_->refcount > 0);
size_t size = 0;
for (ChunkHeader* c = shared_->chunkHead; c != 0; c = c->next)
for(ChunkHeader* c = shared_->chunkHead; c != 0; c = c->next)
size += c->size;
return size;
}
//! Whether the allocator is shared.
/*! \return true or false.
*/
bool Shared() const RAPIDJSON_NOEXCEPT {
*/
bool Shared() const RAPIDJSON_NOEXCEPT
{
RAPIDJSON_NOEXCEPT_ASSERT(shared_->refcount > 0);
return shared_->refcount > 1;
}
//! Allocates a memory block. (concept Allocator)
void* Malloc(size_t size) {
void* Malloc(size_t size)
{
RAPIDJSON_NOEXCEPT_ASSERT(shared_->refcount > 0);
if (!size)
if(!size)
return NULL;
size = RAPIDJSON_ALIGN(size);
if (RAPIDJSON_UNLIKELY(shared_->chunkHead->size + size > shared_->chunkHead->capacity))
if (!AddChunk(chunk_capacity_ > size ? chunk_capacity_ : size))
if(RAPIDJSON_UNLIKELY(shared_->chunkHead->size + size > shared_->chunkHead->capacity))
if(!AddChunk(chunk_capacity_ > size ? chunk_capacity_ : size))
return NULL;
void *buffer = GetChunkBuffer(shared_) + shared_->chunkHead->size;
void* buffer = GetChunkBuffer(shared_) + shared_->chunkHead->size;
shared_->chunkHead->size += size;
return buffer;
}
//! Resizes a memory block (concept Allocator)
void* Realloc(void* originalPtr, size_t originalSize, size_t newSize) {
if (originalPtr == 0)
void* Realloc(void* originalPtr, size_t originalSize, size_t newSize)
{
if(originalPtr == 0)
return Malloc(newSize);
RAPIDJSON_NOEXCEPT_ASSERT(shared_->refcount > 0);
if (newSize == 0)
if(newSize == 0)
return NULL;
originalSize = RAPIDJSON_ALIGN(originalSize);
newSize = RAPIDJSON_ALIGN(newSize);
newSize = RAPIDJSON_ALIGN(newSize);
// Do not shrink if new size is smaller than original
if (originalSize >= newSize)
if(originalSize >= newSize)
return originalPtr;
// Simply expand it if it is the last allocation and there is sufficient space
if (originalPtr == GetChunkBuffer(shared_) + shared_->chunkHead->size - originalSize) {
if(originalPtr == GetChunkBuffer(shared_) + shared_->chunkHead->size - originalSize)
{
size_t increment = static_cast<size_t>(newSize - originalSize);
if (shared_->chunkHead->size + increment <= shared_->chunkHead->capacity) {
if(shared_->chunkHead->size + increment <= shared_->chunkHead->capacity)
{
shared_->chunkHead->size += increment;
return originalPtr;
}
}
// Realloc process: allocate and copy memory, do not free original buffer.
if (void* newBuffer = Malloc(newSize)) {
if (originalSize)
if(void* newBuffer = Malloc(newSize))
{
if(originalSize)
std::memcpy(newBuffer, originalPtr, originalSize);
return newBuffer;
}
@@ -369,31 +399,36 @@ public:
}
//! Frees a memory block (concept Allocator)
static void Free(void *ptr) RAPIDJSON_NOEXCEPT { (void)ptr; } // Do nothing
static void Free(void* ptr) RAPIDJSON_NOEXCEPT { (void)ptr; } // Do nothing
//! Compare (equality) with another MemoryPoolAllocator
bool operator==(const MemoryPoolAllocator& rhs) const RAPIDJSON_NOEXCEPT {
bool operator==(const MemoryPoolAllocator& rhs) const RAPIDJSON_NOEXCEPT
{
RAPIDJSON_NOEXCEPT_ASSERT(shared_->refcount > 0);
RAPIDJSON_NOEXCEPT_ASSERT(rhs.shared_->refcount > 0);
return shared_ == rhs.shared_;
}
//! Compare (inequality) with another MemoryPoolAllocator
bool operator!=(const MemoryPoolAllocator& rhs) const RAPIDJSON_NOEXCEPT {
bool operator!=(const MemoryPoolAllocator& rhs) const RAPIDJSON_NOEXCEPT
{
return !operator==(rhs);
}
private:
private:
//! Creates a new chunk.
/*! \param capacity Capacity of the chunk in bytes.
\return true if success.
*/
bool AddChunk(size_t capacity) {
if (!baseAllocator_)
bool AddChunk(size_t capacity)
{
if(!baseAllocator_)
shared_->ownBaseAllocator = baseAllocator_ = RAPIDJSON_NEW(BaseAllocator)();
if (ChunkHeader* chunk = static_cast<ChunkHeader*>(baseAllocator_->Malloc(SIZEOF_CHUNK_HEADER + capacity))) {
chunk->capacity = capacity;
chunk->size = 0;
chunk->next = shared_->chunkHead;
if(ChunkHeader* chunk =
static_cast<ChunkHeader*>(baseAllocator_->Malloc(SIZEOF_CHUNK_HEADER + capacity)))
{
chunk->capacity = capacity;
chunk->size = 0;
chunk->next = shared_->chunkHead;
shared_->chunkHead = chunk;
return true;
}
@@ -401,12 +436,13 @@ private:
return false;
}
static inline void* AlignBuffer(void* buf, size_t &size)
static inline void* AlignBuffer(void* buf, size_t& size)
{
RAPIDJSON_NOEXCEPT_ASSERT(buf != 0);
const uintptr_t mask = sizeof(void*) - 1;
const uintptr_t ubuf = reinterpret_cast<uintptr_t>(buf);
if (RAPIDJSON_UNLIKELY(ubuf & mask)) {
if(RAPIDJSON_UNLIKELY(ubuf & mask))
{
const uintptr_t abuf = (ubuf + mask) & ~mask;
RAPIDJSON_ASSERT(size >= abuf - ubuf);
buf = reinterpret_cast<void*>(abuf);
@@ -415,37 +451,38 @@ private:
return buf;
}
size_t chunk_capacity_; //!< The minimum capacity of chunk when they are allocated.
BaseAllocator* baseAllocator_; //!< base allocator for allocating memory chunks.
SharedData *shared_; //!< The shared data of the allocator
size_t chunk_capacity_; //!< The minimum capacity of chunk when they are allocated.
BaseAllocator* baseAllocator_; //!< base allocator for allocating memory chunks.
SharedData* shared_; //!< The shared data of the allocator
};
namespace internal {
template<typename, typename = void>
struct IsRefCounted :
public FalseType
{ };
template<typename T>
struct IsRefCounted<T, typename internal::EnableIfCond<T::kRefCounted>::Type> :
public TrueType
{ };
}
template <typename, typename = void>
struct IsRefCounted : public FalseType
{
};
template <typename T>
struct IsRefCounted<T, typename internal::EnableIfCond<T::kRefCounted>::Type> : public TrueType
{
};
} // namespace internal
template<typename T, typename A>
template <typename T, typename A>
inline T* Realloc(A& a, T* old_p, size_t old_n, size_t new_n)
{
RAPIDJSON_NOEXCEPT_ASSERT(old_n <= (std::numeric_limits<size_t>::max)() / sizeof(T) && new_n <= (std::numeric_limits<size_t>::max)() / sizeof(T));
RAPIDJSON_NOEXCEPT_ASSERT(old_n <= (std::numeric_limits<size_t>::max)() / sizeof(T) &&
new_n <= (std::numeric_limits<size_t>::max)() / sizeof(T));
return static_cast<T*>(a.Realloc(old_p, old_n * sizeof(T), new_n * sizeof(T)));
}
template<typename T, typename A>
inline T *Malloc(A& a, size_t n = 1)
template <typename T, typename A>
inline T* Malloc(A& a, size_t n = 1)
{
return Realloc<T, A>(a, NULL, 0, n);
}
template<typename T, typename A>
inline void Free(A& a, T *p, size_t n = 1)
template <typename T, typename A>
inline void Free(A& a, T* p, size_t n = 1)
{
static_cast<void>(Realloc<T, A>(a, p, n, 0));
}
@@ -456,8 +493,7 @@ RAPIDJSON_DIAG_OFF(effc++) // std::allocator can safely be inherited
#endif
template <typename T, typename BaseAllocator = CrtAllocator>
class StdAllocator :
public std::allocator<T>
class StdAllocator : public std::allocator<T>
{
typedef std::allocator<T> allocator_type;
#if RAPIDJSON_HAS_CXX11
@@ -466,113 +502,90 @@ class StdAllocator :
typedef allocator_type traits_type;
#endif
public:
public:
typedef BaseAllocator BaseAllocatorType;
StdAllocator() RAPIDJSON_NOEXCEPT :
allocator_type(),
baseAllocator_()
{ }
StdAllocator() RAPIDJSON_NOEXCEPT : allocator_type(), baseAllocator_() {}
StdAllocator(const StdAllocator& rhs) RAPIDJSON_NOEXCEPT :
allocator_type(rhs),
baseAllocator_(rhs.baseAllocator_)
{ }
StdAllocator(const StdAllocator& rhs) RAPIDJSON_NOEXCEPT : allocator_type(rhs),
baseAllocator_(rhs.baseAllocator_)
{
}
template<typename U>
StdAllocator(const StdAllocator<U, BaseAllocator>& rhs) RAPIDJSON_NOEXCEPT :
allocator_type(rhs),
baseAllocator_(rhs.baseAllocator_)
{ }
template <typename U>
StdAllocator(const StdAllocator<U, BaseAllocator>& rhs) RAPIDJSON_NOEXCEPT
: allocator_type(rhs),
baseAllocator_(rhs.baseAllocator_)
{
}
#if RAPIDJSON_HAS_CXX11_RVALUE_REFS
StdAllocator(StdAllocator&& rhs) RAPIDJSON_NOEXCEPT :
allocator_type(std::move(rhs)),
baseAllocator_(std::move(rhs.baseAllocator_))
{ }
StdAllocator(StdAllocator&& rhs) RAPIDJSON_NOEXCEPT
: allocator_type(std::move(rhs)),
baseAllocator_(std::move(rhs.baseAllocator_))
{
}
#endif
#if RAPIDJSON_HAS_CXX11
using propagate_on_container_move_assignment = std::true_type;
using propagate_on_container_swap = std::true_type;
using propagate_on_container_swap = std::true_type;
#endif
/* implicit */
StdAllocator(const BaseAllocator& baseAllocator) RAPIDJSON_NOEXCEPT :
allocator_type(),
baseAllocator_(baseAllocator)
{ }
StdAllocator(const BaseAllocator& baseAllocator) RAPIDJSON_NOEXCEPT
: allocator_type(),
baseAllocator_(baseAllocator)
{
}
~StdAllocator() RAPIDJSON_NOEXCEPT
{ }
~StdAllocator() RAPIDJSON_NOEXCEPT {}
template<typename U>
struct rebind {
template <typename U>
struct rebind
{
typedef StdAllocator<U, BaseAllocator> other;
};
typedef typename traits_type::size_type size_type;
typedef typename traits_type::difference_type difference_type;
typedef typename traits_type::size_type size_type;
typedef typename traits_type::difference_type difference_type;
typedef typename traits_type::value_type value_type;
typedef typename traits_type::pointer pointer;
typedef typename traits_type::const_pointer const_pointer;
typedef typename traits_type::value_type value_type;
typedef typename traits_type::pointer pointer;
typedef typename traits_type::const_pointer const_pointer;
#if RAPIDJSON_HAS_CXX11
typedef typename std::add_lvalue_reference<value_type>::type &reference;
typedef typename std::add_lvalue_reference<typename std::add_const<value_type>::type>::type &const_reference;
typedef typename std::add_lvalue_reference<value_type>::type& reference;
typedef typename std::add_lvalue_reference<typename std::add_const<value_type>::type>::type&
const_reference;
pointer address(reference r) const RAPIDJSON_NOEXCEPT
{
return std::addressof(r);
}
const_pointer address(const_reference r) const RAPIDJSON_NOEXCEPT
{
return std::addressof(r);
}
pointer address(reference r) const RAPIDJSON_NOEXCEPT { return std::addressof(r); }
const_pointer address(const_reference r) const RAPIDJSON_NOEXCEPT { return std::addressof(r); }
size_type max_size() const RAPIDJSON_NOEXCEPT
{
return traits_type::max_size(*this);
}
size_type max_size() const RAPIDJSON_NOEXCEPT { return traits_type::max_size(*this); }
template <typename ...Args>
template <typename... Args>
void construct(pointer p, Args&&... args)
{
traits_type::construct(*this, p, std::forward<Args>(args)...);
}
void destroy(pointer p)
{
traits_type::destroy(*this, p);
}
void destroy(pointer p) { traits_type::destroy(*this, p); }
#else // !RAPIDJSON_HAS_CXX11
typedef typename allocator_type::reference reference;
typedef typename allocator_type::reference reference;
typedef typename allocator_type::const_reference const_reference;
pointer address(reference r) const RAPIDJSON_NOEXCEPT
{
return allocator_type::address(r);
}
pointer address(reference r) const RAPIDJSON_NOEXCEPT { return allocator_type::address(r); }
const_pointer address(const_reference r) const RAPIDJSON_NOEXCEPT
{
return allocator_type::address(r);
}
size_type max_size() const RAPIDJSON_NOEXCEPT
{
return allocator_type::max_size();
}
size_type max_size() const RAPIDJSON_NOEXCEPT { return allocator_type::max_size(); }
void construct(pointer p, const_reference r)
{
allocator_type::construct(p, r);
}
void destroy(pointer p)
{
allocator_type::destroy(p);
}
void construct(pointer p, const_reference r) { allocator_type::construct(p, r); }
void destroy(pointer p) { allocator_type::destroy(p); }
#endif // !RAPIDJSON_HAS_CXX11
@@ -587,47 +600,35 @@ public:
RAPIDJSON_NAMESPACE::Free<U>(baseAllocator_, p, n);
}
pointer allocate(size_type n = 1, const void* = 0)
{
return allocate<value_type>(n);
}
void deallocate(pointer p, size_type n = 1)
{
deallocate<value_type>(p, n);
}
pointer allocate(size_type n = 1, const void* = 0) { return allocate<value_type>(n); }
void deallocate(pointer p, size_type n = 1) { deallocate<value_type>(p, n); }
#if RAPIDJSON_HAS_CXX11
using is_always_equal = std::is_empty<BaseAllocator>;
#endif
template<typename U>
template <typename U>
bool operator==(const StdAllocator<U, BaseAllocator>& rhs) const RAPIDJSON_NOEXCEPT
{
return baseAllocator_ == rhs.baseAllocator_;
}
template<typename U>
template <typename U>
bool operator!=(const StdAllocator<U, BaseAllocator>& rhs) const RAPIDJSON_NOEXCEPT
{
return !operator==(rhs);
}
//! rapidjson Allocator concept
static const bool kNeedFree = BaseAllocator::kNeedFree;
static const bool kNeedFree = BaseAllocator::kNeedFree;
static const bool kRefCounted = internal::IsRefCounted<BaseAllocator>::Value;
void* Malloc(size_t size)
{
return baseAllocator_.Malloc(size);
}
void* Malloc(size_t size) { return baseAllocator_.Malloc(size); }
void* Realloc(void* originalPtr, size_t originalSize, size_t newSize)
{
return baseAllocator_.Realloc(originalPtr, originalSize, newSize);
}
static void Free(void *ptr) RAPIDJSON_NOEXCEPT
{
BaseAllocator::Free(ptr);
}
static void Free(void* ptr) RAPIDJSON_NOEXCEPT { BaseAllocator::Free(ptr); }
private:
private:
template <typename, typename>
friend class StdAllocator; // access to StdAllocator<!T>.*
@@ -636,47 +637,45 @@ private:
#if !RAPIDJSON_HAS_CXX17 // std::allocator<void> deprecated in C++17
template <typename BaseAllocator>
class StdAllocator<void, BaseAllocator> :
public std::allocator<void>
class StdAllocator<void, BaseAllocator> : public std::allocator<void>
{
typedef std::allocator<void> allocator_type;
public:
public:
typedef BaseAllocator BaseAllocatorType;
StdAllocator() RAPIDJSON_NOEXCEPT :
allocator_type(),
baseAllocator_()
{ }
StdAllocator() RAPIDJSON_NOEXCEPT : allocator_type(), baseAllocator_() {}
StdAllocator(const StdAllocator& rhs) RAPIDJSON_NOEXCEPT :
allocator_type(rhs),
baseAllocator_(rhs.baseAllocator_)
{ }
StdAllocator(const StdAllocator& rhs) RAPIDJSON_NOEXCEPT : allocator_type(rhs),
baseAllocator_(rhs.baseAllocator_)
{
}
template<typename U>
StdAllocator(const StdAllocator<U, BaseAllocator>& rhs) RAPIDJSON_NOEXCEPT :
allocator_type(rhs),
baseAllocator_(rhs.baseAllocator_)
{ }
template <typename U>
StdAllocator(const StdAllocator<U, BaseAllocator>& rhs) RAPIDJSON_NOEXCEPT
: allocator_type(rhs),
baseAllocator_(rhs.baseAllocator_)
{
}
/* implicit */
StdAllocator(const BaseAllocator& baseAllocator) RAPIDJSON_NOEXCEPT :
allocator_type(),
baseAllocator_(baseAllocator)
{ }
StdAllocator(const BaseAllocator& baseAllocator) RAPIDJSON_NOEXCEPT
: allocator_type(),
baseAllocator_(baseAllocator)
{
}
~StdAllocator() RAPIDJSON_NOEXCEPT
{ }
~StdAllocator() RAPIDJSON_NOEXCEPT {}
template<typename U>
struct rebind {
template <typename U>
struct rebind
{
typedef StdAllocator<U, BaseAllocator> other;
};
typedef typename allocator_type::value_type value_type;
private:
private:
template <typename, typename>
friend class StdAllocator; // access to StdAllocator<!T>.*

View File

@@ -24,33 +24,39 @@ RAPIDJSON_DIAG_OFF(effc++)
#if defined(_MSC_VER) && _MSC_VER <= 1800
RAPIDJSON_DIAG_PUSH
RAPIDJSON_DIAG_OFF(4702) // unreachable code
RAPIDJSON_DIAG_OFF(4512) // assignment operator could not be generated
RAPIDJSON_DIAG_OFF(4702) // unreachable code
RAPIDJSON_DIAG_OFF(4512) // assignment operator could not be generated
#endif
RAPIDJSON_NAMESPACE_BEGIN
//! Cursor stream wrapper for counting line and column number if error exists.
/*!
\tparam InputStream Any stream that implements Stream Concept
*/
template <typename InputStream, typename Encoding = UTF8<> >
class CursorStreamWrapper : public GenericStreamWrapper<InputStream, Encoding> {
public:
template <typename InputStream, typename Encoding = UTF8<>>
class CursorStreamWrapper : public GenericStreamWrapper<InputStream, Encoding>
{
public:
typedef typename Encoding::Ch Ch;
CursorStreamWrapper(InputStream& is):
GenericStreamWrapper<InputStream, Encoding>(is), line_(1), col_(0) {}
CursorStreamWrapper(InputStream& is)
: GenericStreamWrapper<InputStream, Encoding>(is), line_(1), col_(0)
{
}
// counting line and column number
Ch Take() {
Ch Take()
{
Ch ch = this->is_.Take();
if(ch == '\n') {
line_ ++;
if(ch == '\n')
{
line_++;
col_ = 0;
} else {
col_ ++;
}
else
{
col_++;
}
return ch;
}
@@ -60,9 +66,9 @@ public:
//! Get the error column number, if error exists.
size_t GetColumn() const { return col_; }
private:
size_t line_; //!< Current Line
size_t col_; //!< Current Column
private:
size_t line_; //!< Current Line
size_t col_; //!< Current Column
};
#if defined(_MSC_VER) && _MSC_VER <= 1800

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,5 @@
// Tencent is pleased to support the open source community by making RapidJSON available.
//
//
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
//
// Licensed under the MIT License (the "License"); you may not use this file except
@@ -7,9 +7,9 @@
//
// http://opensource.org/licenses/MIT
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
#ifndef RAPIDJSON_ENCODEDSTREAM_H_
@@ -32,30 +32,43 @@ RAPIDJSON_NAMESPACE_BEGIN
//! Input byte stream wrapper with a statically bound encoding.
/*!
\tparam Encoding The interpretation of encoding of the stream. Either UTF8, UTF16LE, UTF16BE, UTF32LE, UTF32BE.
\tparam InputByteStream Type of input byte stream. For example, FileReadStream.
\tparam Encoding The interpretation of encoding of the stream. Either UTF8, UTF16LE, UTF16BE,
UTF32LE, UTF32BE. \tparam InputByteStream Type of input byte stream. For example, FileReadStream.
*/
template <typename Encoding, typename InputByteStream>
class EncodedInputStream {
class EncodedInputStream
{
RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1);
public:
public:
typedef typename Encoding::Ch Ch;
EncodedInputStream(InputByteStream& is) : is_(is) {
current_ = Encoding::TakeBOM(is_);
}
EncodedInputStream(InputByteStream& is) : is_(is) { current_ = Encoding::TakeBOM(is_); }
Ch Peek() const { return current_; }
Ch Take() { Ch c = current_; current_ = Encoding::Take(is_); return c; }
Ch Take()
{
Ch c = current_;
current_ = Encoding::Take(is_);
return c;
}
size_t Tell() const { return is_.Tell(); }
// Not implemented
void Put(Ch) { RAPIDJSON_ASSERT(false); }
void Flush() { RAPIDJSON_ASSERT(false); }
Ch* PutBegin() { RAPIDJSON_ASSERT(false); return 0; }
size_t PutEnd(Ch*) { RAPIDJSON_ASSERT(false); return 0; }
void Flush() { RAPIDJSON_ASSERT(false); }
Ch* PutBegin()
{
RAPIDJSON_ASSERT(false);
return 0;
}
size_t PutEnd(Ch*)
{
RAPIDJSON_ASSERT(false);
return 0;
}
private:
private:
EncodedInputStream(const EncodedInputStream&);
EncodedInputStream& operator=(const EncodedInputStream&);
@@ -65,14 +78,19 @@ private:
//! Specialized for UTF8 MemoryStream.
template <>
class EncodedInputStream<UTF8<>, MemoryStream> {
public:
class EncodedInputStream<UTF8<>, MemoryStream>
{
public:
typedef UTF8<>::Ch Ch;
EncodedInputStream(MemoryStream& is) : is_(is) {
if (static_cast<unsigned char>(is_.Peek()) == 0xEFu) is_.Take();
if (static_cast<unsigned char>(is_.Peek()) == 0xBBu) is_.Take();
if (static_cast<unsigned char>(is_.Peek()) == 0xBFu) is_.Take();
EncodedInputStream(MemoryStream& is) : is_(is)
{
if(static_cast<unsigned char>(is_.Peek()) == 0xEFu)
is_.Take();
if(static_cast<unsigned char>(is_.Peek()) == 0xBBu)
is_.Take();
if(static_cast<unsigned char>(is_.Peek()) == 0xBFu)
is_.Take();
}
Ch Peek() const { return is_.Peek(); }
Ch Take() { return is_.Take(); }
@@ -80,51 +98,76 @@ public:
// Not implemented
void Put(Ch) {}
void Flush() {}
void Flush() {}
Ch* PutBegin() { return 0; }
size_t PutEnd(Ch*) { return 0; }
MemoryStream& is_;
private:
private:
EncodedInputStream(const EncodedInputStream&);
EncodedInputStream& operator=(const EncodedInputStream&);
};
//! Output byte stream wrapper with statically bound encoding.
/*!
\tparam Encoding The interpretation of encoding of the stream. Either UTF8, UTF16LE, UTF16BE, UTF32LE, UTF32BE.
\tparam OutputByteStream Type of input byte stream. For example, FileWriteStream.
\tparam Encoding The interpretation of encoding of the stream. Either UTF8, UTF16LE, UTF16BE,
UTF32LE, UTF32BE. \tparam OutputByteStream Type of input byte stream. For example,
FileWriteStream.
*/
template <typename Encoding, typename OutputByteStream>
class EncodedOutputStream {
class EncodedOutputStream
{
RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1);
public:
public:
typedef typename Encoding::Ch Ch;
EncodedOutputStream(OutputByteStream& os, bool putBOM = true) : os_(os) {
if (putBOM)
EncodedOutputStream(OutputByteStream& os, bool putBOM = true) : os_(os)
{
if(putBOM)
Encoding::PutBOM(os_);
}
void Put(Ch c) { Encoding::Put(os_, c); }
void Put(Ch c) { Encoding::Put(os_, c); }
void Flush() { os_.Flush(); }
// Not implemented
Ch Peek() const { RAPIDJSON_ASSERT(false); return 0;}
Ch Take() { RAPIDJSON_ASSERT(false); return 0;}
size_t Tell() const { RAPIDJSON_ASSERT(false); return 0; }
Ch* PutBegin() { RAPIDJSON_ASSERT(false); return 0; }
size_t PutEnd(Ch*) { RAPIDJSON_ASSERT(false); return 0; }
Ch Peek() const
{
RAPIDJSON_ASSERT(false);
return 0;
}
Ch Take()
{
RAPIDJSON_ASSERT(false);
return 0;
}
size_t Tell() const
{
RAPIDJSON_ASSERT(false);
return 0;
}
Ch* PutBegin()
{
RAPIDJSON_ASSERT(false);
return 0;
}
size_t PutEnd(Ch*)
{
RAPIDJSON_ASSERT(false);
return 0;
}
private:
private:
EncodedOutputStream(const EncodedOutputStream&);
EncodedOutputStream& operator=(const EncodedOutputStream&);
OutputByteStream& os_;
};
#define RAPIDJSON_ENCODINGS_FUNC(x) UTF8<Ch>::x, UTF16LE<Ch>::x, UTF16BE<Ch>::x, UTF32LE<Ch>::x, UTF32BE<Ch>::x
#define RAPIDJSON_ENCODINGS_FUNC(x) \
UTF8<Ch>::x, UTF16LE<Ch>::x, UTF16BE<Ch>::x, UTF32LE<Ch>::x, UTF32BE<Ch>::x
//! Input stream wrapper with dynamically bound encoding and automatic encoding detection.
/*!
@@ -132,9 +175,11 @@ private:
\tparam InputByteStream type of input byte stream to be wrapped.
*/
template <typename CharType, typename InputByteStream>
class AutoUTFInputStream {
class AutoUTFInputStream
{
RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1);
public:
public:
typedef CharType Ch;
//! Constructor.
@@ -142,33 +187,49 @@ public:
\param is input stream to be wrapped.
\param type UTF encoding type if it is not detected from the stream.
*/
AutoUTFInputStream(InputByteStream& is, UTFType type = kUTF8) : is_(&is), type_(type), hasBOM_(false) {
RAPIDJSON_ASSERT(type >= kUTF8 && type <= kUTF32BE);
AutoUTFInputStream(InputByteStream& is, UTFType type = kUTF8)
: is_(&is), type_(type), hasBOM_(false)
{
RAPIDJSON_ASSERT(type >= kUTF8 && type <= kUTF32BE);
DetectType();
static const TakeFunc f[] = { RAPIDJSON_ENCODINGS_FUNC(Take) };
takeFunc_ = f[type_];
current_ = takeFunc_(*is_);
static const TakeFunc f[] = {RAPIDJSON_ENCODINGS_FUNC(Take)};
takeFunc_ = f[type_];
current_ = takeFunc_(*is_);
}
UTFType GetType() const { return type_; }
bool HasBOM() const { return hasBOM_; }
Ch Peek() const { return current_; }
Ch Take() { Ch c = current_; current_ = takeFunc_(*is_); return c; }
Ch Take()
{
Ch c = current_;
current_ = takeFunc_(*is_);
return c;
}
size_t Tell() const { return is_->Tell(); }
// Not implemented
void Put(Ch) { RAPIDJSON_ASSERT(false); }
void Flush() { RAPIDJSON_ASSERT(false); }
Ch* PutBegin() { RAPIDJSON_ASSERT(false); return 0; }
size_t PutEnd(Ch*) { RAPIDJSON_ASSERT(false); return 0; }
void Flush() { RAPIDJSON_ASSERT(false); }
Ch* PutBegin()
{
RAPIDJSON_ASSERT(false);
return 0;
}
size_t PutEnd(Ch*)
{
RAPIDJSON_ASSERT(false);
return 0;
}
private:
private:
AutoUTFInputStream(const AutoUTFInputStream&);
AutoUTFInputStream& operator=(const AutoUTFInputStream&);
// Detect encoding type with BOM or RFC 4627
void DetectType() {
void DetectType()
{
// BOM (Byte Order Mark):
// 00 00 FE FF UTF-32BE
// FF FE 00 00 UTF-32LE
@@ -176,17 +237,52 @@ private:
// FF FE UTF-16LE
// EF BB BF UTF-8
const unsigned char* c = reinterpret_cast<const unsigned char *>(is_->Peek4());
if (!c)
const unsigned char* c = reinterpret_cast<const unsigned char*>(is_->Peek4());
if(!c)
return;
unsigned bom = static_cast<unsigned>(c[0] | (c[1] << 8) | (c[2] << 16) | (c[3] << 24));
hasBOM_ = false;
if (bom == 0xFFFE0000) { type_ = kUTF32BE; hasBOM_ = true; is_->Take(); is_->Take(); is_->Take(); is_->Take(); }
else if (bom == 0x0000FEFF) { type_ = kUTF32LE; hasBOM_ = true; is_->Take(); is_->Take(); is_->Take(); is_->Take(); }
else if ((bom & 0xFFFF) == 0xFFFE) { type_ = kUTF16BE; hasBOM_ = true; is_->Take(); is_->Take(); }
else if ((bom & 0xFFFF) == 0xFEFF) { type_ = kUTF16LE; hasBOM_ = true; is_->Take(); is_->Take(); }
else if ((bom & 0xFFFFFF) == 0xBFBBEF) { type_ = kUTF8; hasBOM_ = true; is_->Take(); is_->Take(); is_->Take(); }
hasBOM_ = false;
if(bom == 0xFFFE0000)
{
type_ = kUTF32BE;
hasBOM_ = true;
is_->Take();
is_->Take();
is_->Take();
is_->Take();
}
else if(bom == 0x0000FEFF)
{
type_ = kUTF32LE;
hasBOM_ = true;
is_->Take();
is_->Take();
is_->Take();
is_->Take();
}
else if((bom & 0xFFFF) == 0xFFFE)
{
type_ = kUTF16BE;
hasBOM_ = true;
is_->Take();
is_->Take();
}
else if((bom & 0xFFFF) == 0xFEFF)
{
type_ = kUTF16LE;
hasBOM_ = true;
is_->Take();
is_->Take();
}
else if((bom & 0xFFFFFF) == 0xBFBBEF)
{
type_ = kUTF8;
hasBOM_ = true;
is_->Take();
is_->Take();
is_->Take();
}
// RFC 4627: Section 3
// "Since the first two characters of a JSON text will always be ASCII
@@ -199,21 +295,26 @@ private:
// xx 00 xx 00 UTF-16LE
// xx xx xx xx UTF-8
if (!hasBOM_) {
if(!hasBOM_)
{
int pattern = (c[0] ? 1 : 0) | (c[1] ? 2 : 0) | (c[2] ? 4 : 0) | (c[3] ? 8 : 0);
switch (pattern) {
switch(pattern)
{
case 0x08: type_ = kUTF32BE; break;
case 0x0A: type_ = kUTF16BE; break;
case 0x01: type_ = kUTF32LE; break;
case 0x05: type_ = kUTF16LE; break;
case 0x0F: type_ = kUTF8; break;
case 0x0F: type_ = kUTF8; break;
default: break; // Use type defined by user.
}
}
// Runtime check whether the size of character type is sufficient. It only perform checks with assertion.
if (type_ == kUTF16LE || type_ == kUTF16BE) RAPIDJSON_ASSERT(sizeof(Ch) >= 2);
if (type_ == kUTF32LE || type_ == kUTF32BE) RAPIDJSON_ASSERT(sizeof(Ch) >= 4);
// Runtime check whether the size of character type is sufficient. It only perform checks
// with assertion.
if(type_ == kUTF16LE || type_ == kUTF16BE)
RAPIDJSON_ASSERT(sizeof(Ch) >= 2);
if(type_ == kUTF32LE || type_ == kUTF32BE)
RAPIDJSON_ASSERT(sizeof(Ch) >= 4);
}
typedef Ch (*TakeFunc)(InputByteStream& is);
@@ -230,9 +331,11 @@ private:
\tparam OutputByteStream type of output byte stream to be wrapped.
*/
template <typename CharType, typename OutputByteStream>
class AutoUTFOutputStream {
class AutoUTFOutputStream
{
RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1);
public:
public:
typedef CharType Ch;
//! Constructor.
@@ -241,39 +344,64 @@ public:
\param type UTF encoding type.
\param putBOM Whether to write BOM at the beginning of the stream.
*/
AutoUTFOutputStream(OutputByteStream& os, UTFType type, bool putBOM) : os_(&os), type_(type) {
AutoUTFOutputStream(OutputByteStream& os, UTFType type, bool putBOM) : os_(&os), type_(type)
{
RAPIDJSON_ASSERT(type >= kUTF8 && type <= kUTF32BE);
// Runtime check whether the size of character type is sufficient. It only perform checks with assertion.
if (type_ == kUTF16LE || type_ == kUTF16BE) RAPIDJSON_ASSERT(sizeof(Ch) >= 2);
if (type_ == kUTF32LE || type_ == kUTF32BE) RAPIDJSON_ASSERT(sizeof(Ch) >= 4);
// Runtime check whether the size of character type is sufficient. It only perform checks
// with assertion.
if(type_ == kUTF16LE || type_ == kUTF16BE)
RAPIDJSON_ASSERT(sizeof(Ch) >= 2);
if(type_ == kUTF32LE || type_ == kUTF32BE)
RAPIDJSON_ASSERT(sizeof(Ch) >= 4);
static const PutFunc f[] = { RAPIDJSON_ENCODINGS_FUNC(Put) };
putFunc_ = f[type_];
static const PutFunc f[] = {RAPIDJSON_ENCODINGS_FUNC(Put)};
putFunc_ = f[type_];
if (putBOM)
if(putBOM)
PutBOM();
}
UTFType GetType() const { return type_; }
void Put(Ch c) { putFunc_(*os_, c); }
void Flush() { os_->Flush(); }
void Flush() { os_->Flush(); }
// Not implemented
Ch Peek() const { RAPIDJSON_ASSERT(false); return 0;}
Ch Take() { RAPIDJSON_ASSERT(false); return 0;}
size_t Tell() const { RAPIDJSON_ASSERT(false); return 0; }
Ch* PutBegin() { RAPIDJSON_ASSERT(false); return 0; }
size_t PutEnd(Ch*) { RAPIDJSON_ASSERT(false); return 0; }
Ch Peek() const
{
RAPIDJSON_ASSERT(false);
return 0;
}
Ch Take()
{
RAPIDJSON_ASSERT(false);
return 0;
}
size_t Tell() const
{
RAPIDJSON_ASSERT(false);
return 0;
}
Ch* PutBegin()
{
RAPIDJSON_ASSERT(false);
return 0;
}
size_t PutEnd(Ch*)
{
RAPIDJSON_ASSERT(false);
return 0;
}
private:
private:
AutoUTFOutputStream(const AutoUTFOutputStream&);
AutoUTFOutputStream& operator=(const AutoUTFOutputStream&);
void PutBOM() {
void PutBOM()
{
typedef void (*PutBOMFunc)(OutputByteStream&);
static const PutBOMFunc f[] = { RAPIDJSON_ENCODINGS_FUNC(PutBOM) };
static const PutBOMFunc f[] = {RAPIDJSON_ENCODINGS_FUNC(PutBOM)};
f[type_](*os_);
}

View File

@@ -1,5 +1,5 @@
// Tencent is pleased to support the open source community by making RapidJSON available.
//
//
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
//
// Licensed under the MIT License (the "License"); you may not use this file except
@@ -7,9 +7,9 @@
//
// http://opensource.org/licenses/MIT
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
#ifndef RAPIDJSON_ENCODINGS_H_
@@ -20,7 +20,7 @@
#if defined(_MSC_VER) && !defined(__clang__)
RAPIDJSON_DIAG_PUSH
RAPIDJSON_DIAG_OFF(4244) // conversion from 'type1' to 'type2', possible loss of data
RAPIDJSON_DIAG_OFF(4702) // unreachable code
RAPIDJSON_DIAG_OFF(4702) // unreachable code
#elif defined(__GNUC__)
RAPIDJSON_DIAG_PUSH
RAPIDJSON_DIAG_OFF(effc++)
@@ -37,7 +37,8 @@ RAPIDJSON_NAMESPACE_BEGIN
\code
concept Encoding {
typename Ch; //! Type of character. A "character" is actually a code unit in unicode's definition.
typename Ch; //! Type of character. A "character" is actually a code unit in unicode's
definition.
enum { supportUnicode = 1 }; // or 0 if not supporting unicode
@@ -92,26 +93,34 @@ concept Encoding {
\tparam CharType Code unit for storing 8-bit UTF-8 data. Default is char.
\note implements Encoding concept
*/
template<typename CharType = char>
struct UTF8 {
template <typename CharType = char>
struct UTF8
{
typedef CharType Ch;
enum { supportUnicode = 1 };
enum
{
supportUnicode = 1
};
template<typename OutputStream>
static void Encode(OutputStream& os, unsigned codepoint) {
if (codepoint <= 0x7F)
template <typename OutputStream>
static void Encode(OutputStream& os, unsigned codepoint)
{
if(codepoint <= 0x7F)
os.Put(static_cast<Ch>(codepoint & 0xFF));
else if (codepoint <= 0x7FF) {
else if(codepoint <= 0x7FF)
{
os.Put(static_cast<Ch>(0xC0 | ((codepoint >> 6) & 0xFF)));
os.Put(static_cast<Ch>(0x80 | ((codepoint & 0x3F))));
}
else if (codepoint <= 0xFFFF) {
else if(codepoint <= 0xFFFF)
{
os.Put(static_cast<Ch>(0xE0 | ((codepoint >> 12) & 0xFF)));
os.Put(static_cast<Ch>(0x80 | ((codepoint >> 6) & 0x3F)));
os.Put(static_cast<Ch>(0x80 | (codepoint & 0x3F)));
}
else {
else
{
RAPIDJSON_ASSERT(codepoint <= 0x10FFFF);
os.Put(static_cast<Ch>(0xF0 | ((codepoint >> 18) & 0xFF)));
os.Put(static_cast<Ch>(0x80 | ((codepoint >> 12) & 0x3F)));
@@ -120,20 +129,24 @@ struct UTF8 {
}
}
template<typename OutputStream>
static void EncodeUnsafe(OutputStream& os, unsigned codepoint) {
if (codepoint <= 0x7F)
template <typename OutputStream>
static void EncodeUnsafe(OutputStream& os, unsigned codepoint)
{
if(codepoint <= 0x7F)
PutUnsafe(os, static_cast<Ch>(codepoint & 0xFF));
else if (codepoint <= 0x7FF) {
else if(codepoint <= 0x7FF)
{
PutUnsafe(os, static_cast<Ch>(0xC0 | ((codepoint >> 6) & 0xFF)));
PutUnsafe(os, static_cast<Ch>(0x80 | ((codepoint & 0x3F))));
}
else if (codepoint <= 0xFFFF) {
else if(codepoint <= 0xFFFF)
{
PutUnsafe(os, static_cast<Ch>(0xE0 | ((codepoint >> 12) & 0xFF)));
PutUnsafe(os, static_cast<Ch>(0x80 | ((codepoint >> 6) & 0x3F)));
PutUnsafe(os, static_cast<Ch>(0x80 | (codepoint & 0x3F)));
}
else {
else
{
RAPIDJSON_ASSERT(codepoint <= 0x10FFFF);
PutUnsafe(os, static_cast<Ch>(0xF0 | ((codepoint >> 18) & 0xFF)));
PutUnsafe(os, static_cast<Ch>(0x80 | ((codepoint >> 12) & 0x3F)));
@@ -143,31 +156,66 @@ struct UTF8 {
}
template <typename InputStream>
static bool Decode(InputStream& is, unsigned* codepoint) {
#define RAPIDJSON_COPY() c = is.Take(); *codepoint = (*codepoint << 6) | (static_cast<unsigned char>(c) & 0x3Fu)
static bool Decode(InputStream& is, unsigned* codepoint)
{
#define RAPIDJSON_COPY() \
c = is.Take(); \
*codepoint = (*codepoint << 6) | (static_cast<unsigned char>(c) & 0x3Fu)
#define RAPIDJSON_TRANS(mask) result &= ((GetRange(static_cast<unsigned char>(c)) & mask) != 0)
#define RAPIDJSON_TAIL() RAPIDJSON_COPY(); RAPIDJSON_TRANS(0x70)
#define RAPIDJSON_TAIL() \
RAPIDJSON_COPY(); \
RAPIDJSON_TRANS(0x70)
typename InputStream::Ch c = is.Take();
if (!(c & 0x80)) {
if(!(c & 0x80))
{
*codepoint = static_cast<unsigned char>(c);
return true;
}
unsigned char type = GetRange(static_cast<unsigned char>(c));
if (type >= 32) {
if(type >= 32)
{
*codepoint = 0;
} else {
}
else
{
*codepoint = (0xFFu >> type) & static_cast<unsigned char>(c);
}
bool result = true;
switch (type) {
switch(type)
{
case 2: RAPIDJSON_TAIL(); return result;
case 3: RAPIDJSON_TAIL(); RAPIDJSON_TAIL(); return result;
case 4: RAPIDJSON_COPY(); RAPIDJSON_TRANS(0x50); RAPIDJSON_TAIL(); return result;
case 5: RAPIDJSON_COPY(); RAPIDJSON_TRANS(0x10); RAPIDJSON_TAIL(); RAPIDJSON_TAIL(); return result;
case 6: RAPIDJSON_TAIL(); RAPIDJSON_TAIL(); RAPIDJSON_TAIL(); return result;
case 10: RAPIDJSON_COPY(); RAPIDJSON_TRANS(0x20); RAPIDJSON_TAIL(); return result;
case 11: RAPIDJSON_COPY(); RAPIDJSON_TRANS(0x60); RAPIDJSON_TAIL(); RAPIDJSON_TAIL(); return result;
case 3:
RAPIDJSON_TAIL();
RAPIDJSON_TAIL();
return result;
case 4:
RAPIDJSON_COPY();
RAPIDJSON_TRANS(0x50);
RAPIDJSON_TAIL();
return result;
case 5:
RAPIDJSON_COPY();
RAPIDJSON_TRANS(0x10);
RAPIDJSON_TAIL();
RAPIDJSON_TAIL();
return result;
case 6:
RAPIDJSON_TAIL();
RAPIDJSON_TAIL();
RAPIDJSON_TAIL();
return result;
case 10:
RAPIDJSON_COPY();
RAPIDJSON_TRANS(0x20);
RAPIDJSON_TAIL();
return result;
case 11:
RAPIDJSON_COPY();
RAPIDJSON_TRANS(0x60);
RAPIDJSON_TAIL();
RAPIDJSON_TAIL();
return result;
default: return false;
}
#undef RAPIDJSON_COPY
@@ -176,24 +224,55 @@ struct UTF8 {
}
template <typename InputStream, typename OutputStream>
static bool Validate(InputStream& is, OutputStream& os) {
#define RAPIDJSON_COPY() if (c != '\0') os.Put(c = is.Take())
static bool Validate(InputStream& is, OutputStream& os)
{
#define RAPIDJSON_COPY() \
if(c != '\0') \
os.Put(c = is.Take())
#define RAPIDJSON_TRANS(mask) result &= ((GetRange(static_cast<unsigned char>(c)) & mask) != 0)
#define RAPIDJSON_TAIL() RAPIDJSON_COPY(); RAPIDJSON_TRANS(0x70)
#define RAPIDJSON_TAIL() \
RAPIDJSON_COPY(); \
RAPIDJSON_TRANS(0x70)
Ch c = static_cast<Ch>(-1);
RAPIDJSON_COPY();
if (!(c & 0x80))
if(!(c & 0x80))
return true;
bool result = true;
switch (GetRange(static_cast<unsigned char>(c))) {
switch(GetRange(static_cast<unsigned char>(c)))
{
case 2: RAPIDJSON_TAIL(); return result;
case 3: RAPIDJSON_TAIL(); RAPIDJSON_TAIL(); return result;
case 4: RAPIDJSON_COPY(); RAPIDJSON_TRANS(0x50); RAPIDJSON_TAIL(); return result;
case 5: RAPIDJSON_COPY(); RAPIDJSON_TRANS(0x10); RAPIDJSON_TAIL(); RAPIDJSON_TAIL(); return result;
case 6: RAPIDJSON_TAIL(); RAPIDJSON_TAIL(); RAPIDJSON_TAIL(); return result;
case 10: RAPIDJSON_COPY(); RAPIDJSON_TRANS(0x20); RAPIDJSON_TAIL(); return result;
case 11: RAPIDJSON_COPY(); RAPIDJSON_TRANS(0x60); RAPIDJSON_TAIL(); RAPIDJSON_TAIL(); return result;
case 3:
RAPIDJSON_TAIL();
RAPIDJSON_TAIL();
return result;
case 4:
RAPIDJSON_COPY();
RAPIDJSON_TRANS(0x50);
RAPIDJSON_TAIL();
return result;
case 5:
RAPIDJSON_COPY();
RAPIDJSON_TRANS(0x10);
RAPIDJSON_TAIL();
RAPIDJSON_TAIL();
return result;
case 6:
RAPIDJSON_TAIL();
RAPIDJSON_TAIL();
RAPIDJSON_TAIL();
return result;
case 10:
RAPIDJSON_COPY();
RAPIDJSON_TRANS(0x20);
RAPIDJSON_TAIL();
return result;
case 11:
RAPIDJSON_COPY();
RAPIDJSON_TRANS(0x60);
RAPIDJSON_TAIL();
RAPIDJSON_TAIL();
return result;
default: return false;
}
#undef RAPIDJSON_COPY
@@ -201,45 +280,62 @@ struct UTF8 {
#undef RAPIDJSON_TAIL
}
static unsigned char GetRange(unsigned char c) {
static unsigned char GetRange(unsigned char c)
{
// Referring to DFA of http://bjoern.hoehrmann.de/utf-8/decoder/dfa/
// With new mapping 1 -> 0x10, 7 -> 0x20, 9 -> 0x40, such that AND operation can test multiple types.
// With new mapping 1 -> 0x10, 7 -> 0x20, 9 -> 0x40, such that AND operation can test
// multiple types.
static const unsigned char type[] = {
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,
0x40,0x40,0x40,0x40,0x40,0x40,0x40,0x40,0x40,0x40,0x40,0x40,0x40,0x40,0x40,0x40,
0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,
0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,0x20,
8,8,2,2,2,2,2,2,2,2,2,2,2,2,2,2, 2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,
10,3,3,3,3,3,3,3,3,3,3,3,3,4,3,3, 11,6,6,6,5,8,8,8,8,8,8,8,8,8,8,8,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10,
0x10, 0x10, 0x10, 0x10, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40,
0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 8, 8, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
10, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4,
3, 3, 11, 6, 6, 6, 5, 8, 8, 8, 8, 8, 8, 8,
8, 8, 8, 8,
};
return type[c];
}
template <typename InputByteStream>
static CharType TakeBOM(InputByteStream& is) {
static CharType TakeBOM(InputByteStream& is)
{
RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1);
typename InputByteStream::Ch c = Take(is);
if (static_cast<unsigned char>(c) != 0xEFu) return c;
if(static_cast<unsigned char>(c) != 0xEFu)
return c;
c = is.Take();
if (static_cast<unsigned char>(c) != 0xBBu) return c;
if(static_cast<unsigned char>(c) != 0xBBu)
return c;
c = is.Take();
if (static_cast<unsigned char>(c) != 0xBFu) return c;
if(static_cast<unsigned char>(c) != 0xBFu)
return c;
c = is.Take();
return c;
}
template <typename InputByteStream>
static Ch Take(InputByteStream& is) {
static Ch Take(InputByteStream& is)
{
RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1);
return static_cast<Ch>(is.Take());
}
template <typename OutputByteStream>
static void PutBOM(OutputByteStream& os) {
static void PutBOM(OutputByteStream& os)
{
RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1);
os.Put(static_cast<typename OutputByteStream::Ch>(0xEFu));
os.Put(static_cast<typename OutputByteStream::Ch>(0xBBu));
@@ -247,7 +343,8 @@ struct UTF8 {
}
template <typename OutputByteStream>
static void Put(OutputByteStream& os, Ch c) {
static void Put(OutputByteStream& os, Ch c)
{
RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1);
os.Put(static_cast<typename OutputByteStream::Ch>(c));
}
@@ -259,27 +356,35 @@ struct UTF8 {
//! UTF-16 encoding.
/*! http://en.wikipedia.org/wiki/UTF-16
http://tools.ietf.org/html/rfc2781
\tparam CharType Type for storing 16-bit UTF-16 data. Default is wchar_t. C++11 may use char16_t instead.
\note implements Encoding concept
\tparam CharType Type for storing 16-bit UTF-16 data. Default is wchar_t. C++11 may use char16_t
instead. \note implements Encoding concept
\note For in-memory access, no need to concern endianness. The code units and code points are represented by CPU's endianness.
For streaming, use UTF16LE and UTF16BE, which handle endianness.
\note For in-memory access, no need to concern endianness. The code units and code points are
represented by CPU's endianness. For streaming, use UTF16LE and UTF16BE, which handle endianness.
*/
template<typename CharType = wchar_t>
struct UTF16 {
template <typename CharType = wchar_t>
struct UTF16
{
typedef CharType Ch;
RAPIDJSON_STATIC_ASSERT(sizeof(Ch) >= 2);
enum { supportUnicode = 1 };
enum
{
supportUnicode = 1
};
template<typename OutputStream>
static void Encode(OutputStream& os, unsigned codepoint) {
template <typename OutputStream>
static void Encode(OutputStream& os, unsigned codepoint)
{
RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputStream::Ch) >= 2);
if (codepoint <= 0xFFFF) {
RAPIDJSON_ASSERT(codepoint < 0xD800 || codepoint > 0xDFFF); // Code point itself cannot be surrogate pair
if(codepoint <= 0xFFFF)
{
RAPIDJSON_ASSERT(codepoint < 0xD800 ||
codepoint > 0xDFFF); // Code point itself cannot be surrogate pair
os.Put(static_cast<typename OutputStream::Ch>(codepoint));
}
else {
else
{
RAPIDJSON_ASSERT(codepoint <= 0x10FFFF);
unsigned v = codepoint - 0x10000;
os.Put(static_cast<typename OutputStream::Ch>((v >> 10) | 0xD800));
@@ -287,15 +392,18 @@ struct UTF16 {
}
}
template<typename OutputStream>
static void EncodeUnsafe(OutputStream& os, unsigned codepoint) {
template <typename OutputStream>
static void EncodeUnsafe(OutputStream& os, unsigned codepoint)
{
RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputStream::Ch) >= 2);
if (codepoint <= 0xFFFF) {
RAPIDJSON_ASSERT(codepoint < 0xD800 || codepoint > 0xDFFF); // Code point itself cannot be surrogate pair
if(codepoint <= 0xFFFF)
{
RAPIDJSON_ASSERT(codepoint < 0xD800 ||
codepoint > 0xDFFF); // Code point itself cannot be surrogate pair
PutUnsafe(os, static_cast<typename OutputStream::Ch>(codepoint));
}
else {
else
{
RAPIDJSON_ASSERT(codepoint <= 0x10FFFF);
unsigned v = codepoint - 0x10000;
PutUnsafe(os, static_cast<typename OutputStream::Ch>((v >> 10) | 0xD800));
@@ -304,16 +412,19 @@ struct UTF16 {
}
template <typename InputStream>
static bool Decode(InputStream& is, unsigned* codepoint) {
static bool Decode(InputStream& is, unsigned* codepoint)
{
RAPIDJSON_STATIC_ASSERT(sizeof(typename InputStream::Ch) >= 2);
typename InputStream::Ch c = is.Take();
if (c < 0xD800 || c > 0xDFFF) {
if(c < 0xD800 || c > 0xDFFF)
{
*codepoint = static_cast<unsigned>(c);
return true;
}
else if (c <= 0xDBFF) {
else if(c <= 0xDBFF)
{
*codepoint = (static_cast<unsigned>(c) & 0x3FF) << 10;
c = is.Take();
c = is.Take();
*codepoint |= (static_cast<unsigned>(c) & 0x3FF);
*codepoint += 0x10000;
return c >= 0xDC00 && c <= 0xDFFF;
@@ -322,14 +433,16 @@ struct UTF16 {
}
template <typename InputStream, typename OutputStream>
static bool Validate(InputStream& is, OutputStream& os) {
static bool Validate(InputStream& is, OutputStream& os)
{
RAPIDJSON_STATIC_ASSERT(sizeof(typename InputStream::Ch) >= 2);
RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputStream::Ch) >= 2);
typename InputStream::Ch c;
os.Put(static_cast<typename OutputStream::Ch>(c = is.Take()));
if (c < 0xD800 || c > 0xDFFF)
if(c < 0xD800 || c > 0xDFFF)
return true;
else if (c <= 0xDBFF) {
else if(c <= 0xDBFF)
{
os.Put(c = is.Take());
return c >= 0xDC00 && c <= 0xDFFF;
}
@@ -338,17 +451,20 @@ struct UTF16 {
};
//! UTF-16 little endian encoding.
template<typename CharType = wchar_t>
struct UTF16LE : UTF16<CharType> {
template <typename CharType = wchar_t>
struct UTF16LE : UTF16<CharType>
{
template <typename InputByteStream>
static CharType TakeBOM(InputByteStream& is) {
static CharType TakeBOM(InputByteStream& is)
{
RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1);
CharType c = Take(is);
return static_cast<uint16_t>(c) == 0xFEFFu ? Take(is) : c;
}
template <typename InputByteStream>
static CharType Take(InputByteStream& is) {
static CharType Take(InputByteStream& is)
{
RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1);
unsigned c = static_cast<uint8_t>(is.Take());
c |= static_cast<unsigned>(static_cast<uint8_t>(is.Take())) << 8;
@@ -356,14 +472,16 @@ struct UTF16LE : UTF16<CharType> {
}
template <typename OutputByteStream>
static void PutBOM(OutputByteStream& os) {
static void PutBOM(OutputByteStream& os)
{
RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1);
os.Put(static_cast<typename OutputByteStream::Ch>(0xFFu));
os.Put(static_cast<typename OutputByteStream::Ch>(0xFEu));
}
template <typename OutputByteStream>
static void Put(OutputByteStream& os, CharType c) {
static void Put(OutputByteStream& os, CharType c)
{
RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1);
os.Put(static_cast<typename OutputByteStream::Ch>(static_cast<unsigned>(c) & 0xFFu));
os.Put(static_cast<typename OutputByteStream::Ch>((static_cast<unsigned>(c) >> 8) & 0xFFu));
@@ -371,17 +489,20 @@ struct UTF16LE : UTF16<CharType> {
};
//! UTF-16 big endian encoding.
template<typename CharType = wchar_t>
struct UTF16BE : UTF16<CharType> {
template <typename CharType = wchar_t>
struct UTF16BE : UTF16<CharType>
{
template <typename InputByteStream>
static CharType TakeBOM(InputByteStream& is) {
static CharType TakeBOM(InputByteStream& is)
{
RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1);
CharType c = Take(is);
return static_cast<uint16_t>(c) == 0xFEFFu ? Take(is) : c;
}
template <typename InputByteStream>
static CharType Take(InputByteStream& is) {
static CharType Take(InputByteStream& is)
{
RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1);
unsigned c = static_cast<unsigned>(static_cast<uint8_t>(is.Take())) << 8;
c |= static_cast<unsigned>(static_cast<uint8_t>(is.Take()));
@@ -389,14 +510,16 @@ struct UTF16BE : UTF16<CharType> {
}
template <typename OutputByteStream>
static void PutBOM(OutputByteStream& os) {
static void PutBOM(OutputByteStream& os)
{
RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1);
os.Put(static_cast<typename OutputByteStream::Ch>(0xFEu));
os.Put(static_cast<typename OutputByteStream::Ch>(0xFFu));
}
template <typename OutputByteStream>
static void Put(OutputByteStream& os, CharType c) {
static void Put(OutputByteStream& os, CharType c)
{
RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1);
os.Put(static_cast<typename OutputByteStream::Ch>((static_cast<unsigned>(c) >> 8) & 0xFFu));
os.Put(static_cast<typename OutputByteStream::Ch>(static_cast<unsigned>(c) & 0xFFu));
@@ -406,45 +529,53 @@ struct UTF16BE : UTF16<CharType> {
///////////////////////////////////////////////////////////////////////////////
// UTF32
//! UTF-32 encoding.
//! UTF-32 encoding.
/*! http://en.wikipedia.org/wiki/UTF-32
\tparam CharType Type for storing 32-bit UTF-32 data. Default is unsigned. C++11 may use char32_t instead.
\note implements Encoding concept
\tparam CharType Type for storing 32-bit UTF-32 data. Default is unsigned. C++11 may use
char32_t instead. \note implements Encoding concept
\note For in-memory access, no need to concern endianness. The code units and code points are represented by CPU's endianness.
For streaming, use UTF32LE and UTF32BE, which handle endianness.
\note For in-memory access, no need to concern endianness. The code units and code points are
represented by CPU's endianness. For streaming, use UTF32LE and UTF32BE, which handle endianness.
*/
template<typename CharType = unsigned>
struct UTF32 {
template <typename CharType = unsigned>
struct UTF32
{
typedef CharType Ch;
RAPIDJSON_STATIC_ASSERT(sizeof(Ch) >= 4);
enum { supportUnicode = 1 };
enum
{
supportUnicode = 1
};
template<typename OutputStream>
static void Encode(OutputStream& os, unsigned codepoint) {
template <typename OutputStream>
static void Encode(OutputStream& os, unsigned codepoint)
{
RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputStream::Ch) >= 4);
RAPIDJSON_ASSERT(codepoint <= 0x10FFFF);
os.Put(codepoint);
}
template<typename OutputStream>
static void EncodeUnsafe(OutputStream& os, unsigned codepoint) {
template <typename OutputStream>
static void EncodeUnsafe(OutputStream& os, unsigned codepoint)
{
RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputStream::Ch) >= 4);
RAPIDJSON_ASSERT(codepoint <= 0x10FFFF);
PutUnsafe(os, codepoint);
}
template <typename InputStream>
static bool Decode(InputStream& is, unsigned* codepoint) {
static bool Decode(InputStream& is, unsigned* codepoint)
{
RAPIDJSON_STATIC_ASSERT(sizeof(typename InputStream::Ch) >= 4);
Ch c = is.Take();
Ch c = is.Take();
*codepoint = c;
return c <= 0x10FFFF;
}
template <typename InputStream, typename OutputStream>
static bool Validate(InputStream& is, OutputStream& os) {
static bool Validate(InputStream& is, OutputStream& os)
{
RAPIDJSON_STATIC_ASSERT(sizeof(typename InputStream::Ch) >= 4);
Ch c;
os.Put(c = is.Take());
@@ -453,17 +584,20 @@ struct UTF32 {
};
//! UTF-32 little endian enocoding.
template<typename CharType = unsigned>
struct UTF32LE : UTF32<CharType> {
template <typename CharType = unsigned>
struct UTF32LE : UTF32<CharType>
{
template <typename InputByteStream>
static CharType TakeBOM(InputByteStream& is) {
static CharType TakeBOM(InputByteStream& is)
{
RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1);
CharType c = Take(is);
return static_cast<uint32_t>(c) == 0x0000FEFFu ? Take(is) : c;
}
template <typename InputByteStream>
static CharType Take(InputByteStream& is) {
static CharType Take(InputByteStream& is)
{
RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1);
unsigned c = static_cast<uint8_t>(is.Take());
c |= static_cast<unsigned>(static_cast<uint8_t>(is.Take())) << 8;
@@ -473,7 +607,8 @@ struct UTF32LE : UTF32<CharType> {
}
template <typename OutputByteStream>
static void PutBOM(OutputByteStream& os) {
static void PutBOM(OutputByteStream& os)
{
RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1);
os.Put(static_cast<typename OutputByteStream::Ch>(0xFFu));
os.Put(static_cast<typename OutputByteStream::Ch>(0xFEu));
@@ -482,7 +617,8 @@ struct UTF32LE : UTF32<CharType> {
}
template <typename OutputByteStream>
static void Put(OutputByteStream& os, CharType c) {
static void Put(OutputByteStream& os, CharType c)
{
RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1);
os.Put(static_cast<typename OutputByteStream::Ch>(c & 0xFFu));
os.Put(static_cast<typename OutputByteStream::Ch>((c >> 8) & 0xFFu));
@@ -492,17 +628,20 @@ struct UTF32LE : UTF32<CharType> {
};
//! UTF-32 big endian encoding.
template<typename CharType = unsigned>
struct UTF32BE : UTF32<CharType> {
template <typename CharType = unsigned>
struct UTF32BE : UTF32<CharType>
{
template <typename InputByteStream>
static CharType TakeBOM(InputByteStream& is) {
static CharType TakeBOM(InputByteStream& is)
{
RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1);
CharType c = Take(is);
return static_cast<uint32_t>(c) == 0x0000FEFFu ? Take(is) : c;
return static_cast<uint32_t>(c) == 0x0000FEFFu ? Take(is) : c;
}
template <typename InputByteStream>
static CharType Take(InputByteStream& is) {
static CharType Take(InputByteStream& is)
{
RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1);
unsigned c = static_cast<unsigned>(static_cast<uint8_t>(is.Take())) << 24;
c |= static_cast<unsigned>(static_cast<uint8_t>(is.Take())) << 16;
@@ -512,7 +651,8 @@ struct UTF32BE : UTF32<CharType> {
}
template <typename OutputByteStream>
static void PutBOM(OutputByteStream& os) {
static void PutBOM(OutputByteStream& os)
{
RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1);
os.Put(static_cast<typename OutputByteStream::Ch>(0x00u));
os.Put(static_cast<typename OutputByteStream::Ch>(0x00u));
@@ -521,7 +661,8 @@ struct UTF32BE : UTF32<CharType> {
}
template <typename OutputByteStream>
static void Put(OutputByteStream& os, CharType c) {
static void Put(OutputByteStream& os, CharType c)
{
RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1);
os.Put(static_cast<typename OutputByteStream::Ch>((c >> 24) & 0xFFu));
os.Put(static_cast<typename OutputByteStream::Ch>((c >> 16) & 0xFFu));
@@ -538,59 +679,71 @@ struct UTF32BE : UTF32<CharType> {
\tparam CharType Code unit for storing 7-bit ASCII data. Default is char.
\note implements Encoding concept
*/
template<typename CharType = char>
struct ASCII {
template <typename CharType = char>
struct ASCII
{
typedef CharType Ch;
enum { supportUnicode = 0 };
enum
{
supportUnicode = 0
};
template<typename OutputStream>
static void Encode(OutputStream& os, unsigned codepoint) {
template <typename OutputStream>
static void Encode(OutputStream& os, unsigned codepoint)
{
RAPIDJSON_ASSERT(codepoint <= 0x7F);
os.Put(static_cast<Ch>(codepoint & 0xFF));
}
template<typename OutputStream>
static void EncodeUnsafe(OutputStream& os, unsigned codepoint) {
template <typename OutputStream>
static void EncodeUnsafe(OutputStream& os, unsigned codepoint)
{
RAPIDJSON_ASSERT(codepoint <= 0x7F);
PutUnsafe(os, static_cast<Ch>(codepoint & 0xFF));
}
template <typename InputStream>
static bool Decode(InputStream& is, unsigned* codepoint) {
uint8_t c = static_cast<uint8_t>(is.Take());
static bool Decode(InputStream& is, unsigned* codepoint)
{
uint8_t c = static_cast<uint8_t>(is.Take());
*codepoint = c;
return c <= 0X7F;
}
template <typename InputStream, typename OutputStream>
static bool Validate(InputStream& is, OutputStream& os) {
static bool Validate(InputStream& is, OutputStream& os)
{
uint8_t c = static_cast<uint8_t>(is.Take());
os.Put(static_cast<typename OutputStream::Ch>(c));
return c <= 0x7F;
}
template <typename InputByteStream>
static CharType TakeBOM(InputByteStream& is) {
static CharType TakeBOM(InputByteStream& is)
{
RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1);
uint8_t c = static_cast<uint8_t>(Take(is));
return static_cast<Ch>(c);
}
template <typename InputByteStream>
static Ch Take(InputByteStream& is) {
static Ch Take(InputByteStream& is)
{
RAPIDJSON_STATIC_ASSERT(sizeof(typename InputByteStream::Ch) == 1);
return static_cast<Ch>(is.Take());
}
template <typename OutputByteStream>
static void PutBOM(OutputByteStream& os) {
static void PutBOM(OutputByteStream& os)
{
RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1);
(void)os;
}
template <typename OutputByteStream>
static void Put(OutputByteStream& os, Ch c) {
static void Put(OutputByteStream& os, Ch c)
{
RAPIDJSON_STATIC_ASSERT(sizeof(typename OutputByteStream::Ch) == 1);
os.Put(static_cast<typename OutputByteStream::Ch>(c));
}
@@ -600,50 +753,61 @@ struct ASCII {
// AutoUTF
//! Runtime-specified UTF encoding type of a stream.
enum UTFType {
kUTF8 = 0, //!< UTF-8.
kUTF16LE = 1, //!< UTF-16 little endian.
kUTF16BE = 2, //!< UTF-16 big endian.
kUTF32LE = 3, //!< UTF-32 little endian.
kUTF32BE = 4 //!< UTF-32 big endian.
enum UTFType
{
kUTF8 = 0, //!< UTF-8.
kUTF16LE = 1, //!< UTF-16 little endian.
kUTF16BE = 2, //!< UTF-16 big endian.
kUTF32LE = 3, //!< UTF-32 little endian.
kUTF32BE = 4 //!< UTF-32 big endian.
};
//! Dynamically select encoding according to stream's runtime-specified UTF encoding type.
/*! \note This class can be used with AutoUTFInputtStream and AutoUTFOutputStream, which provides GetType().
*/
template<typename CharType>
struct AutoUTF {
/*! \note This class can be used with AutoUTFInputtStream and AutoUTFOutputStream, which provides
* GetType().
*/
template <typename CharType>
struct AutoUTF
{
typedef CharType Ch;
enum { supportUnicode = 1 };
enum
{
supportUnicode = 1
};
#define RAPIDJSON_ENCODINGS_FUNC(x) UTF8<Ch>::x, UTF16LE<Ch>::x, UTF16BE<Ch>::x, UTF32LE<Ch>::x, UTF32BE<Ch>::x
#define RAPIDJSON_ENCODINGS_FUNC(x) \
UTF8<Ch>::x, UTF16LE<Ch>::x, UTF16BE<Ch>::x, UTF32LE<Ch>::x, UTF32BE<Ch>::x
template<typename OutputStream>
static RAPIDJSON_FORCEINLINE void Encode(OutputStream& os, unsigned codepoint) {
template <typename OutputStream>
static RAPIDJSON_FORCEINLINE void Encode(OutputStream& os, unsigned codepoint)
{
typedef void (*EncodeFunc)(OutputStream&, unsigned);
static const EncodeFunc f[] = { RAPIDJSON_ENCODINGS_FUNC(Encode) };
static const EncodeFunc f[] = {RAPIDJSON_ENCODINGS_FUNC(Encode)};
(*f[os.GetType()])(os, codepoint);
}
template<typename OutputStream>
static RAPIDJSON_FORCEINLINE void EncodeUnsafe(OutputStream& os, unsigned codepoint) {
template <typename OutputStream>
static RAPIDJSON_FORCEINLINE void EncodeUnsafe(OutputStream& os, unsigned codepoint)
{
typedef void (*EncodeFunc)(OutputStream&, unsigned);
static const EncodeFunc f[] = { RAPIDJSON_ENCODINGS_FUNC(EncodeUnsafe) };
static const EncodeFunc f[] = {RAPIDJSON_ENCODINGS_FUNC(EncodeUnsafe)};
(*f[os.GetType()])(os, codepoint);
}
template <typename InputStream>
static RAPIDJSON_FORCEINLINE bool Decode(InputStream& is, unsigned* codepoint) {
static RAPIDJSON_FORCEINLINE bool Decode(InputStream& is, unsigned* codepoint)
{
typedef bool (*DecodeFunc)(InputStream&, unsigned*);
static const DecodeFunc f[] = { RAPIDJSON_ENCODINGS_FUNC(Decode) };
static const DecodeFunc f[] = {RAPIDJSON_ENCODINGS_FUNC(Decode)};
return (*f[is.GetType()])(is, codepoint);
}
template <typename InputStream, typename OutputStream>
static RAPIDJSON_FORCEINLINE bool Validate(InputStream& is, OutputStream& os) {
static RAPIDJSON_FORCEINLINE bool Validate(InputStream& is, OutputStream& os)
{
typedef bool (*ValidateFunc)(InputStream&, OutputStream&);
static const ValidateFunc f[] = { RAPIDJSON_ENCODINGS_FUNC(Validate) };
static const ValidateFunc f[] = {RAPIDJSON_ENCODINGS_FUNC(Validate)};
return (*f[is.GetType()])(is, os);
}
@@ -654,56 +818,67 @@ struct AutoUTF {
// Transcoder
//! Encoding conversion.
template<typename SourceEncoding, typename TargetEncoding>
struct Transcoder {
//! Take one Unicode codepoint from source encoding, convert it to target encoding and put it to the output stream.
template<typename InputStream, typename OutputStream>
static RAPIDJSON_FORCEINLINE bool Transcode(InputStream& is, OutputStream& os) {
template <typename SourceEncoding, typename TargetEncoding>
struct Transcoder
{
//! Take one Unicode codepoint from source encoding, convert it to target encoding and put it to
//! the output stream.
template <typename InputStream, typename OutputStream>
static RAPIDJSON_FORCEINLINE bool Transcode(InputStream& is, OutputStream& os)
{
unsigned codepoint;
if (!SourceEncoding::Decode(is, &codepoint))
if(!SourceEncoding::Decode(is, &codepoint))
return false;
TargetEncoding::Encode(os, codepoint);
return true;
}
template<typename InputStream, typename OutputStream>
static RAPIDJSON_FORCEINLINE bool TranscodeUnsafe(InputStream& is, OutputStream& os) {
template <typename InputStream, typename OutputStream>
static RAPIDJSON_FORCEINLINE bool TranscodeUnsafe(InputStream& is, OutputStream& os)
{
unsigned codepoint;
if (!SourceEncoding::Decode(is, &codepoint))
if(!SourceEncoding::Decode(is, &codepoint))
return false;
TargetEncoding::EncodeUnsafe(os, codepoint);
return true;
}
//! Validate one Unicode codepoint from an encoded stream.
template<typename InputStream, typename OutputStream>
static RAPIDJSON_FORCEINLINE bool Validate(InputStream& is, OutputStream& os) {
return Transcode(is, os); // Since source/target encoding is different, must transcode.
template <typename InputStream, typename OutputStream>
static RAPIDJSON_FORCEINLINE bool Validate(InputStream& is, OutputStream& os)
{
return Transcode(is, os); // Since source/target encoding is different, must transcode.
}
};
// Forward declaration.
template<typename Stream>
template <typename Stream>
inline void PutUnsafe(Stream& stream, typename Stream::Ch c);
//! Specialization of Transcoder with same source and target encoding.
template<typename Encoding>
struct Transcoder<Encoding, Encoding> {
template<typename InputStream, typename OutputStream>
static RAPIDJSON_FORCEINLINE bool Transcode(InputStream& is, OutputStream& os) {
os.Put(is.Take()); // Just copy one code unit. This semantic is different from primary template class.
template <typename Encoding>
struct Transcoder<Encoding, Encoding>
{
template <typename InputStream, typename OutputStream>
static RAPIDJSON_FORCEINLINE bool Transcode(InputStream& is, OutputStream& os)
{
os.Put(is.Take()); // Just copy one code unit. This semantic is different from primary
// template class.
return true;
}
template<typename InputStream, typename OutputStream>
static RAPIDJSON_FORCEINLINE bool TranscodeUnsafe(InputStream& is, OutputStream& os) {
PutUnsafe(os, is.Take()); // Just copy one code unit. This semantic is different from primary template class.
template <typename InputStream, typename OutputStream>
static RAPIDJSON_FORCEINLINE bool TranscodeUnsafe(InputStream& is, OutputStream& os)
{
PutUnsafe(os, is.Take()); // Just copy one code unit. This semantic is different from
// primary template class.
return true;
}
template<typename InputStream, typename OutputStream>
static RAPIDJSON_FORCEINLINE bool Validate(InputStream& is, OutputStream& os) {
return Encoding::Validate(is, os); // source/target encoding are the same
template <typename InputStream, typename OutputStream>
static RAPIDJSON_FORCEINLINE bool Validate(InputStream& is, OutputStream& os)
{
return Encoding::Validate(is, os); // source/target encoding are the same
}
};

View File

@@ -19,8 +19,8 @@
#ifdef __clang__
RAPIDJSON_DIAG_PUSH
RAPIDJSON_DIAG_OFF(switch-enum)
RAPIDJSON_DIAG_OFF(covered-switch-default)
RAPIDJSON_DIAG_OFF(switch - enum)
RAPIDJSON_DIAG_OFF(covered - switch - default)
#endif
RAPIDJSON_NAMESPACE_BEGIN
@@ -33,35 +33,51 @@ RAPIDJSON_NAMESPACE_BEGIN
\note User can make a copy of this function for localization.
Using switch-case is safer for future modification of error codes.
*/
inline const RAPIDJSON_ERROR_CHARTYPE* GetParseError_En(ParseErrorCode parseErrorCode) {
switch (parseErrorCode) {
case kParseErrorNone: return RAPIDJSON_ERROR_STRING("No error.");
inline const RAPIDJSON_ERROR_CHARTYPE* GetParseError_En(ParseErrorCode parseErrorCode)
{
switch(parseErrorCode)
{
case kParseErrorNone: return RAPIDJSON_ERROR_STRING("No error.");
case kParseErrorDocumentEmpty: return RAPIDJSON_ERROR_STRING("The document is empty.");
case kParseErrorDocumentRootNotSingular: return RAPIDJSON_ERROR_STRING("The document root must not be followed by other values.");
case kParseErrorDocumentEmpty: return RAPIDJSON_ERROR_STRING("The document is empty.");
case kParseErrorDocumentRootNotSingular:
return RAPIDJSON_ERROR_STRING("The document root must not be followed by other values.");
case kParseErrorValueInvalid: return RAPIDJSON_ERROR_STRING("Invalid value.");
case kParseErrorValueInvalid: return RAPIDJSON_ERROR_STRING("Invalid value.");
case kParseErrorObjectMissName: return RAPIDJSON_ERROR_STRING("Missing a name for object member.");
case kParseErrorObjectMissColon: return RAPIDJSON_ERROR_STRING("Missing a colon after a name of object member.");
case kParseErrorObjectMissCommaOrCurlyBracket: return RAPIDJSON_ERROR_STRING("Missing a comma or '}' after an object member.");
case kParseErrorObjectMissName:
return RAPIDJSON_ERROR_STRING("Missing a name for object member.");
case kParseErrorObjectMissColon:
return RAPIDJSON_ERROR_STRING("Missing a colon after a name of object member.");
case kParseErrorObjectMissCommaOrCurlyBracket:
return RAPIDJSON_ERROR_STRING("Missing a comma or '}' after an object member.");
case kParseErrorArrayMissCommaOrSquareBracket: return RAPIDJSON_ERROR_STRING("Missing a comma or ']' after an array element.");
case kParseErrorArrayMissCommaOrSquareBracket:
return RAPIDJSON_ERROR_STRING("Missing a comma or ']' after an array element.");
case kParseErrorStringUnicodeEscapeInvalidHex: return RAPIDJSON_ERROR_STRING("Incorrect hex digit after \\u escape in string.");
case kParseErrorStringUnicodeSurrogateInvalid: return RAPIDJSON_ERROR_STRING("The surrogate pair in string is invalid.");
case kParseErrorStringEscapeInvalid: return RAPIDJSON_ERROR_STRING("Invalid escape character in string.");
case kParseErrorStringMissQuotationMark: return RAPIDJSON_ERROR_STRING("Missing a closing quotation mark in string.");
case kParseErrorStringInvalidEncoding: return RAPIDJSON_ERROR_STRING("Invalid encoding in string.");
case kParseErrorStringUnicodeEscapeInvalidHex:
return RAPIDJSON_ERROR_STRING("Incorrect hex digit after \\u escape in string.");
case kParseErrorStringUnicodeSurrogateInvalid:
return RAPIDJSON_ERROR_STRING("The surrogate pair in string is invalid.");
case kParseErrorStringEscapeInvalid:
return RAPIDJSON_ERROR_STRING("Invalid escape character in string.");
case kParseErrorStringMissQuotationMark:
return RAPIDJSON_ERROR_STRING("Missing a closing quotation mark in string.");
case kParseErrorStringInvalidEncoding:
return RAPIDJSON_ERROR_STRING("Invalid encoding in string.");
case kParseErrorNumberTooBig: return RAPIDJSON_ERROR_STRING("Number too big to be stored in double.");
case kParseErrorNumberMissFraction: return RAPIDJSON_ERROR_STRING("Miss fraction part in number.");
case kParseErrorNumberMissExponent: return RAPIDJSON_ERROR_STRING("Miss exponent in number.");
case kParseErrorNumberTooBig:
return RAPIDJSON_ERROR_STRING("Number too big to be stored in double.");
case kParseErrorNumberMissFraction:
return RAPIDJSON_ERROR_STRING("Miss fraction part in number.");
case kParseErrorNumberMissExponent: return RAPIDJSON_ERROR_STRING("Miss exponent in number.");
case kParseErrorTermination: return RAPIDJSON_ERROR_STRING("Terminate parsing due to Handler error.");
case kParseErrorUnspecificSyntaxError: return RAPIDJSON_ERROR_STRING("Unspecific syntax error.");
case kParseErrorTermination:
return RAPIDJSON_ERROR_STRING("Terminate parsing due to Handler error.");
case kParseErrorUnspecificSyntaxError:
return RAPIDJSON_ERROR_STRING("Unspecific syntax error.");
default: return RAPIDJSON_ERROR_STRING("Unknown error.");
default: return RAPIDJSON_ERROR_STRING("Unknown error.");
}
}
@@ -73,46 +89,102 @@ inline const RAPIDJSON_ERROR_CHARTYPE* GetParseError_En(ParseErrorCode parseErro
\note User can make a copy of this function for localization.
Using switch-case is safer for future modification of error codes.
*/
inline const RAPIDJSON_ERROR_CHARTYPE* GetValidateError_En(ValidateErrorCode validateErrorCode) {
switch (validateErrorCode) {
case kValidateErrors: return RAPIDJSON_ERROR_STRING("One or more validation errors have occurred");
case kValidateErrorNone: return RAPIDJSON_ERROR_STRING("No error.");
inline const RAPIDJSON_ERROR_CHARTYPE* GetValidateError_En(ValidateErrorCode validateErrorCode)
{
switch(validateErrorCode)
{
case kValidateErrors:
return RAPIDJSON_ERROR_STRING("One or more validation errors have occurred");
case kValidateErrorNone: return RAPIDJSON_ERROR_STRING("No error.");
case kValidateErrorMultipleOf: return RAPIDJSON_ERROR_STRING("Number '%actual' is not a multiple of the 'multipleOf' value '%expected'.");
case kValidateErrorMaximum: return RAPIDJSON_ERROR_STRING("Number '%actual' is greater than the 'maximum' value '%expected'.");
case kValidateErrorExclusiveMaximum: return RAPIDJSON_ERROR_STRING("Number '%actual' is greater than or equal to the 'exclusiveMaximum' value '%expected'.");
case kValidateErrorMinimum: return RAPIDJSON_ERROR_STRING("Number '%actual' is less than the 'minimum' value '%expected'.");
case kValidateErrorExclusiveMinimum: return RAPIDJSON_ERROR_STRING("Number '%actual' is less than or equal to the 'exclusiveMinimum' value '%expected'.");
case kValidateErrorMultipleOf:
return RAPIDJSON_ERROR_STRING(
"Number '%actual' is not a multiple of the 'multipleOf' value '%expected'.");
case kValidateErrorMaximum:
return RAPIDJSON_ERROR_STRING(
"Number '%actual' is greater than the 'maximum' value '%expected'.");
case kValidateErrorExclusiveMaximum:
return RAPIDJSON_ERROR_STRING("Number '%actual' is greater than or equal to the "
"'exclusiveMaximum' value '%expected'.");
case kValidateErrorMinimum:
return RAPIDJSON_ERROR_STRING(
"Number '%actual' is less than the 'minimum' value '%expected'.");
case kValidateErrorExclusiveMinimum:
return RAPIDJSON_ERROR_STRING(
"Number '%actual' is less than or equal to the 'exclusiveMinimum' value '%expected'.");
case kValidateErrorMaxLength: return RAPIDJSON_ERROR_STRING("String '%actual' is longer than the 'maxLength' value '%expected'.");
case kValidateErrorMinLength: return RAPIDJSON_ERROR_STRING("String '%actual' is shorter than the 'minLength' value '%expected'.");
case kValidateErrorPattern: return RAPIDJSON_ERROR_STRING("String '%actual' does not match the 'pattern' regular expression.");
case kValidateErrorMaxLength:
return RAPIDJSON_ERROR_STRING(
"String '%actual' is longer than the 'maxLength' value '%expected'.");
case kValidateErrorMinLength:
return RAPIDJSON_ERROR_STRING(
"String '%actual' is shorter than the 'minLength' value '%expected'.");
case kValidateErrorPattern:
return RAPIDJSON_ERROR_STRING(
"String '%actual' does not match the 'pattern' regular expression.");
case kValidateErrorMaxItems: return RAPIDJSON_ERROR_STRING("Array of length '%actual' is longer than the 'maxItems' value '%expected'.");
case kValidateErrorMinItems: return RAPIDJSON_ERROR_STRING("Array of length '%actual' is shorter than the 'minItems' value '%expected'.");
case kValidateErrorUniqueItems: return RAPIDJSON_ERROR_STRING("Array has duplicate items at indices '%duplicates' but 'uniqueItems' is true.");
case kValidateErrorAdditionalItems: return RAPIDJSON_ERROR_STRING("Array has an additional item at index '%disallowed' that is not allowed by the schema.");
case kValidateErrorMaxItems:
return RAPIDJSON_ERROR_STRING(
"Array of length '%actual' is longer than the 'maxItems' value '%expected'.");
case kValidateErrorMinItems:
return RAPIDJSON_ERROR_STRING(
"Array of length '%actual' is shorter than the 'minItems' value '%expected'.");
case kValidateErrorUniqueItems:
return RAPIDJSON_ERROR_STRING(
"Array has duplicate items at indices '%duplicates' but 'uniqueItems' is true.");
case kValidateErrorAdditionalItems:
return RAPIDJSON_ERROR_STRING("Array has an additional item at index '%disallowed' that is "
"not allowed by the schema.");
case kValidateErrorMaxProperties: return RAPIDJSON_ERROR_STRING("Object has '%actual' members which is more than 'maxProperties' value '%expected'.");
case kValidateErrorMinProperties: return RAPIDJSON_ERROR_STRING("Object has '%actual' members which is less than 'minProperties' value '%expected'.");
case kValidateErrorRequired: return RAPIDJSON_ERROR_STRING("Object is missing the following members required by the schema: '%missing'.");
case kValidateErrorAdditionalProperties: return RAPIDJSON_ERROR_STRING("Object has an additional member '%disallowed' that is not allowed by the schema.");
case kValidateErrorPatternProperties: return RAPIDJSON_ERROR_STRING("Object has 'patternProperties' that are not allowed by the schema.");
case kValidateErrorDependencies: return RAPIDJSON_ERROR_STRING("Object has missing property or schema dependencies, refer to following errors.");
case kValidateErrorMaxProperties:
return RAPIDJSON_ERROR_STRING(
"Object has '%actual' members which is more than 'maxProperties' value '%expected'.");
case kValidateErrorMinProperties:
return RAPIDJSON_ERROR_STRING(
"Object has '%actual' members which is less than 'minProperties' value '%expected'.");
case kValidateErrorRequired:
return RAPIDJSON_ERROR_STRING(
"Object is missing the following members required by the schema: '%missing'.");
case kValidateErrorAdditionalProperties:
return RAPIDJSON_ERROR_STRING(
"Object has an additional member '%disallowed' that is not allowed by the schema.");
case kValidateErrorPatternProperties:
return RAPIDJSON_ERROR_STRING(
"Object has 'patternProperties' that are not allowed by the schema.");
case kValidateErrorDependencies:
return RAPIDJSON_ERROR_STRING(
"Object has missing property or schema dependencies, refer to following errors.");
case kValidateErrorEnum: return RAPIDJSON_ERROR_STRING("Property has a value that is not one of its allowed enumerated values.");
case kValidateErrorType: return RAPIDJSON_ERROR_STRING("Property has a type '%actual' that is not in the following list: '%expected'.");
case kValidateErrorEnum:
return RAPIDJSON_ERROR_STRING(
"Property has a value that is not one of its allowed enumerated values.");
case kValidateErrorType:
return RAPIDJSON_ERROR_STRING(
"Property has a type '%actual' that is not in the following list: '%expected'.");
case kValidateErrorOneOf: return RAPIDJSON_ERROR_STRING("Property did not match any of the sub-schemas specified by 'oneOf', refer to following errors.");
case kValidateErrorOneOfMatch: return RAPIDJSON_ERROR_STRING("Property matched more than one of the sub-schemas specified by 'oneOf', indices '%matches'.");
case kValidateErrorAllOf: return RAPIDJSON_ERROR_STRING("Property did not match all of the sub-schemas specified by 'allOf', refer to following errors.");
case kValidateErrorAnyOf: return RAPIDJSON_ERROR_STRING("Property did not match any of the sub-schemas specified by 'anyOf', refer to following errors.");
case kValidateErrorNot: return RAPIDJSON_ERROR_STRING("Property matched the sub-schema specified by 'not'.");
case kValidateErrorOneOf:
return RAPIDJSON_ERROR_STRING("Property did not match any of the sub-schemas specified by "
"'oneOf', refer to following errors.");
case kValidateErrorOneOfMatch:
return RAPIDJSON_ERROR_STRING("Property matched more than one of the sub-schemas specified "
"by 'oneOf', indices '%matches'.");
case kValidateErrorAllOf:
return RAPIDJSON_ERROR_STRING("Property did not match all of the sub-schemas specified by "
"'allOf', refer to following errors.");
case kValidateErrorAnyOf:
return RAPIDJSON_ERROR_STRING("Property did not match any of the sub-schemas specified by "
"'anyOf', refer to following errors.");
case kValidateErrorNot:
return RAPIDJSON_ERROR_STRING("Property matched the sub-schema specified by 'not'.");
case kValidateErrorReadOnly: return RAPIDJSON_ERROR_STRING("Property is read-only but has been provided when validation is for writing.");
case kValidateErrorWriteOnly: return RAPIDJSON_ERROR_STRING("Property is write-only but has been provided when validation is for reading.");
case kValidateErrorReadOnly:
return RAPIDJSON_ERROR_STRING(
"Property is read-only but has been provided when validation is for writing.");
case kValidateErrorWriteOnly:
return RAPIDJSON_ERROR_STRING(
"Property is write-only but has been provided when validation is for reading.");
default: return RAPIDJSON_ERROR_STRING("Unknown error.");
default: return RAPIDJSON_ERROR_STRING("Unknown error.");
}
}
@@ -124,27 +196,46 @@ inline const RAPIDJSON_ERROR_CHARTYPE* GetValidateError_En(ValidateErrorCode val
\note User can make a copy of this function for localization.
Using switch-case is safer for future modification of error codes.
*/
inline const RAPIDJSON_ERROR_CHARTYPE* GetSchemaError_En(SchemaErrorCode schemaErrorCode) {
switch (schemaErrorCode) {
case kSchemaErrorNone: return RAPIDJSON_ERROR_STRING("No error.");
inline const RAPIDJSON_ERROR_CHARTYPE* GetSchemaError_En(SchemaErrorCode schemaErrorCode)
{
switch(schemaErrorCode)
{
case kSchemaErrorNone: return RAPIDJSON_ERROR_STRING("No error.");
case kSchemaErrorStartUnknown: return RAPIDJSON_ERROR_STRING("Pointer '%value' to start of schema does not resolve to a location in the document.");
case kSchemaErrorRefPlainName: return RAPIDJSON_ERROR_STRING("$ref fragment '%value' must be a JSON pointer.");
case kSchemaErrorRefInvalid: return RAPIDJSON_ERROR_STRING("$ref must not be an empty string.");
case kSchemaErrorRefPointerInvalid: return RAPIDJSON_ERROR_STRING("$ref fragment '%value' is not a valid JSON pointer at offset '%offset'.");
case kSchemaErrorRefUnknown: return RAPIDJSON_ERROR_STRING("$ref '%value' does not resolve to a location in the target document.");
case kSchemaErrorRefCyclical: return RAPIDJSON_ERROR_STRING("$ref '%value' is cyclical.");
case kSchemaErrorRefNoRemoteProvider: return RAPIDJSON_ERROR_STRING("$ref is remote but there is no remote provider.");
case kSchemaErrorRefNoRemoteSchema: return RAPIDJSON_ERROR_STRING("$ref '%value' is remote but the remote provider did not return a schema.");
case kSchemaErrorRegexInvalid: return RAPIDJSON_ERROR_STRING("Invalid regular expression '%value' in 'pattern' or 'patternProperties'.");
case kSchemaErrorSpecUnknown: return RAPIDJSON_ERROR_STRING("JSON schema draft or OpenAPI version is not recognized.");
case kSchemaErrorSpecUnsupported: return RAPIDJSON_ERROR_STRING("JSON schema draft or OpenAPI version is not supported.");
case kSchemaErrorSpecIllegal: return RAPIDJSON_ERROR_STRING("Both JSON schema draft and OpenAPI version found in document.");
case kSchemaErrorReadOnlyAndWriteOnly: return RAPIDJSON_ERROR_STRING("Property must not be both 'readOnly' and 'writeOnly'.");
case kSchemaErrorStartUnknown:
return RAPIDJSON_ERROR_STRING(
"Pointer '%value' to start of schema does not resolve to a location in the document.");
case kSchemaErrorRefPlainName:
return RAPIDJSON_ERROR_STRING("$ref fragment '%value' must be a JSON pointer.");
case kSchemaErrorRefInvalid: return RAPIDJSON_ERROR_STRING("$ref must not be an empty string.");
case kSchemaErrorRefPointerInvalid:
return RAPIDJSON_ERROR_STRING(
"$ref fragment '%value' is not a valid JSON pointer at offset '%offset'.");
case kSchemaErrorRefUnknown:
return RAPIDJSON_ERROR_STRING(
"$ref '%value' does not resolve to a location in the target document.");
case kSchemaErrorRefCyclical: return RAPIDJSON_ERROR_STRING("$ref '%value' is cyclical.");
case kSchemaErrorRefNoRemoteProvider:
return RAPIDJSON_ERROR_STRING("$ref is remote but there is no remote provider.");
case kSchemaErrorRefNoRemoteSchema:
return RAPIDJSON_ERROR_STRING(
"$ref '%value' is remote but the remote provider did not return a schema.");
case kSchemaErrorRegexInvalid:
return RAPIDJSON_ERROR_STRING(
"Invalid regular expression '%value' in 'pattern' or 'patternProperties'.");
case kSchemaErrorSpecUnknown:
return RAPIDJSON_ERROR_STRING("JSON schema draft or OpenAPI version is not recognized.");
case kSchemaErrorSpecUnsupported:
return RAPIDJSON_ERROR_STRING("JSON schema draft or OpenAPI version is not supported.");
case kSchemaErrorSpecIllegal:
return RAPIDJSON_ERROR_STRING(
"Both JSON schema draft and OpenAPI version found in document.");
case kSchemaErrorReadOnlyAndWriteOnly:
return RAPIDJSON_ERROR_STRING("Property must not be both 'readOnly' and 'writeOnly'.");
default: return RAPIDJSON_ERROR_STRING("Unknown error.");
default: return RAPIDJSON_ERROR_STRING("Unknown error.");
}
}
}
//! Maps error code of pointer parse into error message.
/*!
@@ -154,16 +245,22 @@ inline const RAPIDJSON_ERROR_CHARTYPE* GetValidateError_En(ValidateErrorCode val
\note User can make a copy of this function for localization.
Using switch-case is safer for future modification of error codes.
*/
inline const RAPIDJSON_ERROR_CHARTYPE* GetPointerParseError_En(PointerParseErrorCode pointerParseErrorCode) {
switch (pointerParseErrorCode) {
case kPointerParseErrorNone: return RAPIDJSON_ERROR_STRING("No error.");
inline const RAPIDJSON_ERROR_CHARTYPE*
GetPointerParseError_En(PointerParseErrorCode pointerParseErrorCode)
{
switch(pointerParseErrorCode)
{
case kPointerParseErrorNone: return RAPIDJSON_ERROR_STRING("No error.");
case kPointerParseErrorTokenMustBeginWithSolidus: return RAPIDJSON_ERROR_STRING("A token must begin with a '/'.");
case kPointerParseErrorInvalidEscape: return RAPIDJSON_ERROR_STRING("Invalid escape.");
case kPointerParseErrorInvalidPercentEncoding: return RAPIDJSON_ERROR_STRING("Invalid percent encoding in URI fragment.");
case kPointerParseErrorCharacterMustPercentEncode: return RAPIDJSON_ERROR_STRING("A character must be percent encoded in a URI fragment.");
case kPointerParseErrorTokenMustBeginWithSolidus:
return RAPIDJSON_ERROR_STRING("A token must begin with a '/'.");
case kPointerParseErrorInvalidEscape: return RAPIDJSON_ERROR_STRING("Invalid escape.");
case kPointerParseErrorInvalidPercentEncoding:
return RAPIDJSON_ERROR_STRING("Invalid percent encoding in URI fragment.");
case kPointerParseErrorCharacterMustPercentEncode:
return RAPIDJSON_ERROR_STRING("A character must be percent encoded in a URI fragment.");
default: return RAPIDJSON_ERROR_STRING("Unknown error.");
default: return RAPIDJSON_ERROR_STRING("Unknown error.");
}
}

View File

@@ -61,32 +61,33 @@ RAPIDJSON_NAMESPACE_BEGIN
/*! \ingroup RAPIDJSON_ERRORS
\see GenericReader::Parse, GenericReader::GetParseErrorCode
*/
enum ParseErrorCode {
kParseErrorNone = 0, //!< No error.
enum ParseErrorCode
{
kParseErrorNone = 0, //!< No error.
kParseErrorDocumentEmpty, //!< The document is empty.
kParseErrorDocumentRootNotSingular, //!< The document root must not follow by other values.
kParseErrorDocumentEmpty, //!< The document is empty.
kParseErrorDocumentRootNotSingular, //!< The document root must not follow by other values.
kParseErrorValueInvalid, //!< Invalid value.
kParseErrorValueInvalid, //!< Invalid value.
kParseErrorObjectMissName, //!< Missing a name for object member.
kParseErrorObjectMissColon, //!< Missing a colon after a name of object member.
kParseErrorObjectMissCommaOrCurlyBracket, //!< Missing a comma or '}' after an object member.
kParseErrorObjectMissName, //!< Missing a name for object member.
kParseErrorObjectMissColon, //!< Missing a colon after a name of object member.
kParseErrorObjectMissCommaOrCurlyBracket, //!< Missing a comma or '}' after an object member.
kParseErrorArrayMissCommaOrSquareBracket, //!< Missing a comma or ']' after an array element.
kParseErrorArrayMissCommaOrSquareBracket, //!< Missing a comma or ']' after an array element.
kParseErrorStringUnicodeEscapeInvalidHex, //!< Incorrect hex digit after \\u escape in string.
kParseErrorStringUnicodeSurrogateInvalid, //!< The surrogate pair in string is invalid.
kParseErrorStringEscapeInvalid, //!< Invalid escape character in string.
kParseErrorStringMissQuotationMark, //!< Missing a closing quotation mark in string.
kParseErrorStringInvalidEncoding, //!< Invalid encoding in string.
kParseErrorStringUnicodeEscapeInvalidHex, //!< Incorrect hex digit after \\u escape in string.
kParseErrorStringUnicodeSurrogateInvalid, //!< The surrogate pair in string is invalid.
kParseErrorStringEscapeInvalid, //!< Invalid escape character in string.
kParseErrorStringMissQuotationMark, //!< Missing a closing quotation mark in string.
kParseErrorStringInvalidEncoding, //!< Invalid encoding in string.
kParseErrorNumberTooBig, //!< Number too big to be stored in double.
kParseErrorNumberMissFraction, //!< Miss fraction part in number.
kParseErrorNumberMissExponent, //!< Miss exponent in number.
kParseErrorNumberTooBig, //!< Number too big to be stored in double.
kParseErrorNumberMissFraction, //!< Miss fraction part in number.
kParseErrorNumberMissExponent, //!< Miss exponent in number.
kParseErrorTermination, //!< Parsing was terminated.
kParseErrorUnspecificSyntaxError //!< Unspecific syntax error.
kParseErrorTermination, //!< Parsing was terminated.
kParseErrorUnspecificSyntaxError //!< Unspecific syntax error.
};
//! Result of parsing (wraps ParseErrorCode)
@@ -103,10 +104,12 @@ enum ParseErrorCode {
\endcode
\see GenericReader::Parse, GenericDocument::Parse
*/
struct ParseResult {
struct ParseResult
{
//!! Unspecified boolean type
typedef bool (ParseResult::*BooleanType)() const;
public:
public:
//! Default constructor, no error.
ParseResult() : code_(kParseErrorNone), offset_(0) {}
//! Constructor to set an error.
@@ -124,18 +127,25 @@ public:
bool operator==(const ParseResult& that) const { return code_ == that.code_; }
bool operator==(ParseErrorCode code) const { return code_ == code; }
friend bool operator==(ParseErrorCode code, const ParseResult & err) { return code == err.code_; }
friend bool operator==(ParseErrorCode code, const ParseResult& err)
{
return code == err.code_;
}
bool operator!=(const ParseResult& that) const { return !(*this == that); }
bool operator!=(ParseErrorCode code) const { return !(*this == code); }
friend bool operator!=(ParseErrorCode code, const ParseResult & err) { return err != code; }
friend bool operator!=(ParseErrorCode code, const ParseResult& err) { return err != code; }
//! Reset error code.
void Clear() { Set(kParseErrorNone); }
//! Update error code and offset.
void Set(ParseErrorCode code, size_t offset = 0) { code_ = code; offset_ = offset; }
void Set(ParseErrorCode code, size_t offset = 0)
{
code_ = code;
offset_ = offset;
}
private:
private:
ParseErrorCode code_;
size_t offset_;
};
@@ -159,43 +169,49 @@ typedef const RAPIDJSON_ERROR_CHARTYPE* (*GetParseErrorFunc)(ParseErrorCode);
/*! \ingroup RAPIDJSON_ERRORS
\see GenericSchemaValidator
*/
enum ValidateErrorCode {
kValidateErrors = -1, //!< Top level error code when kValidateContinueOnErrorsFlag set.
kValidateErrorNone = 0, //!< No error.
enum ValidateErrorCode
{
kValidateErrors = -1, //!< Top level error code when kValidateContinueOnErrorsFlag set.
kValidateErrorNone = 0, //!< No error.
kValidateErrorMultipleOf, //!< Number is not a multiple of the 'multipleOf' value.
kValidateErrorMaximum, //!< Number is greater than the 'maximum' value.
kValidateErrorExclusiveMaximum, //!< Number is greater than or equal to the 'maximum' value.
kValidateErrorMinimum, //!< Number is less than the 'minimum' value.
kValidateErrorExclusiveMinimum, //!< Number is less than or equal to the 'minimum' value.
kValidateErrorMultipleOf, //!< Number is not a multiple of the 'multipleOf' value.
kValidateErrorMaximum, //!< Number is greater than the 'maximum' value.
kValidateErrorExclusiveMaximum, //!< Number is greater than or equal to the 'maximum' value.
kValidateErrorMinimum, //!< Number is less than the 'minimum' value.
kValidateErrorExclusiveMinimum, //!< Number is less than or equal to the 'minimum' value.
kValidateErrorMaxLength, //!< String is longer than the 'maxLength' value.
kValidateErrorMinLength, //!< String is longer than the 'maxLength' value.
kValidateErrorPattern, //!< String does not match the 'pattern' regular expression.
kValidateErrorMaxLength, //!< String is longer than the 'maxLength' value.
kValidateErrorMinLength, //!< String is longer than the 'maxLength' value.
kValidateErrorPattern, //!< String does not match the 'pattern' regular expression.
kValidateErrorMaxItems, //!< Array is longer than the 'maxItems' value.
kValidateErrorMinItems, //!< Array is shorter than the 'minItems' value.
kValidateErrorUniqueItems, //!< Array has duplicate items but 'uniqueItems' is true.
kValidateErrorAdditionalItems, //!< Array has additional items that are not allowed by the schema.
kValidateErrorMaxItems, //!< Array is longer than the 'maxItems' value.
kValidateErrorMinItems, //!< Array is shorter than the 'minItems' value.
kValidateErrorUniqueItems, //!< Array has duplicate items but 'uniqueItems' is true.
kValidateErrorAdditionalItems, //!< Array has additional items that are not allowed by the
//!< schema.
kValidateErrorMaxProperties, //!< Object has more members than 'maxProperties' value.
kValidateErrorMinProperties, //!< Object has less members than 'minProperties' value.
kValidateErrorRequired, //!< Object is missing one or more members required by the schema.
kValidateErrorAdditionalProperties, //!< Object has additional members that are not allowed by the schema.
kValidateErrorPatternProperties, //!< See other errors.
kValidateErrorDependencies, //!< Object has missing property or schema dependencies.
kValidateErrorMaxProperties, //!< Object has more members than 'maxProperties' value.
kValidateErrorMinProperties, //!< Object has less members than 'minProperties' value.
kValidateErrorRequired, //!< Object is missing one or more members required by the schema.
kValidateErrorAdditionalProperties, //!< Object has additional members that are not allowed by
//!< the schema.
kValidateErrorPatternProperties, //!< See other errors.
kValidateErrorDependencies, //!< Object has missing property or schema dependencies.
kValidateErrorEnum, //!< Property has a value that is not one of its allowed enumerated values.
kValidateErrorType, //!< Property has a type that is not allowed by the schema.
kValidateErrorEnum, //!< Property has a value that is not one of its allowed enumerated values.
kValidateErrorType, //!< Property has a type that is not allowed by the schema.
kValidateErrorOneOf, //!< Property did not match any of the sub-schemas specified by 'oneOf'.
kValidateErrorOneOfMatch, //!< Property matched more than one of the sub-schemas specified by 'oneOf'.
kValidateErrorAllOf, //!< Property did not match all of the sub-schemas specified by 'allOf'.
kValidateErrorAnyOf, //!< Property did not match any of the sub-schemas specified by 'anyOf'.
kValidateErrorNot, //!< Property matched the sub-schema specified by 'not'.
kValidateErrorOneOf, //!< Property did not match any of the sub-schemas specified by 'oneOf'.
kValidateErrorOneOfMatch, //!< Property matched more than one of the sub-schemas specified by
//!< 'oneOf'.
kValidateErrorAllOf, //!< Property did not match all of the sub-schemas specified by 'allOf'.
kValidateErrorAnyOf, //!< Property did not match any of the sub-schemas specified by 'anyOf'.
kValidateErrorNot, //!< Property matched the sub-schema specified by 'not'.
kValidateErrorReadOnly, //!< Property is read-only but has been provided when validation is for writing
kValidateErrorWriteOnly //!< Property is write-only but has been provided when validation is for reading
kValidateErrorReadOnly, //!< Property is read-only but has been provided when validation is for
//!< writing
kValidateErrorWriteOnly //!< Property is write-only but has been provided when validation is for
//!< reading
};
//! Function pointer type of GetValidateError().
@@ -217,22 +233,25 @@ typedef const RAPIDJSON_ERROR_CHARTYPE* (*GetValidateErrorFunc)(ValidateErrorCod
/*! \ingroup RAPIDJSON_ERRORS
\see GenericSchemaValidator
*/
enum SchemaErrorCode {
kSchemaErrorNone = 0, //!< No error.
enum SchemaErrorCode
{
kSchemaErrorNone = 0, //!< No error.
kSchemaErrorStartUnknown, //!< Pointer to start of schema does not resolve to a location in the document
kSchemaErrorRefPlainName, //!< $ref fragment must be a JSON pointer
kSchemaErrorRefInvalid, //!< $ref must not be an empty string
kSchemaErrorRefPointerInvalid, //!< $ref fragment is not a valid JSON pointer at offset
kSchemaErrorRefUnknown, //!< $ref does not resolve to a location in the target document
kSchemaErrorRefCyclical, //!< $ref is cyclical
kSchemaErrorRefNoRemoteProvider, //!< $ref is remote but there is no remote provider
kSchemaErrorRefNoRemoteSchema, //!< $ref is remote but the remote provider did not return a schema
kSchemaErrorRegexInvalid, //!< Invalid regular expression in 'pattern' or 'patternProperties'
kSchemaErrorSpecUnknown, //!< JSON schema draft or OpenAPI version is not recognized
kSchemaErrorSpecUnsupported, //!< JSON schema draft or OpenAPI version is not supported
kSchemaErrorSpecIllegal, //!< Both JSON schema draft and OpenAPI version found in document
kSchemaErrorReadOnlyAndWriteOnly //!< Property must not be both 'readOnly' and 'writeOnly'
kSchemaErrorStartUnknown, //!< Pointer to start of schema does not resolve to a location in the
//!< document
kSchemaErrorRefPlainName, //!< $ref fragment must be a JSON pointer
kSchemaErrorRefInvalid, //!< $ref must not be an empty string
kSchemaErrorRefPointerInvalid, //!< $ref fragment is not a valid JSON pointer at offset
kSchemaErrorRefUnknown, //!< $ref does not resolve to a location in the target document
kSchemaErrorRefCyclical, //!< $ref is cyclical
kSchemaErrorRefNoRemoteProvider, //!< $ref is remote but there is no remote provider
kSchemaErrorRefNoRemoteSchema, //!< $ref is remote but the remote provider did not return a
//!< schema
kSchemaErrorRegexInvalid, //!< Invalid regular expression in 'pattern' or 'patternProperties'
kSchemaErrorSpecUnknown, //!< JSON schema draft or OpenAPI version is not recognized
kSchemaErrorSpecUnsupported, //!< JSON schema draft or OpenAPI version is not supported
kSchemaErrorSpecIllegal, //!< Both JSON schema draft and OpenAPI version found in document
kSchemaErrorReadOnlyAndWriteOnly //!< Property must not be both 'readOnly' and 'writeOnly'
};
//! Function pointer type of GetSchemaError().
@@ -254,13 +273,15 @@ typedef const RAPIDJSON_ERROR_CHARTYPE* (*GetSchemaErrorFunc)(SchemaErrorCode);
/*! \ingroup RAPIDJSON_ERRORS
\see GenericPointer::GenericPointer, GenericPointer::GetParseErrorCode
*/
enum PointerParseErrorCode {
kPointerParseErrorNone = 0, //!< The parse is successful
enum PointerParseErrorCode
{
kPointerParseErrorNone = 0, //!< The parse is successful
kPointerParseErrorTokenMustBeginWithSolidus, //!< A token must begin with a '/'
kPointerParseErrorInvalidEscape, //!< Invalid escape
kPointerParseErrorInvalidPercentEncoding, //!< Invalid percent encoding in URI fragment
kPointerParseErrorCharacterMustPercentEncode //!< A character must percent encoded in URI fragment
kPointerParseErrorTokenMustBeginWithSolidus, //!< A token must begin with a '/'
kPointerParseErrorInvalidEscape, //!< Invalid escape
kPointerParseErrorInvalidPercentEncoding, //!< Invalid percent encoding in URI fragment
kPointerParseErrorCharacterMustPercentEncode //!< A character must percent encoded in URI
//!< fragment
};
//! Function pointer type of GetPointerParseError().
@@ -275,7 +296,6 @@ enum PointerParseErrorCode {
*/
typedef const RAPIDJSON_ERROR_CHARTYPE* (*GetPointerParseErrorFunc)(PointerParseErrorCode);
RAPIDJSON_NAMESPACE_END
#ifdef __clang__

View File

@@ -1,5 +1,5 @@
// Tencent is pleased to support the open source community by making RapidJSON available.
//
//
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
//
// Licensed under the MIT License (the "License"); you may not use this file except
@@ -7,9 +7,9 @@
//
// http://opensource.org/licenses/MIT
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
#ifndef RAPIDJSON_FILEREADSTREAM_H_
@@ -21,8 +21,8 @@
#ifdef __clang__
RAPIDJSON_DIAG_PUSH
RAPIDJSON_DIAG_OFF(padded)
RAPIDJSON_DIAG_OFF(unreachable-code)
RAPIDJSON_DIAG_OFF(missing-noreturn)
RAPIDJSON_DIAG_OFF(unreachable - code)
RAPIDJSON_DIAG_OFF(missing - noreturn)
#endif
RAPIDJSON_NAMESPACE_BEGIN
@@ -31,9 +31,10 @@ RAPIDJSON_NAMESPACE_BEGIN
/*!
\note implements Stream concept
*/
class FileReadStream {
public:
typedef char Ch; //!< Character type (byte).
class FileReadStream
{
public:
typedef char Ch; //!< Character type (byte).
//! Constructor.
/*!
@@ -41,38 +42,61 @@ public:
\param buffer user-supplied buffer.
\param bufferSize size of buffer in bytes. Must >=4 bytes.
*/
FileReadStream(std::FILE* fp, char* buffer, size_t bufferSize) : fp_(fp), buffer_(buffer), bufferSize_(bufferSize), bufferLast_(0), current_(buffer_), readCount_(0), count_(0), eof_(false) {
FileReadStream(std::FILE* fp, char* buffer, size_t bufferSize)
: fp_(fp),
buffer_(buffer),
bufferSize_(bufferSize),
bufferLast_(0),
current_(buffer_),
readCount_(0),
count_(0),
eof_(false)
{
RAPIDJSON_ASSERT(fp_ != 0);
RAPIDJSON_ASSERT(bufferSize >= 4);
Read();
}
Ch Peek() const { return *current_; }
Ch Take() { Ch c = *current_; Read(); return c; }
Ch Take()
{
Ch c = *current_;
Read();
return c;
}
size_t Tell() const { return count_ + static_cast<size_t>(current_ - buffer_); }
// Not implemented
void Put(Ch) { RAPIDJSON_ASSERT(false); }
void Flush() { RAPIDJSON_ASSERT(false); }
Ch* PutBegin() { RAPIDJSON_ASSERT(false); return 0; }
size_t PutEnd(Ch*) { RAPIDJSON_ASSERT(false); return 0; }
// For encoding detection only.
const Ch* Peek4() const {
return (current_ + 4 - !eof_ <= bufferLast_) ? current_ : 0;
void Flush() { RAPIDJSON_ASSERT(false); }
Ch* PutBegin()
{
RAPIDJSON_ASSERT(false);
return 0;
}
size_t PutEnd(Ch*)
{
RAPIDJSON_ASSERT(false);
return 0;
}
private:
void Read() {
if (current_ < bufferLast_)
++current_;
else if (!eof_) {
count_ += readCount_;
readCount_ = std::fread(buffer_, 1, bufferSize_, fp_);
bufferLast_ = buffer_ + readCount_ - 1;
current_ = buffer_;
// For encoding detection only.
const Ch* Peek4() const { return (current_ + 4 - !eof_ <= bufferLast_) ? current_ : 0; }
if (readCount_ < bufferSize_) {
private:
void Read()
{
if(current_ < bufferLast_)
++current_;
else if(!eof_)
{
count_ += readCount_;
readCount_ = std::fread(buffer_, 1, bufferSize_, fp_);
bufferLast_ = buffer_ + readCount_ - 1;
current_ = buffer_;
if(readCount_ < bufferSize_)
{
buffer_[readCount_] = '\0';
++bufferLast_;
eof_ = true;
@@ -81,12 +105,12 @@ private:
}
std::FILE* fp_;
Ch *buffer_;
Ch* buffer_;
size_t bufferSize_;
Ch *bufferLast_;
Ch *current_;
Ch* bufferLast_;
Ch* current_;
size_t readCount_;
size_t count_; //!< Number of characters read
size_t count_; //!< Number of characters read
bool eof_;
};

View File

@@ -1,5 +1,5 @@
// Tencent is pleased to support the open source community by making RapidJSON available.
//
//
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
//
// Licensed under the MIT License (the "License"); you may not use this file except
@@ -7,9 +7,9 @@
//
// http://opensource.org/licenses/MIT
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
#ifndef RAPIDJSON_FILEWRITESTREAM_H_
@@ -20,7 +20,7 @@
#ifdef __clang__
RAPIDJSON_DIAG_PUSH
RAPIDJSON_DIAG_OFF(unreachable-code)
RAPIDJSON_DIAG_OFF(unreachable - code)
#endif
RAPIDJSON_NAMESPACE_BEGIN
@@ -29,24 +29,30 @@ RAPIDJSON_NAMESPACE_BEGIN
/*!
\note implements Stream concept
*/
class FileWriteStream {
public:
typedef char Ch; //!< Character type. Only support char.
class FileWriteStream
{
public:
typedef char Ch; //!< Character type. Only support char.
FileWriteStream(std::FILE* fp, char* buffer, size_t bufferSize) : fp_(fp), buffer_(buffer), bufferEnd_(buffer + bufferSize), current_(buffer_) {
FileWriteStream(std::FILE* fp, char* buffer, size_t bufferSize)
: fp_(fp), buffer_(buffer), bufferEnd_(buffer + bufferSize), current_(buffer_)
{
RAPIDJSON_ASSERT(fp_ != 0);
}
void Put(char c) {
if (current_ >= bufferEnd_)
void Put(char c)
{
if(current_ >= bufferEnd_)
Flush();
*current_++ = c;
}
void PutN(char c, size_t n) {
void PutN(char c, size_t n)
{
size_t avail = static_cast<size_t>(bufferEnd_ - current_);
while (n > avail) {
while(n > avail)
{
std::memset(current_, c, avail);
current_ += avail;
Flush();
@@ -54,16 +60,20 @@ public:
avail = static_cast<size_t>(bufferEnd_ - current_);
}
if (n > 0) {
if(n > 0)
{
std::memset(current_, c, n);
current_ += n;
}
}
void Flush() {
if (current_ != buffer_) {
void Flush()
{
if(current_ != buffer_)
{
size_t result = std::fwrite(buffer_, 1, static_cast<size_t>(current_ - buffer_), fp_);
if (result < static_cast<size_t>(current_ - buffer_)) {
if(result < static_cast<size_t>(current_ - buffer_))
{
// failure deliberately ignored at this time
// added to avoid warn_unused_result build errors
}
@@ -72,26 +82,47 @@ public:
}
// Not implemented
char Peek() const { RAPIDJSON_ASSERT(false); return 0; }
char Take() { RAPIDJSON_ASSERT(false); return 0; }
size_t Tell() const { RAPIDJSON_ASSERT(false); return 0; }
char* PutBegin() { RAPIDJSON_ASSERT(false); return 0; }
size_t PutEnd(char*) { RAPIDJSON_ASSERT(false); return 0; }
char Peek() const
{
RAPIDJSON_ASSERT(false);
return 0;
}
char Take()
{
RAPIDJSON_ASSERT(false);
return 0;
}
size_t Tell() const
{
RAPIDJSON_ASSERT(false);
return 0;
}
char* PutBegin()
{
RAPIDJSON_ASSERT(false);
return 0;
}
size_t PutEnd(char*)
{
RAPIDJSON_ASSERT(false);
return 0;
}
private:
private:
// Prohibit copy constructor & assignment operator.
FileWriteStream(const FileWriteStream&);
FileWriteStream& operator=(const FileWriteStream&);
std::FILE* fp_;
char *buffer_;
char *bufferEnd_;
char *current_;
char* buffer_;
char* bufferEnd_;
char* current_;
};
//! Implement specialized version of PutN() with memset() for better performance.
template<>
inline void PutN(FileWriteStream& stream, char c, size_t n) {
template <>
inline void PutN(FileWriteStream& stream, char c, size_t n)
{
stream.PutN(c, n);
}

View File

@@ -1,5 +1,5 @@
// Tencent is pleased to support the open source community by making RapidJSON available.
//
//
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
//
// Licensed under the MIT License (the "License"); you may not use this file except
@@ -7,9 +7,9 @@
//
// http://opensource.org/licenses/MIT
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
#ifndef RAPIDJSON_FWD_H_
@@ -21,17 +21,26 @@ RAPIDJSON_NAMESPACE_BEGIN
// encodings.h
template<typename CharType> struct UTF8;
template<typename CharType> struct UTF16;
template<typename CharType> struct UTF16BE;
template<typename CharType> struct UTF16LE;
template<typename CharType> struct UTF32;
template<typename CharType> struct UTF32BE;
template<typename CharType> struct UTF32LE;
template<typename CharType> struct ASCII;
template<typename CharType> struct AutoUTF;
template <typename CharType>
struct UTF8;
template <typename CharType>
struct UTF16;
template <typename CharType>
struct UTF16BE;
template <typename CharType>
struct UTF16LE;
template <typename CharType>
struct UTF32;
template <typename CharType>
struct UTF32BE;
template <typename CharType>
struct UTF32LE;
template <typename CharType>
struct ASCII;
template <typename CharType>
struct AutoUTF;
template<typename SourceEncoding, typename TargetEncoding>
template <typename SourceEncoding, typename TargetEncoding>
struct Transcoder;
// allocators.h
@@ -46,12 +55,12 @@ class MemoryPoolAllocator;
template <typename Encoding>
struct GenericStringStream;
typedef GenericStringStream<UTF8<char> > StringStream;
typedef GenericStringStream<UTF8<char>> StringStream;
template <typename Encoding>
struct GenericInsituStringStream;
typedef GenericInsituStringStream<UTF8<char> > InsituStringStream;
typedef GenericInsituStringStream<UTF8<char>> InsituStringStream;
// stringbuffer.h
@@ -81,7 +90,7 @@ struct MemoryStream;
// reader.h
template<typename Encoding, typename Derived>
template <typename Encoding, typename Derived>
struct BaseReaderHandler;
template <typename SourceEncoding, typename TargetEncoding, typename StackAllocator>
@@ -91,29 +100,37 @@ typedef GenericReader<UTF8<char>, UTF8<char>, CrtAllocator> Reader;
// writer.h
template<typename OutputStream, typename SourceEncoding, typename TargetEncoding, typename StackAllocator, unsigned writeFlags>
template <typename OutputStream,
typename SourceEncoding,
typename TargetEncoding,
typename StackAllocator,
unsigned writeFlags>
class Writer;
// prettywriter.h
template<typename OutputStream, typename SourceEncoding, typename TargetEncoding, typename StackAllocator, unsigned writeFlags>
template <typename OutputStream,
typename SourceEncoding,
typename TargetEncoding,
typename StackAllocator,
unsigned writeFlags>
class PrettyWriter;
// document.h
template <typename Encoding, typename Allocator>
template <typename Encoding, typename Allocator>
class GenericMember;
template <bool Const, typename Encoding, typename Allocator>
class GenericMemberIterator;
template<typename CharType>
template <typename CharType>
struct GenericStringRef;
template <typename Encoding, typename Allocator>
template <typename Encoding, typename Allocator>
class GenericValue;
typedef GenericValue<UTF8<char>, MemoryPoolAllocator<CrtAllocator> > Value;
typedef GenericValue<UTF8<char>, MemoryPoolAllocator<CrtAllocator>> Value;
template <typename Encoding, typename Allocator, typename StackAllocator>
class GenericDocument;
@@ -138,13 +155,11 @@ class GenericSchemaDocument;
typedef GenericSchemaDocument<Value, CrtAllocator> SchemaDocument;
typedef IGenericRemoteSchemaDocumentProvider<SchemaDocument> IRemoteSchemaDocumentProvider;
template <
typename SchemaDocumentType,
typename OutputHandler,
typename StateAllocator>
template <typename SchemaDocumentType, typename OutputHandler, typename StateAllocator>
class GenericSchemaValidator;
typedef GenericSchemaValidator<SchemaDocument, BaseReaderHandler<UTF8<char>, void>, CrtAllocator> SchemaValidator;
typedef GenericSchemaValidator<SchemaDocument, BaseReaderHandler<UTF8<char>, void>, CrtAllocator>
SchemaValidator;
RAPIDJSON_NAMESPACE_END

View File

@@ -1,5 +1,5 @@
// Tencent is pleased to support the open source community by making RapidJSON available.
//
//
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
//
// Licensed under the MIT License (the "License"); you may not use this file except
@@ -7,9 +7,9 @@
//
// http://opensource.org/licenses/MIT
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
#ifndef RAPIDJSON_BIGINTEGER_H_
@@ -22,132 +22,153 @@
#if !defined(_ARM64EC_)
#pragma intrinsic(_umul128)
#else
#pragma comment(lib,"softintrin")
#pragma comment(lib, "softintrin")
#endif
#endif
RAPIDJSON_NAMESPACE_BEGIN
namespace internal {
class BigInteger {
public:
class BigInteger
{
public:
typedef uint64_t Type;
BigInteger(const BigInteger& rhs) : count_(rhs.count_) {
BigInteger(const BigInteger& rhs) : count_(rhs.count_)
{
std::memcpy(digits_, rhs.digits_, count_ * sizeof(Type));
}
explicit BigInteger(uint64_t u) : count_(1) {
digits_[0] = u;
}
explicit BigInteger(uint64_t u) : count_(1) { digits_[0] = u; }
template<typename Ch>
BigInteger(const Ch* decimals, size_t length) : count_(1) {
template <typename Ch>
BigInteger(const Ch* decimals, size_t length) : count_(1)
{
RAPIDJSON_ASSERT(length > 0);
digits_[0] = 0;
size_t i = 0;
const size_t kMaxDigitPerIteration = 19; // 2^64 = 18446744073709551616 > 10^19
while (length >= kMaxDigitPerIteration) {
digits_[0] = 0;
size_t i = 0;
const size_t kMaxDigitPerIteration = 19; // 2^64 = 18446744073709551616 > 10^19
while(length >= kMaxDigitPerIteration)
{
AppendDecimal64(decimals + i, decimals + i + kMaxDigitPerIteration);
length -= kMaxDigitPerIteration;
i += kMaxDigitPerIteration;
}
if (length > 0)
if(length > 0)
AppendDecimal64(decimals + i, decimals + i + length);
}
BigInteger& operator=(const BigInteger &rhs)
BigInteger& operator=(const BigInteger& rhs)
{
if (this != &rhs) {
if(this != &rhs)
{
count_ = rhs.count_;
std::memcpy(digits_, rhs.digits_, count_ * sizeof(Type));
}
return *this;
}
BigInteger& operator=(uint64_t u) {
digits_[0] = u;
count_ = 1;
BigInteger& operator=(uint64_t u)
{
digits_[0] = u;
count_ = 1;
return *this;
}
BigInteger& operator+=(uint64_t u) {
BigInteger& operator+=(uint64_t u)
{
Type backup = digits_[0];
digits_[0] += u;
for (size_t i = 0; i < count_ - 1; i++) {
if (digits_[i] >= backup)
for(size_t i = 0; i < count_ - 1; i++)
{
if(digits_[i] >= backup)
return *this; // no carry
backup = digits_[i + 1];
digits_[i + 1] += 1;
}
// Last carry
if (digits_[count_ - 1] < backup)
if(digits_[count_ - 1] < backup)
PushBack(1);
return *this;
}
BigInteger& operator*=(uint64_t u) {
if (u == 0) return *this = 0;
if (u == 1) return *this;
if (*this == 1) return *this = u;
BigInteger& operator*=(uint64_t u)
{
if(u == 0)
return *this = 0;
if(u == 1)
return *this;
if(*this == 1)
return *this = u;
uint64_t k = 0;
for (size_t i = 0; i < count_; i++) {
for(size_t i = 0; i < count_; i++)
{
uint64_t hi;
digits_[i] = MulAdd64(digits_[i], u, k, &hi);
k = hi;
k = hi;
}
if (k > 0)
if(k > 0)
PushBack(k);
return *this;
}
BigInteger& operator*=(uint32_t u) {
if (u == 0) return *this = 0;
if (u == 1) return *this;
if (*this == 1) return *this = u;
BigInteger& operator*=(uint32_t u)
{
if(u == 0)
return *this = 0;
if(u == 1)
return *this;
if(*this == 1)
return *this = u;
uint64_t k = 0;
for (size_t i = 0; i < count_; i++) {
const uint64_t c = digits_[i] >> 32;
const uint64_t d = digits_[i] & 0xFFFFFFFF;
for(size_t i = 0; i < count_; i++)
{
const uint64_t c = digits_[i] >> 32;
const uint64_t d = digits_[i] & 0xFFFFFFFF;
const uint64_t uc = u * c;
const uint64_t ud = u * d;
const uint64_t p0 = ud + k;
const uint64_t p1 = uc + (p0 >> 32);
digits_[i] = (p0 & 0xFFFFFFFF) | (p1 << 32);
k = p1 >> 32;
digits_[i] = (p0 & 0xFFFFFFFF) | (p1 << 32);
k = p1 >> 32;
}
if (k > 0)
if(k > 0)
PushBack(k);
return *this;
}
BigInteger& operator<<=(size_t shift) {
if (IsZero() || shift == 0) return *this;
BigInteger& operator<<=(size_t shift)
{
if(IsZero() || shift == 0)
return *this;
size_t offset = shift / kTypeBit;
size_t offset = shift / kTypeBit;
size_t interShift = shift % kTypeBit;
RAPIDJSON_ASSERT(count_ + offset <= kCapacity);
if (interShift == 0) {
if(interShift == 0)
{
std::memmove(digits_ + offset, digits_, count_ * sizeof(Type));
count_ += offset;
}
else {
else
{
digits_[count_] = 0;
for (size_t i = count_; i > 0; i--)
digits_[i + offset] = (digits_[i] << interShift) | (digits_[i - 1] >> (kTypeBit - interShift));
for(size_t i = count_; i > 0; i--)
digits_[i + offset] =
(digits_[i] << interShift) | (digits_[i - 1] >> (kTypeBit - interShift));
digits_[offset] = digits_[0] << interShift;
count_ += offset;
if (digits_[count_])
if(digits_[count_])
count_++;
}
@@ -156,96 +177,121 @@ public:
return *this;
}
bool operator==(const BigInteger& rhs) const {
return count_ == rhs.count_ && std::memcmp(digits_, rhs.digits_, count_ * sizeof(Type)) == 0;
bool operator==(const BigInteger& rhs) const
{
return count_ == rhs.count_ &&
std::memcmp(digits_, rhs.digits_, count_ * sizeof(Type)) == 0;
}
bool operator==(const Type rhs) const {
return count_ == 1 && digits_[0] == rhs;
}
bool operator==(const Type rhs) const { return count_ == 1 && digits_[0] == rhs; }
BigInteger& MultiplyPow5(unsigned exp) {
static const uint32_t kPow5[12] = {
5,
5 * 5,
5 * 5 * 5,
5 * 5 * 5 * 5,
5 * 5 * 5 * 5 * 5,
5 * 5 * 5 * 5 * 5 * 5,
5 * 5 * 5 * 5 * 5 * 5 * 5,
5 * 5 * 5 * 5 * 5 * 5 * 5 * 5,
5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5,
5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5,
5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5,
5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5
};
if (exp == 0) return *this;
for (; exp >= 27; exp -= 27) *this *= RAPIDJSON_UINT64_C2(0X6765C793, 0XFA10079D); // 5^27
for (; exp >= 13; exp -= 13) *this *= static_cast<uint32_t>(1220703125u); // 5^13
if (exp > 0) *this *= kPow5[exp - 1];
BigInteger& MultiplyPow5(unsigned exp)
{
static const uint32_t kPow5[12] = {5,
5 * 5,
5 * 5 * 5,
5 * 5 * 5 * 5,
5 * 5 * 5 * 5 * 5,
5 * 5 * 5 * 5 * 5 * 5,
5 * 5 * 5 * 5 * 5 * 5 * 5,
5 * 5 * 5 * 5 * 5 * 5 * 5 * 5,
5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5,
5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5,
5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5,
5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5 * 5};
if(exp == 0)
return *this;
for(; exp >= 27; exp -= 27)
*this *= RAPIDJSON_UINT64_C2(0X6765C793, 0XFA10079D); // 5^27
for(; exp >= 13; exp -= 13)
*this *= static_cast<uint32_t>(1220703125u); // 5^13
if(exp > 0)
*this *= kPow5[exp - 1];
return *this;
}
// Compute absolute difference of this and rhs.
// Assume this != rhs
bool Difference(const BigInteger& rhs, BigInteger* out) const {
bool Difference(const BigInteger& rhs, BigInteger* out) const
{
int cmp = Compare(rhs);
RAPIDJSON_ASSERT(cmp != 0);
const BigInteger *a, *b; // Makes a > b
const BigInteger *a, *b; // Makes a > b
bool ret;
if (cmp < 0) { a = &rhs; b = this; ret = true; }
else { a = this; b = &rhs; ret = false; }
if(cmp < 0)
{
a = &rhs;
b = this;
ret = true;
}
else
{
a = this;
b = &rhs;
ret = false;
}
Type borrow = 0;
for (size_t i = 0; i < a->count_; i++) {
for(size_t i = 0; i < a->count_; i++)
{
Type d = a->digits_[i] - borrow;
if (i < b->count_)
if(i < b->count_)
d -= b->digits_[i];
borrow = (d > a->digits_[i]) ? 1 : 0;
borrow = (d > a->digits_[i]) ? 1 : 0;
out->digits_[i] = d;
if (d != 0)
if(d != 0)
out->count_ = i + 1;
}
return ret;
}
int Compare(const BigInteger& rhs) const {
if (count_ != rhs.count_)
int Compare(const BigInteger& rhs) const
{
if(count_ != rhs.count_)
return count_ < rhs.count_ ? -1 : 1;
for (size_t i = count_; i-- > 0;)
if (digits_[i] != rhs.digits_[i])
for(size_t i = count_; i-- > 0;)
if(digits_[i] != rhs.digits_[i])
return digits_[i] < rhs.digits_[i] ? -1 : 1;
return 0;
}
size_t GetCount() const { return count_; }
Type GetDigit(size_t index) const { RAPIDJSON_ASSERT(index < count_); return digits_[index]; }
Type GetDigit(size_t index) const
{
RAPIDJSON_ASSERT(index < count_);
return digits_[index];
}
bool IsZero() const { return count_ == 1 && digits_[0] == 0; }
private:
template<typename Ch>
void AppendDecimal64(const Ch* begin, const Ch* end) {
private:
template <typename Ch>
void AppendDecimal64(const Ch* begin, const Ch* end)
{
uint64_t u = ParseUint64(begin, end);
if (IsZero())
if(IsZero())
*this = u;
else {
else
{
unsigned exp = static_cast<unsigned>(end - begin);
(MultiplyPow5(exp) <<= exp) += u; // *this = *this * 10^exp + u
(MultiplyPow5(exp) <<= exp) += u; // *this = *this * 10^exp + u
}
}
void PushBack(Type digit) {
void PushBack(Type digit)
{
RAPIDJSON_ASSERT(count_ < kCapacity);
digits_[count_++] = digit;
}
template<typename Ch>
static uint64_t ParseUint64(const Ch* begin, const Ch* end) {
template <typename Ch>
static uint64_t ParseUint64(const Ch* begin, const Ch* end)
{
uint64_t r = 0;
for (const Ch* p = begin; p != end; ++p) {
for(const Ch* p = begin; p != end; ++p)
{
RAPIDJSON_ASSERT(*p >= Ch('0') && *p <= Ch('9'));
r = r * 10u + static_cast<unsigned>(*p - Ch('0'));
}
@@ -253,13 +299,15 @@ private:
}
// Assume a * b + k < 2^128
static uint64_t MulAdd64(uint64_t a, uint64_t b, uint64_t k, uint64_t* outHigh) {
static uint64_t MulAdd64(uint64_t a, uint64_t b, uint64_t k, uint64_t* outHigh)
{
#if defined(_MSC_VER) && defined(_M_AMD64)
uint64_t low = _umul128(a, b, outHigh) + k;
if (low < k)
if(low < k)
(*outHigh)++;
return low;
#elif defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6)) && defined(__x86_64__)
#elif defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6)) && \
defined(__x86_64__)
__extension__ typedef unsigned __int128 uint128;
uint128 p = static_cast<uint128>(a) * static_cast<uint128>(b);
p += k;
@@ -270,22 +318,22 @@ private:
uint64_t x0 = a0 * b0, x1 = a0 * b1, x2 = a1 * b0, x3 = a1 * b1;
x1 += (x0 >> 32); // can't give carry
x1 += x2;
if (x1 < x2)
if(x1 < x2)
x3 += (static_cast<uint64_t>(1) << 32);
uint64_t lo = (x1 << 32) + (x0 & 0xFFFFFFFF);
uint64_t hi = x3 + (x1 >> 32);
lo += k;
if (lo < k)
if(lo < k)
hi++;
*outHigh = hi;
return lo;
#endif
}
static const size_t kBitCount = 3328; // 64bit * 54 > 10^1000
static const size_t kBitCount = 3328; // 64bit * 54 > 10^1000
static const size_t kCapacity = kBitCount / sizeof(Type);
static const size_t kTypeBit = sizeof(Type) * 8;
static const size_t kTypeBit = sizeof(Type) * 8;
Type digits_[kCapacity];
size_t count_;

View File

@@ -29,7 +29,8 @@
RAPIDJSON_NAMESPACE_BEGIN
namespace internal {
inline uint32_t clzll(uint64_t x) {
inline uint32_t clzll(uint64_t x)
{
// Passing 0 to __builtin_clzll is UB in GCC and results in an
// infinite loop in the software implementation.
RAPIDJSON_ASSERT(x != 0);
@@ -40,7 +41,7 @@ inline uint32_t clzll(uint64_t x) {
_BitScanReverse64(&r, x);
#else
// Scan the high 32 bits.
if (_BitScanReverse(&r, static_cast<uint32_t>(x >> 32)))
if(_BitScanReverse(&r, static_cast<uint32_t>(x >> 32)))
return 63 - (r + 32);
// Scan the low 32 bits.
@@ -48,13 +49,14 @@ inline uint32_t clzll(uint64_t x) {
#endif // _WIN64
return 63 - r;
#elif (defined(__GNUC__) && __GNUC__ >= 4) || RAPIDJSON_HAS_BUILTIN(__builtin_clzll)
#elif(defined(__GNUC__) && __GNUC__ >= 4) || RAPIDJSON_HAS_BUILTIN(__builtin_clzll)
// __builtin_clzll wrapper
return static_cast<uint32_t>(__builtin_clzll(x));
#else
// naive version
uint32_t r = 0;
while (!(x & (static_cast<uint64_t>(1) << 63))) {
while(!(x & (static_cast<uint64_t>(1) << 63)))
{
x <<= 1;
++r;
}

View File

@@ -28,7 +28,7 @@
#if !defined(_ARM64EC_)
#pragma intrinsic(_umul128)
#else
#pragma comment(lib,"softintrin")
#pragma comment(lib, "softintrin")
#endif
#endif
@@ -45,72 +45,80 @@ RAPIDJSON_DIAG_PUSH
RAPIDJSON_DIAG_OFF(padded)
#endif
struct DiyFp {
struct DiyFp
{
DiyFp() : f(), e() {}
DiyFp(uint64_t fp, int exp) : f(fp), e(exp) {}
explicit DiyFp(double d) {
union {
explicit DiyFp(double d)
{
union
{
double d;
uint64_t u64;
} u = { d };
} u = {d};
int biased_e = static_cast<int>((u.u64 & kDpExponentMask) >> kDpSignificandSize);
int biased_e = static_cast<int>((u.u64 & kDpExponentMask) >> kDpSignificandSize);
uint64_t significand = (u.u64 & kDpSignificandMask);
if (biased_e != 0) {
if(biased_e != 0)
{
f = significand + kDpHiddenBit;
e = biased_e - kDpExponentBias;
}
else {
else
{
f = significand;
e = kDpMinExponent + 1;
}
}
DiyFp operator-(const DiyFp& rhs) const {
return DiyFp(f - rhs.f, e);
}
DiyFp operator-(const DiyFp& rhs) const { return DiyFp(f - rhs.f, e); }
DiyFp operator*(const DiyFp& rhs) const {
DiyFp operator*(const DiyFp& rhs) const
{
#if defined(_MSC_VER) && defined(_M_AMD64)
uint64_t h;
uint64_t l = _umul128(f, rhs.f, &h);
if (l & (uint64_t(1) << 63)) // rounding
if(l & (uint64_t(1) << 63)) // rounding
h++;
return DiyFp(h, e + rhs.e + 64);
#elif defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6)) && defined(__x86_64__)
#elif defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6)) && \
defined(__x86_64__)
__extension__ typedef unsigned __int128 uint128;
uint128 p = static_cast<uint128>(f) * static_cast<uint128>(rhs.f);
uint128 p = static_cast<uint128>(f) * static_cast<uint128>(rhs.f);
uint64_t h = static_cast<uint64_t>(p >> 64);
uint64_t l = static_cast<uint64_t>(p);
if (l & (uint64_t(1) << 63)) // rounding
if(l & (uint64_t(1) << 63)) // rounding
h++;
return DiyFp(h, e + rhs.e + 64);
#else
const uint64_t M32 = 0xFFFFFFFF;
const uint64_t a = f >> 32;
const uint64_t b = f & M32;
const uint64_t c = rhs.f >> 32;
const uint64_t d = rhs.f & M32;
const uint64_t ac = a * c;
const uint64_t bc = b * c;
const uint64_t ad = a * d;
const uint64_t bd = b * d;
uint64_t tmp = (bd >> 32) + (ad & M32) + (bc & M32);
tmp += 1U << 31; /// mult_round
const uint64_t a = f >> 32;
const uint64_t b = f & M32;
const uint64_t c = rhs.f >> 32;
const uint64_t d = rhs.f & M32;
const uint64_t ac = a * c;
const uint64_t bc = b * c;
const uint64_t ad = a * d;
const uint64_t bd = b * d;
uint64_t tmp = (bd >> 32) + (ad & M32) + (bc & M32);
tmp += 1U << 31; /// mult_round
return DiyFp(ac + (ad >> 32) + (bc >> 32) + (tmp >> 32), e + rhs.e + 64);
#endif
}
DiyFp Normalize() const {
DiyFp Normalize() const
{
int s = static_cast<int>(clzll(f));
return DiyFp(f << s, e - s);
}
DiyFp NormalizeBoundary() const {
DiyFp NormalizeBoundary() const
{
DiyFp res = *this;
while (!(res.f & (kDpHiddenBit << 1))) {
while(!(res.f & (kDpHiddenBit << 1)))
{
res.f <<= 1;
res.e--;
}
@@ -119,50 +127,57 @@ struct DiyFp {
return res;
}
void NormalizedBoundaries(DiyFp* minus, DiyFp* plus) const {
void NormalizedBoundaries(DiyFp* minus, DiyFp* plus) const
{
DiyFp pl = DiyFp((f << 1) + 1, e - 1).NormalizeBoundary();
DiyFp mi = (f == kDpHiddenBit) ? DiyFp((f << 2) - 1, e - 2) : DiyFp((f << 1) - 1, e - 1);
mi.f <<= mi.e - pl.e;
mi.e = pl.e;
*plus = pl;
mi.e = pl.e;
*plus = pl;
*minus = mi;
}
double ToDouble() const {
union {
double ToDouble() const
{
union
{
double d;
uint64_t u64;
}u;
} u;
RAPIDJSON_ASSERT(f <= kDpHiddenBit + kDpSignificandMask);
if (e < kDpDenormalExponent) {
if(e < kDpDenormalExponent)
{
// Underflow.
return 0.0;
}
if (e >= kDpMaxExponent) {
if(e >= kDpMaxExponent)
{
// Overflow.
return std::numeric_limits<double>::infinity();
}
const uint64_t be = (e == kDpDenormalExponent && (f & kDpHiddenBit) == 0) ? 0 :
static_cast<uint64_t>(e + kDpExponentBias);
u.u64 = (f & kDpSignificandMask) | (be << kDpSignificandSize);
const uint64_t be = (e == kDpDenormalExponent && (f & kDpHiddenBit) == 0)
? 0
: static_cast<uint64_t>(e + kDpExponentBias);
u.u64 = (f & kDpSignificandMask) | (be << kDpSignificandSize);
return u.d;
}
static const int kDiySignificandSize = 64;
static const int kDpSignificandSize = 52;
static const int kDpExponentBias = 0x3FF + kDpSignificandSize;
static const int kDpMaxExponent = 0x7FF - kDpExponentBias;
static const int kDpMinExponent = -kDpExponentBias;
static const int kDpDenormalExponent = -kDpExponentBias + 1;
static const uint64_t kDpExponentMask = RAPIDJSON_UINT64_C2(0x7FF00000, 0x00000000);
static const int kDiySignificandSize = 64;
static const int kDpSignificandSize = 52;
static const int kDpExponentBias = 0x3FF + kDpSignificandSize;
static const int kDpMaxExponent = 0x7FF - kDpExponentBias;
static const int kDpMinExponent = -kDpExponentBias;
static const int kDpDenormalExponent = -kDpExponentBias + 1;
static const uint64_t kDpExponentMask = RAPIDJSON_UINT64_C2(0x7FF00000, 0x00000000);
static const uint64_t kDpSignificandMask = RAPIDJSON_UINT64_C2(0x000FFFFF, 0xFFFFFFFF);
static const uint64_t kDpHiddenBit = RAPIDJSON_UINT64_C2(0x00100000, 0x00000000);
static const uint64_t kDpHiddenBit = RAPIDJSON_UINT64_C2(0x00100000, 0x00000000);
uint64_t f;
int e;
};
inline DiyFp GetCachedPowerByIndex(size_t index) {
inline DiyFp GetCachedPowerByIndex(size_t index)
{
// 10^-348, 10^-340, ..., 10^340
static const uint64_t kCachedPowers_F[] = {
RAPIDJSON_UINT64_C2(0xfa8fd5a0, 0x081c0288), RAPIDJSON_UINT64_C2(0xbaaee17f, 0xa23ebf76),
@@ -208,41 +223,40 @@ inline DiyFp GetCachedPowerByIndex(size_t index) {
RAPIDJSON_UINT64_C2(0x80444b5e, 0x7aa7cf85), RAPIDJSON_UINT64_C2(0xbf21e440, 0x03acdd2d),
RAPIDJSON_UINT64_C2(0x8e679c2f, 0x5e44ff8f), RAPIDJSON_UINT64_C2(0xd433179d, 0x9c8cb841),
RAPIDJSON_UINT64_C2(0x9e19db92, 0xb4e31ba9), RAPIDJSON_UINT64_C2(0xeb96bf6e, 0xbadf77d9),
RAPIDJSON_UINT64_C2(0xaf87023b, 0x9bf0ee6b)
};
RAPIDJSON_UINT64_C2(0xaf87023b, 0x9bf0ee6b)};
static const int16_t kCachedPowers_E[] = {
-1220, -1193, -1166, -1140, -1113, -1087, -1060, -1034, -1007, -980,
-954, -927, -901, -874, -847, -821, -794, -768, -741, -715,
-688, -661, -635, -608, -582, -555, -529, -502, -475, -449,
-422, -396, -369, -343, -316, -289, -263, -236, -210, -183,
-157, -130, -103, -77, -50, -24, 3, 30, 56, 83,
109, 136, 162, 189, 216, 242, 269, 295, 322, 348,
375, 402, 428, 455, 481, 508, 534, 561, 588, 614,
641, 667, 694, 720, 747, 774, 800, 827, 853, 880,
907, 933, 960, 986, 1013, 1039, 1066
};
-1220, -1193, -1166, -1140, -1113, -1087, -1060, -1034, -1007, -980, -954, -927, -901,
-874, -847, -821, -794, -768, -741, -715, -688, -661, -635, -608, -582, -555,
-529, -502, -475, -449, -422, -396, -369, -343, -316, -289, -263, -236, -210,
-183, -157, -130, -103, -77, -50, -24, 3, 30, 56, 83, 109, 136,
162, 189, 216, 242, 269, 295, 322, 348, 375, 402, 428, 455, 481,
508, 534, 561, 588, 614, 641, 667, 694, 720, 747, 774, 800, 827,
853, 880, 907, 933, 960, 986, 1013, 1039, 1066};
RAPIDJSON_ASSERT(index < 87);
return DiyFp(kCachedPowers_F[index], kCachedPowers_E[index]);
}
inline DiyFp GetCachedPower(int e, int* K) {
inline DiyFp GetCachedPower(int e, int* K)
{
//int k = static_cast<int>(ceil((-61 - e) * 0.30102999566398114)) + 374;
double dk = (-61 - e) * 0.30102999566398114 + 347; // dk must be positive, so can do ceiling in positive
// int k = static_cast<int>(ceil((-61 - e) * 0.30102999566398114)) + 374;
double dk =
(-61 - e) * 0.30102999566398114 + 347; // dk must be positive, so can do ceiling in positive
int k = static_cast<int>(dk);
if (dk - k > 0.0)
if(dk - k > 0.0)
k++;
unsigned index = static_cast<unsigned>((k >> 3) + 1);
*K = -(-348 + static_cast<int>(index << 3)); // decimal exponent no need lookup table
*K = -(-348 + static_cast<int>(index << 3)); // decimal exponent no need lookup table
return GetCachedPowerByIndex(index);
}
inline DiyFp GetCachedPower10(int exp, int *outExp) {
inline DiyFp GetCachedPower10(int exp, int* outExp)
{
RAPIDJSON_ASSERT(exp >= -348);
unsigned index = static_cast<unsigned>(exp + 348) / 8u;
*outExp = -348 + static_cast<int>(index) * 8;
*outExp = -348 + static_cast<int>(index) * 8;
return GetCachedPowerByIndex(index);
}

View File

@@ -1,5 +1,5 @@
// Tencent is pleased to support the open source community by making RapidJSON available.
//
//
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
//
// Licensed under the MIT License (the "License"); you may not use this file except
@@ -7,9 +7,9 @@
//
// http://opensource.org/licenses/MIT
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
// This is a C++ header-only implementation of Grisu2 algorithm from the publication:
@@ -29,66 +29,126 @@ namespace internal {
#ifdef __GNUC__
RAPIDJSON_DIAG_PUSH
RAPIDJSON_DIAG_OFF(effc++)
RAPIDJSON_DIAG_OFF(array-bounds) // some gcc versions generate wrong warnings https://gcc.gnu.org/bugzilla/show_bug.cgi?id=59124
RAPIDJSON_DIAG_OFF(array - bounds) // some gcc versions generate wrong warnings
// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=59124
#endif
inline void GrisuRound(char* buffer, int len, uint64_t delta, uint64_t rest, uint64_t ten_kappa, uint64_t wp_w) {
while (rest < wp_w && delta - rest >= ten_kappa &&
(rest + ten_kappa < wp_w || /// closer
wp_w - rest > rest + ten_kappa - wp_w)) {
inline void
GrisuRound(char* buffer, int len, uint64_t delta, uint64_t rest, uint64_t ten_kappa, uint64_t wp_w)
{
while(rest < wp_w && delta - rest >= ten_kappa &&
(rest + ten_kappa < wp_w || /// closer
wp_w - rest > rest + ten_kappa - wp_w))
{
buffer[len - 1]--;
rest += ten_kappa;
}
}
inline int CountDecimalDigit32(uint32_t n) {
inline int CountDecimalDigit32(uint32_t n)
{
// Simple pure C++ implementation was faster than __builtin_clz version in this situation.
if (n < 10) return 1;
if (n < 100) return 2;
if (n < 1000) return 3;
if (n < 10000) return 4;
if (n < 100000) return 5;
if (n < 1000000) return 6;
if (n < 10000000) return 7;
if (n < 100000000) return 8;
if(n < 10)
return 1;
if(n < 100)
return 2;
if(n < 1000)
return 3;
if(n < 10000)
return 4;
if(n < 100000)
return 5;
if(n < 1000000)
return 6;
if(n < 10000000)
return 7;
if(n < 100000000)
return 8;
// Will not reach 10 digits in DigitGen()
//if (n < 1000000000) return 9;
//return 10;
// if (n < 1000000000) return 9;
// return 10;
return 9;
}
inline void DigitGen(const DiyFp& W, const DiyFp& Mp, uint64_t delta, char* buffer, int* len, int* K) {
static const uint64_t kPow10[] = { 1ULL, 10ULL, 100ULL, 1000ULL, 10000ULL, 100000ULL, 1000000ULL, 10000000ULL, 100000000ULL,
1000000000ULL, 10000000000ULL, 100000000000ULL, 1000000000000ULL,
10000000000000ULL, 100000000000000ULL, 1000000000000000ULL,
10000000000000000ULL, 100000000000000000ULL, 1000000000000000000ULL,
10000000000000000000ULL };
inline void
DigitGen(const DiyFp& W, const DiyFp& Mp, uint64_t delta, char* buffer, int* len, int* K)
{
static const uint64_t kPow10[] = {1ULL,
10ULL,
100ULL,
1000ULL,
10000ULL,
100000ULL,
1000000ULL,
10000000ULL,
100000000ULL,
1000000000ULL,
10000000000ULL,
100000000000ULL,
1000000000000ULL,
10000000000000ULL,
100000000000000ULL,
1000000000000000ULL,
10000000000000000ULL,
100000000000000000ULL,
1000000000000000000ULL,
10000000000000000000ULL};
const DiyFp one(uint64_t(1) << -Mp.e, Mp.e);
const DiyFp wp_w = Mp - W;
uint32_t p1 = static_cast<uint32_t>(Mp.f >> -one.e);
uint64_t p2 = Mp.f & (one.f - 1);
int kappa = CountDecimalDigit32(p1); // kappa in [0, 9]
*len = 0;
uint32_t p1 = static_cast<uint32_t>(Mp.f >> -one.e);
uint64_t p2 = Mp.f & (one.f - 1);
int kappa = CountDecimalDigit32(p1); // kappa in [0, 9]
*len = 0;
while (kappa > 0) {
while(kappa > 0)
{
uint32_t d = 0;
switch (kappa) {
case 9: d = p1 / 100000000; p1 %= 100000000; break;
case 8: d = p1 / 10000000; p1 %= 10000000; break;
case 7: d = p1 / 1000000; p1 %= 1000000; break;
case 6: d = p1 / 100000; p1 %= 100000; break;
case 5: d = p1 / 10000; p1 %= 10000; break;
case 4: d = p1 / 1000; p1 %= 1000; break;
case 3: d = p1 / 100; p1 %= 100; break;
case 2: d = p1 / 10; p1 %= 10; break;
case 1: d = p1; p1 = 0; break;
default:;
switch(kappa)
{
case 9:
d = p1 / 100000000;
p1 %= 100000000;
break;
case 8:
d = p1 / 10000000;
p1 %= 10000000;
break;
case 7:
d = p1 / 1000000;
p1 %= 1000000;
break;
case 6:
d = p1 / 100000;
p1 %= 100000;
break;
case 5:
d = p1 / 10000;
p1 %= 10000;
break;
case 4:
d = p1 / 1000;
p1 %= 1000;
break;
case 3:
d = p1 / 100;
p1 %= 100;
break;
case 2:
d = p1 / 10;
p1 %= 10;
break;
case 1:
d = p1;
p1 = 0;
break;
default:;
}
if (d || *len)
if(d || *len)
buffer[(*len)++] = static_cast<char>('0' + static_cast<char>(d));
kappa--;
uint64_t tmp = (static_cast<uint64_t>(p1) << -one.e) + p2;
if (tmp <= delta) {
if(tmp <= delta)
{
*K += kappa;
GrisuRound(buffer, *len, delta, tmp, kPow10[kappa] << -one.e, wp_w.f);
return;
@@ -96,15 +156,17 @@ inline void DigitGen(const DiyFp& W, const DiyFp& Mp, uint64_t delta, char* buff
}
// kappa = 0
for (;;) {
for(;;)
{
p2 *= 10;
delta *= 10;
char d = static_cast<char>(p2 >> -one.e);
if (d || *len)
if(d || *len)
buffer[(*len)++] = static_cast<char>('0' + d);
p2 &= one.f - 1;
kappa--;
if (p2 < delta) {
if(p2 < delta)
{
*K += kappa;
int index = -kappa;
GrisuRound(buffer, *len, delta, p2, one.f, wp_w.f * (index < 20 ? kPow10[index] : 0));
@@ -113,37 +175,42 @@ inline void DigitGen(const DiyFp& W, const DiyFp& Mp, uint64_t delta, char* buff
}
}
inline void Grisu2(double value, char* buffer, int* length, int* K) {
inline void Grisu2(double value, char* buffer, int* length, int* K)
{
const DiyFp v(value);
DiyFp w_m, w_p;
v.NormalizedBoundaries(&w_m, &w_p);
const DiyFp c_mk = GetCachedPower(w_p.e, K);
const DiyFp W = v.Normalize() * c_mk;
DiyFp Wp = w_p * c_mk;
DiyFp Wm = w_m * c_mk;
const DiyFp W = v.Normalize() * c_mk;
DiyFp Wp = w_p * c_mk;
DiyFp Wm = w_m * c_mk;
Wm.f++;
Wp.f--;
DigitGen(W, Wp, Wp.f - Wm.f, buffer, length, K);
}
inline char* WriteExponent(int K, char* buffer) {
if (K < 0) {
inline char* WriteExponent(int K, char* buffer)
{
if(K < 0)
{
*buffer++ = '-';
K = -K;
K = -K;
}
if (K >= 100) {
if(K >= 100)
{
*buffer++ = static_cast<char>('0' + static_cast<char>(K / 100));
K %= 100;
const char* d = GetDigitsLut() + K * 2;
*buffer++ = d[0];
*buffer++ = d[1];
*buffer++ = d[0];
*buffer++ = d[1];
}
else if (K >= 10) {
else if(K >= 10)
{
const char* d = GetDigitsLut() + K * 2;
*buffer++ = d[0];
*buffer++ = d[1];
*buffer++ = d[0];
*buffer++ = d[1];
}
else
*buffer++ = static_cast<char>('0' + static_cast<char>(K));
@@ -151,87 +218,100 @@ inline char* WriteExponent(int K, char* buffer) {
return buffer;
}
inline char* Prettify(char* buffer, int length, int k, int maxDecimalPlaces) {
const int kk = length + k; // 10^(kk-1) <= v < 10^kk
inline char* Prettify(char* buffer, int length, int k, int maxDecimalPlaces)
{
const int kk = length + k; // 10^(kk-1) <= v < 10^kk
if (0 <= k && kk <= 21) {
if(0 <= k && kk <= 21)
{
// 1234e7 -> 12340000000
for (int i = length; i < kk; i++)
for(int i = length; i < kk; i++)
buffer[i] = '0';
buffer[kk] = '.';
buffer[kk] = '.';
buffer[kk + 1] = '0';
return &buffer[kk + 2];
}
else if (0 < kk && kk <= 21) {
else if(0 < kk && kk <= 21)
{
// 1234e-2 -> 12.34
std::memmove(&buffer[kk + 1], &buffer[kk], static_cast<size_t>(length - kk));
buffer[kk] = '.';
if (0 > k + maxDecimalPlaces) {
if(0 > k + maxDecimalPlaces)
{
// When maxDecimalPlaces = 2, 1.2345 -> 1.23, 1.102 -> 1.1
// Remove extra trailing zeros (at least one) after truncation.
for (int i = kk + maxDecimalPlaces; i > kk + 1; i--)
if (buffer[i] != '0')
for(int i = kk + maxDecimalPlaces; i > kk + 1; i--)
if(buffer[i] != '0')
return &buffer[i + 1];
return &buffer[kk + 2]; // Reserve one zero
}
else
return &buffer[length + 1];
}
else if (-6 < kk && kk <= 0) {
else if(-6 < kk && kk <= 0)
{
// 1234e-6 -> 0.001234
const int offset = 2 - kk;
std::memmove(&buffer[offset], &buffer[0], static_cast<size_t>(length));
buffer[0] = '0';
buffer[1] = '.';
for (int i = 2; i < offset; i++)
for(int i = 2; i < offset; i++)
buffer[i] = '0';
if (length - kk > maxDecimalPlaces) {
if(length - kk > maxDecimalPlaces)
{
// When maxDecimalPlaces = 2, 0.123 -> 0.12, 0.102 -> 0.1
// Remove extra trailing zeros (at least one) after truncation.
for (int i = maxDecimalPlaces + 1; i > 2; i--)
if (buffer[i] != '0')
for(int i = maxDecimalPlaces + 1; i > 2; i--)
if(buffer[i] != '0')
return &buffer[i + 1];
return &buffer[3]; // Reserve one zero
}
else
return &buffer[length + offset];
}
else if (kk < -maxDecimalPlaces) {
else if(kk < -maxDecimalPlaces)
{
// Truncate to zero
buffer[0] = '0';
buffer[1] = '.';
buffer[2] = '0';
return &buffer[3];
}
else if (length == 1) {
else if(length == 1)
{
// 1e30
buffer[1] = 'e';
return WriteExponent(kk - 1, &buffer[2]);
}
else {
else
{
// 1234e30 -> 1.234e33
std::memmove(&buffer[2], &buffer[1], static_cast<size_t>(length - 1));
buffer[1] = '.';
buffer[1] = '.';
buffer[length + 1] = 'e';
return WriteExponent(kk - 1, &buffer[0 + length + 2]);
}
}
inline char* dtoa(double value, char* buffer, int maxDecimalPlaces = 324) {
inline char* dtoa(double value, char* buffer, int maxDecimalPlaces = 324)
{
RAPIDJSON_ASSERT(maxDecimalPlaces >= 1);
Double d(value);
if (d.IsZero()) {
if (d.Sign())
*buffer++ = '-'; // -0.0, Issue #289
if(d.IsZero())
{
if(d.Sign())
*buffer++ = '-'; // -0.0, Issue #289
buffer[0] = '0';
buffer[1] = '.';
buffer[2] = '0';
return &buffer[3];
}
else {
if (value < 0) {
else
{
if(value < 0)
{
*buffer++ = '-';
value = -value;
value = -value;
}
int length, K;
Grisu2(value, buffer, &length, &K);

View File

@@ -1,5 +1,5 @@
// Tencent is pleased to support the open source community by making RapidJSON available.
//
//
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
//
// Licensed under the MIT License (the "License"); you may not use this file except
@@ -7,9 +7,9 @@
//
// http://opensource.org/licenses/MIT
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
#ifndef RAPIDJSON_IEEE754_
@@ -20,8 +20,9 @@
RAPIDJSON_NAMESPACE_BEGIN
namespace internal {
class Double {
public:
class Double
{
public:
Double() {}
Double(double d) : d_(d) {}
Double(uint64_t u) : u_(u) {}
@@ -29,14 +30,18 @@ public:
double Value() const { return d_; }
uint64_t Uint64Value() const { return u_; }
double NextPositiveDouble() const {
double NextPositiveDouble() const
{
RAPIDJSON_ASSERT(!Sign());
return Double(u_ + 1).Value();
}
bool Sign() const { return (u_ & kSignMask) != 0; }
uint64_t Significand() const { return u_ & kSignificandMask; }
int Exponent() const { return static_cast<int>(((u_ & kExponentMask) >> kSignificandSize) - kExponentBias); }
int Exponent() const
{
return static_cast<int>(((u_ & kExponentMask) >> kSignificandSize) - kExponentBias);
}
bool IsNan() const { return (u_ & kExponentMask) == kExponentMask && Significand() != 0; }
bool IsInf() const { return (u_ & kExponentMask) == kExponentMask && Significand() == 0; }
@@ -44,29 +49,37 @@ public:
bool IsNormal() const { return (u_ & kExponentMask) != 0 || Significand() == 0; }
bool IsZero() const { return (u_ & (kExponentMask | kSignificandMask)) == 0; }
uint64_t IntegerSignificand() const { return IsNormal() ? Significand() | kHiddenBit : Significand(); }
int IntegerExponent() const { return (IsNormal() ? Exponent() : kDenormalExponent) - kSignificandSize; }
uint64_t IntegerSignificand() const
{
return IsNormal() ? Significand() | kHiddenBit : Significand();
}
int IntegerExponent() const
{
return (IsNormal() ? Exponent() : kDenormalExponent) - kSignificandSize;
}
uint64_t ToBias() const { return (u_ & kSignMask) ? ~u_ + 1 : u_ | kSignMask; }
static int EffectiveSignificandSize(int order) {
if (order >= -1021)
static int EffectiveSignificandSize(int order)
{
if(order >= -1021)
return 53;
else if (order <= -1074)
else if(order <= -1074)
return 0;
else
return order + 1074;
}
private:
static const int kSignificandSize = 52;
static const int kExponentBias = 0x3FF;
static const int kDenormalExponent = 1 - kExponentBias;
static const uint64_t kSignMask = RAPIDJSON_UINT64_C2(0x80000000, 0x00000000);
static const uint64_t kExponentMask = RAPIDJSON_UINT64_C2(0x7FF00000, 0x00000000);
private:
static const int kSignificandSize = 52;
static const int kExponentBias = 0x3FF;
static const int kDenormalExponent = 1 - kExponentBias;
static const uint64_t kSignMask = RAPIDJSON_UINT64_C2(0x80000000, 0x00000000);
static const uint64_t kExponentMask = RAPIDJSON_UINT64_C2(0x7FF00000, 0x00000000);
static const uint64_t kSignificandMask = RAPIDJSON_UINT64_C2(0x000FFFFF, 0xFFFFFFFF);
static const uint64_t kHiddenBit = RAPIDJSON_UINT64_C2(0x00100000, 0x00000000);
static const uint64_t kHiddenBit = RAPIDJSON_UINT64_C2(0x00100000, 0x00000000);
union {
union
{
double d_;
uint64_t u_;
};

View File

@@ -20,40 +20,45 @@
RAPIDJSON_NAMESPACE_BEGIN
namespace internal {
inline const char* GetDigitsLut() {
inline const char* GetDigitsLut()
{
static const char cDigitsLut[200] = {
'0','0','0','1','0','2','0','3','0','4','0','5','0','6','0','7','0','8','0','9',
'1','0','1','1','1','2','1','3','1','4','1','5','1','6','1','7','1','8','1','9',
'2','0','2','1','2','2','2','3','2','4','2','5','2','6','2','7','2','8','2','9',
'3','0','3','1','3','2','3','3','3','4','3','5','3','6','3','7','3','8','3','9',
'4','0','4','1','4','2','4','3','4','4','4','5','4','6','4','7','4','8','4','9',
'5','0','5','1','5','2','5','3','5','4','5','5','5','6','5','7','5','8','5','9',
'6','0','6','1','6','2','6','3','6','4','6','5','6','6','6','7','6','8','6','9',
'7','0','7','1','7','2','7','3','7','4','7','5','7','6','7','7','7','8','7','9',
'8','0','8','1','8','2','8','3','8','4','8','5','8','6','8','7','8','8','8','9',
'9','0','9','1','9','2','9','3','9','4','9','5','9','6','9','7','9','8','9','9'
};
'0', '0', '0', '1', '0', '2', '0', '3', '0', '4', '0', '5', '0', '6', '0', '7', '0',
'8', '0', '9', '1', '0', '1', '1', '1', '2', '1', '3', '1', '4', '1', '5', '1', '6',
'1', '7', '1', '8', '1', '9', '2', '0', '2', '1', '2', '2', '2', '3', '2', '4', '2',
'5', '2', '6', '2', '7', '2', '8', '2', '9', '3', '0', '3', '1', '3', '2', '3', '3',
'3', '4', '3', '5', '3', '6', '3', '7', '3', '8', '3', '9', '4', '0', '4', '1', '4',
'2', '4', '3', '4', '4', '4', '5', '4', '6', '4', '7', '4', '8', '4', '9', '5', '0',
'5', '1', '5', '2', '5', '3', '5', '4', '5', '5', '5', '6', '5', '7', '5', '8', '5',
'9', '6', '0', '6', '1', '6', '2', '6', '3', '6', '4', '6', '5', '6', '6', '6', '7',
'6', '8', '6', '9', '7', '0', '7', '1', '7', '2', '7', '3', '7', '4', '7', '5', '7',
'6', '7', '7', '7', '8', '7', '9', '8', '0', '8', '1', '8', '2', '8', '3', '8', '4',
'8', '5', '8', '6', '8', '7', '8', '8', '8', '9', '9', '0', '9', '1', '9', '2', '9',
'3', '9', '4', '9', '5', '9', '6', '9', '7', '9', '8', '9', '9'};
return cDigitsLut;
}
inline char* u32toa(uint32_t value, char* buffer) {
inline char* u32toa(uint32_t value, char* buffer)
{
RAPIDJSON_ASSERT(buffer != 0);
const char* cDigitsLut = GetDigitsLut();
if (value < 10000) {
if(value < 10000)
{
const uint32_t d1 = (value / 100) << 1;
const uint32_t d2 = (value % 100) << 1;
if (value >= 1000)
if(value >= 1000)
*buffer++ = cDigitsLut[d1];
if (value >= 100)
if(value >= 100)
*buffer++ = cDigitsLut[d1 + 1];
if (value >= 10)
if(value >= 10)
*buffer++ = cDigitsLut[d2];
*buffer++ = cDigitsLut[d2 + 1];
}
else if (value < 100000000) {
else if(value < 100000000)
{
// value = bbbbcccc
const uint32_t b = value / 10000;
const uint32_t c = value % 10000;
@@ -64,11 +69,11 @@ inline char* u32toa(uint32_t value, char* buffer) {
const uint32_t d3 = (c / 100) << 1;
const uint32_t d4 = (c % 100) << 1;
if (value >= 10000000)
if(value >= 10000000)
*buffer++ = cDigitsLut[d1];
if (value >= 1000000)
if(value >= 1000000)
*buffer++ = cDigitsLut[d1 + 1];
if (value >= 100000)
if(value >= 100000)
*buffer++ = cDigitsLut[d2];
*buffer++ = cDigitsLut[d2 + 1];
@@ -77,16 +82,18 @@ inline char* u32toa(uint32_t value, char* buffer) {
*buffer++ = cDigitsLut[d4];
*buffer++ = cDigitsLut[d4 + 1];
}
else {
else
{
// value = aabbbbcccc in decimal
const uint32_t a = value / 100000000; // 1 to 42
value %= 100000000;
if (a >= 10) {
if(a >= 10)
{
const unsigned i = a << 1;
*buffer++ = cDigitsLut[i];
*buffer++ = cDigitsLut[i + 1];
*buffer++ = cDigitsLut[i];
*buffer++ = cDigitsLut[i + 1];
}
else
*buffer++ = static_cast<char>('0' + static_cast<char>(a));
@@ -112,45 +119,51 @@ inline char* u32toa(uint32_t value, char* buffer) {
return buffer;
}
inline char* i32toa(int32_t value, char* buffer) {
inline char* i32toa(int32_t value, char* buffer)
{
RAPIDJSON_ASSERT(buffer != 0);
uint32_t u = static_cast<uint32_t>(value);
if (value < 0) {
if(value < 0)
{
*buffer++ = '-';
u = ~u + 1;
u = ~u + 1;
}
return u32toa(u, buffer);
}
inline char* u64toa(uint64_t value, char* buffer) {
inline char* u64toa(uint64_t value, char* buffer)
{
RAPIDJSON_ASSERT(buffer != 0);
const char* cDigitsLut = GetDigitsLut();
const uint64_t kTen8 = 100000000;
const uint64_t kTen9 = kTen8 * 10;
const uint64_t kTen10 = kTen8 * 100;
const uint64_t kTen11 = kTen8 * 1000;
const uint64_t kTen12 = kTen8 * 10000;
const uint64_t kTen13 = kTen8 * 100000;
const uint64_t kTen14 = kTen8 * 1000000;
const uint64_t kTen15 = kTen8 * 10000000;
const uint64_t kTen16 = kTen8 * kTen8;
const uint64_t kTen8 = 100000000;
const uint64_t kTen9 = kTen8 * 10;
const uint64_t kTen10 = kTen8 * 100;
const uint64_t kTen11 = kTen8 * 1000;
const uint64_t kTen12 = kTen8 * 10000;
const uint64_t kTen13 = kTen8 * 100000;
const uint64_t kTen14 = kTen8 * 1000000;
const uint64_t kTen15 = kTen8 * 10000000;
const uint64_t kTen16 = kTen8 * kTen8;
if (value < kTen8) {
if(value < kTen8)
{
uint32_t v = static_cast<uint32_t>(value);
if (v < 10000) {
if(v < 10000)
{
const uint32_t d1 = (v / 100) << 1;
const uint32_t d2 = (v % 100) << 1;
if (v >= 1000)
if(v >= 1000)
*buffer++ = cDigitsLut[d1];
if (v >= 100)
if(v >= 100)
*buffer++ = cDigitsLut[d1 + 1];
if (v >= 10)
if(v >= 10)
*buffer++ = cDigitsLut[d2];
*buffer++ = cDigitsLut[d2 + 1];
}
else {
else
{
// value = bbbbcccc
const uint32_t b = v / 10000;
const uint32_t c = v % 10000;
@@ -161,11 +174,11 @@ inline char* u64toa(uint64_t value, char* buffer) {
const uint32_t d3 = (c / 100) << 1;
const uint32_t d4 = (c % 100) << 1;
if (value >= 10000000)
if(value >= 10000000)
*buffer++ = cDigitsLut[d1];
if (value >= 1000000)
if(value >= 1000000)
*buffer++ = cDigitsLut[d1 + 1];
if (value >= 100000)
if(value >= 100000)
*buffer++ = cDigitsLut[d2];
*buffer++ = cDigitsLut[d2 + 1];
@@ -175,7 +188,8 @@ inline char* u64toa(uint64_t value, char* buffer) {
*buffer++ = cDigitsLut[d4 + 1];
}
}
else if (value < kTen16) {
else if(value < kTen16)
{
const uint32_t v0 = static_cast<uint32_t>(value / kTen8);
const uint32_t v1 = static_cast<uint32_t>(value % kTen8);
@@ -197,19 +211,19 @@ inline char* u64toa(uint64_t value, char* buffer) {
const uint32_t d7 = (c1 / 100) << 1;
const uint32_t d8 = (c1 % 100) << 1;
if (value >= kTen15)
if(value >= kTen15)
*buffer++ = cDigitsLut[d1];
if (value >= kTen14)
if(value >= kTen14)
*buffer++ = cDigitsLut[d1 + 1];
if (value >= kTen13)
if(value >= kTen13)
*buffer++ = cDigitsLut[d2];
if (value >= kTen12)
if(value >= kTen12)
*buffer++ = cDigitsLut[d2 + 1];
if (value >= kTen11)
if(value >= kTen11)
*buffer++ = cDigitsLut[d3];
if (value >= kTen10)
if(value >= kTen10)
*buffer++ = cDigitsLut[d3 + 1];
if (value >= kTen9)
if(value >= kTen9)
*buffer++ = cDigitsLut[d4];
*buffer++ = cDigitsLut[d4 + 1];
@@ -222,31 +236,35 @@ inline char* u64toa(uint64_t value, char* buffer) {
*buffer++ = cDigitsLut[d8];
*buffer++ = cDigitsLut[d8 + 1];
}
else {
else
{
const uint32_t a = static_cast<uint32_t>(value / kTen16); // 1 to 1844
value %= kTen16;
if (a < 10)
if(a < 10)
*buffer++ = static_cast<char>('0' + static_cast<char>(a));
else if (a < 100) {
else if(a < 100)
{
const uint32_t i = a << 1;
*buffer++ = cDigitsLut[i];
*buffer++ = cDigitsLut[i + 1];
*buffer++ = cDigitsLut[i];
*buffer++ = cDigitsLut[i + 1];
}
else if (a < 1000) {
else if(a < 1000)
{
*buffer++ = static_cast<char>('0' + static_cast<char>(a / 100));
const uint32_t i = (a % 100) << 1;
*buffer++ = cDigitsLut[i];
*buffer++ = cDigitsLut[i + 1];
*buffer++ = cDigitsLut[i];
*buffer++ = cDigitsLut[i + 1];
}
else {
else
{
const uint32_t i = (a / 100) << 1;
const uint32_t j = (a % 100) << 1;
*buffer++ = cDigitsLut[i];
*buffer++ = cDigitsLut[i + 1];
*buffer++ = cDigitsLut[j];
*buffer++ = cDigitsLut[j + 1];
*buffer++ = cDigitsLut[i];
*buffer++ = cDigitsLut[i + 1];
*buffer++ = cDigitsLut[j];
*buffer++ = cDigitsLut[j + 1];
}
const uint32_t v0 = static_cast<uint32_t>(value / kTen8);
@@ -291,12 +309,14 @@ inline char* u64toa(uint64_t value, char* buffer) {
return buffer;
}
inline char* i64toa(int64_t value, char* buffer) {
inline char* i64toa(int64_t value, char* buffer)
{
RAPIDJSON_ASSERT(buffer != 0);
uint64_t u = static_cast<uint64_t>(value);
if (value < 0) {
if(value < 0)
{
*buffer++ = '-';
u = ~u + 1;
u = ~u + 1;
}
return u64toa(u, buffer);

View File

@@ -1,5 +1,5 @@
// Tencent is pleased to support the open source community by making RapidJSON available.
//
//
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
//
// Licensed under the MIT License (the "License"); you may not use this file except
@@ -7,9 +7,9 @@
//
// http://opensource.org/licenses/MIT
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
#ifndef RAPIDJSON_INTERNAL_META_H_
@@ -36,140 +36,253 @@ RAPIDJSON_NAMESPACE_BEGIN
namespace internal {
// Helper to wrap/convert arbitrary types to void, useful for arbitrary type matching
template <typename T> struct Void { typedef void Type; };
template <typename T>
struct Void
{
typedef void Type;
};
///////////////////////////////////////////////////////////////////////////////
// BoolType, TrueType, FalseType
//
template <bool Cond> struct BoolType {
template <bool Cond>
struct BoolType
{
static const bool Value = Cond;
typedef BoolType Type;
};
typedef BoolType<true> TrueType;
typedef BoolType<false> FalseType;
///////////////////////////////////////////////////////////////////////////////
// SelectIf, BoolExpr, NotExpr, AndExpr, OrExpr
//
template <bool C> struct SelectIfImpl { template <typename T1, typename T2> struct Apply { typedef T1 Type; }; };
template <> struct SelectIfImpl<false> { template <typename T1, typename T2> struct Apply { typedef T2 Type; }; };
template <bool C, typename T1, typename T2> struct SelectIfCond : SelectIfImpl<C>::template Apply<T1,T2> {};
template <typename C, typename T1, typename T2> struct SelectIf : SelectIfCond<C::Value, T1, T2> {};
template <bool C>
struct SelectIfImpl
{
template <typename T1, typename T2>
struct Apply
{
typedef T1 Type;
};
};
template <>
struct SelectIfImpl<false>
{
template <typename T1, typename T2>
struct Apply
{
typedef T2 Type;
};
};
template <bool C, typename T1, typename T2>
struct SelectIfCond : SelectIfImpl<C>::template Apply<T1, T2>
{
};
template <typename C, typename T1, typename T2>
struct SelectIf : SelectIfCond<C::Value, T1, T2>
{
};
template <bool Cond1, bool Cond2> struct AndExprCond : FalseType {};
template <> struct AndExprCond<true, true> : TrueType {};
template <bool Cond1, bool Cond2> struct OrExprCond : TrueType {};
template <> struct OrExprCond<false, false> : FalseType {};
template <typename C> struct BoolExpr : SelectIf<C,TrueType,FalseType>::Type {};
template <typename C> struct NotExpr : SelectIf<C,FalseType,TrueType>::Type {};
template <typename C1, typename C2> struct AndExpr : AndExprCond<C1::Value, C2::Value>::Type {};
template <typename C1, typename C2> struct OrExpr : OrExprCond<C1::Value, C2::Value>::Type {};
template <bool Cond1, bool Cond2>
struct AndExprCond : FalseType
{
};
template <>
struct AndExprCond<true, true> : TrueType
{
};
template <bool Cond1, bool Cond2>
struct OrExprCond : TrueType
{
};
template <>
struct OrExprCond<false, false> : FalseType
{
};
template <typename C>
struct BoolExpr : SelectIf<C, TrueType, FalseType>::Type
{
};
template <typename C>
struct NotExpr : SelectIf<C, FalseType, TrueType>::Type
{
};
template <typename C1, typename C2>
struct AndExpr : AndExprCond<C1::Value, C2::Value>::Type
{
};
template <typename C1, typename C2>
struct OrExpr : OrExprCond<C1::Value, C2::Value>::Type
{
};
///////////////////////////////////////////////////////////////////////////////
// AddConst, MaybeAddConst, RemoveConst
template <typename T> struct AddConst { typedef const T Type; };
template <bool Constify, typename T> struct MaybeAddConst : SelectIfCond<Constify, const T, T> {};
template <typename T> struct RemoveConst { typedef T Type; };
template <typename T> struct RemoveConst<const T> { typedef T Type; };
template <typename T>
struct AddConst
{
typedef const T Type;
};
template <bool Constify, typename T>
struct MaybeAddConst : SelectIfCond<Constify, const T, T>
{
};
template <typename T>
struct RemoveConst
{
typedef T Type;
};
template <typename T>
struct RemoveConst<const T>
{
typedef T Type;
};
///////////////////////////////////////////////////////////////////////////////
// IsSame, IsConst, IsMoreConst, IsPointer
//
template <typename T, typename U> struct IsSame : FalseType {};
template <typename T> struct IsSame<T, T> : TrueType {};
template <typename T, typename U>
struct IsSame : FalseType
{
};
template <typename T>
struct IsSame<T, T> : TrueType
{
};
template <typename T> struct IsConst : FalseType {};
template <typename T> struct IsConst<const T> : TrueType {};
template <typename T>
struct IsConst : FalseType
{
};
template <typename T>
struct IsConst<const T> : TrueType
{
};
template <typename CT, typename T>
struct IsMoreConst
: AndExpr<IsSame<typename RemoveConst<CT>::Type, typename RemoveConst<T>::Type>,
BoolType<IsConst<CT>::Value >= IsConst<T>::Value> >::Type {};
struct IsMoreConst : AndExpr<IsSame<typename RemoveConst<CT>::Type, typename RemoveConst<T>::Type>,
BoolType<IsConst<CT>::Value >= IsConst<T>::Value>>::Type
{
};
template <typename T> struct IsPointer : FalseType {};
template <typename T> struct IsPointer<T*> : TrueType {};
template <typename T>
struct IsPointer : FalseType
{
};
template <typename T>
struct IsPointer<T*> : TrueType
{
};
///////////////////////////////////////////////////////////////////////////////
// IsBaseOf
//
#if RAPIDJSON_HAS_CXX11_TYPETRAITS
template <typename B, typename D> struct IsBaseOf
: BoolType< ::std::is_base_of<B,D>::value> {};
template <typename B, typename D>
struct IsBaseOf : BoolType<::std::is_base_of<B, D>::value>
{
};
#else // simplified version adopted from Boost
template<typename B, typename D> struct IsBaseOfImpl {
template <typename B, typename D>
struct IsBaseOfImpl
{
RAPIDJSON_STATIC_ASSERT(sizeof(B) != 0);
RAPIDJSON_STATIC_ASSERT(sizeof(D) != 0);
typedef char (&Yes)[1];
typedef char (&No) [2];
typedef char (&No)[2];
template <typename T>
static Yes Check(const D*, T);
static No Check(const B*, int);
static No Check(const B*, int);
struct Host {
struct Host
{
operator const B*() const;
operator const D*();
};
enum { Value = (sizeof(Check(Host(), 0)) == sizeof(Yes)) };
enum
{
Value = (sizeof(Check(Host(), 0)) == sizeof(Yes))
};
};
template <typename B, typename D> struct IsBaseOf
: OrExpr<IsSame<B, D>, BoolExpr<IsBaseOfImpl<B, D> > >::Type {};
template <typename B, typename D>
struct IsBaseOf : OrExpr<IsSame<B, D>, BoolExpr<IsBaseOfImpl<B, D>>>::Type
{
};
#endif // RAPIDJSON_HAS_CXX11_TYPETRAITS
//////////////////////////////////////////////////////////////////////////
// EnableIf / DisableIf
//
template <bool Condition, typename T = void> struct EnableIfCond { typedef T Type; };
template <typename T> struct EnableIfCond<false, T> { /* empty */ };
template <bool Condition, typename T = void>
struct EnableIfCond
{
typedef T Type;
};
template <typename T>
struct EnableIfCond<false, T>
{ /* empty */
};
template <bool Condition, typename T = void> struct DisableIfCond { typedef T Type; };
template <typename T> struct DisableIfCond<true, T> { /* empty */ };
template <bool Condition, typename T = void>
struct DisableIfCond
{
typedef T Type;
};
template <typename T>
struct DisableIfCond<true, T>
{ /* empty */
};
template <typename Condition, typename T = void>
struct EnableIf : EnableIfCond<Condition::Value, T> {};
struct EnableIf : EnableIfCond<Condition::Value, T>
{
};
template <typename Condition, typename T = void>
struct DisableIf : DisableIfCond<Condition::Value, T> {};
struct DisableIf : DisableIfCond<Condition::Value, T>
{
};
// SFINAE helpers
struct SfinaeTag {};
template <typename T> struct RemoveSfinaeTag;
template <typename T> struct RemoveSfinaeTag<SfinaeTag&(*)(T)> { typedef T Type; };
struct SfinaeTag
{
};
template <typename T>
struct RemoveSfinaeTag;
template <typename T>
struct RemoveSfinaeTag<SfinaeTag& (*)(T)>
{
typedef T Type;
};
#define RAPIDJSON_REMOVEFPTR_(type) \
typename ::RAPIDJSON_NAMESPACE::internal::RemoveSfinaeTag \
< ::RAPIDJSON_NAMESPACE::internal::SfinaeTag&(*) type>::Type
#define RAPIDJSON_REMOVEFPTR_(type) \
typename ::RAPIDJSON_NAMESPACE::internal::RemoveSfinaeTag< \
::RAPIDJSON_NAMESPACE::internal::SfinaeTag&(*)type>::Type
#define RAPIDJSON_ENABLEIF(cond) \
typename ::RAPIDJSON_NAMESPACE::internal::EnableIf \
<RAPIDJSON_REMOVEFPTR_(cond)>::Type * = NULL
typename ::RAPIDJSON_NAMESPACE::internal::EnableIf<RAPIDJSON_REMOVEFPTR_(cond)>::Type* = NULL
#define RAPIDJSON_DISABLEIF(cond) \
typename ::RAPIDJSON_NAMESPACE::internal::DisableIf \
<RAPIDJSON_REMOVEFPTR_(cond)>::Type * = NULL
typename ::RAPIDJSON_NAMESPACE::internal::DisableIf<RAPIDJSON_REMOVEFPTR_(cond)>::Type* = NULL
#define RAPIDJSON_ENABLEIF_RETURN(cond,returntype) \
typename ::RAPIDJSON_NAMESPACE::internal::EnableIf \
<RAPIDJSON_REMOVEFPTR_(cond), \
RAPIDJSON_REMOVEFPTR_(returntype)>::Type
#define RAPIDJSON_ENABLEIF_RETURN(cond, returntype) \
typename ::RAPIDJSON_NAMESPACE::internal::EnableIf<RAPIDJSON_REMOVEFPTR_(cond), \
RAPIDJSON_REMOVEFPTR_(returntype)>::Type
#define RAPIDJSON_DISABLEIF_RETURN(cond,returntype) \
typename ::RAPIDJSON_NAMESPACE::internal::DisableIf \
<RAPIDJSON_REMOVEFPTR_(cond), \
RAPIDJSON_REMOVEFPTR_(returntype)>::Type
#define RAPIDJSON_DISABLEIF_RETURN(cond, returntype) \
typename ::RAPIDJSON_NAMESPACE::internal::DisableIf<RAPIDJSON_REMOVEFPTR_(cond), \
RAPIDJSON_REMOVEFPTR_(returntype)>::Type
} // namespace internal
RAPIDJSON_NAMESPACE_END

View File

@@ -1,5 +1,5 @@
// Tencent is pleased to support the open source community by making RapidJSON available.
//
//
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
//
// Licensed under the MIT License (the "License"); you may not use this file except
@@ -7,9 +7,9 @@
//
// http://opensource.org/licenses/MIT
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
#ifndef RAPIDJSON_POW10_
@@ -25,26 +25,39 @@ namespace internal {
\param n non-negative exponent. Must <= 308.
\return 10.0^n
*/
inline double Pow10(int n) {
static const double e[] = { // 1e-0...1e308: 309 * 8 bytes = 2472 bytes
1e+0,
1e+1, 1e+2, 1e+3, 1e+4, 1e+5, 1e+6, 1e+7, 1e+8, 1e+9, 1e+10, 1e+11, 1e+12, 1e+13, 1e+14, 1e+15, 1e+16, 1e+17, 1e+18, 1e+19, 1e+20,
1e+21, 1e+22, 1e+23, 1e+24, 1e+25, 1e+26, 1e+27, 1e+28, 1e+29, 1e+30, 1e+31, 1e+32, 1e+33, 1e+34, 1e+35, 1e+36, 1e+37, 1e+38, 1e+39, 1e+40,
1e+41, 1e+42, 1e+43, 1e+44, 1e+45, 1e+46, 1e+47, 1e+48, 1e+49, 1e+50, 1e+51, 1e+52, 1e+53, 1e+54, 1e+55, 1e+56, 1e+57, 1e+58, 1e+59, 1e+60,
1e+61, 1e+62, 1e+63, 1e+64, 1e+65, 1e+66, 1e+67, 1e+68, 1e+69, 1e+70, 1e+71, 1e+72, 1e+73, 1e+74, 1e+75, 1e+76, 1e+77, 1e+78, 1e+79, 1e+80,
1e+81, 1e+82, 1e+83, 1e+84, 1e+85, 1e+86, 1e+87, 1e+88, 1e+89, 1e+90, 1e+91, 1e+92, 1e+93, 1e+94, 1e+95, 1e+96, 1e+97, 1e+98, 1e+99, 1e+100,
1e+101,1e+102,1e+103,1e+104,1e+105,1e+106,1e+107,1e+108,1e+109,1e+110,1e+111,1e+112,1e+113,1e+114,1e+115,1e+116,1e+117,1e+118,1e+119,1e+120,
1e+121,1e+122,1e+123,1e+124,1e+125,1e+126,1e+127,1e+128,1e+129,1e+130,1e+131,1e+132,1e+133,1e+134,1e+135,1e+136,1e+137,1e+138,1e+139,1e+140,
1e+141,1e+142,1e+143,1e+144,1e+145,1e+146,1e+147,1e+148,1e+149,1e+150,1e+151,1e+152,1e+153,1e+154,1e+155,1e+156,1e+157,1e+158,1e+159,1e+160,
1e+161,1e+162,1e+163,1e+164,1e+165,1e+166,1e+167,1e+168,1e+169,1e+170,1e+171,1e+172,1e+173,1e+174,1e+175,1e+176,1e+177,1e+178,1e+179,1e+180,
1e+181,1e+182,1e+183,1e+184,1e+185,1e+186,1e+187,1e+188,1e+189,1e+190,1e+191,1e+192,1e+193,1e+194,1e+195,1e+196,1e+197,1e+198,1e+199,1e+200,
1e+201,1e+202,1e+203,1e+204,1e+205,1e+206,1e+207,1e+208,1e+209,1e+210,1e+211,1e+212,1e+213,1e+214,1e+215,1e+216,1e+217,1e+218,1e+219,1e+220,
1e+221,1e+222,1e+223,1e+224,1e+225,1e+226,1e+227,1e+228,1e+229,1e+230,1e+231,1e+232,1e+233,1e+234,1e+235,1e+236,1e+237,1e+238,1e+239,1e+240,
1e+241,1e+242,1e+243,1e+244,1e+245,1e+246,1e+247,1e+248,1e+249,1e+250,1e+251,1e+252,1e+253,1e+254,1e+255,1e+256,1e+257,1e+258,1e+259,1e+260,
1e+261,1e+262,1e+263,1e+264,1e+265,1e+266,1e+267,1e+268,1e+269,1e+270,1e+271,1e+272,1e+273,1e+274,1e+275,1e+276,1e+277,1e+278,1e+279,1e+280,
1e+281,1e+282,1e+283,1e+284,1e+285,1e+286,1e+287,1e+288,1e+289,1e+290,1e+291,1e+292,1e+293,1e+294,1e+295,1e+296,1e+297,1e+298,1e+299,1e+300,
1e+301,1e+302,1e+303,1e+304,1e+305,1e+306,1e+307,1e+308
};
inline double Pow10(int n)
{
static const double e[] = {
// 1e-0...1e308: 309 * 8 bytes = 2472 bytes
1e+0, 1e+1, 1e+2, 1e+3, 1e+4, 1e+5, 1e+6, 1e+7, 1e+8, 1e+9, 1e+10,
1e+11, 1e+12, 1e+13, 1e+14, 1e+15, 1e+16, 1e+17, 1e+18, 1e+19, 1e+20, 1e+21,
1e+22, 1e+23, 1e+24, 1e+25, 1e+26, 1e+27, 1e+28, 1e+29, 1e+30, 1e+31, 1e+32,
1e+33, 1e+34, 1e+35, 1e+36, 1e+37, 1e+38, 1e+39, 1e+40, 1e+41, 1e+42, 1e+43,
1e+44, 1e+45, 1e+46, 1e+47, 1e+48, 1e+49, 1e+50, 1e+51, 1e+52, 1e+53, 1e+54,
1e+55, 1e+56, 1e+57, 1e+58, 1e+59, 1e+60, 1e+61, 1e+62, 1e+63, 1e+64, 1e+65,
1e+66, 1e+67, 1e+68, 1e+69, 1e+70, 1e+71, 1e+72, 1e+73, 1e+74, 1e+75, 1e+76,
1e+77, 1e+78, 1e+79, 1e+80, 1e+81, 1e+82, 1e+83, 1e+84, 1e+85, 1e+86, 1e+87,
1e+88, 1e+89, 1e+90, 1e+91, 1e+92, 1e+93, 1e+94, 1e+95, 1e+96, 1e+97, 1e+98,
1e+99, 1e+100, 1e+101, 1e+102, 1e+103, 1e+104, 1e+105, 1e+106, 1e+107, 1e+108, 1e+109,
1e+110, 1e+111, 1e+112, 1e+113, 1e+114, 1e+115, 1e+116, 1e+117, 1e+118, 1e+119, 1e+120,
1e+121, 1e+122, 1e+123, 1e+124, 1e+125, 1e+126, 1e+127, 1e+128, 1e+129, 1e+130, 1e+131,
1e+132, 1e+133, 1e+134, 1e+135, 1e+136, 1e+137, 1e+138, 1e+139, 1e+140, 1e+141, 1e+142,
1e+143, 1e+144, 1e+145, 1e+146, 1e+147, 1e+148, 1e+149, 1e+150, 1e+151, 1e+152, 1e+153,
1e+154, 1e+155, 1e+156, 1e+157, 1e+158, 1e+159, 1e+160, 1e+161, 1e+162, 1e+163, 1e+164,
1e+165, 1e+166, 1e+167, 1e+168, 1e+169, 1e+170, 1e+171, 1e+172, 1e+173, 1e+174, 1e+175,
1e+176, 1e+177, 1e+178, 1e+179, 1e+180, 1e+181, 1e+182, 1e+183, 1e+184, 1e+185, 1e+186,
1e+187, 1e+188, 1e+189, 1e+190, 1e+191, 1e+192, 1e+193, 1e+194, 1e+195, 1e+196, 1e+197,
1e+198, 1e+199, 1e+200, 1e+201, 1e+202, 1e+203, 1e+204, 1e+205, 1e+206, 1e+207, 1e+208,
1e+209, 1e+210, 1e+211, 1e+212, 1e+213, 1e+214, 1e+215, 1e+216, 1e+217, 1e+218, 1e+219,
1e+220, 1e+221, 1e+222, 1e+223, 1e+224, 1e+225, 1e+226, 1e+227, 1e+228, 1e+229, 1e+230,
1e+231, 1e+232, 1e+233, 1e+234, 1e+235, 1e+236, 1e+237, 1e+238, 1e+239, 1e+240, 1e+241,
1e+242, 1e+243, 1e+244, 1e+245, 1e+246, 1e+247, 1e+248, 1e+249, 1e+250, 1e+251, 1e+252,
1e+253, 1e+254, 1e+255, 1e+256, 1e+257, 1e+258, 1e+259, 1e+260, 1e+261, 1e+262, 1e+263,
1e+264, 1e+265, 1e+266, 1e+267, 1e+268, 1e+269, 1e+270, 1e+271, 1e+272, 1e+273, 1e+274,
1e+275, 1e+276, 1e+277, 1e+278, 1e+279, 1e+280, 1e+281, 1e+282, 1e+283, 1e+284, 1e+285,
1e+286, 1e+287, 1e+288, 1e+289, 1e+290, 1e+291, 1e+292, 1e+293, 1e+294, 1e+295, 1e+296,
1e+297, 1e+298, 1e+299, 1e+300, 1e+301, 1e+302, 1e+303, 1e+304, 1e+305, 1e+306, 1e+307,
1e+308};
RAPIDJSON_ASSERT(n >= 0 && n <= 308);
return e[n];
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,5 @@
// Tencent is pleased to support the open source community by making RapidJSON available.
//
//
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
//
// Licensed under the MIT License (the "License"); you may not use this file except
@@ -7,9 +7,9 @@
//
// http://opensource.org/licenses/MIT
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
#ifndef RAPIDJSON_INTERNAL_STACK_H_
@@ -21,7 +21,7 @@
#if defined(__clang__)
RAPIDJSON_DIAG_PUSH
RAPIDJSON_DIAG_OFF(c++98-compat)
RAPIDJSON_DIAG_OFF(c++ 98 - compat)
#endif
RAPIDJSON_NAMESPACE_BEGIN
@@ -32,13 +32,21 @@ namespace internal {
//! A type-unsafe stack for storing different types of data.
/*! \tparam Allocator Allocator for allocating stack memory.
*/
*/
template <typename Allocator>
class Stack {
public:
class Stack
{
public:
// Optimization note: Do not allocate memory for stack_ in constructor.
// Do it lazily when first Push() -> Expand() -> Resize().
Stack(Allocator* allocator, size_t stackCapacity) : allocator_(allocator), ownAllocator_(0), stack_(0), stackTop_(0), stackEnd_(0), initialCapacity_(stackCapacity) {
Stack(Allocator* allocator, size_t stackCapacity)
: allocator_(allocator),
ownAllocator_(0),
stack_(0),
stackTop_(0),
stackEnd_(0),
initialCapacity_(stackCapacity)
{
}
#if RAPIDJSON_HAS_CXX11_RVALUE_REFS
@@ -50,44 +58,44 @@ public:
stackEnd_(rhs.stackEnd_),
initialCapacity_(rhs.initialCapacity_)
{
rhs.allocator_ = 0;
rhs.ownAllocator_ = 0;
rhs.stack_ = 0;
rhs.stackTop_ = 0;
rhs.stackEnd_ = 0;
rhs.allocator_ = 0;
rhs.ownAllocator_ = 0;
rhs.stack_ = 0;
rhs.stackTop_ = 0;
rhs.stackEnd_ = 0;
rhs.initialCapacity_ = 0;
}
#endif
~Stack() {
Destroy();
}
~Stack() { Destroy(); }
#if RAPIDJSON_HAS_CXX11_RVALUE_REFS
Stack& operator=(Stack&& rhs) {
if (&rhs != this)
Stack& operator=(Stack&& rhs)
{
if(&rhs != this)
{
Destroy();
allocator_ = rhs.allocator_;
ownAllocator_ = rhs.ownAllocator_;
stack_ = rhs.stack_;
stackTop_ = rhs.stackTop_;
stackEnd_ = rhs.stackEnd_;
allocator_ = rhs.allocator_;
ownAllocator_ = rhs.ownAllocator_;
stack_ = rhs.stack_;
stackTop_ = rhs.stackTop_;
stackEnd_ = rhs.stackEnd_;
initialCapacity_ = rhs.initialCapacity_;
rhs.allocator_ = 0;
rhs.ownAllocator_ = 0;
rhs.stack_ = 0;
rhs.stackTop_ = 0;
rhs.stackEnd_ = 0;
rhs.allocator_ = 0;
rhs.ownAllocator_ = 0;
rhs.stack_ = 0;
rhs.stackTop_ = 0;
rhs.stackEnd_ = 0;
rhs.initialCapacity_ = 0;
}
return *this;
}
#endif
void Swap(Stack& rhs) RAPIDJSON_NOEXCEPT {
void Swap(Stack& rhs) RAPIDJSON_NOEXCEPT
{
internal::Swap(allocator_, rhs.allocator_);
internal::Swap(ownAllocator_, rhs.ownAllocator_);
internal::Swap(stack_, rhs.stack_);
@@ -98,11 +106,13 @@ public:
void Clear() { stackTop_ = stack_; }
void ShrinkToFit() {
if (Empty()) {
void ShrinkToFit()
{
if(Empty())
{
// If the stack is empty, completely deallocate the memory.
Allocator::Free(stack_); // NOLINT (+clang-analyzer-unix.Malloc)
stack_ = 0;
stack_ = 0;
stackTop_ = 0;
stackEnd_ = 0;
}
@@ -112,21 +122,25 @@ public:
// Optimization note: try to minimize the size of this function for force inline.
// Expansion is run very infrequently, so it is moved to another (probably non-inline) function.
template<typename T>
RAPIDJSON_FORCEINLINE void Reserve(size_t count = 1) {
// Expand the stack if needed
if (RAPIDJSON_UNLIKELY(static_cast<std::ptrdiff_t>(sizeof(T) * count) > (stackEnd_ - stackTop_)))
template <typename T>
RAPIDJSON_FORCEINLINE void Reserve(size_t count = 1)
{
// Expand the stack if needed
if(RAPIDJSON_UNLIKELY(static_cast<std::ptrdiff_t>(sizeof(T) * count) >
(stackEnd_ - stackTop_)))
Expand<T>(count);
}
template<typename T>
RAPIDJSON_FORCEINLINE T* Push(size_t count = 1) {
template <typename T>
RAPIDJSON_FORCEINLINE T* Push(size_t count = 1)
{
Reserve<T>(count);
return PushUnsafe<T>(count);
}
template<typename T>
RAPIDJSON_FORCEINLINE T* PushUnsafe(size_t count = 1) {
template <typename T>
RAPIDJSON_FORCEINLINE T* PushUnsafe(size_t count = 1)
{
RAPIDJSON_ASSERT(stackTop_);
RAPIDJSON_ASSERT(static_cast<std::ptrdiff_t>(sizeof(T) * count) <= (stackEnd_ - stackTop_));
T* ret = reinterpret_cast<T*>(stackTop_);
@@ -134,42 +148,56 @@ public:
return ret;
}
template<typename T>
T* Pop(size_t count) {
template <typename T>
T* Pop(size_t count)
{
RAPIDJSON_ASSERT(GetSize() >= count * sizeof(T));
stackTop_ -= count * sizeof(T);
return reinterpret_cast<T*>(stackTop_);
}
template<typename T>
T* Top() {
template <typename T>
T* Top()
{
RAPIDJSON_ASSERT(GetSize() >= sizeof(T));
return reinterpret_cast<T*>(stackTop_ - sizeof(T));
}
template<typename T>
const T* Top() const {
template <typename T>
const T* Top() const
{
RAPIDJSON_ASSERT(GetSize() >= sizeof(T));
return reinterpret_cast<T*>(stackTop_ - sizeof(T));
}
template<typename T>
T* End() { return reinterpret_cast<T*>(stackTop_); }
template<typename T>
const T* End() const { return reinterpret_cast<T*>(stackTop_); }
template<typename T>
T* Bottom() { return reinterpret_cast<T*>(stack_); }
template<typename T>
const T* Bottom() const { return reinterpret_cast<T*>(stack_); }
bool HasAllocator() const {
return allocator_ != 0;
template <typename T>
T* End()
{
return reinterpret_cast<T*>(stackTop_);
}
Allocator& GetAllocator() {
template <typename T>
const T* End() const
{
return reinterpret_cast<T*>(stackTop_);
}
template <typename T>
T* Bottom()
{
return reinterpret_cast<T*>(stack_);
}
template <typename T>
const T* Bottom() const
{
return reinterpret_cast<T*>(stack_);
}
bool HasAllocator() const { return allocator_ != 0; }
Allocator& GetAllocator()
{
RAPIDJSON_ASSERT(allocator_);
return *allocator_;
}
@@ -178,34 +206,41 @@ public:
size_t GetSize() const { return static_cast<size_t>(stackTop_ - stack_); }
size_t GetCapacity() const { return static_cast<size_t>(stackEnd_ - stack_); }
private:
template<typename T>
void Expand(size_t count) {
// Only expand the capacity if the current stack exists. Otherwise just create a stack with initial capacity.
private:
template <typename T>
void Expand(size_t count)
{
// Only expand the capacity if the current stack exists. Otherwise just create a stack with
// initial capacity.
size_t newCapacity;
if (stack_ == 0) {
if (!allocator_)
if(stack_ == 0)
{
if(!allocator_)
ownAllocator_ = allocator_ = RAPIDJSON_NEW(Allocator)();
newCapacity = initialCapacity_;
} else {
}
else
{
newCapacity = GetCapacity();
newCapacity += (newCapacity + 1) / 2;
}
size_t newSize = GetSize() + sizeof(T) * count;
if (newCapacity < newSize)
if(newCapacity < newSize)
newCapacity = newSize;
Resize(newCapacity);
}
void Resize(size_t newCapacity) {
const size_t size = GetSize(); // Backup the current size
stack_ = static_cast<char*>(allocator_->Realloc(stack_, GetCapacity(), newCapacity));
void Resize(size_t newCapacity)
{
const size_t size = GetSize(); // Backup the current size
stack_ = static_cast<char*>(allocator_->Realloc(stack_, GetCapacity(), newCapacity));
stackTop_ = stack_ + size;
stackEnd_ = stack_ + newCapacity;
}
void Destroy() {
void Destroy()
{
Allocator::Free(stack_);
RAPIDJSON_DELETE(ownAllocator_); // Only delete if it is owned by the stack
}
@@ -216,9 +251,9 @@ private:
Allocator* allocator_;
Allocator* ownAllocator_;
char *stack_;
char *stackTop_;
char *stackEnd_;
char* stack_;
char* stackTop_;
char* stackEnd_;
size_t initialCapacity_;
};

View File

@@ -1,5 +1,5 @@
// Tencent is pleased to support the open source community by making RapidJSON available.
//
//
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
//
// Licensed under the MIT License (the "License"); you may not use this file except
@@ -7,9 +7,9 @@
//
// http://opensource.org/licenses/MIT
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
#ifndef RAPIDJSON_INTERNAL_STRFUNC_H_
@@ -24,24 +24,29 @@ namespace internal {
//! Custom strlen() which works on different character types.
/*! \tparam Ch Character type (e.g. char, wchar_t, short)
\param s Null-terminated input string.
\return Number of characters in the string.
\note This has the same semantics as strlen(), the return value is not number of Unicode codepoints.
\return Number of characters in the string.
\note This has the same semantics as strlen(), the return value is not number of Unicode
codepoints.
*/
template <typename Ch>
inline SizeType StrLen(const Ch* s) {
inline SizeType StrLen(const Ch* s)
{
RAPIDJSON_ASSERT(s != 0);
const Ch* p = s;
while (*p) ++p;
while(*p)
++p;
return SizeType(p - s);
}
template <>
inline SizeType StrLen(const char* s) {
inline SizeType StrLen(const char* s)
{
return SizeType(std::strlen(s));
}
template <>
inline SizeType StrLen(const wchar_t* s) {
inline SizeType StrLen(const wchar_t* s)
{
return SizeType(std::wcslen(s));
}
@@ -51,25 +56,34 @@ inline SizeType StrLen(const wchar_t* s) {
\param s2 Null-terminated input string.
\return 0 if equal
*/
template<typename Ch>
inline int StrCmp(const Ch* s1, const Ch* s2) {
template <typename Ch>
inline int StrCmp(const Ch* s1, const Ch* s2)
{
RAPIDJSON_ASSERT(s1 != 0);
RAPIDJSON_ASSERT(s2 != 0);
while(*s1 && (*s1 == *s2)) { s1++; s2++; }
return static_cast<unsigned>(*s1) < static_cast<unsigned>(*s2) ? -1 : static_cast<unsigned>(*s1) > static_cast<unsigned>(*s2);
while(*s1 && (*s1 == *s2))
{
s1++;
s2++;
}
return static_cast<unsigned>(*s1) < static_cast<unsigned>(*s2)
? -1
: static_cast<unsigned>(*s1) > static_cast<unsigned>(*s2);
}
//! Returns number of code points in a encoded string.
template<typename Encoding>
bool CountStringCodePoint(const typename Encoding::Ch* s, SizeType length, SizeType* outCount) {
template <typename Encoding>
bool CountStringCodePoint(const typename Encoding::Ch* s, SizeType length, SizeType* outCount)
{
RAPIDJSON_ASSERT(s != 0);
RAPIDJSON_ASSERT(outCount != 0);
GenericStringStream<Encoding> is(s);
const typename Encoding::Ch* end = s + length;
SizeType count = 0;
while (is.src_ < end) {
SizeType count = 0;
while(is.src_ < end)
{
unsigned codepoint;
if (!Encoding::Decode(is, &codepoint))
if(!Encoding::Decode(is, &codepoint))
return false;
count++;
}

View File

@@ -1,5 +1,5 @@
// Tencent is pleased to support the open source community by making RapidJSON available.
//
//
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
//
// Licensed under the MIT License (the "License"); you may not use this file except
@@ -7,9 +7,9 @@
//
// http://opensource.org/licenses/MIT
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
#ifndef RAPIDJSON_STRTOD_
@@ -25,17 +25,20 @@
RAPIDJSON_NAMESPACE_BEGIN
namespace internal {
inline double FastPath(double significand, int exp) {
if (exp < -308)
inline double FastPath(double significand, int exp)
{
if(exp < -308)
return 0.0;
else if (exp >= 0)
else if(exp >= 0)
return significand * internal::Pow10(exp);
else
return significand / internal::Pow10(-exp);
}
inline double StrtodNormalPrecision(double d, int p) {
if (p < -308) {
inline double StrtodNormalPrecision(double d, int p)
{
if(p < -308)
{
// Prevent expSum < -308, making Pow10(p) = 0
d = FastPath(d, -308);
d = FastPath(d, p + 308);
@@ -46,27 +49,33 @@ inline double StrtodNormalPrecision(double d, int p) {
}
template <typename T>
inline T Min3(T a, T b, T c) {
inline T Min3(T a, T b, T c)
{
T m = a;
if (m > b) m = b;
if (m > c) m = c;
if(m > b)
m = b;
if(m > c)
m = c;
return m;
}
inline int CheckWithinHalfULP(double b, const BigInteger& d, int dExp) {
inline int CheckWithinHalfULP(double b, const BigInteger& d, int dExp)
{
const Double db(b);
const uint64_t bInt = db.IntegerSignificand();
const int bExp = db.IntegerExponent();
const int hExp = bExp - 1;
const int bExp = db.IntegerExponent();
const int hExp = bExp - 1;
int dS_Exp2 = 0, dS_Exp5 = 0, bS_Exp2 = 0, bS_Exp5 = 0, hS_Exp2 = 0, hS_Exp5 = 0;
// Adjust for decimal exponent
if (dExp >= 0) {
if(dExp >= 0)
{
dS_Exp2 += dExp;
dS_Exp5 += dExp;
}
else {
else
{
bS_Exp2 -= dExp;
bS_Exp5 -= dExp;
hS_Exp2 -= dExp;
@@ -74,17 +83,19 @@ inline int CheckWithinHalfULP(double b, const BigInteger& d, int dExp) {
}
// Adjust for binary exponent
if (bExp >= 0)
if(bExp >= 0)
bS_Exp2 += bExp;
else {
else
{
dS_Exp2 -= bExp;
hS_Exp2 -= bExp;
}
// Adjust for half ulp exponent
if (hExp >= 0)
if(hExp >= 0)
hS_Exp2 += hExp;
else {
else
{
dS_Exp2 -= hExp;
bS_Exp2 -= hExp;
}
@@ -110,16 +121,19 @@ inline int CheckWithinHalfULP(double b, const BigInteger& d, int dExp) {
return delta.Compare(hS);
}
inline bool StrtodFast(double d, int p, double* result) {
inline bool StrtodFast(double d, int p, double* result)
{
// Use fast path for string-to-double conversion if possible
// see http://www.exploringbinary.com/fast-path-decimal-to-floating-point-conversion/
if (p > 22 && p < 22 + 16) {
if(p > 22 && p < 22 + 16)
{
// Fast Path Cases In Disguise
d *= internal::Pow10(p - 22);
p = 22;
}
if (p >= -22 && p <= 22 && d <= 9007199254740991.0) { // 2^53 - 1
if(p >= -22 && p <= 22 && d <= 9007199254740991.0)
{ // 2^53 - 1
*result = FastPath(d, p);
return true;
}
@@ -128,24 +142,26 @@ inline bool StrtodFast(double d, int p, double* result) {
}
// Compute an approximation and see if it is within 1/2 ULP
template<typename Ch>
inline bool StrtodDiyFp(const Ch* decimals, int dLen, int dExp, double* result) {
template <typename Ch>
inline bool StrtodDiyFp(const Ch* decimals, int dLen, int dExp, double* result)
{
uint64_t significand = 0;
int i = 0; // 2^64 - 1 = 18446744073709551615, 1844674407370955161 = 0x1999999999999999
for (; i < dLen; i++) {
if (significand > RAPIDJSON_UINT64_C2(0x19999999, 0x99999999) ||
(significand == RAPIDJSON_UINT64_C2(0x19999999, 0x99999999) && decimals[i] >= Ch('5')))
int i = 0; // 2^64 - 1 = 18446744073709551615, 1844674407370955161 = 0x1999999999999999
for(; i < dLen; i++)
{
if(significand > RAPIDJSON_UINT64_C2(0x19999999, 0x99999999) ||
(significand == RAPIDJSON_UINT64_C2(0x19999999, 0x99999999) && decimals[i] >= Ch('5')))
break;
significand = significand * 10u + static_cast<unsigned>(decimals[i] - Ch('0'));
}
if (i < dLen && decimals[i] >= Ch('5')) // Rounding
if(i < dLen && decimals[i] >= Ch('5')) // Rounding
significand++;
int remaining = dLen - i;
int remaining = dLen - i;
const int kUlpShift = 3;
const int kUlp = 1 << kUlpShift;
int64_t error = (remaining == 0) ? 0 : kUlp / 2;
const int kUlp = 1 << kUlpShift;
int64_t error = (remaining == 0) ? 0 : kUlp / 2;
DiyFp v(significand, 0);
v = v.Normalize();
@@ -155,20 +171,21 @@ inline bool StrtodDiyFp(const Ch* decimals, int dLen, int dExp, double* result)
int actualExp;
DiyFp cachedPower = GetCachedPower10(dExp, &actualExp);
if (actualExp != dExp) {
if(actualExp != dExp)
{
static const DiyFp kPow10[] = {
DiyFp(RAPIDJSON_UINT64_C2(0xa0000000, 0x00000000), -60), // 10^1
DiyFp(RAPIDJSON_UINT64_C2(0xc8000000, 0x00000000), -57), // 10^2
DiyFp(RAPIDJSON_UINT64_C2(0xfa000000, 0x00000000), -54), // 10^3
DiyFp(RAPIDJSON_UINT64_C2(0x9c400000, 0x00000000), -50), // 10^4
DiyFp(RAPIDJSON_UINT64_C2(0xc3500000, 0x00000000), -47), // 10^5
DiyFp(RAPIDJSON_UINT64_C2(0xf4240000, 0x00000000), -44), // 10^6
DiyFp(RAPIDJSON_UINT64_C2(0x98968000, 0x00000000), -40) // 10^7
DiyFp(RAPIDJSON_UINT64_C2(0xa0000000, 0x00000000), -60), // 10^1
DiyFp(RAPIDJSON_UINT64_C2(0xc8000000, 0x00000000), -57), // 10^2
DiyFp(RAPIDJSON_UINT64_C2(0xfa000000, 0x00000000), -54), // 10^3
DiyFp(RAPIDJSON_UINT64_C2(0x9c400000, 0x00000000), -50), // 10^4
DiyFp(RAPIDJSON_UINT64_C2(0xc3500000, 0x00000000), -47), // 10^5
DiyFp(RAPIDJSON_UINT64_C2(0xf4240000, 0x00000000), -44), // 10^6
DiyFp(RAPIDJSON_UINT64_C2(0x98968000, 0x00000000), -40) // 10^7
};
int adjustment = dExp - actualExp;
RAPIDJSON_ASSERT(adjustment >= 1 && adjustment < 8);
v = v * kPow10[adjustment - 1];
if (dLen + adjustment > 19) // has more digits than decimal digits in 64-bit
if(dLen + adjustment > 19) // has more digits than decimal digits in 64-bit
error += kUlp / 2;
}
@@ -177,25 +194,28 @@ inline bool StrtodDiyFp(const Ch* decimals, int dLen, int dExp, double* result)
error += kUlp + (error == 0 ? 0 : 1);
const int oldExp = v.e;
v = v.Normalize();
v = v.Normalize();
error <<= oldExp - v.e;
const int effectiveSignificandSize = Double::EffectiveSignificandSize(64 + v.e);
int precisionSize = 64 - effectiveSignificandSize;
if (precisionSize + kUlpShift >= 64) {
int precisionSize = 64 - effectiveSignificandSize;
if(precisionSize + kUlpShift >= 64)
{
int scaleExp = (precisionSize + kUlpShift) - 63;
v.f >>= scaleExp;
v.e += scaleExp;
v.e += scaleExp;
error = (error >> scaleExp) + 1 + kUlp;
precisionSize -= scaleExp;
}
DiyFp rounded(v.f >> precisionSize, v.e + precisionSize);
const uint64_t precisionBits = (v.f & ((uint64_t(1) << precisionSize) - 1)) * kUlp;
const uint64_t halfWay = (uint64_t(1) << (precisionSize - 1)) * kUlp;
if (precisionBits >= halfWay + static_cast<unsigned>(error)) {
const uint64_t halfWay = (uint64_t(1) << (precisionSize - 1)) * kUlp;
if(precisionBits >= halfWay + static_cast<unsigned>(error))
{
rounded.f++;
if (rounded.f & (DiyFp::kDpHiddenBit << 1)) { // rounding overflows mantissa (issue #340)
if(rounded.f & (DiyFp::kDpHiddenBit << 1))
{ // rounding overflows mantissa (issue #340)
rounded.f >>= 1;
rounded.e++;
}
@@ -203,20 +223,23 @@ inline bool StrtodDiyFp(const Ch* decimals, int dLen, int dExp, double* result)
*result = rounded.ToDouble();
return halfWay - static_cast<unsigned>(error) >= precisionBits || precisionBits >= halfWay + static_cast<unsigned>(error);
return halfWay - static_cast<unsigned>(error) >= precisionBits ||
precisionBits >= halfWay + static_cast<unsigned>(error);
}
template<typename Ch>
inline double StrtodBigInteger(double approx, const Ch* decimals, int dLen, int dExp) {
template <typename Ch>
inline double StrtodBigInteger(double approx, const Ch* decimals, int dLen, int dExp)
{
RAPIDJSON_ASSERT(dLen >= 0);
const BigInteger dInt(decimals, static_cast<unsigned>(dLen));
Double a(approx);
int cmp = CheckWithinHalfULP(a.Value(), dInt, dExp);
if (cmp < 0)
return a.Value(); // within half ULP
else if (cmp == 0) {
if(cmp < 0)
return a.Value(); // within half ULP
else if(cmp == 0)
{
// Round towards even
if (a.Significand() & 1)
if(a.Significand() & 1)
return a.NextPositiveDouble();
else
return a.Value();
@@ -225,13 +248,15 @@ inline double StrtodBigInteger(double approx, const Ch* decimals, int dLen, int
return a.NextPositiveDouble();
}
template<typename Ch>
inline double StrtodFullPrecision(double d, int p, const Ch* decimals, size_t length, size_t decimalPosition, int exp) {
template <typename Ch>
inline double StrtodFullPrecision(
double d, int p, const Ch* decimals, size_t length, size_t decimalPosition, int exp)
{
RAPIDJSON_ASSERT(d >= 0.0);
RAPIDJSON_ASSERT(length >= 1);
double result = 0.0;
if (StrtodFast(d, p, &result))
if(StrtodFast(d, p, &result))
return result;
RAPIDJSON_ASSERT(length <= INT_MAX);
@@ -248,39 +273,43 @@ inline double StrtodFullPrecision(double d, int p, const Ch* decimals, size_t le
RAPIDJSON_ASSERT(dExp <= INT_MAX - dLen);
// Trim leading zeros
while (dLen > 0 && *decimals == '0') {
while(dLen > 0 && *decimals == '0')
{
dLen--;
decimals++;
}
// Trim trailing zeros
while (dLen > 0 && decimals[dLen - 1] == '0') {
while(dLen > 0 && decimals[dLen - 1] == '0')
{
dLen--;
dExp++;
}
if (dLen == 0) { // Buffer only contains zeros.
if(dLen == 0)
{ // Buffer only contains zeros.
return 0.0;
}
// Trim right-most digits
const int kMaxDecimalDigit = 767 + 1;
if (dLen > kMaxDecimalDigit) {
if(dLen > kMaxDecimalDigit)
{
dExp += dLen - kMaxDecimalDigit;
dLen = kMaxDecimalDigit;
}
// If too small, underflow to zero.
// Any x <= 10^-324 is interpreted as zero.
if (dLen + dExp <= -324)
if(dLen + dExp <= -324)
return 0.0;
// If too large, overflow to infinity.
// Any x >= 10^309 is interpreted as +infinity.
if (dLen + dExp > 309)
if(dLen + dExp > 309)
return std::numeric_limits<double>::infinity();
if (StrtodDiyFp(decimals, dLen, dExp, &result))
if(StrtodDiyFp(decimals, dLen, dExp, &result))
return result;
// Use approximation from StrtodDiyFp and make adjustment with BigInteger comparison

View File

@@ -19,7 +19,7 @@
#if defined(__clang__)
RAPIDJSON_DIAG_PUSH
RAPIDJSON_DIAG_OFF(c++98-compat)
RAPIDJSON_DIAG_OFF(c++ 98 - compat)
#endif
RAPIDJSON_NAMESPACE_BEGIN
@@ -30,10 +30,11 @@ namespace internal {
\note This has the same semantics as std::swap().
*/
template <typename T>
inline void Swap(T& a, T& b) RAPIDJSON_NOEXCEPT {
inline void Swap(T& a, T& b) RAPIDJSON_NOEXCEPT
{
T tmp = a;
a = b;
b = tmp;
a = b;
b = tmp;
}
} // namespace internal

View File

@@ -1,5 +1,5 @@
// Tencent is pleased to support the open source community by making RapidJSON available.
//
//
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
//
// Licensed under the MIT License (the "License"); you may not use this file except
@@ -7,9 +7,9 @@
//
// http://opensource.org/licenses/MIT
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
#ifndef RAPIDJSON_ISTREAMWRAPPER_H_
@@ -44,17 +44,27 @@ RAPIDJSON_NAMESPACE_BEGIN
\tparam StreamType Class derived from \c std::basic_istream.
*/
template <typename StreamType>
class BasicIStreamWrapper {
public:
class BasicIStreamWrapper
{
public:
typedef typename StreamType::char_type Ch;
//! Constructor.
/*!
\param stream stream opened for read.
*/
BasicIStreamWrapper(StreamType &stream) : stream_(stream), buffer_(peekBuffer_), bufferSize_(4), bufferLast_(0), current_(buffer_), readCount_(0), count_(0), eof_(false) {
BasicIStreamWrapper(StreamType& stream)
: stream_(stream),
buffer_(peekBuffer_),
bufferSize_(4),
bufferLast_(0),
current_(buffer_),
readCount_(0),
count_(0),
eof_(false)
{
Read();
}
@@ -64,55 +74,78 @@ public:
\param buffer user-supplied buffer.
\param bufferSize size of buffer in bytes. Must >=4 bytes.
*/
BasicIStreamWrapper(StreamType &stream, char* buffer, size_t bufferSize) : stream_(stream), buffer_(buffer), bufferSize_(bufferSize), bufferLast_(0), current_(buffer_), readCount_(0), count_(0), eof_(false) {
BasicIStreamWrapper(StreamType& stream, char* buffer, size_t bufferSize)
: stream_(stream),
buffer_(buffer),
bufferSize_(bufferSize),
bufferLast_(0),
current_(buffer_),
readCount_(0),
count_(0),
eof_(false)
{
RAPIDJSON_ASSERT(bufferSize >= 4);
Read();
}
Ch Peek() const { return *current_; }
Ch Take() { Ch c = *current_; Read(); return c; }
Ch Take()
{
Ch c = *current_;
Read();
return c;
}
size_t Tell() const { return count_ + static_cast<size_t>(current_ - buffer_); }
// Not implemented
void Put(Ch) { RAPIDJSON_ASSERT(false); }
void Flush() { RAPIDJSON_ASSERT(false); }
Ch* PutBegin() { RAPIDJSON_ASSERT(false); return 0; }
size_t PutEnd(Ch*) { RAPIDJSON_ASSERT(false); return 0; }
// For encoding detection only.
const Ch* Peek4() const {
return (current_ + 4 - !eof_ <= bufferLast_) ? current_ : 0;
void Flush() { RAPIDJSON_ASSERT(false); }
Ch* PutBegin()
{
RAPIDJSON_ASSERT(false);
return 0;
}
size_t PutEnd(Ch*)
{
RAPIDJSON_ASSERT(false);
return 0;
}
private:
// For encoding detection only.
const Ch* Peek4() const { return (current_ + 4 - !eof_ <= bufferLast_) ? current_ : 0; }
private:
BasicIStreamWrapper();
BasicIStreamWrapper(const BasicIStreamWrapper&);
BasicIStreamWrapper& operator=(const BasicIStreamWrapper&);
void Read() {
if (current_ < bufferLast_)
void Read()
{
if(current_ < bufferLast_)
++current_;
else if (!eof_) {
else if(!eof_)
{
count_ += readCount_;
readCount_ = bufferSize_;
readCount_ = bufferSize_;
bufferLast_ = buffer_ + readCount_ - 1;
current_ = buffer_;
current_ = buffer_;
if (!stream_.read(buffer_, static_cast<std::streamsize>(bufferSize_))) {
readCount_ = static_cast<size_t>(stream_.gcount());
if(!stream_.read(buffer_, static_cast<std::streamsize>(bufferSize_)))
{
readCount_ = static_cast<size_t>(stream_.gcount());
*(bufferLast_ = buffer_ + readCount_) = '\0';
eof_ = true;
eof_ = true;
}
}
}
StreamType &stream_;
StreamType& stream_;
Ch peekBuffer_[4], *buffer_;
size_t bufferSize_;
Ch *bufferLast_;
Ch *current_;
Ch* bufferLast_;
Ch* current_;
size_t readCount_;
size_t count_; //!< Number of characters read
size_t count_; //!< Number of characters read
bool eof_;
};

View File

@@ -1,5 +1,5 @@
// Tencent is pleased to support the open source community by making RapidJSON available.
//
//
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
//
// Licensed under the MIT License (the "License"); you may not use this file except
@@ -7,9 +7,9 @@
//
// http://opensource.org/licenses/MIT
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
#ifndef RAPIDJSON_MEMORYBUFFER_H_
@@ -27,17 +27,22 @@ RAPIDJSON_NAMESPACE_BEGIN
It is similar to FileWriteBuffer but the destination is an in-memory buffer instead of a file.
Differences between MemoryBuffer and StringBuffer:
1. StringBuffer has Encoding but MemoryBuffer is only a byte buffer.
2. StringBuffer::GetString() returns a null-terminated string. MemoryBuffer::GetBuffer() returns a buffer without terminator.
1. StringBuffer has Encoding but MemoryBuffer is only a byte buffer.
2. StringBuffer::GetString() returns a null-terminated string. MemoryBuffer::GetBuffer() returns
a buffer without terminator.
\tparam Allocator type for allocating memory buffer.
\note implements Stream concept
*/
template <typename Allocator = CrtAllocator>
struct GenericMemoryBuffer {
struct GenericMemoryBuffer
{
typedef char Ch; // byte
GenericMemoryBuffer(Allocator* allocator = 0, size_t capacity = kDefaultCapacity) : stack_(allocator, capacity) {}
GenericMemoryBuffer(Allocator* allocator = 0, size_t capacity = kDefaultCapacity)
: stack_(allocator, capacity)
{
}
void Put(Ch c) { *stack_.template Push<Ch>() = c; }
void Flush() {}
@@ -47,9 +52,7 @@ struct GenericMemoryBuffer {
Ch* Push(size_t count) { return stack_.template Push<Ch>(count); }
void Pop(size_t count) { stack_.template Pop<Ch>(count); }
const Ch* GetBuffer() const {
return stack_.template Bottom<Ch>();
}
const Ch* GetBuffer() const { return stack_.template Bottom<Ch>(); }
size_t GetSize() const { return stack_.GetSize(); }
@@ -60,8 +63,9 @@ struct GenericMemoryBuffer {
typedef GenericMemoryBuffer<> MemoryBuffer;
//! Implement specialized version of PutN() with memset() for better performance.
template<>
inline void PutN(MemoryBuffer& memoryBuffer, char c, size_t n) {
template <>
inline void PutN(MemoryBuffer& memoryBuffer, char c, size_t n)
{
std::memset(memoryBuffer.stack_.Push<char>(n), c, n * sizeof(c));
}

View File

@@ -1,5 +1,5 @@
// Tencent is pleased to support the open source community by making RapidJSON available.
//
//
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
//
// Licensed under the MIT License (the "License"); you may not use this file except
@@ -7,9 +7,9 @@
//
// http://opensource.org/licenses/MIT
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
#ifndef RAPIDJSON_MEMORYSTREAM_H_
@@ -19,8 +19,8 @@
#ifdef __clang__
RAPIDJSON_DIAG_PUSH
RAPIDJSON_DIAG_OFF(unreachable-code)
RAPIDJSON_DIAG_OFF(missing-noreturn)
RAPIDJSON_DIAG_OFF(unreachable - code)
RAPIDJSON_DIAG_OFF(missing - noreturn)
#endif
RAPIDJSON_NAMESPACE_BEGIN
@@ -33,33 +33,43 @@ RAPIDJSON_NAMESPACE_BEGIN
Differences between MemoryStream and StringStream:
1. StringStream has encoding but MemoryStream is a byte stream.
2. MemoryStream needs size of the source buffer and the buffer don't need to be null terminated. StringStream assume null-terminated string as source.
3. MemoryStream supports Peek4() for encoding detection. StringStream is specified with an encoding so it should not have Peek4().
\note implements Stream concept
2. MemoryStream needs size of the source buffer and the buffer don't need to be null terminated.
StringStream assume null-terminated string as source.
3. MemoryStream supports Peek4() for encoding detection. StringStream is specified with an
encoding so it should not have Peek4(). \note implements Stream concept
*/
struct MemoryStream {
struct MemoryStream
{
typedef char Ch; // byte
MemoryStream(const Ch *src, size_t size) : src_(src), begin_(src), end_(src + size), size_(size) {}
MemoryStream(const Ch* src, size_t size) : src_(src), begin_(src), end_(src + size), size_(size)
{
}
Ch Peek() const { return RAPIDJSON_UNLIKELY(src_ == end_) ? '\0' : *src_; }
Ch Take() { return RAPIDJSON_UNLIKELY(src_ == end_) ? '\0' : *src_++; }
size_t Tell() const { return static_cast<size_t>(src_ - begin_); }
Ch* PutBegin() { RAPIDJSON_ASSERT(false); return 0; }
Ch* PutBegin()
{
RAPIDJSON_ASSERT(false);
return 0;
}
void Put(Ch) { RAPIDJSON_ASSERT(false); }
void Flush() { RAPIDJSON_ASSERT(false); }
size_t PutEnd(Ch*) { RAPIDJSON_ASSERT(false); return 0; }
// For encoding detection only.
const Ch* Peek4() const {
return Tell() + 4 <= size_ ? src_ : 0;
size_t PutEnd(Ch*)
{
RAPIDJSON_ASSERT(false);
return 0;
}
const Ch* src_; //!< Current read position.
const Ch* begin_; //!< Original head of the string.
const Ch* end_; //!< End of stream.
size_t size_; //!< Size of the stream.
// For encoding detection only.
const Ch* Peek4() const { return Tell() + 4 <= size_ ? src_ : 0; }
const Ch* src_; //!< Current read position.
const Ch* begin_; //!< Original head of the string.
const Ch* end_; //!< End of stream.
size_t size_; //!< Size of the stream.
};
RAPIDJSON_NAMESPACE_END

View File

@@ -1,37 +1,37 @@
// ISO C9x compliant inttypes.h for Microsoft Visual Studio
// Based on ISO/IEC 9899:TC2 Committee draft (May 6, 2005) WG14/N1124
//
// Based on ISO/IEC 9899:TC2 Committee draft (May 6, 2005) WG14/N1124
//
// Copyright (c) 2006-2013 Alexander Chemeris
//
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
//
// 1. Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
//
// 2. Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
//
//
// 3. Neither the name of the product nor the names of its contributors may
// be used to endorse or promote products derived from this software
// without specific prior written permission.
//
//
// THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED
// WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
// EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
// OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
// OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
// OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
// ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
//
///////////////////////////////////////////////////////////////////////////////
// The above software in this distribution may have been modified by
// THL A29 Limited ("Tencent Modifications").
// The above software in this distribution may have been modified by
// THL A29 Limited ("Tencent Modifications").
// All Tencent Modifications are Copyright (C) 2015 THL A29 Limited.
#ifndef _MSC_VER // [
@@ -54,9 +54,10 @@
// 7.8 Format conversion of integer types
typedef struct {
intmax_t quot;
intmax_t rem;
typedef struct
{
intmax_t quot;
intmax_t rem;
} imaxdiv_t;
// 7.8.1 Macros for format specifiers
@@ -64,212 +65,212 @@ typedef struct {
#if !defined(__cplusplus) || defined(__STDC_FORMAT_MACROS) // [ See footnote 185 at page 198
// The fprintf macros for signed integers are:
#define PRId8 "d"
#define PRIi8 "i"
#define PRIdLEAST8 "d"
#define PRIiLEAST8 "i"
#define PRIdFAST8 "d"
#define PRIiFAST8 "i"
#define PRId8 "d"
#define PRIi8 "i"
#define PRIdLEAST8 "d"
#define PRIiLEAST8 "i"
#define PRIdFAST8 "d"
#define PRIiFAST8 "i"
#define PRId16 "hd"
#define PRIi16 "hi"
#define PRIdLEAST16 "hd"
#define PRIiLEAST16 "hi"
#define PRIdFAST16 "hd"
#define PRIiFAST16 "hi"
#define PRId16 "hd"
#define PRIi16 "hi"
#define PRIdLEAST16 "hd"
#define PRIiLEAST16 "hi"
#define PRIdFAST16 "hd"
#define PRIiFAST16 "hi"
#define PRId32 "I32d"
#define PRIi32 "I32i"
#define PRIdLEAST32 "I32d"
#define PRIiLEAST32 "I32i"
#define PRIdFAST32 "I32d"
#define PRIiFAST32 "I32i"
#define PRId32 "I32d"
#define PRIi32 "I32i"
#define PRIdLEAST32 "I32d"
#define PRIiLEAST32 "I32i"
#define PRIdFAST32 "I32d"
#define PRIiFAST32 "I32i"
#define PRId64 "I64d"
#define PRIi64 "I64i"
#define PRIdLEAST64 "I64d"
#define PRIiLEAST64 "I64i"
#define PRIdFAST64 "I64d"
#define PRIiFAST64 "I64i"
#define PRId64 "I64d"
#define PRIi64 "I64i"
#define PRIdLEAST64 "I64d"
#define PRIiLEAST64 "I64i"
#define PRIdFAST64 "I64d"
#define PRIiFAST64 "I64i"
#define PRIdMAX "I64d"
#define PRIiMAX "I64i"
#define PRIdMAX "I64d"
#define PRIiMAX "I64i"
#define PRIdPTR "Id"
#define PRIiPTR "Ii"
#define PRIdPTR "Id"
#define PRIiPTR "Ii"
// The fprintf macros for unsigned integers are:
#define PRIo8 "o"
#define PRIu8 "u"
#define PRIx8 "x"
#define PRIX8 "X"
#define PRIoLEAST8 "o"
#define PRIuLEAST8 "u"
#define PRIxLEAST8 "x"
#define PRIXLEAST8 "X"
#define PRIoFAST8 "o"
#define PRIuFAST8 "u"
#define PRIxFAST8 "x"
#define PRIXFAST8 "X"
#define PRIo8 "o"
#define PRIu8 "u"
#define PRIx8 "x"
#define PRIX8 "X"
#define PRIoLEAST8 "o"
#define PRIuLEAST8 "u"
#define PRIxLEAST8 "x"
#define PRIXLEAST8 "X"
#define PRIoFAST8 "o"
#define PRIuFAST8 "u"
#define PRIxFAST8 "x"
#define PRIXFAST8 "X"
#define PRIo16 "ho"
#define PRIu16 "hu"
#define PRIx16 "hx"
#define PRIX16 "hX"
#define PRIoLEAST16 "ho"
#define PRIuLEAST16 "hu"
#define PRIxLEAST16 "hx"
#define PRIXLEAST16 "hX"
#define PRIoFAST16 "ho"
#define PRIuFAST16 "hu"
#define PRIxFAST16 "hx"
#define PRIXFAST16 "hX"
#define PRIo16 "ho"
#define PRIu16 "hu"
#define PRIx16 "hx"
#define PRIX16 "hX"
#define PRIoLEAST16 "ho"
#define PRIuLEAST16 "hu"
#define PRIxLEAST16 "hx"
#define PRIXLEAST16 "hX"
#define PRIoFAST16 "ho"
#define PRIuFAST16 "hu"
#define PRIxFAST16 "hx"
#define PRIXFAST16 "hX"
#define PRIo32 "I32o"
#define PRIu32 "I32u"
#define PRIx32 "I32x"
#define PRIX32 "I32X"
#define PRIoLEAST32 "I32o"
#define PRIuLEAST32 "I32u"
#define PRIxLEAST32 "I32x"
#define PRIXLEAST32 "I32X"
#define PRIoFAST32 "I32o"
#define PRIuFAST32 "I32u"
#define PRIxFAST32 "I32x"
#define PRIXFAST32 "I32X"
#define PRIo32 "I32o"
#define PRIu32 "I32u"
#define PRIx32 "I32x"
#define PRIX32 "I32X"
#define PRIoLEAST32 "I32o"
#define PRIuLEAST32 "I32u"
#define PRIxLEAST32 "I32x"
#define PRIXLEAST32 "I32X"
#define PRIoFAST32 "I32o"
#define PRIuFAST32 "I32u"
#define PRIxFAST32 "I32x"
#define PRIXFAST32 "I32X"
#define PRIo64 "I64o"
#define PRIu64 "I64u"
#define PRIx64 "I64x"
#define PRIX64 "I64X"
#define PRIoLEAST64 "I64o"
#define PRIuLEAST64 "I64u"
#define PRIxLEAST64 "I64x"
#define PRIXLEAST64 "I64X"
#define PRIoFAST64 "I64o"
#define PRIuFAST64 "I64u"
#define PRIxFAST64 "I64x"
#define PRIXFAST64 "I64X"
#define PRIo64 "I64o"
#define PRIu64 "I64u"
#define PRIx64 "I64x"
#define PRIX64 "I64X"
#define PRIoLEAST64 "I64o"
#define PRIuLEAST64 "I64u"
#define PRIxLEAST64 "I64x"
#define PRIXLEAST64 "I64X"
#define PRIoFAST64 "I64o"
#define PRIuFAST64 "I64u"
#define PRIxFAST64 "I64x"
#define PRIXFAST64 "I64X"
#define PRIoMAX "I64o"
#define PRIuMAX "I64u"
#define PRIxMAX "I64x"
#define PRIXMAX "I64X"
#define PRIoMAX "I64o"
#define PRIuMAX "I64u"
#define PRIxMAX "I64x"
#define PRIXMAX "I64X"
#define PRIoPTR "Io"
#define PRIuPTR "Iu"
#define PRIxPTR "Ix"
#define PRIXPTR "IX"
#define PRIoPTR "Io"
#define PRIuPTR "Iu"
#define PRIxPTR "Ix"
#define PRIXPTR "IX"
// The fscanf macros for signed integers are:
#define SCNd8 "d"
#define SCNi8 "i"
#define SCNdLEAST8 "d"
#define SCNiLEAST8 "i"
#define SCNdFAST8 "d"
#define SCNiFAST8 "i"
#define SCNd8 "d"
#define SCNi8 "i"
#define SCNdLEAST8 "d"
#define SCNiLEAST8 "i"
#define SCNdFAST8 "d"
#define SCNiFAST8 "i"
#define SCNd16 "hd"
#define SCNi16 "hi"
#define SCNdLEAST16 "hd"
#define SCNiLEAST16 "hi"
#define SCNdFAST16 "hd"
#define SCNiFAST16 "hi"
#define SCNd16 "hd"
#define SCNi16 "hi"
#define SCNdLEAST16 "hd"
#define SCNiLEAST16 "hi"
#define SCNdFAST16 "hd"
#define SCNiFAST16 "hi"
#define SCNd32 "ld"
#define SCNi32 "li"
#define SCNdLEAST32 "ld"
#define SCNiLEAST32 "li"
#define SCNdFAST32 "ld"
#define SCNiFAST32 "li"
#define SCNd32 "ld"
#define SCNi32 "li"
#define SCNdLEAST32 "ld"
#define SCNiLEAST32 "li"
#define SCNdFAST32 "ld"
#define SCNiFAST32 "li"
#define SCNd64 "I64d"
#define SCNi64 "I64i"
#define SCNdLEAST64 "I64d"
#define SCNiLEAST64 "I64i"
#define SCNdFAST64 "I64d"
#define SCNiFAST64 "I64i"
#define SCNd64 "I64d"
#define SCNi64 "I64i"
#define SCNdLEAST64 "I64d"
#define SCNiLEAST64 "I64i"
#define SCNdFAST64 "I64d"
#define SCNiFAST64 "I64i"
#define SCNdMAX "I64d"
#define SCNiMAX "I64i"
#define SCNdMAX "I64d"
#define SCNiMAX "I64i"
#ifdef _WIN64 // [
# define SCNdPTR "I64d"
# define SCNiPTR "I64i"
#else // _WIN64 ][
# define SCNdPTR "ld"
# define SCNiPTR "li"
#endif // _WIN64 ]
#define SCNdPTR "I64d"
#define SCNiPTR "I64i"
#else // _WIN64 ][
#define SCNdPTR "ld"
#define SCNiPTR "li"
#endif // _WIN64 ]
// The fscanf macros for unsigned integers are:
#define SCNo8 "o"
#define SCNu8 "u"
#define SCNx8 "x"
#define SCNX8 "X"
#define SCNoLEAST8 "o"
#define SCNuLEAST8 "u"
#define SCNxLEAST8 "x"
#define SCNXLEAST8 "X"
#define SCNoFAST8 "o"
#define SCNuFAST8 "u"
#define SCNxFAST8 "x"
#define SCNXFAST8 "X"
#define SCNo8 "o"
#define SCNu8 "u"
#define SCNx8 "x"
#define SCNX8 "X"
#define SCNoLEAST8 "o"
#define SCNuLEAST8 "u"
#define SCNxLEAST8 "x"
#define SCNXLEAST8 "X"
#define SCNoFAST8 "o"
#define SCNuFAST8 "u"
#define SCNxFAST8 "x"
#define SCNXFAST8 "X"
#define SCNo16 "ho"
#define SCNu16 "hu"
#define SCNx16 "hx"
#define SCNX16 "hX"
#define SCNoLEAST16 "ho"
#define SCNuLEAST16 "hu"
#define SCNxLEAST16 "hx"
#define SCNXLEAST16 "hX"
#define SCNoFAST16 "ho"
#define SCNuFAST16 "hu"
#define SCNxFAST16 "hx"
#define SCNXFAST16 "hX"
#define SCNo16 "ho"
#define SCNu16 "hu"
#define SCNx16 "hx"
#define SCNX16 "hX"
#define SCNoLEAST16 "ho"
#define SCNuLEAST16 "hu"
#define SCNxLEAST16 "hx"
#define SCNXLEAST16 "hX"
#define SCNoFAST16 "ho"
#define SCNuFAST16 "hu"
#define SCNxFAST16 "hx"
#define SCNXFAST16 "hX"
#define SCNo32 "lo"
#define SCNu32 "lu"
#define SCNx32 "lx"
#define SCNX32 "lX"
#define SCNoLEAST32 "lo"
#define SCNuLEAST32 "lu"
#define SCNxLEAST32 "lx"
#define SCNXLEAST32 "lX"
#define SCNoFAST32 "lo"
#define SCNuFAST32 "lu"
#define SCNxFAST32 "lx"
#define SCNXFAST32 "lX"
#define SCNo32 "lo"
#define SCNu32 "lu"
#define SCNx32 "lx"
#define SCNX32 "lX"
#define SCNoLEAST32 "lo"
#define SCNuLEAST32 "lu"
#define SCNxLEAST32 "lx"
#define SCNXLEAST32 "lX"
#define SCNoFAST32 "lo"
#define SCNuFAST32 "lu"
#define SCNxFAST32 "lx"
#define SCNXFAST32 "lX"
#define SCNo64 "I64o"
#define SCNu64 "I64u"
#define SCNx64 "I64x"
#define SCNX64 "I64X"
#define SCNoLEAST64 "I64o"
#define SCNuLEAST64 "I64u"
#define SCNxLEAST64 "I64x"
#define SCNXLEAST64 "I64X"
#define SCNoFAST64 "I64o"
#define SCNuFAST64 "I64u"
#define SCNxFAST64 "I64x"
#define SCNXFAST64 "I64X"
#define SCNo64 "I64o"
#define SCNu64 "I64u"
#define SCNx64 "I64x"
#define SCNX64 "I64X"
#define SCNoLEAST64 "I64o"
#define SCNuLEAST64 "I64u"
#define SCNxLEAST64 "I64x"
#define SCNXLEAST64 "I64X"
#define SCNoFAST64 "I64o"
#define SCNuFAST64 "I64u"
#define SCNxFAST64 "I64x"
#define SCNXFAST64 "I64X"
#define SCNoMAX "I64o"
#define SCNuMAX "I64u"
#define SCNxMAX "I64x"
#define SCNXMAX "I64X"
#define SCNoMAX "I64o"
#define SCNuMAX "I64u"
#define SCNxMAX "I64x"
#define SCNXMAX "I64X"
#ifdef _WIN64 // [
# define SCNoPTR "I64o"
# define SCNuPTR "I64u"
# define SCNxPTR "I64x"
# define SCNXPTR "I64X"
#else // _WIN64 ][
# define SCNoPTR "lo"
# define SCNuPTR "lu"
# define SCNxPTR "lx"
# define SCNXPTR "lX"
#endif // _WIN64 ]
#define SCNoPTR "I64o"
#define SCNuPTR "I64u"
#define SCNxPTR "I64x"
#define SCNXPTR "I64X"
#else // _WIN64 ][
#define SCNoPTR "lo"
#define SCNuPTR "lu"
#define SCNxPTR "lx"
#define SCNXPTR "lX"
#endif // _WIN64 ]
#endif // __STDC_FORMAT_MACROS ]
@@ -284,23 +285,24 @@ typedef struct {
// in %MSVC.NET%\crt\src\div.c
#ifdef STATIC_IMAXDIV // [
static
#else // STATIC_IMAXDIV ][
#else // STATIC_IMAXDIV ][
_inline
#endif // STATIC_IMAXDIV ]
imaxdiv_t __cdecl imaxdiv(intmax_t numer, intmax_t denom)
#endif // STATIC_IMAXDIV ]
imaxdiv_t __cdecl imaxdiv(intmax_t numer, intmax_t denom)
{
imaxdiv_t result;
imaxdiv_t result;
result.quot = numer / denom;
result.rem = numer % denom;
result.quot = numer / denom;
result.rem = numer % denom;
if (numer < 0 && result.rem > 0) {
// did division wrong; must fix up
++result.quot;
result.rem -= denom;
}
if(numer < 0 && result.rem > 0)
{
// did division wrong; must fix up
++result.quot;
result.rem -= denom;
}
return result;
return result;
}
// 7.8.2.3 The strtoimax and strtoumax functions

View File

@@ -1,37 +1,37 @@
// ISO C9x compliant stdint.h for Microsoft Visual Studio
// Based on ISO/IEC 9899:TC2 Committee draft (May 6, 2005) WG14/N1124
//
// Based on ISO/IEC 9899:TC2 Committee draft (May 6, 2005) WG14/N1124
//
// Copyright (c) 2006-2013 Alexander Chemeris
//
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
//
// 1. Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
//
// 2. Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
//
//
// 3. Neither the name of the product nor the names of its contributors may
// be used to endorse or promote products derived from this software
// without specific prior written permission.
//
//
// THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED
// WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
// EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
// OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
// OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
// OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
// ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
//
///////////////////////////////////////////////////////////////////////////////
// The above software in this distribution may have been modified by
// THL A29 Limited ("Tencent Modifications").
// The above software in this distribution may have been modified by
// THL A29 Limited ("Tencent Modifications").
// All Tencent Modifications are Copyright (C) 2015 THL A29 Limited.
#ifndef _MSC_VER // [
@@ -45,7 +45,8 @@
#pragma once
#endif
// miloyip: Originally Visual Studio 2010 uses its own stdint.h. However it generates warning with INT64_C(), so change to use this file for vs2010.
// miloyip: Originally Visual Studio 2010 uses its own stdint.h. However it generates warning with
// INT64_C(), so change to use this file for vs2010.
#if _MSC_VER >= 1600 // [
#include <stdint.h>
@@ -62,12 +63,12 @@
// 7.18.4.1 Macros for minimum-width integer constants
#define INT8_C(val) val##i8
#define INT8_C(val) val##i8
#define INT16_C(val) val##i16
#define INT32_C(val) val##i32
#define INT64_C(val) val##i64
#define UINT8_C(val) val##ui8
#define UINT8_C(val) val##ui8
#define UINT16_C(val) val##ui16
#define UINT32_C(val) val##ui32
#define UINT64_C(val) val##ui64
@@ -76,10 +77,10 @@
// These #ifndef's are needed to prevent collisions with <boost/cstdint.hpp>.
// Check out Issue 9 for the details.
#ifndef INTMAX_C // [
# define INTMAX_C INT64_C
#endif // INTMAX_C ]
#define INTMAX_C INT64_C
#endif // INTMAX_C ]
#ifndef UINTMAX_C // [
# define UINTMAX_C UINT64_C
#define UINTMAX_C UINT64_C
#endif // UINTMAX_C ]
#endif // __STDC_CONSTANT_MACROS ]
@@ -95,20 +96,19 @@
#if defined(__cplusplus) && !defined(_M_ARM)
extern "C" {
#endif
# include <wchar.h>
#include <wchar.h>
#if defined(__cplusplus) && !defined(_M_ARM)
}
#endif
// Define _W64 macros to mark types changing their size, like intptr_t.
#ifndef _W64
# if !defined(__midl) && (defined(_X86_) || defined(_M_IX86)) && _MSC_VER >= 1300
# define _W64 __w64
# else
# define _W64
# endif
#if !defined(__midl) && (defined(_X86_) || defined(_M_IX86)) && _MSC_VER >= 1300
#define _W64 __w64
#else
#define _W64
#endif
#endif
// 7.18.1 Integer types
@@ -117,168 +117,166 @@ extern "C" {
// Visual Studio 6 and Embedded Visual C++ 4 doesn't
// realize that, e.g. char has the same size as __int8
// so we give up on __intX for them.
#if (_MSC_VER < 1300)
typedef signed char int8_t;
typedef signed short int16_t;
typedef signed int int32_t;
typedef unsigned char uint8_t;
typedef unsigned short uint16_t;
typedef unsigned int uint32_t;
#if(_MSC_VER < 1300)
typedef signed char int8_t;
typedef signed short int16_t;
typedef signed int int32_t;
typedef unsigned char uint8_t;
typedef unsigned short uint16_t;
typedef unsigned int uint32_t;
#else
typedef signed __int8 int8_t;
typedef signed __int16 int16_t;
typedef signed __int32 int32_t;
typedef unsigned __int8 uint8_t;
typedef unsigned __int16 uint16_t;
typedef unsigned __int32 uint32_t;
typedef signed __int8 int8_t;
typedef signed __int16 int16_t;
typedef signed __int32 int32_t;
typedef unsigned __int8 uint8_t;
typedef unsigned __int16 uint16_t;
typedef unsigned __int32 uint32_t;
#endif
typedef signed __int64 int64_t;
typedef unsigned __int64 uint64_t;
typedef signed __int64 int64_t;
typedef unsigned __int64 uint64_t;
// 7.18.1.2 Minimum-width integer types
typedef int8_t int_least8_t;
typedef int16_t int_least16_t;
typedef int32_t int_least32_t;
typedef int64_t int_least64_t;
typedef uint8_t uint_least8_t;
typedef uint16_t uint_least16_t;
typedef uint32_t uint_least32_t;
typedef uint64_t uint_least64_t;
typedef int8_t int_least8_t;
typedef int16_t int_least16_t;
typedef int32_t int_least32_t;
typedef int64_t int_least64_t;
typedef uint8_t uint_least8_t;
typedef uint16_t uint_least16_t;
typedef uint32_t uint_least32_t;
typedef uint64_t uint_least64_t;
// 7.18.1.3 Fastest minimum-width integer types
typedef int8_t int_fast8_t;
typedef int16_t int_fast16_t;
typedef int32_t int_fast32_t;
typedef int64_t int_fast64_t;
typedef uint8_t uint_fast8_t;
typedef uint16_t uint_fast16_t;
typedef uint32_t uint_fast32_t;
typedef uint64_t uint_fast64_t;
typedef int8_t int_fast8_t;
typedef int16_t int_fast16_t;
typedef int32_t int_fast32_t;
typedef int64_t int_fast64_t;
typedef uint8_t uint_fast8_t;
typedef uint16_t uint_fast16_t;
typedef uint32_t uint_fast32_t;
typedef uint64_t uint_fast64_t;
// 7.18.1.4 Integer types capable of holding object pointers
#ifdef _WIN64 // [
typedef signed __int64 intptr_t;
typedef unsigned __int64 uintptr_t;
#else // _WIN64 ][
typedef _W64 signed int intptr_t;
typedef _W64 unsigned int uintptr_t;
#endif // _WIN64 ]
typedef signed __int64 intptr_t;
typedef unsigned __int64 uintptr_t;
#else // _WIN64 ][
typedef _W64 signed int intptr_t;
typedef _W64 unsigned int uintptr_t;
#endif // _WIN64 ]
// 7.18.1.5 Greatest-width integer types
typedef int64_t intmax_t;
typedef uint64_t uintmax_t;
typedef int64_t intmax_t;
typedef uint64_t uintmax_t;
// 7.18.2 Limits of specified-width integer types
#if !defined(__cplusplus) || defined(__STDC_LIMIT_MACROS) // [ See footnote 220 at page 257 and footnote 221 at page 259
#if !defined(__cplusplus) || \
defined(__STDC_LIMIT_MACROS) // [ See footnote 220 at page 257 and footnote 221 at page 259
// 7.18.2.1 Limits of exact-width integer types
#define INT8_MIN ((int8_t)_I8_MIN)
#define INT8_MAX _I8_MAX
#define INT16_MIN ((int16_t)_I16_MIN)
#define INT16_MAX _I16_MAX
#define INT32_MIN ((int32_t)_I32_MIN)
#define INT32_MAX _I32_MAX
#define INT64_MIN ((int64_t)_I64_MIN)
#define INT64_MAX _I64_MAX
#define UINT8_MAX _UI8_MAX
#define UINT16_MAX _UI16_MAX
#define UINT32_MAX _UI32_MAX
#define UINT64_MAX _UI64_MAX
#define INT8_MIN ((int8_t)_I8_MIN)
#define INT8_MAX _I8_MAX
#define INT16_MIN ((int16_t)_I16_MIN)
#define INT16_MAX _I16_MAX
#define INT32_MIN ((int32_t)_I32_MIN)
#define INT32_MAX _I32_MAX
#define INT64_MIN ((int64_t)_I64_MIN)
#define INT64_MAX _I64_MAX
#define UINT8_MAX _UI8_MAX
#define UINT16_MAX _UI16_MAX
#define UINT32_MAX _UI32_MAX
#define UINT64_MAX _UI64_MAX
// 7.18.2.2 Limits of minimum-width integer types
#define INT_LEAST8_MIN INT8_MIN
#define INT_LEAST8_MAX INT8_MAX
#define INT_LEAST16_MIN INT16_MIN
#define INT_LEAST16_MAX INT16_MAX
#define INT_LEAST32_MIN INT32_MIN
#define INT_LEAST32_MAX INT32_MAX
#define INT_LEAST64_MIN INT64_MIN
#define INT_LEAST64_MAX INT64_MAX
#define UINT_LEAST8_MAX UINT8_MAX
#define UINT_LEAST16_MAX UINT16_MAX
#define UINT_LEAST32_MAX UINT32_MAX
#define UINT_LEAST64_MAX UINT64_MAX
#define INT_LEAST8_MIN INT8_MIN
#define INT_LEAST8_MAX INT8_MAX
#define INT_LEAST16_MIN INT16_MIN
#define INT_LEAST16_MAX INT16_MAX
#define INT_LEAST32_MIN INT32_MIN
#define INT_LEAST32_MAX INT32_MAX
#define INT_LEAST64_MIN INT64_MIN
#define INT_LEAST64_MAX INT64_MAX
#define UINT_LEAST8_MAX UINT8_MAX
#define UINT_LEAST16_MAX UINT16_MAX
#define UINT_LEAST32_MAX UINT32_MAX
#define UINT_LEAST64_MAX UINT64_MAX
// 7.18.2.3 Limits of fastest minimum-width integer types
#define INT_FAST8_MIN INT8_MIN
#define INT_FAST8_MAX INT8_MAX
#define INT_FAST16_MIN INT16_MIN
#define INT_FAST16_MAX INT16_MAX
#define INT_FAST32_MIN INT32_MIN
#define INT_FAST32_MAX INT32_MAX
#define INT_FAST64_MIN INT64_MIN
#define INT_FAST64_MAX INT64_MAX
#define UINT_FAST8_MAX UINT8_MAX
#define UINT_FAST16_MAX UINT16_MAX
#define UINT_FAST32_MAX UINT32_MAX
#define UINT_FAST64_MAX UINT64_MAX
#define INT_FAST8_MIN INT8_MIN
#define INT_FAST8_MAX INT8_MAX
#define INT_FAST16_MIN INT16_MIN
#define INT_FAST16_MAX INT16_MAX
#define INT_FAST32_MIN INT32_MIN
#define INT_FAST32_MAX INT32_MAX
#define INT_FAST64_MIN INT64_MIN
#define INT_FAST64_MAX INT64_MAX
#define UINT_FAST8_MAX UINT8_MAX
#define UINT_FAST16_MAX UINT16_MAX
#define UINT_FAST32_MAX UINT32_MAX
#define UINT_FAST64_MAX UINT64_MAX
// 7.18.2.4 Limits of integer types capable of holding object pointers
#ifdef _WIN64 // [
# define INTPTR_MIN INT64_MIN
# define INTPTR_MAX INT64_MAX
# define UINTPTR_MAX UINT64_MAX
#define INTPTR_MIN INT64_MIN
#define INTPTR_MAX INT64_MAX
#define UINTPTR_MAX UINT64_MAX
#else // _WIN64 ][
# define INTPTR_MIN INT32_MIN
# define INTPTR_MAX INT32_MAX
# define UINTPTR_MAX UINT32_MAX
#define INTPTR_MIN INT32_MIN
#define INTPTR_MAX INT32_MAX
#define UINTPTR_MAX UINT32_MAX
#endif // _WIN64 ]
// 7.18.2.5 Limits of greatest-width integer types
#define INTMAX_MIN INT64_MIN
#define INTMAX_MAX INT64_MAX
#define UINTMAX_MAX UINT64_MAX
#define INTMAX_MIN INT64_MIN
#define INTMAX_MAX INT64_MAX
#define UINTMAX_MAX UINT64_MAX
// 7.18.3 Limits of other integer types
#ifdef _WIN64 // [
# define PTRDIFF_MIN _I64_MIN
# define PTRDIFF_MAX _I64_MAX
#else // _WIN64 ][
# define PTRDIFF_MIN _I32_MIN
# define PTRDIFF_MAX _I32_MAX
#endif // _WIN64 ]
#define PTRDIFF_MIN _I64_MIN
#define PTRDIFF_MAX _I64_MAX
#else // _WIN64 ][
#define PTRDIFF_MIN _I32_MIN
#define PTRDIFF_MAX _I32_MAX
#endif // _WIN64 ]
#define SIG_ATOMIC_MIN INT_MIN
#define SIG_ATOMIC_MAX INT_MAX
#define SIG_ATOMIC_MIN INT_MIN
#define SIG_ATOMIC_MAX INT_MAX
#ifndef SIZE_MAX // [
# ifdef _WIN64 // [
# define SIZE_MAX _UI64_MAX
# else // _WIN64 ][
# define SIZE_MAX _UI32_MAX
# endif // _WIN64 ]
#endif // SIZE_MAX ]
#ifdef _WIN64 // [
#define SIZE_MAX _UI64_MAX
#else // _WIN64 ][
#define SIZE_MAX _UI32_MAX
#endif // _WIN64 ]
#endif // SIZE_MAX ]
// WCHAR_MIN and WCHAR_MAX are also defined in <wchar.h>
#ifndef WCHAR_MIN // [
# define WCHAR_MIN 0
#endif // WCHAR_MIN ]
#define WCHAR_MIN 0
#endif // WCHAR_MIN ]
#ifndef WCHAR_MAX // [
# define WCHAR_MAX _UI16_MAX
#endif // WCHAR_MAX ]
#define WCHAR_MAX _UI16_MAX
#endif // WCHAR_MAX ]
#define WINT_MIN 0
#define WINT_MAX _UI16_MAX
#define WINT_MIN 0
#define WINT_MAX _UI16_MAX
#endif // __STDC_LIMIT_MACROS ]
// 7.18.4 Limits of other integer types
#if !defined(__cplusplus) || defined(__STDC_CONSTANT_MACROS) // [ See footnote 224 at page 260
// 7.18.4.1 Macros for minimum-width integer constants
#define INT8_C(val) val##i8
#define INT8_C(val) val##i8
#define INT16_C(val) val##i16
#define INT32_C(val) val##i32
#define INT64_C(val) val##i64
#define UINT8_C(val) val##ui8
#define UINT8_C(val) val##ui8
#define UINT16_C(val) val##ui16
#define UINT32_C(val) val##ui32
#define UINT64_C(val) val##ui64
@@ -287,10 +285,10 @@ typedef uint64_t uintmax_t;
// These #ifndef's are needed to prevent collisions with <boost/cstdint.hpp>.
// Check out Issue 9 for the details.
#ifndef INTMAX_C // [
# define INTMAX_C INT64_C
#endif // INTMAX_C ]
#define INTMAX_C INT64_C
#endif // INTMAX_C ]
#ifndef UINTMAX_C // [
# define UINTMAX_C UINT64_C
#define UINTMAX_C UINT64_C
#endif // UINTMAX_C ]
#endif // __STDC_CONSTANT_MACROS ]

View File

@@ -1,5 +1,5 @@
// Tencent is pleased to support the open source community by making RapidJSON available.
//
//
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
//
// Licensed under the MIT License (the "License"); you may not use this file except
@@ -7,9 +7,9 @@
//
// http://opensource.org/licenses/MIT
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
#ifndef RAPIDJSON_OSTREAMWRAPPER_H_
@@ -40,29 +40,46 @@ RAPIDJSON_NAMESPACE_BEGIN
\tparam StreamType Class derived from \c std::basic_ostream.
*/
template <typename StreamType>
class BasicOStreamWrapper {
public:
class BasicOStreamWrapper
{
public:
typedef typename StreamType::char_type Ch;
BasicOStreamWrapper(StreamType& stream) : stream_(stream) {}
void Put(Ch c) {
stream_.put(c);
}
void Put(Ch c) { stream_.put(c); }
void Flush() {
stream_.flush();
}
void Flush() { stream_.flush(); }
// Not implemented
char Peek() const { RAPIDJSON_ASSERT(false); return 0; }
char Take() { RAPIDJSON_ASSERT(false); return 0; }
size_t Tell() const { RAPIDJSON_ASSERT(false); return 0; }
char* PutBegin() { RAPIDJSON_ASSERT(false); return 0; }
size_t PutEnd(char*) { RAPIDJSON_ASSERT(false); return 0; }
char Peek() const
{
RAPIDJSON_ASSERT(false);
return 0;
}
char Take()
{
RAPIDJSON_ASSERT(false);
return 0;
}
size_t Tell() const
{
RAPIDJSON_ASSERT(false);
return 0;
}
char* PutBegin()
{
RAPIDJSON_ASSERT(false);
return 0;
}
size_t PutEnd(char*)
{
RAPIDJSON_ASSERT(false);
return 0;
}
private:
private:
BasicOStreamWrapper(const BasicOStreamWrapper&);
BasicOStreamWrapper& operator=(const BasicOStreamWrapper&);

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,5 @@
// Tencent is pleased to support the open source community by making RapidJSON available.
//
//
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
//
// Licensed under the MIT License (the "License"); you may not use this file except
@@ -7,9 +7,9 @@
//
// http://opensource.org/licenses/MIT
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
#ifndef RAPIDJSON_PRETTYWRITER_H_
@@ -24,7 +24,7 @@ RAPIDJSON_DIAG_OFF(effc++)
#if defined(__clang__)
RAPIDJSON_DIAG_PUSH
RAPIDJSON_DIAG_OFF(c++98-compat)
RAPIDJSON_DIAG_OFF(c++ 98 - compat)
#endif
RAPIDJSON_NAMESPACE_BEGIN
@@ -32,8 +32,9 @@ RAPIDJSON_NAMESPACE_BEGIN
//! Combination of PrettyWriter format flags.
/*! \see PrettyWriter::SetFormatOptions
*/
enum PrettyFormatOptions {
kFormatDefault = 0, //!< Default pretty formatting.
enum PrettyFormatOptions
{
kFormatDefault = 0, //!< Default pretty formatting.
kFormatSingleLineArray = 1 //!< Format arrays on a single line.
};
@@ -44,9 +45,15 @@ enum PrettyFormatOptions {
\tparam TargetEncoding Encoding of output stream.
\tparam StackAllocator Type of allocator for allocating memory of stack.
*/
template<typename OutputStream, typename SourceEncoding = UTF8<>, typename TargetEncoding = UTF8<>, typename StackAllocator = CrtAllocator, unsigned writeFlags = kWriteDefaultFlags>
class PrettyWriter : public Writer<OutputStream, SourceEncoding, TargetEncoding, StackAllocator, writeFlags> {
public:
template <typename OutputStream,
typename SourceEncoding = UTF8<>,
typename TargetEncoding = UTF8<>,
typename StackAllocator = CrtAllocator,
unsigned writeFlags = kWriteDefaultFlags>
class PrettyWriter
: public Writer<OutputStream, SourceEncoding, TargetEncoding, StackAllocator, writeFlags>
{
public:
typedef Writer<OutputStream, SourceEncoding, TargetEncoding, StackAllocator, writeFlags> Base;
typedef typename Base::Ch Ch;
@@ -55,34 +62,54 @@ public:
\param allocator User supplied allocator. If it is null, it will create a private one.
\param levelDepth Initial capacity of stack.
*/
explicit PrettyWriter(OutputStream& os, StackAllocator* allocator = 0, size_t levelDepth = Base::kDefaultLevelDepth) :
Base(os, allocator, levelDepth), indentChar_(' '), indentCharCount_(4), formatOptions_(kFormatDefault) {}
explicit PrettyWriter(OutputStream& os,
StackAllocator* allocator = 0,
size_t levelDepth = Base::kDefaultLevelDepth)
: Base(os, allocator, levelDepth),
indentChar_(' '),
indentCharCount_(4),
formatOptions_(kFormatDefault)
{
}
explicit PrettyWriter(StackAllocator* allocator = 0, size_t levelDepth = Base::kDefaultLevelDepth) :
Base(allocator, levelDepth), indentChar_(' '), indentCharCount_(4), formatOptions_(kFormatDefault) {}
explicit PrettyWriter(StackAllocator* allocator = 0,
size_t levelDepth = Base::kDefaultLevelDepth)
: Base(allocator, levelDepth),
indentChar_(' '),
indentCharCount_(4),
formatOptions_(kFormatDefault)
{
}
#if RAPIDJSON_HAS_CXX11_RVALUE_REFS
PrettyWriter(PrettyWriter&& rhs) :
Base(std::forward<PrettyWriter>(rhs)), indentChar_(rhs.indentChar_), indentCharCount_(rhs.indentCharCount_), formatOptions_(rhs.formatOptions_) {}
PrettyWriter(PrettyWriter&& rhs)
: Base(std::forward<PrettyWriter>(rhs)),
indentChar_(rhs.indentChar_),
indentCharCount_(rhs.indentCharCount_),
formatOptions_(rhs.formatOptions_)
{
}
#endif
//! Set custom indentation.
/*! \param indentChar Character for indentation. Must be whitespace character (' ', '\\t', '\\n', '\\r').
\param indentCharCount Number of indent characters for each indentation level.
\note The default indentation is 4 spaces.
/*! \param indentChar Character for indentation. Must be whitespace character (' ', '\\t',
'\\n', '\\r'). \param indentCharCount Number of indent characters for each indentation
level. \note The default indentation is 4 spaces.
*/
PrettyWriter& SetIndent(Ch indentChar, unsigned indentCharCount) {
RAPIDJSON_ASSERT(indentChar == ' ' || indentChar == '\t' || indentChar == '\n' || indentChar == '\r');
indentChar_ = indentChar;
PrettyWriter& SetIndent(Ch indentChar, unsigned indentCharCount)
{
RAPIDJSON_ASSERT(indentChar == ' ' || indentChar == '\t' || indentChar == '\n' ||
indentChar == '\r');
indentChar_ = indentChar;
indentCharCount_ = indentCharCount;
return *this;
}
//! Set pretty writer formatting options.
/*! \param options Formatting options.
*/
PrettyWriter& SetFormatOptions(PrettyFormatOptions options) {
*/
PrettyWriter& SetFormatOptions(PrettyFormatOptions options)
{
formatOptions_ = options;
return *this;
}
@@ -92,22 +119,52 @@ public:
*/
//@{
bool Null() { PrettyPrefix(kNullType); return Base::EndValue(Base::WriteNull()); }
bool Bool(bool b) { PrettyPrefix(b ? kTrueType : kFalseType); return Base::EndValue(Base::WriteBool(b)); }
bool Int(int i) { PrettyPrefix(kNumberType); return Base::EndValue(Base::WriteInt(i)); }
bool Uint(unsigned u) { PrettyPrefix(kNumberType); return Base::EndValue(Base::WriteUint(u)); }
bool Int64(int64_t i64) { PrettyPrefix(kNumberType); return Base::EndValue(Base::WriteInt64(i64)); }
bool Uint64(uint64_t u64) { PrettyPrefix(kNumberType); return Base::EndValue(Base::WriteUint64(u64)); }
bool Double(double d) { PrettyPrefix(kNumberType); return Base::EndValue(Base::WriteDouble(d)); }
bool Null()
{
PrettyPrefix(kNullType);
return Base::EndValue(Base::WriteNull());
}
bool Bool(bool b)
{
PrettyPrefix(b ? kTrueType : kFalseType);
return Base::EndValue(Base::WriteBool(b));
}
bool Int(int i)
{
PrettyPrefix(kNumberType);
return Base::EndValue(Base::WriteInt(i));
}
bool Uint(unsigned u)
{
PrettyPrefix(kNumberType);
return Base::EndValue(Base::WriteUint(u));
}
bool Int64(int64_t i64)
{
PrettyPrefix(kNumberType);
return Base::EndValue(Base::WriteInt64(i64));
}
bool Uint64(uint64_t u64)
{
PrettyPrefix(kNumberType);
return Base::EndValue(Base::WriteUint64(u64));
}
bool Double(double d)
{
PrettyPrefix(kNumberType);
return Base::EndValue(Base::WriteDouble(d));
}
bool RawNumber(const Ch* str, SizeType length, bool copy = false) {
bool RawNumber(const Ch* str, SizeType length, bool copy = false)
{
RAPIDJSON_ASSERT(str != 0);
(void)copy;
PrettyPrefix(kNumberType);
return Base::EndValue(Base::WriteString(str, length));
}
bool String(const Ch* str, SizeType length, bool copy = false) {
bool String(const Ch* str, SizeType length, bool copy = false)
{
RAPIDJSON_ASSERT(str != 0);
(void)copy;
PrettyPrefix(kStringType);
@@ -115,65 +172,76 @@ public:
}
#if RAPIDJSON_HAS_STDSTRING
bool String(const std::basic_string<Ch>& str) {
bool String(const std::basic_string<Ch>& str)
{
return String(str.data(), SizeType(str.size()));
}
#endif
bool StartObject() {
bool StartObject()
{
PrettyPrefix(kObjectType);
new (Base::level_stack_.template Push<typename Base::Level>()) typename Base::Level(false);
new(Base::level_stack_.template Push<typename Base::Level>()) typename Base::Level(false);
return Base::WriteStartObject();
}
bool Key(const Ch* str, SizeType length, bool copy = false) { return String(str, length, copy); }
bool Key(const Ch* str, SizeType length, bool copy = false)
{
return String(str, length, copy);
}
#if RAPIDJSON_HAS_STDSTRING
bool Key(const std::basic_string<Ch>& str) {
return Key(str.data(), SizeType(str.size()));
}
bool Key(const std::basic_string<Ch>& str) { return Key(str.data(), SizeType(str.size())); }
#endif
bool EndObject(SizeType memberCount = 0) {
bool EndObject(SizeType memberCount = 0)
{
(void)memberCount;
RAPIDJSON_ASSERT(Base::level_stack_.GetSize() >= sizeof(typename Base::Level)); // not inside an Object
RAPIDJSON_ASSERT(!Base::level_stack_.template Top<typename Base::Level>()->inArray); // currently inside an Array, not Object
RAPIDJSON_ASSERT(0 == Base::level_stack_.template Top<typename Base::Level>()->valueCount % 2); // Object has a Key without a Value
RAPIDJSON_ASSERT(Base::level_stack_.GetSize() >=
sizeof(typename Base::Level)); // not inside an Object
RAPIDJSON_ASSERT(!Base::level_stack_.template Top<typename Base::Level>()
->inArray); // currently inside an Array, not Object
RAPIDJSON_ASSERT(0 == Base::level_stack_.template Top<typename Base::Level>()->valueCount %
2); // Object has a Key without a Value
bool empty = Base::level_stack_.template Pop<typename Base::Level>(1)->valueCount == 0;
if (!empty) {
if(!empty)
{
Base::os_->Put('\n');
WriteIndent();
}
bool ret = Base::EndValue(Base::WriteEndObject());
(void)ret;
RAPIDJSON_ASSERT(ret == true);
if (Base::level_stack_.Empty()) // end of json text
if(Base::level_stack_.Empty()) // end of json text
Base::Flush();
return true;
}
bool StartArray() {
bool StartArray()
{
PrettyPrefix(kArrayType);
new (Base::level_stack_.template Push<typename Base::Level>()) typename Base::Level(true);
new(Base::level_stack_.template Push<typename Base::Level>()) typename Base::Level(true);
return Base::WriteStartArray();
}
bool EndArray(SizeType memberCount = 0) {
bool EndArray(SizeType memberCount = 0)
{
(void)memberCount;
RAPIDJSON_ASSERT(Base::level_stack_.GetSize() >= sizeof(typename Base::Level));
RAPIDJSON_ASSERT(Base::level_stack_.template Top<typename Base::Level>()->inArray);
bool empty = Base::level_stack_.template Pop<typename Base::Level>(1)->valueCount == 0;
if (!empty && !(formatOptions_ & kFormatSingleLineArray)) {
if(!empty && !(formatOptions_ & kFormatSingleLineArray))
{
Base::os_->Put('\n');
WriteIndent();
}
bool ret = Base::EndValue(Base::WriteEndArray());
(void)ret;
RAPIDJSON_ASSERT(ret == true);
if (Base::level_stack_.Empty()) // end of json text
if(Base::level_stack_.Empty()) // end of json text
Base::Flush();
return true;
}
@@ -193,42 +261,51 @@ public:
/*!
For user to write a stringified JSON as a value.
\param json A well-formed JSON value. It should not contain null character within [0, length - 1] range.
\param length Length of the json.
\param type Type of the root of json.
\note When using PrettyWriter::RawValue(), the result json may not be indented correctly.
\param json A well-formed JSON value. It should not contain null character within [0, length
- 1] range. \param length Length of the json. \param type Type of the root of json. \note
When using PrettyWriter::RawValue(), the result json may not be indented correctly.
*/
bool RawValue(const Ch* json, size_t length, Type type) {
bool RawValue(const Ch* json, size_t length, Type type)
{
RAPIDJSON_ASSERT(json != 0);
PrettyPrefix(type);
return Base::EndValue(Base::WriteRawValue(json, length));
}
protected:
void PrettyPrefix(Type type) {
protected:
void PrettyPrefix(Type type)
{
(void)type;
if (Base::level_stack_.GetSize() != 0) { // this value is not at root
if(Base::level_stack_.GetSize() != 0)
{ // this value is not at root
typename Base::Level* level = Base::level_stack_.template Top<typename Base::Level>();
if (level->inArray) {
if (level->valueCount > 0) {
if(level->inArray)
{
if(level->valueCount > 0)
{
Base::os_->Put(','); // add comma if it is not the first element in array
if (formatOptions_ & kFormatSingleLineArray)
if(formatOptions_ & kFormatSingleLineArray)
Base::os_->Put(' ');
}
if (!(formatOptions_ & kFormatSingleLineArray)) {
if(!(formatOptions_ & kFormatSingleLineArray))
{
Base::os_->Put('\n');
WriteIndent();
}
}
else { // in object
if (level->valueCount > 0) {
if (level->valueCount % 2 == 0) {
else
{ // in object
if(level->valueCount > 0)
{
if(level->valueCount % 2 == 0)
{
Base::os_->Put(',');
Base::os_->Put('\n');
}
else {
else
{
Base::os_->Put(':');
Base::os_->Put(' ');
}
@@ -236,21 +313,25 @@ protected:
else
Base::os_->Put('\n');
if (level->valueCount % 2 == 0)
if(level->valueCount % 2 == 0)
WriteIndent();
}
if (!level->inArray && level->valueCount % 2 == 0)
RAPIDJSON_ASSERT(type == kStringType); // if it's in object, then even number should be a name
if(!level->inArray && level->valueCount % 2 == 0)
RAPIDJSON_ASSERT(
type == kStringType); // if it's in object, then even number should be a name
level->valueCount++;
}
else {
RAPIDJSON_ASSERT(!Base::hasRoot_); // Should only has one and only one root.
else
{
RAPIDJSON_ASSERT(!Base::hasRoot_); // Should only has one and only one root.
Base::hasRoot_ = true;
}
}
void WriteIndent() {
size_t count = (Base::level_stack_.GetSize() / sizeof(typename Base::Level)) * indentCharCount_;
void WriteIndent()
{
size_t count =
(Base::level_stack_.GetSize() / sizeof(typename Base::Level)) * indentCharCount_;
PutN(*Base::os_, static_cast<typename OutputStream::Ch>(indentChar_), count);
}
@@ -258,7 +339,7 @@ protected:
unsigned indentCharCount_;
PrettyFormatOptions formatOptions_;
private:
private:
// Prohibit copy constructor & assignment operator.
PrettyWriter(const PrettyWriter&);
PrettyWriter& operator=(const PrettyWriter&);

View File

@@ -36,8 +36,8 @@
different translation units of a single application.
*/
#include <cstdlib> // malloc(), realloc(), free(), size_t
#include <cstring> // memset(), memcpy(), memmove(), memcmp()
#include <cstdlib> // malloc(), realloc(), free(), size_t
#include <cstring> // memset(), memcpy(), memmove(), memcmp()
///////////////////////////////////////////////////////////////////////////////
// RAPIDJSON_VERSION_STRING
@@ -226,8 +226,8 @@
///////////////////////////////////////////////////////////////////////////////
// RAPIDJSON_ENDIAN
#define RAPIDJSON_LITTLEENDIAN 0 //!< Little endian machine
#define RAPIDJSON_BIGENDIAN 1 //!< Big endian machine
#define RAPIDJSON_LITTLEENDIAN 0 //!< Little endian machine
#define RAPIDJSON_BIGENDIAN 1 //!< Big endian machine
//! Endianness of the machine.
/*!
@@ -244,41 +244,46 @@
*/
#ifndef RAPIDJSON_ENDIAN
// Detect with GCC 4.6's macro
# ifdef __BYTE_ORDER__
# if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
# define RAPIDJSON_ENDIAN RAPIDJSON_LITTLEENDIAN
# elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
# define RAPIDJSON_ENDIAN RAPIDJSON_BIGENDIAN
# else
# error Unknown machine endianness detected. User needs to define RAPIDJSON_ENDIAN.
# endif // __BYTE_ORDER__
#ifdef __BYTE_ORDER__
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
#define RAPIDJSON_ENDIAN RAPIDJSON_LITTLEENDIAN
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
#define RAPIDJSON_ENDIAN RAPIDJSON_BIGENDIAN
#else
#error Unknown machine endianness detected. User needs to define RAPIDJSON_ENDIAN.
#endif // __BYTE_ORDER__
// Detect with GLIBC's endian.h
# elif defined(__GLIBC__)
# include <endian.h>
# if (__BYTE_ORDER == __LITTLE_ENDIAN)
# define RAPIDJSON_ENDIAN RAPIDJSON_LITTLEENDIAN
# elif (__BYTE_ORDER == __BIG_ENDIAN)
# define RAPIDJSON_ENDIAN RAPIDJSON_BIGENDIAN
# else
# error Unknown machine endianness detected. User needs to define RAPIDJSON_ENDIAN.
# endif // __GLIBC__
#elif defined(__GLIBC__)
#include <endian.h>
#if(__BYTE_ORDER == __LITTLE_ENDIAN)
#define RAPIDJSON_ENDIAN RAPIDJSON_LITTLEENDIAN
#elif(__BYTE_ORDER == __BIG_ENDIAN)
#define RAPIDJSON_ENDIAN RAPIDJSON_BIGENDIAN
#else
#error Unknown machine endianness detected. User needs to define RAPIDJSON_ENDIAN.
#endif // __GLIBC__
// Detect with _LITTLE_ENDIAN and _BIG_ENDIAN macro
# elif defined(_LITTLE_ENDIAN) && !defined(_BIG_ENDIAN)
# define RAPIDJSON_ENDIAN RAPIDJSON_LITTLEENDIAN
# elif defined(_BIG_ENDIAN) && !defined(_LITTLE_ENDIAN)
# define RAPIDJSON_ENDIAN RAPIDJSON_BIGENDIAN
#elif defined(_LITTLE_ENDIAN) && !defined(_BIG_ENDIAN)
#define RAPIDJSON_ENDIAN RAPIDJSON_LITTLEENDIAN
#elif defined(_BIG_ENDIAN) && !defined(_LITTLE_ENDIAN)
#define RAPIDJSON_ENDIAN RAPIDJSON_BIGENDIAN
// Detect with architecture macros
# elif defined(__sparc) || defined(__sparc__) || defined(_POWER) || defined(__powerpc__) || defined(__ppc__) || defined(__ppc64__) || defined(__hpux) || defined(__hppa) || defined(_MIPSEB) || defined(_POWER) || defined(__s390__)
# define RAPIDJSON_ENDIAN RAPIDJSON_BIGENDIAN
# elif defined(__i386__) || defined(__alpha__) || defined(__ia64) || defined(__ia64__) || defined(_M_IX86) || defined(_M_IA64) || defined(_M_ALPHA) || defined(__amd64) || defined(__amd64__) || defined(_M_AMD64) || defined(__x86_64) || defined(__x86_64__) || defined(_M_X64) || defined(__bfin__)
# define RAPIDJSON_ENDIAN RAPIDJSON_LITTLEENDIAN
# elif defined(_MSC_VER) && (defined(_M_ARM) || defined(_M_ARM64))
# define RAPIDJSON_ENDIAN RAPIDJSON_LITTLEENDIAN
# elif defined(RAPIDJSON_DOXYGEN_RUNNING)
# define RAPIDJSON_ENDIAN
# else
# error Unknown machine endianness detected. User needs to define RAPIDJSON_ENDIAN.
# endif
#elif defined(__sparc) || defined(__sparc__) || defined(_POWER) || defined(__powerpc__) || \
defined(__ppc__) || defined(__ppc64__) || defined(__hpux) || defined(__hppa) || \
defined(_MIPSEB) || defined(_POWER) || defined(__s390__)
#define RAPIDJSON_ENDIAN RAPIDJSON_BIGENDIAN
#elif defined(__i386__) || defined(__alpha__) || defined(__ia64) || defined(__ia64__) || \
defined(_M_IX86) || defined(_M_IA64) || defined(_M_ALPHA) || defined(__amd64) || \
defined(__amd64__) || defined(_M_AMD64) || defined(__x86_64) || defined(__x86_64__) || \
defined(_M_X64) || defined(__bfin__)
#define RAPIDJSON_ENDIAN RAPIDJSON_LITTLEENDIAN
#elif defined(_MSC_VER) && (defined(_M_ARM) || defined(_M_ARM64))
#define RAPIDJSON_ENDIAN RAPIDJSON_LITTLEENDIAN
#elif defined(RAPIDJSON_DOXYGEN_RUNNING)
#define RAPIDJSON_ENDIAN
#else
#error Unknown machine endianness detected. User needs to define RAPIDJSON_ENDIAN.
#endif
#endif // RAPIDJSON_ENDIAN
///////////////////////////////////////////////////////////////////////////////
@@ -286,7 +291,8 @@
//! Whether using 64-bit architecture
#ifndef RAPIDJSON_64BIT
#if defined(__LP64__) || (defined(__x86_64__) && defined(__ILP32__)) || defined(_WIN64) || defined(__EMSCRIPTEN__)
#if defined(__LP64__) || (defined(__x86_64__) && defined(__ILP32__)) || defined(_WIN64) || \
defined(__EMSCRIPTEN__)
#define RAPIDJSON_64BIT 1
#else
#define RAPIDJSON_64BIT 0
@@ -317,7 +323,8 @@
Use this macro to define 64-bit constants by a pair of 32-bit integer.
*/
#ifndef RAPIDJSON_UINT64_C2
#define RAPIDJSON_UINT64_C2(high32, low32) ((static_cast<uint64_t>(high32) << 32) | static_cast<uint64_t>(low32))
#define RAPIDJSON_UINT64_C2(high32, low32) \
((static_cast<uint64_t>(high32) << 32) | static_cast<uint64_t>(low32))
#endif
///////////////////////////////////////////////////////////////////////////////
@@ -327,12 +334,13 @@
/*!
\ingroup RAPIDJSON_CONFIG
This optimization uses the fact that current X86-64 architecture only implement lower 48-bit virtual address.
The higher 16-bit can be used for storing other data.
\c GenericValue uses this optimization to reduce its size form 24 bytes to 16 bytes in 64-bit architecture.
This optimization uses the fact that current X86-64 architecture only implement lower 48-bit
virtual address. The higher 16-bit can be used for storing other data. \c GenericValue uses this
optimization to reduce its size form 24 bytes to 16 bytes in 64-bit architecture.
*/
#ifndef RAPIDJSON_48BITPOINTER_OPTIMIZATION
#if defined(__amd64__) || defined(__amd64) || defined(__x86_64__) || defined(__x86_64) || defined(_M_X64) || defined(_M_AMD64)
#if defined(__amd64__) || defined(__amd64) || defined(__x86_64__) || defined(__x86_64) || \
defined(_M_X64) || defined(_M_AMD64)
#define RAPIDJSON_48BITPOINTER_OPTIMIZATION 1
#else
#define RAPIDJSON_48BITPOINTER_OPTIMIZATION 0
@@ -343,8 +351,14 @@
#if RAPIDJSON_64BIT != 1
#error RAPIDJSON_48BITPOINTER_OPTIMIZATION can only be set to 1 when RAPIDJSON_64BIT=1
#endif
#define RAPIDJSON_SETPOINTER(type, p, x) (p = reinterpret_cast<type *>((reinterpret_cast<uintptr_t>(p) & static_cast<uintptr_t>(RAPIDJSON_UINT64_C2(0xFFFF0000, 0x00000000))) | reinterpret_cast<uintptr_t>(reinterpret_cast<const void*>(x))))
#define RAPIDJSON_GETPOINTER(type, p) (reinterpret_cast<type *>(reinterpret_cast<uintptr_t>(p) & static_cast<uintptr_t>(RAPIDJSON_UINT64_C2(0x0000FFFF, 0xFFFFFFFF))))
#define RAPIDJSON_SETPOINTER(type, p, x) \
(p = reinterpret_cast<type*>( \
(reinterpret_cast<uintptr_t>(p) & \
static_cast<uintptr_t>(RAPIDJSON_UINT64_C2(0xFFFF0000, 0x00000000))) | \
reinterpret_cast<uintptr_t>(reinterpret_cast<const void*>(x))))
#define RAPIDJSON_GETPOINTER(type, p) \
(reinterpret_cast<type*>(reinterpret_cast<uintptr_t>(p) & \
static_cast<uintptr_t>(RAPIDJSON_UINT64_C2(0x0000FFFF, 0xFFFFFFFF))))
#else
#define RAPIDJSON_SETPOINTER(type, p, x) (p = (x))
#define RAPIDJSON_GETPOINTER(type, p) (p)
@@ -379,8 +393,8 @@
If any of these symbols is defined, RapidJSON defines the macro
\c RAPIDJSON_SIMD to indicate the availability of the optimized code.
*/
#if defined(RAPIDJSON_SSE2) || defined(RAPIDJSON_SSE42) \
|| defined(RAPIDJSON_NEON) || defined(RAPIDJSON_DOXYGEN_RUNNING)
#if defined(RAPIDJSON_SSE2) || defined(RAPIDJSON_SSE42) || defined(RAPIDJSON_NEON) || \
defined(RAPIDJSON_DOXYGEN_RUNNING)
#define RAPIDJSON_SIMD
#endif
@@ -442,9 +456,8 @@ RAPIDJSON_NAMESPACE_END
// Prefer C++11 static_assert, if available
#ifndef RAPIDJSON_STATIC_ASSERT
#if RAPIDJSON_CPLUSPLUS >= 201103L || ( defined(_MSC_VER) && _MSC_VER >= 1800 )
#define RAPIDJSON_STATIC_ASSERT(x) \
static_assert(x, RAPIDJSON_STRINGIFY(x))
#if RAPIDJSON_CPLUSPLUS >= 201103L || (defined(_MSC_VER) && _MSC_VER >= 1800)
#define RAPIDJSON_STATIC_ASSERT(x) static_assert(x, RAPIDJSON_STRINGIFY(x))
#endif // C++11
#endif // RAPIDJSON_STATIC_ASSERT
@@ -454,15 +467,26 @@ RAPIDJSON_NAMESPACE_END
//!@cond RAPIDJSON_HIDDEN_FROM_DOXYGEN
#endif
RAPIDJSON_NAMESPACE_BEGIN
template <bool x> struct STATIC_ASSERTION_FAILURE;
template <> struct STATIC_ASSERTION_FAILURE<true> { enum { value = 1 }; };
template <size_t x> struct StaticAssertTest {};
template <bool x>
struct STATIC_ASSERTION_FAILURE;
template <>
struct STATIC_ASSERTION_FAILURE<true>
{
enum
{
value = 1
};
};
template <size_t x>
struct StaticAssertTest
{
};
RAPIDJSON_NAMESPACE_END
#if defined(__GNUC__) || defined(__clang__)
#define RAPIDJSON_STATIC_ASSERT_UNUSED_ATTRIBUTE __attribute__((unused))
#else
#define RAPIDJSON_STATIC_ASSERT_UNUSED_ATTRIBUTE
#define RAPIDJSON_STATIC_ASSERT_UNUSED_ATTRIBUTE
#endif
#ifndef __clang__
//!@endcond
@@ -473,9 +497,9 @@ RAPIDJSON_NAMESPACE_END
\param x compile-time condition
\hideinitializer
*/
#define RAPIDJSON_STATIC_ASSERT(x) \
typedef ::RAPIDJSON_NAMESPACE::StaticAssertTest< \
sizeof(::RAPIDJSON_NAMESPACE::STATIC_ASSERTION_FAILURE<bool(x) >)> \
#define RAPIDJSON_STATIC_ASSERT(x) \
typedef ::RAPIDJSON_NAMESPACE::StaticAssertTest<sizeof( \
::RAPIDJSON_NAMESPACE::STATIC_ASSERTION_FAILURE<bool(x)>)> \
RAPIDJSON_JOIN(StaticAssertTypedef, __LINE__) RAPIDJSON_STATIC_ASSERT_UNUSED_ATTRIBUTE
#endif // RAPIDJSON_STATIC_ASSERT
@@ -513,13 +537,15 @@ RAPIDJSON_NAMESPACE_END
//!@cond RAPIDJSON_HIDDEN_FROM_DOXYGEN
#define RAPIDJSON_MULTILINEMACRO_BEGIN do {
#define RAPIDJSON_MULTILINEMACRO_BEGIN \
do \
{
#define RAPIDJSON_MULTILINEMACRO_END \
} while((void)0, 0)
} \
while((void)0, 0)
// adopted from Boost
#define RAPIDJSON_VERSION_CODE(x,y,z) \
(((x)*100000) + ((y)*100) + (z))
#define RAPIDJSON_VERSION_CODE(x, y, z) (((x) * 100000) + ((y) * 100) + (z))
#if defined(__has_builtin)
#define RAPIDJSON_HAS_BUILTIN(x) __has_builtin(x)
@@ -531,24 +557,25 @@ RAPIDJSON_NAMESPACE_END
// RAPIDJSON_DIAG_PUSH/POP, RAPIDJSON_DIAG_OFF
#if defined(__GNUC__)
#define RAPIDJSON_GNUC \
RAPIDJSON_VERSION_CODE(__GNUC__,__GNUC_MINOR__,__GNUC_PATCHLEVEL__)
#define RAPIDJSON_GNUC RAPIDJSON_VERSION_CODE(__GNUC__, __GNUC_MINOR__, __GNUC_PATCHLEVEL__)
#endif
#if defined(__clang__) || (defined(RAPIDJSON_GNUC) && RAPIDJSON_GNUC >= RAPIDJSON_VERSION_CODE(4,2,0))
#if defined(__clang__) || \
(defined(RAPIDJSON_GNUC) && RAPIDJSON_GNUC >= RAPIDJSON_VERSION_CODE(4, 2, 0))
#define RAPIDJSON_PRAGMA(x) _Pragma(RAPIDJSON_STRINGIFY(x))
#define RAPIDJSON_DIAG_PRAGMA(x) RAPIDJSON_PRAGMA(GCC diagnostic x)
#define RAPIDJSON_DIAG_OFF(x) \
RAPIDJSON_DIAG_PRAGMA(ignored RAPIDJSON_STRINGIFY(RAPIDJSON_JOIN(-W,x)))
RAPIDJSON_DIAG_PRAGMA(ignored RAPIDJSON_STRINGIFY(RAPIDJSON_JOIN(-W, x)))
// push/pop support in Clang and GCC>=4.6
#if defined(__clang__) || (defined(RAPIDJSON_GNUC) && RAPIDJSON_GNUC >= RAPIDJSON_VERSION_CODE(4,6,0))
#if defined(__clang__) || \
(defined(RAPIDJSON_GNUC) && RAPIDJSON_GNUC >= RAPIDJSON_VERSION_CODE(4, 6, 0))
#define RAPIDJSON_DIAG_PUSH RAPIDJSON_DIAG_PRAGMA(push)
#define RAPIDJSON_DIAG_POP RAPIDJSON_DIAG_PRAGMA(pop)
#else // GCC >= 4.2, < 4.6
#define RAPIDJSON_DIAG_POP RAPIDJSON_DIAG_PRAGMA(pop)
#else // GCC >= 4.2, < 4.6
#define RAPIDJSON_DIAG_PUSH /* ignored */
#define RAPIDJSON_DIAG_POP /* ignored */
#define RAPIDJSON_DIAG_POP /* ignored */
#endif
#elif defined(_MSC_VER)
@@ -557,9 +584,9 @@ RAPIDJSON_NAMESPACE_END
#define RAPIDJSON_PRAGMA(x) __pragma(x)
#define RAPIDJSON_DIAG_PRAGMA(x) RAPIDJSON_PRAGMA(warning(x))
#define RAPIDJSON_DIAG_OFF(x) RAPIDJSON_DIAG_PRAGMA(disable: x)
#define RAPIDJSON_DIAG_OFF(x) RAPIDJSON_DIAG_PRAGMA(disable : x)
#define RAPIDJSON_DIAG_PUSH RAPIDJSON_DIAG_PRAGMA(push)
#define RAPIDJSON_DIAG_POP RAPIDJSON_DIAG_PRAGMA(pop)
#define RAPIDJSON_DIAG_POP RAPIDJSON_DIAG_PRAGMA(pop)
#else
@@ -580,15 +607,16 @@ RAPIDJSON_NAMESPACE_END
#if RAPIDJSON_HAS_CXX11
#define RAPIDJSON_HAS_CXX11_RVALUE_REFS 1
#elif defined(__clang__)
#if __has_feature(cxx_rvalue_references) && \
(defined(_MSC_VER) || defined(_LIBCPP_VERSION) || defined(__GLIBCXX__) && __GLIBCXX__ >= 20080306)
#if __has_feature(cxx_rvalue_references) && (defined(_MSC_VER) || defined(_LIBCPP_VERSION) || \
defined(__GLIBCXX__) && __GLIBCXX__ >= 20080306)
#define RAPIDJSON_HAS_CXX11_RVALUE_REFS 1
#else
#define RAPIDJSON_HAS_CXX11_RVALUE_REFS 0
#endif
#elif (defined(RAPIDJSON_GNUC) && (RAPIDJSON_GNUC >= RAPIDJSON_VERSION_CODE(4,3,0)) && defined(__GXX_EXPERIMENTAL_CXX0X__)) || \
(defined(_MSC_VER) && _MSC_VER >= 1600) || \
(defined(__SUNPRO_CC) && __SUNPRO_CC >= 0x5140 && defined(__GXX_EXPERIMENTAL_CXX0X__))
#elif(defined(RAPIDJSON_GNUC) && (RAPIDJSON_GNUC >= RAPIDJSON_VERSION_CODE(4, 3, 0)) && \
defined(__GXX_EXPERIMENTAL_CXX0X__)) || \
(defined(_MSC_VER) && _MSC_VER >= 1600) || \
(defined(__SUNPRO_CC) && __SUNPRO_CC >= 0x5140 && defined(__GXX_EXPERIMENTAL_CXX0X__))
#define RAPIDJSON_HAS_CXX11_RVALUE_REFS 1
#else
@@ -605,8 +633,9 @@ RAPIDJSON_NAMESPACE_END
#define RAPIDJSON_HAS_CXX11_NOEXCEPT 1
#elif defined(__clang__)
#define RAPIDJSON_HAS_CXX11_NOEXCEPT __has_feature(cxx_noexcept)
#elif (defined(RAPIDJSON_GNUC) && (RAPIDJSON_GNUC >= RAPIDJSON_VERSION_CODE(4,6,0)) && defined(__GXX_EXPERIMENTAL_CXX0X__)) || \
(defined(_MSC_VER) && _MSC_VER >= 1900) || \
#elif(defined(RAPIDJSON_GNUC) && (RAPIDJSON_GNUC >= RAPIDJSON_VERSION_CODE(4, 6, 0)) && \
defined(__GXX_EXPERIMENTAL_CXX0X__)) || \
(defined(_MSC_VER) && _MSC_VER >= 1900) || \
(defined(__SUNPRO_CC) && __SUNPRO_CC >= 0x5140 && defined(__GXX_EXPERIMENTAL_CXX0X__))
#define RAPIDJSON_HAS_CXX11_NOEXCEPT 1
#else
@@ -623,7 +652,7 @@ RAPIDJSON_NAMESPACE_END
// no automatic detection, yet
#ifndef RAPIDJSON_HAS_CXX11_TYPETRAITS
#if (defined(_MSC_VER) && _MSC_VER >= 1700)
#if(defined(_MSC_VER) && _MSC_VER >= 1700)
#define RAPIDJSON_HAS_CXX11_TYPETRAITS 1
#else
#define RAPIDJSON_HAS_CXX11_TYPETRAITS 0
@@ -633,9 +662,10 @@ RAPIDJSON_NAMESPACE_END
#ifndef RAPIDJSON_HAS_CXX11_RANGE_FOR
#if defined(__clang__)
#define RAPIDJSON_HAS_CXX11_RANGE_FOR __has_feature(cxx_range_for)
#elif (defined(RAPIDJSON_GNUC) && (RAPIDJSON_GNUC >= RAPIDJSON_VERSION_CODE(4,6,0)) && defined(__GXX_EXPERIMENTAL_CXX0X__)) || \
(defined(_MSC_VER) && _MSC_VER >= 1700) || \
(defined(__SUNPRO_CC) && __SUNPRO_CC >= 0x5140 && defined(__GXX_EXPERIMENTAL_CXX0X__))
#elif(defined(RAPIDJSON_GNUC) && (RAPIDJSON_GNUC >= RAPIDJSON_VERSION_CODE(4, 6, 0)) && \
defined(__GXX_EXPERIMENTAL_CXX0X__)) || \
(defined(_MSC_VER) && _MSC_VER >= 1700) || \
(defined(__SUNPRO_CC) && __SUNPRO_CC >= 0x5140 && defined(__GXX_EXPERIMENTAL_CXX0X__))
#define RAPIDJSON_HAS_CXX11_RANGE_FOR 1
#else
#define RAPIDJSON_HAS_CXX11_RANGE_FOR 0
@@ -650,31 +680,31 @@ RAPIDJSON_NAMESPACE_END
#endif
#if RAPIDJSON_HAS_CXX17
# define RAPIDJSON_DELIBERATE_FALLTHROUGH [[fallthrough]]
#define RAPIDJSON_DELIBERATE_FALLTHROUGH [[fallthrough]]
#elif defined(__has_cpp_attribute)
# if __has_cpp_attribute(clang::fallthrough)
# define RAPIDJSON_DELIBERATE_FALLTHROUGH [[clang::fallthrough]]
# elif __has_cpp_attribute(fallthrough)
# define RAPIDJSON_DELIBERATE_FALLTHROUGH __attribute__((fallthrough))
# else
# define RAPIDJSON_DELIBERATE_FALLTHROUGH
# endif
#if __has_cpp_attribute(clang::fallthrough)
#define RAPIDJSON_DELIBERATE_FALLTHROUGH [[clang::fallthrough]]
#elif __has_cpp_attribute(fallthrough)
#define RAPIDJSON_DELIBERATE_FALLTHROUGH __attribute__((fallthrough))
#else
# define RAPIDJSON_DELIBERATE_FALLTHROUGH
#define RAPIDJSON_DELIBERATE_FALLTHROUGH
#endif
#else
#define RAPIDJSON_DELIBERATE_FALLTHROUGH
#endif
//!@endcond
//! Assertion (in non-throwing contexts).
/*! \ingroup RAPIDJSON_CONFIG
Some functions provide a \c noexcept guarantee, if the compiler supports it.
In these cases, the \ref RAPIDJSON_ASSERT macro cannot be overridden to
throw an exception. This macro adds a separate customization point for
such cases.
/*! \ingroup RAPIDJSON_CONFIG
Some functions provide a \c noexcept guarantee, if the compiler supports it.
In these cases, the \ref RAPIDJSON_ASSERT macro cannot be overridden to
throw an exception. This macro adds a separate customization point for
such cases.
Defaults to C \c assert() (as \ref RAPIDJSON_ASSERT), if \c noexcept is
supported, and to \ref RAPIDJSON_ASSERT otherwise.
*/
Defaults to C \c assert() (as \ref RAPIDJSON_ASSERT), if \c noexcept is
supported, and to \ref RAPIDJSON_ASSERT otherwise.
*/
///////////////////////////////////////////////////////////////////////////////
// RAPIDJSON_NOEXCEPT_ASSERT
@@ -726,14 +756,15 @@ RAPIDJSON_NAMESPACE_END
RAPIDJSON_NAMESPACE_BEGIN
//! Type of JSON value
enum Type {
kNullType = 0, //!< null
kFalseType = 1, //!< false
kTrueType = 2, //!< true
kObjectType = 3, //!< object
kArrayType = 4, //!< array
kStringType = 5, //!< string
kNumberType = 6 //!< number
enum Type
{
kNullType = 0, //!< null
kFalseType = 1, //!< false
kTrueType = 2, //!< true
kObjectType = 3, //!< object
kArrayType = 4, //!< array
kStringType = 5, //!< string
kNumberType = 6 //!< number
};
RAPIDJSON_NAMESPACE_END

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -69,34 +69,41 @@ concept Stream {
For custom stream, this type can be specialized for other configuration.
See TEST(Reader, CustomStringStream) in readertest.cpp for example.
*/
template<typename Stream>
struct StreamTraits {
template <typename Stream>
struct StreamTraits
{
//! Whether to make local copy of stream for optimization during parsing.
/*!
By default, for safety, streams do not use local copy optimization.
Stream that can be copied fast should specialize this, like StreamTraits<StringStream>.
*/
enum { copyOptimization = 0 };
enum
{
copyOptimization = 0
};
};
//! Reserve n characters for writing to a stream.
template<typename Stream>
inline void PutReserve(Stream& stream, size_t count) {
template <typename Stream>
inline void PutReserve(Stream& stream, size_t count)
{
(void)stream;
(void)count;
}
//! Write character to a stream, presuming buffer is reserved.
template<typename Stream>
inline void PutUnsafe(Stream& stream, typename Stream::Ch c) {
template <typename Stream>
inline void PutUnsafe(Stream& stream, typename Stream::Ch c)
{
stream.Put(c);
}
//! Put N copies of a character to a stream.
template<typename Stream, typename Ch>
inline void PutN(Stream& stream, Ch c, size_t n) {
template <typename Stream, typename Ch>
inline void PutN(Stream& stream, Ch c, size_t n)
{
PutReserve(stream, n);
for (size_t i = 0; i < n; i++)
for(size_t i = 0; i < n; i++)
PutUnsafe(stream, c);
}
@@ -111,15 +118,16 @@ inline void PutN(Stream& stream, Ch c, size_t n) {
#if defined(_MSC_VER) && _MSC_VER <= 1800
RAPIDJSON_DIAG_PUSH
RAPIDJSON_DIAG_OFF(4702) // unreachable code
RAPIDJSON_DIAG_OFF(4512) // assignment operator could not be generated
RAPIDJSON_DIAG_OFF(4702) // unreachable code
RAPIDJSON_DIAG_OFF(4512) // assignment operator could not be generated
#endif
template <typename InputStream, typename Encoding = UTF8<> >
class GenericStreamWrapper {
public:
template <typename InputStream, typename Encoding = UTF8<>>
class GenericStreamWrapper
{
public:
typedef typename Encoding::Ch Ch;
GenericStreamWrapper(InputStream& is): is_(is) {}
GenericStreamWrapper(InputStream& is) : is_(is) {}
Ch Peek() const { return is_.Peek(); }
Ch Take() { return is_.Take(); }
@@ -136,7 +144,7 @@ public:
UTFType GetType() const { return is_.GetType(); }
bool HasBOM() const { return is_.HasBOM(); }
protected:
protected:
InputStream& is_;
};
@@ -149,33 +157,46 @@ RAPIDJSON_DIAG_POP
//! Read-only string stream.
/*! \note implements Stream concept
*/
*/
template <typename Encoding>
struct GenericStringStream {
struct GenericStringStream
{
typedef typename Encoding::Ch Ch;
GenericStringStream(const Ch *src) : src_(src), head_(src) {}
GenericStringStream(const Ch* src) : src_(src), head_(src) {}
Ch Peek() const { return *src_; }
Ch Take() { return *src_++; }
size_t Tell() const { return static_cast<size_t>(src_ - head_); }
Ch* PutBegin() { RAPIDJSON_ASSERT(false); return 0; }
Ch* PutBegin()
{
RAPIDJSON_ASSERT(false);
return 0;
}
void Put(Ch) { RAPIDJSON_ASSERT(false); }
void Flush() { RAPIDJSON_ASSERT(false); }
size_t PutEnd(Ch*) { RAPIDJSON_ASSERT(false); return 0; }
size_t PutEnd(Ch*)
{
RAPIDJSON_ASSERT(false);
return 0;
}
const Ch* src_; //!< Current read position.
const Ch* head_; //!< Original head of the string.
const Ch* src_; //!< Current read position.
const Ch* head_; //!< Original head of the string.
};
template <typename Encoding>
struct StreamTraits<GenericStringStream<Encoding> > {
enum { copyOptimization = 1 };
struct StreamTraits<GenericStringStream<Encoding>>
{
enum
{
copyOptimization = 1
};
};
//! String stream with UTF8 encoding.
typedef GenericStringStream<UTF8<> > StringStream;
typedef GenericStringStream<UTF8<>> StringStream;
///////////////////////////////////////////////////////////////////////////////
// InsituStringStream
@@ -185,10 +206,11 @@ typedef GenericStringStream<UTF8<> > StringStream;
\note implements Stream concept
*/
template <typename Encoding>
struct GenericInsituStringStream {
struct GenericInsituStringStream
{
typedef typename Encoding::Ch Ch;
GenericInsituStringStream(Ch *src) : src_(src), dst_(0), head_(src) {}
GenericInsituStringStream(Ch* src) : src_(src), dst_(0), head_(src) {}
// Read
Ch Peek() { return *src_; }
@@ -196,13 +218,22 @@ struct GenericInsituStringStream {
size_t Tell() { return static_cast<size_t>(src_ - head_); }
// Write
void Put(Ch c) { RAPIDJSON_ASSERT(dst_ != 0); *dst_++ = c; }
void Put(Ch c)
{
RAPIDJSON_ASSERT(dst_ != 0);
*dst_++ = c;
}
Ch* PutBegin() { return dst_ = src_; }
size_t PutEnd(Ch* begin) { return static_cast<size_t>(dst_ - begin); }
void Flush() {}
Ch* Push(size_t count) { Ch* begin = dst_; dst_ += count; return begin; }
Ch* Push(size_t count)
{
Ch* begin = dst_;
dst_ += count;
return begin;
}
void Pop(size_t count) { dst_ -= count; }
Ch* src_;
@@ -211,12 +242,16 @@ struct GenericInsituStringStream {
};
template <typename Encoding>
struct StreamTraits<GenericInsituStringStream<Encoding> > {
enum { copyOptimization = 1 };
struct StreamTraits<GenericInsituStringStream<Encoding>>
{
enum
{
copyOptimization = 1
};
};
//! Insitu string stream with UTF8 encoding.
typedef GenericInsituStringStream<UTF8<> > InsituStringStream;
typedef GenericInsituStringStream<UTF8<>> InsituStringStream;
RAPIDJSON_NAMESPACE_END

View File

@@ -1,5 +1,5 @@
// Tencent is pleased to support the open source community by making RapidJSON available.
//
//
// Copyright (C) 2015 THL A29 Limited, a Tencent company, and Milo Yip.
//
// Licensed under the MIT License (the "License"); you may not use this file except
@@ -7,9 +7,9 @@
//
// http://opensource.org/licenses/MIT
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
#ifndef RAPIDJSON_STRINGBUFFER_H_
@@ -26,7 +26,7 @@
#if defined(__clang__)
RAPIDJSON_DIAG_PUSH
RAPIDJSON_DIAG_OFF(c++98-compat)
RAPIDJSON_DIAG_OFF(c++ 98 - compat)
#endif
RAPIDJSON_NAMESPACE_BEGIN
@@ -38,16 +38,21 @@ RAPIDJSON_NAMESPACE_BEGIN
\note implements Stream concept
*/
template <typename Encoding, typename Allocator = CrtAllocator>
class GenericStringBuffer {
public:
class GenericStringBuffer
{
public:
typedef typename Encoding::Ch Ch;
GenericStringBuffer(Allocator* allocator = 0, size_t capacity = kDefaultCapacity) : stack_(allocator, capacity) {}
GenericStringBuffer(Allocator* allocator = 0, size_t capacity = kDefaultCapacity)
: stack_(allocator, capacity)
{
}
#if RAPIDJSON_HAS_CXX11_RVALUE_REFS
GenericStringBuffer(GenericStringBuffer&& rhs) : stack_(std::move(rhs.stack_)) {}
GenericStringBuffer& operator=(GenericStringBuffer&& rhs) {
if (&rhs != this)
GenericStringBuffer& operator=(GenericStringBuffer&& rhs)
{
if(&rhs != this)
stack_ = std::move(rhs.stack_);
return *this;
}
@@ -58,7 +63,8 @@ public:
void Flush() {}
void Clear() { stack_.Clear(); }
void ShrinkToFit() {
void ShrinkToFit()
{
// Push and pop a null terminator. This is safe.
*stack_.template Push<Ch>() = '\0';
stack_.ShrinkToFit();
@@ -70,7 +76,8 @@ public:
Ch* PushUnsafe(size_t count) { return stack_.template PushUnsafe<Ch>(count); }
void Pop(size_t count) { stack_.template Pop<Ch>(count); }
const Ch* GetString() const {
const Ch* GetString() const
{
// Push and pop a null terminator. This is safe.
*stack_.template Push<Ch>() = '\0';
stack_.template Pop<Ch>(1);
@@ -87,28 +94,31 @@ public:
static const size_t kDefaultCapacity = 256;
mutable internal::Stack<Allocator> stack_;
private:
private:
// Prohibit copy constructor & assignment operator.
GenericStringBuffer(const GenericStringBuffer&);
GenericStringBuffer& operator=(const GenericStringBuffer&);
};
//! String buffer with UTF8 encoding
typedef GenericStringBuffer<UTF8<> > StringBuffer;
typedef GenericStringBuffer<UTF8<>> StringBuffer;
template<typename Encoding, typename Allocator>
inline void PutReserve(GenericStringBuffer<Encoding, Allocator>& stream, size_t count) {
template <typename Encoding, typename Allocator>
inline void PutReserve(GenericStringBuffer<Encoding, Allocator>& stream, size_t count)
{
stream.Reserve(count);
}
template<typename Encoding, typename Allocator>
inline void PutUnsafe(GenericStringBuffer<Encoding, Allocator>& stream, typename Encoding::Ch c) {
template <typename Encoding, typename Allocator>
inline void PutUnsafe(GenericStringBuffer<Encoding, Allocator>& stream, typename Encoding::Ch c)
{
stream.PutUnsafe(c);
}
//! Implement specialized version of PutN() with memset() for better performance.
template<>
inline void PutN(GenericStringBuffer<UTF8<> >& stream, char c, size_t n) {
template <>
inline void PutN(GenericStringBuffer<UTF8<>>& stream, char c, size_t n)
{
std::memset(stream.stack_.Push<char>(n), c, n * sizeof(c));
}

View File

@@ -19,7 +19,7 @@
#if defined(__clang__)
RAPIDJSON_DIAG_PUSH
RAPIDJSON_DIAG_OFF(c++98-compat)
RAPIDJSON_DIAG_OFF(c++ 98 - compat)
#elif defined(_MSC_VER)
RAPIDJSON_DIAG_OFF(4512) // assignment operator could not be generated
#endif
@@ -29,66 +29,141 @@ RAPIDJSON_NAMESPACE_BEGIN
///////////////////////////////////////////////////////////////////////////////
// GenericUri
template <typename ValueType, typename Allocator=CrtAllocator>
class GenericUri {
public:
template <typename ValueType, typename Allocator = CrtAllocator>
class GenericUri
{
public:
typedef typename ValueType::Ch Ch;
#if RAPIDJSON_HAS_STDSTRING
typedef std::basic_string<Ch> String;
#endif
//! Constructors
GenericUri(Allocator* allocator = 0) : uri_(), base_(), scheme_(), auth_(), path_(), query_(), frag_(), allocator_(allocator), ownAllocator_() {
GenericUri(Allocator* allocator = 0)
: uri_(),
base_(),
scheme_(),
auth_(),
path_(),
query_(),
frag_(),
allocator_(allocator),
ownAllocator_()
{
}
GenericUri(const Ch* uri, SizeType len, Allocator* allocator = 0) : uri_(), base_(), scheme_(), auth_(), path_(), query_(), frag_(), allocator_(allocator), ownAllocator_() {
GenericUri(const Ch* uri, SizeType len, Allocator* allocator = 0)
: uri_(),
base_(),
scheme_(),
auth_(),
path_(),
query_(),
frag_(),
allocator_(allocator),
ownAllocator_()
{
Parse(uri, len);
}
GenericUri(const Ch* uri, Allocator* allocator = 0) : uri_(), base_(), scheme_(), auth_(), path_(), query_(), frag_(), allocator_(allocator), ownAllocator_() {
GenericUri(const Ch* uri, Allocator* allocator = 0)
: uri_(),
base_(),
scheme_(),
auth_(),
path_(),
query_(),
frag_(),
allocator_(allocator),
ownAllocator_()
{
Parse(uri, internal::StrLen<Ch>(uri));
}
// Use with specializations of GenericValue
template<typename T> GenericUri(const T& uri, Allocator* allocator = 0) : uri_(), base_(), scheme_(), auth_(), path_(), query_(), frag_(), allocator_(allocator), ownAllocator_() {
template <typename T>
GenericUri(const T& uri, Allocator* allocator = 0)
: uri_(),
base_(),
scheme_(),
auth_(),
path_(),
query_(),
frag_(),
allocator_(allocator),
ownAllocator_()
{
const Ch* u = uri.template Get<const Ch*>(); // TypeHelper from document.h
Parse(u, internal::StrLen<Ch>(u));
}
#if RAPIDJSON_HAS_STDSTRING
GenericUri(const String& uri, Allocator* allocator = 0) : uri_(), base_(), scheme_(), auth_(), path_(), query_(), frag_(), allocator_(allocator), ownAllocator_() {
GenericUri(const String& uri, Allocator* allocator = 0)
: uri_(),
base_(),
scheme_(),
auth_(),
path_(),
query_(),
frag_(),
allocator_(allocator),
ownAllocator_()
{
Parse(uri.c_str(), internal::StrLen<Ch>(uri.c_str()));
}
#endif
//! Copy constructor
GenericUri(const GenericUri& rhs) : uri_(), base_(), scheme_(), auth_(), path_(), query_(), frag_(), allocator_(), ownAllocator_() {
GenericUri(const GenericUri& rhs)
: uri_(),
base_(),
scheme_(),
auth_(),
path_(),
query_(),
frag_(),
allocator_(),
ownAllocator_()
{
*this = rhs;
}
//! Copy constructor
GenericUri(const GenericUri& rhs, Allocator* allocator) : uri_(), base_(), scheme_(), auth_(), path_(), query_(), frag_(), allocator_(allocator), ownAllocator_() {
GenericUri(const GenericUri& rhs, Allocator* allocator)
: uri_(),
base_(),
scheme_(),
auth_(),
path_(),
query_(),
frag_(),
allocator_(allocator),
ownAllocator_()
{
*this = rhs;
}
//! Destructor.
~GenericUri() {
~GenericUri()
{
Free();
RAPIDJSON_DELETE(ownAllocator_);
}
//! Assignment operator
GenericUri& operator=(const GenericUri& rhs) {
if (this != &rhs) {
GenericUri& operator=(const GenericUri& rhs)
{
if(this != &rhs)
{
// Do not delete ownAllocator
Free();
Allocate(rhs.GetStringLength());
auth_ = CopyPart(scheme_, rhs.scheme_, rhs.GetSchemeStringLength());
path_ = CopyPart(auth_, rhs.auth_, rhs.GetAuthStringLength());
auth_ = CopyPart(scheme_, rhs.scheme_, rhs.GetSchemeStringLength());
path_ = CopyPart(auth_, rhs.auth_, rhs.GetAuthStringLength());
query_ = CopyPart(path_, rhs.path_, rhs.GetPathStringLength());
frag_ = CopyPart(query_, rhs.query_, rhs.GetQueryStringLength());
base_ = CopyPart(frag_, rhs.frag_, rhs.GetFragStringLength());
uri_ = CopyPart(base_, rhs.base_, rhs.GetBaseStringLength());
frag_ = CopyPart(query_, rhs.query_, rhs.GetQueryStringLength());
base_ = CopyPart(frag_, rhs.frag_, rhs.GetFragStringLength());
uri_ = CopyPart(base_, rhs.base_, rhs.GetBaseStringLength());
CopyPart(uri_, rhs.uri_, rhs.GetStringLength());
}
return *this;
@@ -96,7 +171,9 @@ public:
//! Getters
// Use with specializations of GenericValue
template<typename T> void Get(T& uri, Allocator& allocator) {
template <typename T>
void Get(T& uri, Allocator& allocator)
{
uri.template Set<const Ch*>(this->GetString(), allocator); // TypeHelper from document.h
}
@@ -105,7 +182,10 @@ public:
const Ch* GetBaseString() const { return base_; }
SizeType GetBaseStringLength() const { return base_ == 0 ? 0 : internal::StrLen<Ch>(base_); }
const Ch* GetSchemeString() const { return scheme_; }
SizeType GetSchemeStringLength() const { return scheme_ == 0 ? 0 : internal::StrLen<Ch>(scheme_); }
SizeType GetSchemeStringLength() const
{
return scheme_ == 0 ? 0 : internal::StrLen<Ch>(scheme_);
}
const Ch* GetAuthString() const { return auth_; }
SizeType GetAuthStringLength() const { return auth_ == 0 ? 0 : internal::StrLen<Ch>(auth_); }
const Ch* GetPathString() const { return path_; }
@@ -116,36 +196,59 @@ public:
SizeType GetFragStringLength() const { return frag_ == 0 ? 0 : internal::StrLen<Ch>(frag_); }
#if RAPIDJSON_HAS_STDSTRING
static String Get(const GenericUri& uri) { return String(uri.GetString(), uri.GetStringLength()); }
static String GetBase(const GenericUri& uri) { return String(uri.GetBaseString(), uri.GetBaseStringLength()); }
static String GetScheme(const GenericUri& uri) { return String(uri.GetSchemeString(), uri.GetSchemeStringLength()); }
static String GetAuth(const GenericUri& uri) { return String(uri.GetAuthString(), uri.GetAuthStringLength()); }
static String GetPath(const GenericUri& uri) { return String(uri.GetPathString(), uri.GetPathStringLength()); }
static String GetQuery(const GenericUri& uri) { return String(uri.GetQueryString(), uri.GetQueryStringLength()); }
static String GetFrag(const GenericUri& uri) { return String(uri.GetFragString(), uri.GetFragStringLength()); }
static String Get(const GenericUri& uri)
{
return String(uri.GetString(), uri.GetStringLength());
}
static String GetBase(const GenericUri& uri)
{
return String(uri.GetBaseString(), uri.GetBaseStringLength());
}
static String GetScheme(const GenericUri& uri)
{
return String(uri.GetSchemeString(), uri.GetSchemeStringLength());
}
static String GetAuth(const GenericUri& uri)
{
return String(uri.GetAuthString(), uri.GetAuthStringLength());
}
static String GetPath(const GenericUri& uri)
{
return String(uri.GetPathString(), uri.GetPathStringLength());
}
static String GetQuery(const GenericUri& uri)
{
return String(uri.GetQueryString(), uri.GetQueryStringLength());
}
static String GetFrag(const GenericUri& uri)
{
return String(uri.GetFragString(), uri.GetFragStringLength());
}
#endif
//! Equality operators
bool operator==(const GenericUri& rhs) const {
return Match(rhs, true);
}
bool operator==(const GenericUri& rhs) const { return Match(rhs, true); }
bool operator!=(const GenericUri& rhs) const {
return !Match(rhs, true);
}
bool operator!=(const GenericUri& rhs) const { return !Match(rhs, true); }
bool Match(const GenericUri& uri, bool full = true) const {
bool Match(const GenericUri& uri, bool full = true) const
{
Ch* s1;
Ch* s2;
if (full) {
if(full)
{
s1 = uri_;
s2 = uri.uri_;
} else {
}
else
{
s1 = base_;
s2 = uri.base_;
}
if (s1 == s2) return true;
if (s1 == 0 || s2 == 0) return false;
if(s1 == s2)
return true;
if(s1 == 0 || s2 == 0)
return false;
return internal::StrCmp<Ch>(s1, s2) == 0;
}
@@ -153,56 +256,80 @@ public:
// See https://tools.ietf.org/html/rfc3986
// Use for resolving an id or $ref with an in-scope id.
// Returns a new GenericUri for the resolved URI.
GenericUri Resolve(const GenericUri& baseuri, Allocator* allocator = 0) {
GenericUri Resolve(const GenericUri& baseuri, Allocator* allocator = 0)
{
GenericUri resuri;
resuri.allocator_ = allocator;
// Ensure enough space for combining paths
resuri.Allocate(GetStringLength() + baseuri.GetStringLength() + 1); // + 1 for joining slash
if (!(GetSchemeStringLength() == 0)) {
if(!(GetSchemeStringLength() == 0))
{
// Use all of this URI
resuri.auth_ = CopyPart(resuri.scheme_, scheme_, GetSchemeStringLength());
resuri.path_ = CopyPart(resuri.auth_, auth_, GetAuthStringLength());
resuri.auth_ = CopyPart(resuri.scheme_, scheme_, GetSchemeStringLength());
resuri.path_ = CopyPart(resuri.auth_, auth_, GetAuthStringLength());
resuri.query_ = CopyPart(resuri.path_, path_, GetPathStringLength());
resuri.frag_ = CopyPart(resuri.query_, query_, GetQueryStringLength());
resuri.frag_ = CopyPart(resuri.query_, query_, GetQueryStringLength());
resuri.RemoveDotSegments();
} else {
}
else
{
// Use the base scheme
resuri.auth_ = CopyPart(resuri.scheme_, baseuri.scheme_, baseuri.GetSchemeStringLength());
if (!(GetAuthStringLength() == 0)) {
resuri.auth_ =
CopyPart(resuri.scheme_, baseuri.scheme_, baseuri.GetSchemeStringLength());
if(!(GetAuthStringLength() == 0))
{
// Use this auth, path, query
resuri.path_ = CopyPart(resuri.auth_, auth_, GetAuthStringLength());
resuri.path_ = CopyPart(resuri.auth_, auth_, GetAuthStringLength());
resuri.query_ = CopyPart(resuri.path_, path_, GetPathStringLength());
resuri.frag_ = CopyPart(resuri.query_, query_, GetQueryStringLength());
resuri.frag_ = CopyPart(resuri.query_, query_, GetQueryStringLength());
resuri.RemoveDotSegments();
} else {
}
else
{
// Use the base auth
resuri.path_ = CopyPart(resuri.auth_, baseuri.auth_, baseuri.GetAuthStringLength());
if (GetPathStringLength() == 0) {
if(GetPathStringLength() == 0)
{
// Use the base path
resuri.query_ = CopyPart(resuri.path_, baseuri.path_, baseuri.GetPathStringLength());
if (GetQueryStringLength() == 0) {
resuri.query_ =
CopyPart(resuri.path_, baseuri.path_, baseuri.GetPathStringLength());
if(GetQueryStringLength() == 0)
{
// Use the base query
resuri.frag_ = CopyPart(resuri.query_, baseuri.query_, baseuri.GetQueryStringLength());
} else {
resuri.frag_ =
CopyPart(resuri.query_, baseuri.query_, baseuri.GetQueryStringLength());
}
else
{
// Use this query
resuri.frag_ = CopyPart(resuri.query_, query_, GetQueryStringLength());
}
} else {
if (path_[0] == '/') {
}
else
{
if(path_[0] == '/')
{
// Absolute path - use all of this path
resuri.query_ = CopyPart(resuri.path_, path_, GetPathStringLength());
resuri.RemoveDotSegments();
} else {
// Relative path - append this path to base path after base path's last slash
}
else
{
// Relative path - append this path to base path after base path's last
// slash
size_t pos = 0;
if (!(baseuri.GetAuthStringLength() == 0) && baseuri.GetPathStringLength() == 0) {
if(!(baseuri.GetAuthStringLength() == 0) &&
baseuri.GetPathStringLength() == 0)
{
resuri.path_[pos] = '/';
pos++;
}
size_t lastslashpos = baseuri.GetPathStringLength();
while (lastslashpos > 0) {
if (baseuri.path_[lastslashpos - 1] == '/') break;
while(lastslashpos > 0)
{
if(baseuri.path_[lastslashpos - 1] == '/')
break;
lastslashpos--;
}
std::memcpy(&resuri.path_[pos], baseuri.path_, lastslashpos * sizeof(Ch));
@@ -228,74 +355,87 @@ public:
//! Get the allocator of this GenericUri.
Allocator& GetAllocator() { return *allocator_; }
private:
private:
// Allocate memory for a URI
// Returns total amount allocated
std::size_t Allocate(std::size_t len) {
std::size_t Allocate(std::size_t len)
{
// Create own allocator if user did not supply.
if (!allocator_)
ownAllocator_ = allocator_ = RAPIDJSON_NEW(Allocator)();
if(!allocator_)
ownAllocator_ = allocator_ = RAPIDJSON_NEW(Allocator)();
// Allocate one block containing each part of the URI (5) plus base plus full URI, all null terminated.
// Order: scheme, auth, path, query, frag, base, uri
// Note need to set, increment, assign in 3 stages to avoid compiler warning bug.
// Allocate one block containing each part of the URI (5) plus base plus full URI, all null
// terminated. Order: scheme, auth, path, query, frag, base, uri Note need to set,
// increment, assign in 3 stages to avoid compiler warning bug.
size_t total = (3 * len + 7) * sizeof(Ch);
scheme_ = static_cast<Ch*>(allocator_->Malloc(total));
*scheme_ = '\0';
auth_ = scheme_;
scheme_ = static_cast<Ch*>(allocator_->Malloc(total));
*scheme_ = '\0';
auth_ = scheme_;
auth_++;
*auth_ = '\0';
path_ = auth_;
path_ = auth_;
path_++;
*path_ = '\0';
query_ = path_;
query_++;
*query_ = '\0';
frag_ = query_;
frag_ = query_;
frag_++;
*frag_ = '\0';
base_ = frag_;
base_ = frag_;
base_++;
*base_ = '\0';
uri_ = base_;
uri_ = base_;
uri_++;
*uri_ = '\0';
return total;
}
// Free memory for a URI
void Free() {
if (scheme_) {
void Free()
{
if(scheme_)
{
Allocator::Free(scheme_);
scheme_ = 0;
}
}
// Parse a URI into constituent scheme, authority, path, query, & fragment parts
// Supports URIs that match regex ^(([^:/?#]+):)?(//([^/?#]*))?([^?#]*)(\?([^#]*))?(#(.*))? as per
// https://tools.ietf.org/html/rfc3986
void Parse(const Ch* uri, std::size_t len) {
// Supports URIs that match regex ^(([^:/?#]+):)?(//([^/?#]*))?([^?#]*)(\?([^#]*))?(#(.*))? as
// per https://tools.ietf.org/html/rfc3986
void Parse(const Ch* uri, std::size_t len)
{
std::size_t start = 0, pos1 = 0, pos2 = 0;
Allocate(len);
// Look for scheme ([^:/?#]+):)?
if (start < len) {
while (pos1 < len) {
if (uri[pos1] == ':') break;
if(start < len)
{
while(pos1 < len)
{
if(uri[pos1] == ':')
break;
pos1++;
}
if (pos1 != len) {
while (pos2 < len) {
if (uri[pos2] == '/') break;
if (uri[pos2] == '?') break;
if (uri[pos2] == '#') break;
if(pos1 != len)
{
while(pos2 < len)
{
if(uri[pos2] == '/')
break;
if(uri[pos2] == '?')
break;
if(uri[pos2] == '#')
break;
pos2++;
}
if (pos1 < pos2) {
if(pos1 < pos2)
{
pos1++;
std::memcpy(scheme_, &uri[start], pos1 * sizeof(Ch));
scheme_[pos1] = '\0';
start = pos1;
start = pos1;
}
}
}
@@ -304,35 +444,45 @@ private:
auth_ = scheme_ + GetSchemeStringLength();
auth_++;
*auth_ = '\0';
if (start < len - 1 && uri[start] == '/' && uri[start + 1] == '/') {
if(start < len - 1 && uri[start] == '/' && uri[start + 1] == '/')
{
pos2 = start + 2;
while (pos2 < len) {
if (uri[pos2] == '/') break;
if (uri[pos2] == '?') break;
if (uri[pos2] == '#') break;
while(pos2 < len)
{
if(uri[pos2] == '/')
break;
if(uri[pos2] == '?')
break;
if(uri[pos2] == '#')
break;
pos2++;
}
std::memcpy(auth_, &uri[start], (pos2 - start) * sizeof(Ch));
auth_[pos2 - start] = '\0';
start = pos2;
start = pos2;
}
// Look for path ([^?#]*)
// Note need to set, increment, assign in 3 stages to avoid compiler warning bug.
path_ = auth_ + GetAuthStringLength();
path_++;
*path_ = '\0';
if (start < len) {
if(start < len)
{
pos2 = start;
while (pos2 < len) {
if (uri[pos2] == '?') break;
if (uri[pos2] == '#') break;
while(pos2 < len)
{
if(uri[pos2] == '?')
break;
if(uri[pos2] == '#')
break;
pos2++;
}
if (start != pos2) {
if(start != pos2)
{
std::memcpy(path_, &uri[start], (pos2 - start) * sizeof(Ch));
path_[pos2 - start] = '\0';
if (path_[0] == '/')
RemoveDotSegments(); // absolute path - normalize
if(path_[0] == '/')
RemoveDotSegments(); // absolute path - normalize
start = pos2;
}
}
@@ -341,16 +491,20 @@ private:
query_ = path_ + GetPathStringLength();
query_++;
*query_ = '\0';
if (start < len && uri[start] == '?') {
if(start < len && uri[start] == '?')
{
pos2 = start + 1;
while (pos2 < len) {
if (uri[pos2] == '#') break;
while(pos2 < len)
{
if(uri[pos2] == '#')
break;
pos2++;
}
if (start != pos2) {
if(start != pos2)
{
std::memcpy(query_, &uri[start], (pos2 - start) * sizeof(Ch));
query_[pos2 - start] = '\0';
start = pos2;
start = pos2;
}
}
// Look for fragment (#(.*))?
@@ -358,7 +512,8 @@ private:
frag_ = query_ + GetQueryStringLength();
frag_++;
*frag_ = '\0';
if (start < len && uri[start] == '#') {
if(start < len && uri[start] == '#')
{
std::memcpy(frag_, &uri[start], (len - start) * sizeof(Ch));
frag_[len - start] = '\0';
}
@@ -371,36 +526,39 @@ private:
}
// Reconstitute base
void SetBase() {
void SetBase()
{
Ch* next = base_;
std::memcpy(next, scheme_, GetSchemeStringLength() * sizeof(Ch));
next+= GetSchemeStringLength();
next += GetSchemeStringLength();
std::memcpy(next, auth_, GetAuthStringLength() * sizeof(Ch));
next+= GetAuthStringLength();
next += GetAuthStringLength();
std::memcpy(next, path_, GetPathStringLength() * sizeof(Ch));
next+= GetPathStringLength();
next += GetPathStringLength();
std::memcpy(next, query_, GetQueryStringLength() * sizeof(Ch));
next+= GetQueryStringLength();
next += GetQueryStringLength();
*next = '\0';
}
// Reconstitute uri
void SetUri() {
void SetUri()
{
Ch* next = uri_;
std::memcpy(next, base_, GetBaseStringLength() * sizeof(Ch));
next+= GetBaseStringLength();
next += GetBaseStringLength();
std::memcpy(next, frag_, GetFragStringLength() * sizeof(Ch));
next+= GetFragStringLength();
next += GetFragStringLength();
*next = '\0';
}
// Copy a part from one GenericUri to another
// Return the pointer to the next part to be copied to
Ch* CopyPart(Ch* to, Ch* from, std::size_t len) {
Ch* CopyPart(Ch* to, Ch* from, std::size_t len)
{
RAPIDJSON_ASSERT(to != 0);
RAPIDJSON_ASSERT(from != 0);
std::memcpy(to, from, len * sizeof(Ch));
to[len] = '\0';
to[len] = '\0';
Ch* next = to + len + 1;
return next;
}
@@ -408,45 +566,58 @@ private:
// Remove . and .. segments from the path_ member.
// https://tools.ietf.org/html/rfc3986
// This is done in place as we are only removing segments.
void RemoveDotSegments() {
void RemoveDotSegments()
{
std::size_t pathlen = GetPathStringLength();
std::size_t pathpos = 0; // Position in path_
std::size_t newpos = 0; // Position in new path_
std::size_t pathpos = 0; // Position in path_
std::size_t newpos = 0; // Position in new path_
// Loop through each segment in original path_
while (pathpos < pathlen) {
while(pathpos < pathlen)
{
// Get next segment, bounded by '/' or end
size_t slashpos = 0;
while ((pathpos + slashpos) < pathlen) {
if (path_[pathpos + slashpos] == '/') break;
while((pathpos + slashpos) < pathlen)
{
if(path_[pathpos + slashpos] == '/')
break;
slashpos++;
}
// Check for .. and . segments
if (slashpos == 2 && path_[pathpos] == '.' && path_[pathpos + 1] == '.') {
if(slashpos == 2 && path_[pathpos] == '.' && path_[pathpos + 1] == '.')
{
// Backup a .. segment in the new path_
// We expect to find a previously added slash at the end or nothing
RAPIDJSON_ASSERT(newpos == 0 || path_[newpos - 1] == '/');
size_t lastslashpos = newpos;
// Make sure we don't go beyond the start segment
if (lastslashpos > 1) {
if(lastslashpos > 1)
{
// Find the next to last slash and back up to it
lastslashpos--;
while (lastslashpos > 0) {
if (path_[lastslashpos - 1] == '/') break;
while(lastslashpos > 0)
{
if(path_[lastslashpos - 1] == '/')
break;
lastslashpos--;
}
// Set the new path_ position
newpos = lastslashpos;
}
} else if (slashpos == 1 && path_[pathpos] == '.') {
}
else if(slashpos == 1 && path_[pathpos] == '.')
{
// Discard . segment, leaves new path_ unchanged
} else {
}
else
{
// Move any other kind of segment to the new path_
RAPIDJSON_ASSERT(newpos <= pathpos);
std::memmove(&path_[newpos], &path_[pathpos], slashpos * sizeof(Ch));
newpos += slashpos;
// Add slash if not at end
if ((pathpos + slashpos) < pathlen) {
if((pathpos + slashpos) < pathlen)
{
path_[newpos] = '/';
newpos++;
}
@@ -465,8 +636,9 @@ private:
Ch* query_; // Includes the ?
Ch* frag_; // Includes the #
Allocator* allocator_; //!< The current allocator. It is either user-supplied or equal to ownAllocator_.
Allocator* ownAllocator_; //!< Allocator owned by this Uri.
Allocator* allocator_; //!< The current allocator. It is either user-supplied or equal to
//!< ownAllocator_.
Allocator* ownAllocator_; //!< Allocator owned by this Uri.
};
//! GenericUri for Value (UTF-8, default allocator).

File diff suppressed because it is too large Load Diff

View File

@@ -8,12 +8,12 @@ def __version__():
hash = subprocess.check_output("git rev-parse HEAD", shell=True, text=True)[
:hash_width
]
except:
except Exception:
hash = "0" * hash_width
try:
change_count = subprocess.check_output(
f"git rev-list rocm-{rocm_version}..HEAD --count", shell=True, text=True
).strip()
except:
except Exception:
change_count = "0"
return f"{rocm_version}.dev{change_count}+g{hash}"

View File

@@ -14,43 +14,69 @@ Features:
import argparse
import sys
import os
def run_dependency_parser(args):
from src.enhanced_ninja_parser import main as ninja_main
sys.argv = ["enhanced_ninja_parser.py"] + args
ninja_main()
def run_selective_test_filter(args):
from src.selective_test_filter import main as filter_main
sys.argv = ["selective_test_filter.py"] + args
filter_main()
def main():
parser = argparse.ArgumentParser(description="Unified Ninja Dependency & Selective Testing Tool")
parser = argparse.ArgumentParser(
description="Unified Ninja Dependency & Selective Testing Tool"
)
subparsers = parser.add_subparsers(dest="command", required=True)
# Dependency parsing
parser_parse = subparsers.add_parser("parse", help="Parse build.ninja and generate dependency mapping")
parser_parse = subparsers.add_parser(
"parse", help="Parse build.ninja and generate dependency mapping"
)
parser_parse.add_argument("build_ninja", help="Path to build.ninja")
parser_parse.add_argument("--ninja", help="Path to ninja executable", default="ninja")
parser_parse.add_argument("--workspace-root", help="Path to workspace root", default=None)
parser_parse.add_argument(
"--ninja", help="Path to ninja executable", default="ninja"
)
parser_parse.add_argument(
"--workspace-root", help="Path to workspace root", default=None
)
# Selective testing
parser_test = subparsers.add_parser("select", help="Selective test filtering between git refs")
parser_test = subparsers.add_parser(
"select", help="Selective test filtering between git refs"
)
parser_test.add_argument("depmap_json", help="Path to dependency mapping JSON")
parser_test.add_argument("ref1", help="Source git ref")
parser_test.add_argument("ref2", help="Target git ref")
parser_test.add_argument("--all", action="store_true", help="Include all executables")
parser_test.add_argument("--test-prefix", action="store_true", help="Only include executables starting with 'test_'")
parser_test.add_argument("--output", help="Output JSON file", default="tests_to_run.json")
parser_test.add_argument(
"--all", action="store_true", help="Include all executables"
)
parser_test.add_argument(
"--test-prefix",
action="store_true",
help="Only include executables starting with 'test_'",
)
parser_test.add_argument(
"--output", help="Output JSON file", default="tests_to_run.json"
)
# Code auditing
parser_audit = subparsers.add_parser("audit", help="List all files and their dependent executables")
parser_audit = subparsers.add_parser(
"audit", help="List all files and their dependent executables"
)
parser_audit.add_argument("depmap_json", help="Path to dependency mapping JSON")
# Build optimization
parser_opt = subparsers.add_parser("optimize", help="List affected executables for changed files")
parser_opt = subparsers.add_parser(
"optimize", help="List affected executables for changed files"
)
parser_opt.add_argument("depmap_json", help="Path to dependency mapping JSON")
parser_opt.add_argument("changed_files", nargs="+", help="List of changed files")
@@ -73,9 +99,12 @@ def main():
elif args.command == "audit":
run_selective_test_filter([args.depmap_json, "--audit"])
elif args.command == "optimize":
run_selective_test_filter([args.depmap_json, "--optimize-build"] + args.changed_files)
run_selective_test_filter(
[args.depmap_json, "--optimize-build"] + args.changed_files
)
else:
parser.print_help()
if __name__ == "__main__":
main()

View File

@@ -14,96 +14,100 @@ import re
import os
import sys
import subprocess
from pathlib import Path
from collections import defaultdict
import json
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
class EnhancedNinjaDependencyParser:
def __init__(self, build_file_path, ninja_executable="ninja"):
self.build_file_path = build_file_path
self.build_dir = os.path.dirname(build_file_path)
self.ninja_executable = ninja_executable
# Core data structures
self.executable_to_objects = {} # exe -> [object_files]
self.object_to_source = {} # object -> primary_source
self.object_to_all_deps = {} # object -> [all_dependencies]
self.object_to_source = {} # object -> primary_source
self.object_to_all_deps = {} # object -> [all_dependencies]
self.file_to_executables = defaultdict(set) # file -> {executables}
# Thread safety
self.lock = threading.Lock()
def parse_dependencies(self):
"""Main method to parse all dependencies."""
print(f"Parsing ninja dependencies from: {self.build_file_path}")
# Step 1: Parse build file for executable -> object mappings
self._parse_build_file()
# Step 2: Get all object files and their dependencies
print(f"Found {len(self.object_to_source)} object files")
print("Extracting detailed dependencies for all object files...")
self._extract_object_dependencies()
# Step 3: Build the final file -> executables mapping
self._build_file_to_executable_mapping()
def _parse_build_file(self):
"""Parse the ninja build file to extract executable -> object mappings."""
print("Parsing ninja build file...")
with open(self.build_file_path, 'r') as f:
with open(self.build_file_path, "r") as f:
content = f.read()
# Parse executable build rules
exe_pattern = r'^build (bin/[^:]+):\s+\S+\s+([^|]+)'
obj_pattern = r'^build ([^:]+\.(?:cpp|cu|hip)\.o):\s+\S+\s+([^\s|]+)'
lines = content.split('\n')
# Parse executable build rules
exe_pattern = r"^build (bin/[^:]+):\s+\S+\s+([^|]+)"
obj_pattern = r"^build ([^:]+\.(?:cpp|cu|hip)\.o):\s+\S+\s+([^\s|]+)"
lines = content.split("\n")
for line in lines:
# Match executable rules
exe_match = re.match(exe_pattern, line)
if exe_match and ('EXECUTABLE' in line or 'test_' in exe_match.group(1) or 'example_' in exe_match.group(1)):
if exe_match and (
"EXECUTABLE" in line
or "test_" in exe_match.group(1)
or "example_" in exe_match.group(1)
):
exe = exe_match.group(1)
deps_part = exe_match.group(2).strip()
object_files = []
for dep in deps_part.split():
if dep.endswith('.o') and not dep.startswith('/'):
if dep.endswith(".o") and not dep.startswith("/"):
object_files.append(dep)
self.executable_to_objects[exe] = object_files
continue
# Match object compilation rules
obj_match = re.match(obj_pattern, line)
if obj_match:
object_file = obj_match.group(1)
source_file = obj_match.group(2)
self.object_to_source[object_file] = source_file
print(f"Found {len(self.executable_to_objects)} executables")
print(f"Found {len(self.object_to_source)} object-to-source mappings")
def _extract_object_dependencies(self):
"""Extract detailed dependencies for all object files using ninja -t deps."""
object_files = list(self.object_to_source.keys())
# Process object files in parallel for better performance
# Process object files in parallel for better performance
if not object_files:
print("No object files found - skipping dependency extraction")
return
max_workers = min(16, len(object_files)) # Limit concurrent processes
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all object files for processing
future_to_obj = {
executor.submit(self._get_object_dependencies, obj): obj
executor.submit(self._get_object_dependencies, obj): obj
for obj in object_files
}
# Process completed futures
# Process completed futures
completed = 0
for future in as_completed(future_to_obj):
obj_file = future_to_obj[future]
@@ -113,52 +117,52 @@ class EnhancedNinjaDependencyParser:
self.object_to_all_deps[obj_file] = dependencies
completed += 1
if completed % 100 == 0:
print(f"Processed {completed}/{len(object_files)} object files...")
print(
f"Processed {completed}/{len(object_files)} object files..."
)
except Exception as e:
print(f"Error processing {obj_file}: {e}")
print(f"Completed dependency extraction for {len(self.object_to_all_deps)} object files")
print(
f"Completed dependency extraction for {len(self.object_to_all_deps)} object files"
)
def _get_object_dependencies(self, object_file):
"""Get all dependencies for a single object file using ninja -t deps."""
try:
# Run ninja -t deps for this object file
cmd = [self.ninja_executable, "-t", "deps", object_file]
result = subprocess.run(
cmd,
cwd=self.build_dir,
capture_output=True,
text=True,
timeout=30
cmd, cwd=self.build_dir, capture_output=True, text=True, timeout=30
)
if result.returncode != 0:
return []
dependencies = []
lines = result.stdout.strip().split('\n')
lines = result.stdout.strip().split("\n")
for line in lines[1:]: # Skip first line with metadata
line = line.strip()
if line and not line.startswith('#'):
if line and not line.startswith("#"):
# Convert absolute paths to relative paths from workspace root
dep_file = line
ws_root = getattr(self, "workspace_root", "..")
ws_prefix = ws_root.rstrip("/") + "/"
if dep_file.startswith(ws_prefix):
dep_file = dep_file[len(ws_prefix):]
dep_file = dep_file[len(ws_prefix) :]
dependencies.append(dep_file)
return dependencies
except Exception as e:
print(f"Error getting dependencies for {object_file}: {e}")
return []
def _build_file_to_executable_mapping(self):
"""Build the final mapping from files to executables."""
print("Building file-to-executable mapping...")
for exe, object_files in self.executable_to_objects.items():
for obj_file in object_files:
# Add all dependencies of this object file
@@ -167,106 +171,135 @@ class EnhancedNinjaDependencyParser:
# Filter out system files and focus on project files
if self._is_project_file(dep_file):
self.file_to_executables[dep_file].add(exe)
print(f"Built mapping for {len(self.file_to_executables)} files")
# Show statistics
multi_exe_files = {f: exes for f, exes in self.file_to_executables.items() if len(exes) > 1}
multi_exe_files = {
f: exes for f, exes in self.file_to_executables.items() if len(exes) > 1
}
print(f"Files used by multiple executables: {len(multi_exe_files)}")
if multi_exe_files:
print("Sample files with multiple dependencies:")
for f, exes in sorted(multi_exe_files.items())[:5]:
print(f" {f}: {len(exes)} executables")
def _is_project_file(self, file_path):
"""Determine if a file is part of the project (not system files)."""
# Include files that are clearly part of the project
if any(file_path.startswith(prefix) for prefix in [
'include/', 'library/', 'test/', 'example/', 'src/', 'profiler/',
'build/include/', 'build/_deps/gtest', 'client_example', 'codegen', 'tile_engine'
]):
if any(
file_path.startswith(prefix)
for prefix in [
"include/",
"library/",
"test/",
"example/",
"src/",
"profiler/",
"build/include/",
"build/_deps/gtest",
"client_example",
"codegen",
"tile_engine",
]
):
return True
# Exclude system files
if any(file_path.startswith(prefix) for prefix in [
'/usr/', '/opt/rocm', '/lib/', '/system/', '/local/'
]):
if any(
file_path.startswith(prefix)
for prefix in ["/usr/", "/opt/rocm", "/lib/", "/system/", "/local/"]
):
return False
# Include files with common source/header extensions
if file_path.endswith(('.cpp', '.hpp', '.h', '.c', '.cc', '.cxx', '.cu', '.hip', '.inc')):
if file_path.endswith(
(".cpp", ".hpp", ".h", ".c", ".cc", ".cxx", ".cu", ".hip", ".inc")
):
return True
return False
def export_to_csv(self, output_file):
"""Export the file-to-executable mapping to CSV with proper comma separation."""
print(f"Exporting mapping to {output_file}")
with open(output_file, 'w') as f:
with open(output_file, "w") as f:
f.write("source_file,executables\n")
for file_path in sorted(self.file_to_executables.keys()):
executables = sorted(self.file_to_executables[file_path])
# Use semicolon to separate multiple executables within the field
exe_list = ';'.join(executables)
exe_list = ";".join(executables)
f.write(f'"{file_path}","{exe_list}"\n')
def export_to_json(self, output_file):
"""Export the complete mapping to JSON."""
print(f"Exporting complete mapping to {output_file}")
# Build reverse mapping (executable -> files)
exe_to_files = defaultdict(set)
for file_path, exes in self.file_to_executables.items():
for exe in exes:
exe_to_files[exe].add(file_path)
mapping_data = {
'file_to_executables': {
file_path: list(exes) for file_path, exes in self.file_to_executables.items()
"file_to_executables": {
file_path: list(exes)
for file_path, exes in self.file_to_executables.items()
},
'executable_to_files': {
"executable_to_files": {
exe: sorted(files) for exe, files in exe_to_files.items()
},
'statistics': {
'total_files': len(self.file_to_executables),
'total_executables': len(self.executable_to_objects),
'total_object_files': len(self.object_to_source),
'files_with_multiple_executables': len([f for f, exes in self.file_to_executables.items() if len(exes) > 1])
}
"statistics": {
"total_files": len(self.file_to_executables),
"total_executables": len(self.executable_to_objects),
"total_object_files": len(self.object_to_source),
"files_with_multiple_executables": len(
[f for f, exes in self.file_to_executables.items() if len(exes) > 1]
),
},
}
with open(output_file, 'w') as f:
with open(output_file, "w") as f:
json.dump(mapping_data, f, indent=2)
def print_summary(self):
"""Print a summary of the parsed dependencies."""
"""Print a summary of the parsed dependencies."""
print("\n=== Enhanced Dependency Mapping Summary ===")
print(f"Total executables: {len(self.executable_to_objects)}")
print(f"Total files mapped: {len(self.file_to_executables)}")
print(f"Total object files processed: {len(self.object_to_all_deps)}")
# Files by type
cpp_files = sum(1 for f in self.file_to_executables.keys() if f.endswith('.cpp'))
hpp_files = sum(1 for f in self.file_to_executables.keys() if f.endswith('.hpp'))
h_files = sum(1 for f in self.file_to_executables.keys() if f.endswith('.h'))
print(f"\nFile types:")
cpp_files = sum(
1 for f in self.file_to_executables.keys() if f.endswith(".cpp")
)
hpp_files = sum(
1 for f in self.file_to_executables.keys() if f.endswith(".hpp")
)
h_files = sum(1 for f in self.file_to_executables.keys() if f.endswith(".h"))
print("\nFile types:")
print(f" .cpp files: {cpp_files}")
print(f" .hpp files: {hpp_files}")
print(f" .h files: {h_files}")
# Multi-executable files
multi_exe_files = {f: exes for f, exes in self.file_to_executables.items() if len(exes) > 1}
multi_exe_files = {
f: exes for f, exes in self.file_to_executables.items() if len(exes) > 1
}
print(f"\nFiles used by multiple executables: {len(multi_exe_files)}")
if multi_exe_files:
print("\nTop files with most dependencies:")
sorted_multi = sorted(multi_exe_files.items(), key=lambda x: len(x[1]), reverse=True)
sorted_multi = sorted(
multi_exe_files.items(), key=lambda x: len(x[1]), reverse=True
)
for file_path, exes in sorted_multi[:10]:
print(f" {file_path}: {len(exes)} executables")
def main():
# Accept: build_file, ninja_path, workspace_root
default_workspace_root = ".."
@@ -304,15 +337,16 @@ def main():
# Export results
output_dir = os.path.dirname(build_file)
csv_file = os.path.join(output_dir, 'enhanced_file_executable_mapping.csv')
json_file = os.path.join(output_dir, 'enhanced_dependency_mapping.json')
csv_file = os.path.join(output_dir, "enhanced_file_executable_mapping.csv")
json_file = os.path.join(output_dir, "enhanced_dependency_mapping.json")
parser.export_to_csv(csv_file)
parser.export_to_json(json_file)
print(f"\nResults exported to:")
print("\nResults exported to:")
print(f" CSV: {csv_file}")
print(f" JSON: {json_file}")
if __name__ == "__main__":
main()

View File

@@ -30,12 +30,15 @@ import subprocess
import json
import os
def get_changed_files(ref1, ref2):
"""Return a set of files changed between two git refs."""
try:
result = subprocess.run(
["git", "diff", "--name-only", ref1, ref2],
capture_output=True, text=True, check=True
capture_output=True,
text=True,
check=True,
)
files = set(line.strip() for line in result.stdout.splitlines() if line.strip())
return files
@@ -43,6 +46,7 @@ def get_changed_files(ref1, ref2):
print(f"Error running git diff: {e}")
sys.exit(1)
def load_depmap(depmap_json):
"""Load the dependency mapping JSON."""
with open(depmap_json, "r") as f:
@@ -52,6 +56,7 @@ def load_depmap(depmap_json):
return data["file_to_executables"]
return data
def select_tests(file_to_executables, changed_files, filter_mode):
"""Return a set of test executables affected by changed files."""
affected = set()
@@ -64,6 +69,7 @@ def select_tests(file_to_executables, changed_files, filter_mode):
affected.add(exe)
return sorted(affected)
def main():
if "--audit" in sys.argv:
if len(sys.argv) < 2:
@@ -81,7 +87,9 @@ def main():
if "--optimize-build" in sys.argv:
if len(sys.argv) < 3:
print("Usage: python selective_test_filter.py <depmap_json> --optimize-build <changed_file1> [<changed_file2> ...]")
print(
"Usage: python selective_test_filter.py <depmap_json> --optimize-build <changed_file1> [<changed_file2> ...]"
)
sys.exit(1)
depmap_json = sys.argv[1]
changed_files = set(sys.argv[sys.argv.index("--optimize-build") + 1 :])
@@ -100,7 +108,9 @@ def main():
sys.exit(0)
if len(sys.argv) < 4:
print("Usage: python selective_test_filter.py <depmap_json> <ref1> <ref2> [--all | --test-prefix] [--output <output_json>]")
print(
"Usage: python selective_test_filter.py <depmap_json> <ref1> <ref2> [--all | --test-prefix] [--output <output_json>]"
)
sys.exit(1)
depmap_json = sys.argv[1]
@@ -131,9 +141,12 @@ def main():
tests = select_tests(file_to_executables, changed_files, filter_mode)
with open(output_json, "w") as f:
json.dump({"tests_to_run": tests, "changed_files": sorted(changed_files)}, f, indent=2)
json.dump(
{"tests_to_run": tests, "changed_files": sorted(changed_files)}, f, indent=2
)
print(f"Exported {len(tests)} tests to run to {output_json}")
if __name__ == "__main__":
main()

View File

@@ -12,38 +12,38 @@ import os
import re
import sys
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Iterator
from typing import Dict, List, Optional, Iterator
class BuildTarget:
"""Represents a single build target with timing information."""
def __init__(self, start_time: int, end_time: int, output_name: str, cmd_hash: str):
self.start_time = int(start_time)
self.end_time = int(end_time)
self.cmd_hash = cmd_hash
self.duration = self.end_time - self.start_time
self.targets = [output_name] # List of target names for this command hash
@property
def category(self) -> str:
"""Categorize the build target based on file extension."""
# Use the first target for categorization
primary_target = self.targets[0] if self.targets else ""
ext = Path(primary_target).suffix.lower()
if ext in ['.o', '.obj']:
return 'compile'
elif ext in ['.a', '.lib']:
return 'archive'
elif ext in ['.so', '.dll', '.dylib']:
return 'link_shared'
elif ext in ['.exe', '.out']:
return 'link_executable'
elif 'test' in primary_target.lower():
return 'test'
if ext in [".o", ".obj"]:
return "compile"
elif ext in [".a", ".lib"]:
return "archive"
elif ext in [".so", ".dll", ".dylib"]:
return "link_shared"
elif ext in [".exe", ".out"]:
return "link_executable"
elif "test" in primary_target.lower():
return "test"
else:
return 'other'
return "other"
@property
def output_name(self) -> str:
"""Get the primary output name (for backward compatibility)."""
@@ -52,11 +52,11 @@ class BuildTarget:
class ThreadScheduler:
"""Simulates thread allocation for parallelism analysis."""
def __init__(self, legacy_mode: bool = False):
self.workers: List[int] = []
self.legacy_mode = legacy_mode
def allocate_thread(self, target: BuildTarget) -> int:
"""Allocate a thread for the given target."""
if self.legacy_mode:
@@ -73,7 +73,7 @@ class ThreadScheduler:
if worker_end_time <= target.start_time:
self.workers[i] = target.end_time
return i
# No available worker, create a new one
self.workers.append(target.end_time)
return len(self.workers) - 1
@@ -81,62 +81,67 @@ class ThreadScheduler:
class NinjaLogParser:
"""Parser for ninja build log files."""
def __init__(self, show_all_builds: bool = False):
self.show_all_builds = show_all_builds
def parse_log_file(self, log_path: str) -> List[BuildTarget]:
"""Parse the ninja log file and return build targets."""
if not os.path.exists(log_path):
raise FileNotFoundError(f"Ninja log file not found: {log_path}")
with open(log_path, 'r', encoding='utf-8') as file:
with open(log_path, "r", encoding="utf-8") as file:
lines = file.readlines()
if not lines:
raise ValueError("Empty ninja log file")
# Parse and validate header
header = lines[0].strip()
version_match = re.match(r'^# ninja log v(\d+)$', header)
version_match = re.match(r"^# ninja log v(\d+)$", header)
if not version_match:
raise ValueError(f"Invalid ninja log header: {header}")
version = int(version_match.group(1))
if version < 5:
raise ValueError(f"Unsupported ninja log version: {version}")
# Skip additional header line for version 6
start_line = 2 if version > 5 else 1
targets: Dict[str, BuildTarget] = {}
last_end_time = 0
for line_num, line in enumerate(lines[start_line:], start=start_line + 1):
line = line.strip()
# Skip empty lines and comments
if not line or line.startswith('#'):
if not line or line.startswith("#"):
continue
parts = line.split('\t')
parts = line.split("\t")
if len(parts) < 5:
print(f"Warning: Skipping malformed line {line_num}: {line}", file=sys.stderr)
print(
f"Warning: Skipping malformed line {line_num}: {line}",
file=sys.stderr,
)
continue
try:
start_time, end_time, _, output_name, cmd_hash = parts[:5]
start_time, end_time = int(start_time), int(end_time)
# Handle incremental builds
if not self.show_all_builds and end_time < last_end_time:
targets.clear()
last_end_time = end_time
# Group targets by command hash
if cmd_hash not in targets:
targets[cmd_hash] = BuildTarget(start_time, end_time, output_name, cmd_hash)
targets[cmd_hash] = BuildTarget(
start_time, end_time, output_name, cmd_hash
)
else:
# Update with the latest timing and add output
existing = targets[cmd_hash]
@@ -144,223 +149,260 @@ class NinjaLogParser:
existing.end_time = max(existing.end_time, end_time)
existing.duration = existing.end_time - existing.start_time
existing.targets.append(output_name)
except (ValueError, IndexError) as e:
print(f"Warning: Error parsing line {line_num}: {e}", file=sys.stderr)
continue
return sorted(targets.values(), key=lambda t: t.end_time, reverse=True)
class FTimeTraceReader:
"""Reads and processes Clang -ftime-trace JSON files."""
def __init__(self, granularity_us: int = 50000):
self.granularity_us = granularity_us
def read_trace_file(self, trace_path: str) -> Optional[Dict]:
"""Read and parse a Clang time trace file."""
try:
with open(trace_path, 'r', encoding='utf-8') as f:
with open(trace_path, "r", encoding="utf-8") as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError, IOError):
return None
def filter_events(self, trace_data: Dict) -> List[Dict]:
"""Filter trace events based on criteria."""
if 'traceEvents' not in trace_data:
if "traceEvents" not in trace_data:
return []
filtered_events = []
for event in trace_data['traceEvents']:
for event in trace_data["traceEvents"]:
# Only include complete events (ph=X) that meet duration threshold
if (event.get('ph') == 'X' and
event.get('dur', 0) >= self.granularity_us and
not event.get('name', '').startswith('Total')):
if (
event.get("ph") == "X"
and event.get("dur", 0) >= self.granularity_us
and not event.get("name", "").startswith("Total")
):
filtered_events.append(event)
return filtered_events
def adjust_event_timing(self, event: Dict, target: BuildTarget, pid: int, tid: int) -> Dict:
def adjust_event_timing(
self, event: Dict, target: BuildTarget, pid: int, tid: int
) -> Dict:
"""Adjust event timing to align with ninja build timing."""
ninja_duration_us = target.duration * 1000
# Validate event duration against ninja timing
if event.get('dur', 0) > ninja_duration_us:
print(f"Warning: Clang trace event duration ({event['dur']}μs) exceeds "
f"ninja duration ({ninja_duration_us}μs) for {target.output_name}",
file=sys.stderr)
if event.get("dur", 0) > ninja_duration_us:
print(
f"Warning: Clang trace event duration ({event['dur']}μs) exceeds "
f"ninja duration ({ninja_duration_us}μs) for {target.output_name}",
file=sys.stderr,
)
return None
# Adjust event timing
adjusted_event = event.copy()
adjusted_event['pid'] = pid
adjusted_event['tid'] = tid
adjusted_event['ts'] += target.start_time * 1000 # Offset by ninja start time
adjusted_event["pid"] = pid
adjusted_event["tid"] = tid
adjusted_event["ts"] += target.start_time * 1000 # Offset by ninja start time
return adjusted_event
class ChromeTraceGenerator:
"""Generates Chrome tracing format from build targets."""
def __init__(self, process_id: int = 1, embed_ftime_traces: bool = False,
granularity_us: int = 50000, ninja_log_dir: Optional[str] = None,
legacy_format: bool = False):
def __init__(
self,
process_id: int = 1,
embed_ftime_traces: bool = False,
granularity_us: int = 50000,
ninja_log_dir: Optional[str] = None,
legacy_format: bool = False,
):
self.process_id = process_id
self.scheduler = ThreadScheduler(legacy_mode=legacy_format)
self.embed_ftime_traces = embed_ftime_traces
self.ninja_log_dir = ninja_log_dir
self.ftime_reader = FTimeTraceReader(granularity_us) if embed_ftime_traces else None
self.ftime_reader = (
FTimeTraceReader(granularity_us) if embed_ftime_traces else None
)
self.legacy_format = legacy_format
def find_ftime_trace_files(self, target: BuildTarget) -> List[str]:
"""Find Clang -ftime-trace files for a build target."""
if not self.ninja_log_dir:
return []
trace_files = []
# Look for .json files adjacent to object files
obj_path = Path(self.ninja_log_dir) / target.output_name
json_path = obj_path.with_suffix('.json')
json_path = obj_path.with_suffix(".json")
if json_path.exists():
trace_files.append(str(json_path))
return trace_files
def generate_ftime_events(self, target: BuildTarget, tid: int) -> Iterator[Dict]:
"""Generate Clang -ftime-trace events for a target."""
if not self.embed_ftime_traces or not self.ftime_reader:
return
trace_files = self.find_ftime_trace_files(target)
for trace_file in trace_files:
trace_data = self.ftime_reader.read_trace_file(trace_file)
if not trace_data:
continue
filtered_events = self.ftime_reader.filter_events(trace_data)
for event in filtered_events:
adjusted_event = self.ftime_reader.adjust_event_timing(
event, target, self.process_id, tid
)
if adjusted_event:
yield adjusted_event
def generate_trace_events(self, targets: List[BuildTarget]) -> List[Dict]:
"""Generate Chrome trace events from build targets."""
events = []
for target in targets:
thread_id = self.scheduler.allocate_thread(target)
# Add main ninja build event
if self.legacy_format:
# Legacy format: join multiple targets with commas, use "targets" category, empty args
target_name = ', '.join(target.targets) if len(target.targets) > 1 else target.output_name
target_name = (
", ".join(target.targets)
if len(target.targets) > 1
else target.output_name
)
ninja_event = {
'name': target_name,
'cat': 'targets',
'ph': 'X', # Complete event
'ts': target.start_time * 1000, # Convert to microseconds
'dur': target.duration * 1000, # Convert to microseconds
'pid': self.process_id,
'tid': thread_id,
'args': {}
"name": target_name,
"cat": "targets",
"ph": "X", # Complete event
"ts": target.start_time * 1000, # Convert to microseconds
"dur": target.duration * 1000, # Convert to microseconds
"pid": self.process_id,
"tid": thread_id,
"args": {},
}
else:
# New format: smart categorization, detailed args
ninja_event = {
'name': target.output_name,
'cat': target.category,
'ph': 'X', # Complete event
'ts': target.start_time * 1000, # Convert to microseconds
'dur': target.duration * 1000, # Convert to microseconds
'pid': self.process_id,
'tid': thread_id,
'args': {
'output': target.output_name,
'duration_ms': target.duration,
'cmd_hash': target.cmd_hash
}
"name": target.output_name,
"cat": target.category,
"ph": "X", # Complete event
"ts": target.start_time * 1000, # Convert to microseconds
"dur": target.duration * 1000, # Convert to microseconds
"pid": self.process_id,
"tid": thread_id,
"args": {
"output": target.output_name,
"duration_ms": target.duration,
"cmd_hash": target.cmd_hash,
},
}
events.append(ninja_event)
# Add embedded Clang -ftime-trace events
if self.embed_ftime_traces:
ftime_events = list(self.generate_ftime_events(target, thread_id))
events.extend(ftime_events)
if ftime_events:
print(f"Embedded {len(ftime_events)} -ftime-trace events for {target.output_name}",
file=sys.stderr)
print(
f"Embedded {len(ftime_events)} -ftime-trace events for {target.output_name}",
file=sys.stderr,
)
return events
class BuildAnalyzer:
"""Analyzes build performance and provides statistics."""
def __init__(self, targets: List[BuildTarget]):
self.targets = targets
def get_build_summary(self) -> Dict:
"""Generate build performance summary."""
if not self.targets:
return {}
total_duration = sum(t.duration for t in self.targets)
total_targets = len(self.targets)
# Category statistics
category_stats = {}
for target in self.targets:
cat = target.category
if cat not in category_stats:
category_stats[cat] = {'count': 0, 'total_time': 0}
category_stats[cat]['count'] += 1
category_stats[cat]['total_time'] += target.duration
category_stats[cat] = {"count": 0, "total_time": 0}
category_stats[cat]["count"] += 1
category_stats[cat]["total_time"] += target.duration
# Top slowest targets
slowest_targets = sorted(self.targets, key=lambda t: t.duration, reverse=True)[:10]
slowest_targets = sorted(self.targets, key=lambda t: t.duration, reverse=True)[
:10
]
return {
'total_targets': total_targets,
'total_duration_ms': total_duration,
'total_duration_sec': total_duration / 1000,
'average_duration_ms': total_duration / total_targets if total_targets > 0 else 0,
'category_stats': category_stats,
'slowest_targets': [
{'name': t.output_name, 'duration_ms': t.duration, 'category': t.category}
"total_targets": total_targets,
"total_duration_ms": total_duration,
"total_duration_sec": total_duration / 1000,
"average_duration_ms": total_duration / total_targets
if total_targets > 0
else 0,
"category_stats": category_stats,
"slowest_targets": [
{
"name": t.output_name,
"duration_ms": t.duration,
"category": t.category,
}
for t in slowest_targets
]
],
}
def print_summary(self):
"""Print build summary to stderr."""
summary = self.get_build_summary()
if not summary:
print("No build data available", file=sys.stderr)
return
print(f"\n=== Build Summary ===", file=sys.stderr)
print("\n=== Build Summary ===", file=sys.stderr)
print(f"Total targets: {summary['total_targets']}", file=sys.stderr)
print(f"Total time: {summary['total_duration_sec']:.2f}s", file=sys.stderr)
print(f"Average time per target: {summary['average_duration_ms']:.2f}ms", file=sys.stderr)
print(f"\nBy category:", file=sys.stderr)
for category, stats in summary['category_stats'].items():
avg_time = stats['total_time'] / stats['count'] if stats['count'] > 0 else 0
print(f" {category:15} {stats['count']:6} targets "
f"{stats['total_time']/1000:8.2f}s "
f"(avg: {avg_time/1000:.3f}s)", file=sys.stderr)
print(f"\nSlowest targets:", file=sys.stderr)
for i, target in enumerate(summary['slowest_targets'][:5], 1):
print(f" {i:2}. {target['name']} ({target['duration_ms']}ms, {target['category']})", file=sys.stderr)
print(
f"Average time per target: {summary['average_duration_ms']:.2f}ms",
file=sys.stderr,
)
print("\nBy category:", file=sys.stderr)
for category, stats in summary["category_stats"].items():
avg_time = stats["total_time"] / stats["count"] if stats["count"] > 0 else 0
print(
f" {category:15} {stats['count']:6} targets "
f"{stats['total_time'] / 1000:8.2f}s "
f"(avg: {avg_time / 1000:.3f}s)",
file=sys.stderr,
)
print("\nSlowest targets:", file=sys.stderr)
for i, target in enumerate(summary["slowest_targets"][:5], 1):
print(
f" {i:2}. {target['name']} ({target['duration_ms']}ms, {target['category']})",
file=sys.stderr,
)
def create_argument_parser() -> argparse.ArgumentParser:
@@ -376,57 +418,48 @@ Examples:
%(prog)s build/.ninja_log --show-all # Include all builds
%(prog)s build/.ninja_log --embed-ftime-trace # Include Clang timing data
%(prog)s build/.ninja_log --granularity 10000 # Custom granularity threshold
"""
""",
)
parser.add_argument(
'ninja_logs',
nargs='+', # Accept one or more ninja log files
help='Path(s) to the .ninja_log file(s)'
"ninja_logs",
nargs="+", # Accept one or more ninja log files
help="Path(s) to the .ninja_log file(s)",
)
parser.add_argument("-o", "--output", help="Output file (default: stdout)")
parser.add_argument(
'-o', '--output',
help='Output file (default: stdout)'
"--show-all", action="store_true", help="Show all builds, not just the last one"
)
parser.add_argument(
'--show-all',
action='store_true',
help='Show all builds, not just the last one'
"--summary", action="store_true", help="Print build summary to stderr"
)
parser.add_argument(
'--summary',
action='store_true',
help='Print build summary to stderr'
"--pretty", action="store_true", help="Pretty-print JSON output"
)
parser.add_argument(
'--pretty',
action='store_true',
help='Pretty-print JSON output'
"--embed-ftime-trace",
action="store_true",
help="Embed Clang -ftime-trace JSON files found adjacent to targets",
)
parser.add_argument(
'--embed-ftime-trace',
action='store_true',
help='Embed Clang -ftime-trace JSON files found adjacent to targets'
)
parser.add_argument(
'--granularity',
"--granularity",
type=int,
default=50000,
help='Minimum duration for -ftime-trace events in microseconds (default: 50000)'
help="Minimum duration for -ftime-trace events in microseconds (default: 50000)",
)
parser.add_argument(
'--legacy-format',
action='store_true',
help='Output in legacy format compatible with old ninjatracer (simple JSON array, all categories as "targets", empty args)'
"--legacy-format",
action="store_true",
help='Output in legacy format compatible with old ninjatracer (simple JSON array, all categories as "targets", empty args)',
)
return parser
@@ -434,75 +467,79 @@ def main():
"""Main entry point."""
parser = create_argument_parser()
args = parser.parse_args()
try:
# Process multiple ninja log files
all_events = []
for pid, ninja_log_path in enumerate(args.ninja_logs):
# Parse ninja log
log_parser = NinjaLogParser(show_all_builds=args.show_all)
targets = log_parser.parse_log_file(ninja_log_path)
if not targets:
print(f"No build targets found in ninja log: {ninja_log_path}", file=sys.stderr)
print(
f"No build targets found in ninja log: {ninja_log_path}",
file=sys.stderr,
)
continue
# Determine ninja log directory for -ftime-trace files
ninja_log_dir = os.path.dirname(os.path.abspath(ninja_log_path)) if args.embed_ftime_trace else None
ninja_log_dir = (
os.path.dirname(os.path.abspath(ninja_log_path))
if args.embed_ftime_trace
else None
)
# Generate trace events for this log file
trace_generator = ChromeTraceGenerator(
process_id=pid, # Use different PID for each log file
embed_ftime_traces=args.embed_ftime_trace,
granularity_us=args.granularity,
ninja_log_dir=ninja_log_dir,
legacy_format=args.legacy_format
legacy_format=args.legacy_format,
)
events = trace_generator.generate_trace_events(targets)
all_events.extend(events)
# Print summary if requested (for each log file)
if args.summary:
print(f"\n=== Summary for {ninja_log_path} ===", file=sys.stderr)
analyzer = BuildAnalyzer(targets)
analyzer.print_summary()
if not all_events:
print("No build targets found in any ninja log files", file=sys.stderr)
return 1
# Output format logic
if args.legacy_format:
# Legacy format: always output simple JSON array
json_kwargs = {'indent': 2} if args.pretty else {}
json_kwargs = {"indent": 2} if args.pretty else {}
json_output = json.dumps(all_events, **json_kwargs)
elif args.output or args.pretty:
# Enhanced format with metadata (when saving to file or pretty printing)
trace_data = {
'traceEvents': all_events,
'displayTimeUnit': 'ms',
'systemTraceEvents': 'SystemTraceData',
'otherData': {
'version': '1.0',
'generator': 'ninja_json_converter.py'
}
"traceEvents": all_events,
"displayTimeUnit": "ms",
"systemTraceEvents": "SystemTraceData",
"otherData": {"version": "1.0", "generator": "ninja_json_converter.py"},
}
json_kwargs = {'indent': 2} if args.pretty else {}
json_kwargs = {"indent": 2} if args.pretty else {}
json_output = json.dumps(trace_data, **json_kwargs)
else:
# Original format (simple JSON array to stdout)
json_output = json.dumps(all_events)
if args.output:
with open(args.output, 'w') as f:
with open(args.output, "w") as f:
f.write(json_output)
print(f"Trace written to {args.output}", file=sys.stderr)
else:
print(json_output)
return 0
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
return 1

View File

@@ -1,13 +1,16 @@
#!/usr/bin/env python3
import os, io, argparse, datetime
#import numpy as np
import os
import io
import argparse
import datetime
# import numpy as np
import sqlalchemy
from sqlalchemy.types import NVARCHAR, Float, Integer
from sqlalchemy import text
import pymysql
import pandas as pd
from sshtunnel import SSHTunnelForwarder
def print_to_string(*args, **kwargs):
output = io.StringIO()
print(*args, file=output, **kwargs)
@@ -15,15 +18,18 @@ def print_to_string(*args, **kwargs):
output.close()
return contents
def parse_args():
parser = argparse.ArgumentParser(description='Parse results from tf benchmark runs')
parser.add_argument('filename', type=str, help='Log file to prase or directory containing log files')
parser = argparse.ArgumentParser(description="Parse results from tf benchmark runs")
parser.add_argument(
"filename", type=str, help="Log file to prase or directory containing log files"
)
args = parser.parse_args()
files = []
if os.path.isdir(args.filename):
all_files = os.listdir(args.filename)
for name in all_files:
if not 'log' in name:
if "log" not in name:
continue
files.append(os.path.join(args.filename, name))
else:
@@ -31,62 +37,76 @@ def parse_args():
args.files = files
return args
def get_log_params(logfile):
print("logfile=",logfile)
branch_name=' '
node_id=' '
gpu_arch=' '
hip_vers=' '
compute_units=0
environment=' '
rocm_vers=' '
print("logfile=", logfile)
branch_name = " "
node_id = " "
gpu_arch = " "
hip_vers = " "
compute_units = 0
environment = " "
rocm_vers = " "
for line in open(logfile):
if 'Branch name' in line:
lst=line.split()
branch_name=lst[2]
if 'On branch' in line:
lst=line.split()
branch_name=lst[2]
if 'Node name' in line:
lst=line.split()
node_id=lst[2]
if 'GPU_arch' in line:
lst=line.split()
gpu_arch=lst[2]
if 'HIP version' in line:
lst=line.split()
hip_vers=lst[2]
if 'Compute Unit' in line:
lst=line.split()
compute_units=lst[2]
if 'Environment type' in line:
lst=line.split()
environment=lst[2]
if 'InstalledDir' in line:
lst=line.split()
rocm_vers=lst[1][lst[1].find('/opt/rocm-')+len('/opt/rocm-'):lst[1].rfind('/llvm/bin')]
return branch_name, node_id, gpu_arch, compute_units, rocm_vers, hip_vers, environment
if "Branch name" in line:
lst = line.split()
branch_name = lst[2]
if "On branch" in line:
lst = line.split()
branch_name = lst[2]
if "Node name" in line:
lst = line.split()
node_id = lst[2]
if "GPU_arch" in line:
lst = line.split()
gpu_arch = lst[2]
if "HIP version" in line:
lst = line.split()
hip_vers = lst[2]
if "Compute Unit" in line:
lst = line.split()
compute_units = lst[2]
if "Environment type" in line:
lst = line.split()
environment = lst[2]
if "InstalledDir" in line:
lst = line.split()
rocm_vers = lst[1][
lst[1].find("/opt/rocm-") + len("/opt/rocm-") : lst[1].rfind(
"/llvm/bin"
)
]
return (
branch_name,
node_id,
gpu_arch,
compute_units,
rocm_vers,
hip_vers,
environment,
)
def parse_logfile(logfile):
glue=''
res=[]
tests=[]
kernels=[]
tflops=[]
dtype=[]
alayout=[]
blayout=[]
M=[]
N=[]
K=[]
StrideA=[]
StrideB=[]
StrideC=[]
if 'perf_gemm' in logfile and 'gemm_bilinear' not in logfile:
glue = ""
res = []
tests = []
kernels = []
tflops = []
dtype = []
alayout = []
blayout = []
M = []
N = []
K = []
StrideA = []
StrideB = []
StrideC = []
if "perf_gemm" in logfile and "gemm_bilinear" not in logfile:
for line in open(logfile):
if 'Best Perf' in line:
lst=line.split()
if len(lst)>=37: #the line is complete
if "Best Perf" in line:
lst = line.split()
if len(lst) >= 37: # the line is complete
tests.append(glue.join(lst[5:30]))
kernels.append(glue.join(lst[37:]))
tflops.append(lst[33])
@@ -99,7 +119,7 @@ def parse_logfile(logfile):
StrideA.append(lst[23])
StrideB.append(lst[26])
StrideC.append(lst[29])
elif len(lst)<37 and len(lst)>=33: #the tflops are available
elif len(lst) < 37 and len(lst) >= 33: # the tflops are available
tests.append(glue.join(lst[5:30]))
kernels.append("N/A")
tflops.append(lst[33])
@@ -112,87 +132,141 @@ def parse_logfile(logfile):
StrideA.append(lst[23])
StrideB.append(lst[26])
StrideC.append(lst[29])
print("warning: incomplete line:",lst)
elif len(lst)<33: #even the tflops are not available
print("warning: incomplete line:", lst)
elif len(lst) < 33: # even the tflops are not available
print("Error in ckProfiler output!")
print("warning: incomplete line=",lst)
#sort results
#sorted_tests = sorted(tests)
res = [x for _,x in sorted(zip(tests,tflops))]
#sorted_kernels = [x for _,x in sorted(zip(tests,kernels))]
test_list=list(range(1,len(tests)+1))
#parse conv_fwd and conv_bwd performance tests:
elif 'conv_fwd' in logfile or 'conv_bwd' in logfile:
print("warning: incomplete line=", lst)
# sort results
# sorted_tests = sorted(tests)
res = [x for _, x in sorted(zip(tests, tflops))]
# sorted_kernels = [x for _,x in sorted(zip(tests,kernels))]
# test_list = list(range(1, len(tests) + 1))
# parse conv_fwd and conv_bwd performance tests:
elif "conv_fwd" in logfile or "conv_bwd" in logfile:
for line in open(logfile):
if 'tflops:' in line:
lst=line.split()
if "tflops:" in line:
lst = line.split()
res.append(lst[1])
#parse all other performance tests:
elif 'resnet50' in logfile or 'batched_gemm' in logfile or 'grouped_gemm' in logfile or 'gemm_bilinear' in logfile or 'reduction' in logfile:
# parse all other performance tests:
elif (
"resnet50" in logfile
or "batched_gemm" in logfile
or "grouped_gemm" in logfile
or "gemm_bilinear" in logfile
or "reduction" in logfile
):
for line in open(logfile):
if 'Best Perf' in line:
lst=line.split()
if "Best Perf" in line:
lst = line.split()
res.append(lst[4])
elif 'onnx_gemm' in logfile:
elif "onnx_gemm" in logfile:
for line in open(logfile):
if 'Best Perf' in line:
lst=line.split()
if "Best Perf" in line:
lst = line.split()
res.append(lst[33])
elif 'splitK_gemm' in logfile or 'mixed_gemm' in logfile:
elif "splitK_gemm" in logfile or "mixed_gemm" in logfile:
for line in open(logfile):
if 'Best Perf' in line:
lst=line.split()
if "Best Perf" in line:
lst = line.split()
res.append(lst[36])
elif 'perf_fmha' in logfile:
elif "perf_fmha" in logfile:
for line in open(logfile):
if 'TFlops' in line:
lst=line.split()
line_dict=dict(zip(lst[1:],lst))
res.append(line_dict['TFlops,'])
elif 'perf_tile_gemm_basic' in logfile or 'perf_tile_gemm_mem_pipeline' in logfile:
if "TFlops" in line:
lst = line.split()
line_dict = dict(zip(lst[1:], lst))
res.append(line_dict["TFlops,"])
elif "perf_tile_gemm_basic" in logfile or "perf_tile_gemm_mem_pipeline" in logfile:
for line in open(logfile):
if 'TFlops' in line:
lst=line.split()
line_dict=dict(zip(lst[1:],lst))
res.append(line_dict['TFlops,'])
if "TFlops" in line:
lst = line.split()
line_dict = dict(zip(lst[1:], lst))
res.append(line_dict["TFlops,"])
return res
def get_baseline(table, connection):
query = text('''SELECT * from '''+table+''' WHERE Datetime = (SELECT MAX(Datetime) FROM '''+table+''' where Branch_ID='develop' );''')
query = text(
"""SELECT * from """
+ table
+ """ WHERE Datetime = (SELECT MAX(Datetime) FROM """
+ table
+ """ where Branch_ID='develop' );"""
)
return pd.read_sql(query, connection)
def store_new_test_result(table_name, test_results, testlist, branch_name, node_id, gpu_arch, compute_units, rocm_vers, hip_vers, environment, connection):
params=[str(branch_name),str(node_id),str(gpu_arch),compute_units,str(rocm_vers),str(hip_vers),str(environment),str(datetime.datetime.now())]
df=pd.DataFrame(data=[params],columns=['Branch_ID','Node_ID','GPU_arch','Compute Units','ROCM_version','HIP_version','Environment','Datetime'])
df_add=pd.DataFrame(data=[test_results],columns=testlist)
df=pd.concat([df,df_add],axis=1)
#print("new test results dataframe:",df)
df.to_sql(table_name,connection,if_exists='append',index=False)
def store_new_test_result(
table_name,
test_results,
testlist,
branch_name,
node_id,
gpu_arch,
compute_units,
rocm_vers,
hip_vers,
environment,
connection,
):
params = [
str(branch_name),
str(node_id),
str(gpu_arch),
compute_units,
str(rocm_vers),
str(hip_vers),
str(environment),
str(datetime.datetime.now()),
]
df = pd.DataFrame(
data=[params],
columns=[
"Branch_ID",
"Node_ID",
"GPU_arch",
"Compute Units",
"ROCM_version",
"HIP_version",
"Environment",
"Datetime",
],
)
df_add = pd.DataFrame(data=[test_results], columns=testlist)
df = pd.concat([df, df_add], axis=1)
# print("new test results dataframe:",df)
df.to_sql(table_name, connection, if_exists="append", index=False)
return 0
def compare_test_to_baseline(baseline,test,testlist):
regression=0
def compare_test_to_baseline(baseline, test, testlist):
regression = 0
if not baseline.empty:
base=baseline[testlist].to_numpy(dtype='float')
base_list=base[0]
ave_perf=0
base = baseline[testlist].to_numpy(dtype="float")
base_list = base[0]
ave_perf = 0
for i in range(len(base_list)):
# success criterion:
if base_list[i]>1.01*float(test[i]):
print("test # ",i,"shows regression by {:.3f}%".format(
(float(test[i])-base_list[i])/base_list[i]*100))
regression=1
if base_list[i]>0: ave_perf=ave_perf+float(test[i])/base_list[i]
if regression==0:
if base_list[i] > 1.01 * float(test[i]):
print(
"test # ",
i,
"shows regression by {:.3f}%".format(
(float(test[i]) - base_list[i]) / base_list[i] * 100
),
)
regression = 1
if base_list[i] > 0:
ave_perf = ave_perf + float(test[i]) / base_list[i]
if regression == 0:
print("no regressions found")
ave_perf=ave_perf/len(base_list)
print("average performance relative to baseline:",ave_perf)
ave_perf = ave_perf / len(base_list)
print("average performance relative to baseline:", ave_perf)
else:
print("could not find a baseline")
return regression
'''
"""
def post_test_params(tlist,connection):
sorted_dtypes = [x for _,x in sorted(zip(tests,dtype))]
sorted_alayout = [x for _,x in sorted(zip(tests,alayout))]
@@ -223,29 +297,38 @@ def post_test_params(tlist,connection):
'StrideC': Integer()
}
df.to_sql("ck_gemm_test_params",connection,if_exists='replace',index=False, dtype=dtypes)
'''
"""
def main():
args = parse_args()
results=[]
tflops_base=[]
testlist=[]
#parse the test parameters from the logfile
results = []
tflops_base = []
testlist = []
# parse the test parameters from the logfile
for filename in args.files:
branch_name, node_id, gpu_arch, compute_units, rocm_vers, hip_vers, environment = get_log_params(filename)
(
branch_name,
node_id,
gpu_arch,
compute_units,
rocm_vers,
hip_vers,
environment,
) = get_log_params(filename)
print("Branch name:",branch_name)
print("Node name:",node_id)
print("GPU_arch:",gpu_arch)
print("Compute units:",compute_units)
print("ROCM_version:",rocm_vers)
print("HIP_version:",hip_vers)
print("Environment:",environment)
#parse results, get the Tflops value for "Best Perf" kernels
results=parse_logfile(filename)
print("Branch name:", branch_name)
print("Node name:", node_id)
print("GPU_arch:", gpu_arch)
print("Compute units:", compute_units)
print("ROCM_version:", rocm_vers)
print("HIP_version:", hip_vers)
print("Environment:", environment)
# parse results, get the Tflops value for "Best Perf" kernels
results = parse_logfile(filename)
print("Number of tests:",len(results))
sql_hostname = '127.0.0.1'
print("Number of tests:", len(results))
sql_hostname = "127.0.0.1"
sql_username = os.environ["dbuser"]
sql_password = os.environ["dbpassword"]
sql_main_database = os.environ["ck_perf_db"]
@@ -256,127 +339,147 @@ def main():
ssh_pass = os.environ["dbsshpassword"]
with SSHTunnelForwarder(
(ssh_host, ssh_port),
ssh_username=ssh_user,
ssh_password=ssh_pass,
remote_bind_address=(sql_hostname, sql_port)) as tunnel:
sqlEngine = sqlalchemy.create_engine('mysql+pymysql://{0}:{1}@{2}:{3}/{4}'.
format(sql_username, sql_password, sql_hostname, tunnel.local_bind_port, sql_main_database))
(ssh_host, ssh_port),
ssh_username=ssh_user,
ssh_password=ssh_pass,
remote_bind_address=(sql_hostname, sql_port),
) as tunnel:
sqlEngine = sqlalchemy.create_engine(
"mysql+pymysql://{0}:{1}@{2}:{3}/{4}".format(
sql_username,
sql_password,
sql_hostname,
tunnel.local_bind_port,
sql_main_database,
)
)
conn = sqlEngine.connect()
#save gemm performance tests:
if 'perf_gemm' in filename and 'gemm_bilinear' not in filename:
#write the ck_gemm_test_params table only needed once the test set changes
#post_test_params(test_list,conn)
for i in range(1,len(results)+1):
testlist.append("Test%i"%i)
table_name="ck_gemm_tflops"
if 'batched_gemm' in filename:
for i in range(1,len(results)+1):
testlist.append("Test%i"%i)
table_name="ck_batched_gemm_tflops"
if 'grouped_gemm' in filename:
for i in range(1,len(results)+1):
testlist.append("Test%i"%i)
table_name="ck_grouped_gemm_tflops"
if 'perf_conv_fwd' in filename:
for i in range(1,len(results)+1):
testlist.append("Test%i"%i)
table_name="ck_conv_fwd_tflops"
if 'perf_conv_bwd_data' in filename:
for i in range(1,len(results)+1):
testlist.append("Test%i"%i)
table_name="ck_conv_bwd_data_tflops"
if 'grouped_conv_fwd' in filename:
for i in range(1,len(results)+1):
testlist.append("Test%i"%i)
table_name="ck_grouped_conv_fwd_tflops"
if 'grouped_conv_bwd_data' in filename:
for i in range(1,len(results)+1):
testlist.append("Test%i"%i)
table_name="ck_grouped_conv_bwd_data_tflops"
if 'grouped_conv_bwd_weight' in filename:
for i in range(1,len(results)+1):
testlist.append("Test%i"%i)
table_name="ck_grouped_conv_bwd_weight_tflops"
if 'gemm_bilinear' in filename:
for i in range(1,len(results)+1):
testlist.append("Test%i"%i)
table_name="ck_gemm_bilinear_tflops"
if 'reduction' in filename:
for i in range(1,len(results)+1):
testlist.append("Test%i"%i)
table_name="ck_reduction_GBps"
if 'resnet50_N4' in filename:
for i in range(1,50):
testlist.append("Layer%i"%i)
table_name="ck_resnet50_N4_tflops"
if 'resnet50_N256' in filename:
for i in range(1,50):
testlist.append("Layer%i"%i)
table_name="ck_resnet50_N256_tflops"
if 'onnx_gemm' in filename:
for i in range(1,len(results)+1):
testlist.append("Test%i"%i)
table_name="ck_onnx_gemm_tflops"
if 'splitK_gemm' in filename:
for i in range(1,len(results)+1):
testlist.append("Test%i"%i)
table_name="ck_splitK_gemm_tflops"
if 'mixed_gemm' in filename:
for i in range(1,len(results)+1):
testlist.append("Test%i"%i)
table_name="ck_mixed_gemm_tflops"
if 'fmha_fwd' in filename:
for i in range(1,len(results)+1):
testlist.append("Test%i"%i)
table_name="ck_fmha_fwd_tflops"
if 'fmha_bwd' in filename:
for i in range(1,len(results)+1):
testlist.append("Test%i"%i)
table_name="ck_fmha_bwd_tflops"
if 'gemm_basic_fp16' in filename:
for i in range(1, len(results)+1):
testlist.append("Test%i"%i)
table_name="ck_tile_gemm_basic_fp16_tflops"
if 'gemm_mem_pipeline_fp16' in filename:
for i in range(1, len(results)+1):
testlist.append("Test%i"%i)
table_name="ck_tile_gemm_mem_pipeline_fp16_tflops"
if 'gemm_basic_bf16' in filename:
for i in range(1, len(results)+1):
testlist.append("Test%i"%i)
table_name="ck_tile_gemm_basic_bf16_tflops"
if 'gemm_mem_pipeline_bf16' in filename:
for i in range(1, len(results)+1):
testlist.append("Test%i"%i)
table_name="ck_tile_gemm_mem_pipeline_bf16_tflops"
if 'gemm_basic_fp8' in filename:
for i in range(1, len(results)+1):
testlist.append("Test%i"%i)
table_name="ck_tile_gemm_basic_fp8_tflops"
if 'gemm_mem_pipeline_fp8' in filename:
for i in range(1, len(results)+1):
testlist.append("Test%i"%i)
table_name="ck_tile_gemm_mem_pipeline_fp8_tflops"
if 'gemm_basic_bf8' in filename:
for i in range(1, len(results)+1):
testlist.append("Test%i"%i)
table_name="ck_tile_gemm_basic_bf8_tflops"
if 'gemm_mem_pipeline_bf8' in filename:
for i in range(1, len(results)+1):
testlist.append("Test%i"%i)
table_name="ck_tile_gemm_mem_pipeline_bf8_tflops"
# save gemm performance tests:
if "perf_gemm" in filename and "gemm_bilinear" not in filename:
# write the ck_gemm_test_params table only needed once the test set changes
# post_test_params(test_list,conn)
for i in range(1, len(results) + 1):
testlist.append("Test%i" % i)
table_name = "ck_gemm_tflops"
if "batched_gemm" in filename:
for i in range(1, len(results) + 1):
testlist.append("Test%i" % i)
table_name = "ck_batched_gemm_tflops"
if "grouped_gemm" in filename:
for i in range(1, len(results) + 1):
testlist.append("Test%i" % i)
table_name = "ck_grouped_gemm_tflops"
if "perf_conv_fwd" in filename:
for i in range(1, len(results) + 1):
testlist.append("Test%i" % i)
table_name = "ck_conv_fwd_tflops"
if "perf_conv_bwd_data" in filename:
for i in range(1, len(results) + 1):
testlist.append("Test%i" % i)
table_name = "ck_conv_bwd_data_tflops"
if "grouped_conv_fwd" in filename:
for i in range(1, len(results) + 1):
testlist.append("Test%i" % i)
table_name = "ck_grouped_conv_fwd_tflops"
if "grouped_conv_bwd_data" in filename:
for i in range(1, len(results) + 1):
testlist.append("Test%i" % i)
table_name = "ck_grouped_conv_bwd_data_tflops"
if "grouped_conv_bwd_weight" in filename:
for i in range(1, len(results) + 1):
testlist.append("Test%i" % i)
table_name = "ck_grouped_conv_bwd_weight_tflops"
if "gemm_bilinear" in filename:
for i in range(1, len(results) + 1):
testlist.append("Test%i" % i)
table_name = "ck_gemm_bilinear_tflops"
if "reduction" in filename:
for i in range(1, len(results) + 1):
testlist.append("Test%i" % i)
table_name = "ck_reduction_GBps"
if "resnet50_N4" in filename:
for i in range(1, 50):
testlist.append("Layer%i" % i)
table_name = "ck_resnet50_N4_tflops"
if "resnet50_N256" in filename:
for i in range(1, 50):
testlist.append("Layer%i" % i)
table_name = "ck_resnet50_N256_tflops"
if "onnx_gemm" in filename:
for i in range(1, len(results) + 1):
testlist.append("Test%i" % i)
table_name = "ck_onnx_gemm_tflops"
if "splitK_gemm" in filename:
for i in range(1, len(results) + 1):
testlist.append("Test%i" % i)
table_name = "ck_splitK_gemm_tflops"
if "mixed_gemm" in filename:
for i in range(1, len(results) + 1):
testlist.append("Test%i" % i)
table_name = "ck_mixed_gemm_tflops"
if "fmha_fwd" in filename:
for i in range(1, len(results) + 1):
testlist.append("Test%i" % i)
table_name = "ck_fmha_fwd_tflops"
if "fmha_bwd" in filename:
for i in range(1, len(results) + 1):
testlist.append("Test%i" % i)
table_name = "ck_fmha_bwd_tflops"
if "gemm_basic_fp16" in filename:
for i in range(1, len(results) + 1):
testlist.append("Test%i" % i)
table_name = "ck_tile_gemm_basic_fp16_tflops"
if "gemm_mem_pipeline_fp16" in filename:
for i in range(1, len(results) + 1):
testlist.append("Test%i" % i)
table_name = "ck_tile_gemm_mem_pipeline_fp16_tflops"
if "gemm_basic_bf16" in filename:
for i in range(1, len(results) + 1):
testlist.append("Test%i" % i)
table_name = "ck_tile_gemm_basic_bf16_tflops"
if "gemm_mem_pipeline_bf16" in filename:
for i in range(1, len(results) + 1):
testlist.append("Test%i" % i)
table_name = "ck_tile_gemm_mem_pipeline_bf16_tflops"
if "gemm_basic_fp8" in filename:
for i in range(1, len(results) + 1):
testlist.append("Test%i" % i)
table_name = "ck_tile_gemm_basic_fp8_tflops"
if "gemm_mem_pipeline_fp8" in filename:
for i in range(1, len(results) + 1):
testlist.append("Test%i" % i)
table_name = "ck_tile_gemm_mem_pipeline_fp8_tflops"
if "gemm_basic_bf8" in filename:
for i in range(1, len(results) + 1):
testlist.append("Test%i" % i)
table_name = "ck_tile_gemm_basic_bf8_tflops"
if "gemm_mem_pipeline_bf8" in filename:
for i in range(1, len(results) + 1):
testlist.append("Test%i" % i)
table_name = "ck_tile_gemm_mem_pipeline_bf8_tflops"
tflops_base = get_baseline(table_name,conn)
store_new_test_result(table_name, results, testlist, branch_name, node_id, gpu_arch, compute_units, rocm_vers, hip_vers, environment, sqlEngine)
tflops_base = get_baseline(table_name, conn)
store_new_test_result(
table_name,
results,
testlist,
branch_name,
node_id,
gpu_arch,
compute_units,
rocm_vers,
hip_vers,
environment,
sqlEngine,
)
conn.close()
#compare the results to the baseline if baseline exists
regression=0
regression=compare_test_to_baseline(tflops_base,results,testlist)
# compare the results to the baseline if baseline exists
regression = 0
regression = compare_test_to_baseline(tflops_base, results, testlist)
return regression
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@@ -2,18 +2,6 @@
# Copyright © Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# Get list of staged files
STAGED_FILES=$(git diff --cached --name-only)
# Check if any staged file is under include/ck_tile/ or example/ck_tile/
if echo "$STAGED_FILES" | grep -qE '^(include/ck_tile/|example/ck_tile/)'; then
echo "Detected changes in ck_tile-related files. Running remod.py..."
# Run remod.py in both required locations
(cd include/ck_tile/ && python3 remod.py)
(cd example/ck_tile/ && python3 remod.py)
echo "remod.py completed."
else
echo "No changes in ck_tile-related files. Skipping remod.py."
fi
# Run remod.py in both required locations
(cd include/ck_tile/ && python3 remod.py)
(cd example/ck_tile/ && python3 remod.py)

View File

@@ -71,7 +71,7 @@ def tuples(filename):
try:
m, n, k = map(int, line)
lines.append((m, n, k))
except:
except Exception:
pass
return lines
@@ -163,19 +163,19 @@ def run_shape(shape, profiler_bin, op_name, dtype, layout):
m, n, k = shape
try:
op = OPs[op_name]
except:
except KeyError:
raise AssertionError(f"Invalid operator {op_name}")
name_arg = op.name
op_wrapper = op.value()
try:
dtype_arg = str(op_wrapper.dtype[dtype].value)
except:
except KeyError:
raise AssertionError(f"Invalid dtype for {op_name}: {dtype}")
try:
layout_wrapper = op_wrapper.layout[layout]
except:
except KeyError:
raise AssertionError(f"Invalid layout for {op_name}: {layout}")
layout_arg = str(layout_wrapper.value)
# verification: no, initialization: decimal, print tensor: no, time kernel: yes
@@ -286,7 +286,9 @@ def main():
try:
from tqdm import tqdm as iterate
except ImportError:
iterate = lambda x: x
def iterate(x):
return x
for s in iterate(shapes):
run_shape_stdout_lines = run_shape(

File diff suppressed because it is too large Load Diff

View File

@@ -9,7 +9,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/pool.hpp"
#include "ck_tile/ops/pooling.hpp"
#include "ck_tile/host/reference/reference_pool.hpp"
#include "ck_tile/host/kernel_launch.hpp"

File diff suppressed because it is too large Load Diff

View File

@@ -10,28 +10,37 @@ and saves them as CSV files that can be read by the shell script.
"""
import csv
import itertools
import argparse
def generate_2d_configs(mode='full'):
def generate_2d_configs(mode="full"):
"""Generate all 2D model configuration combinations
Args:
mode: 'small' for minimal set (~50 configs), 'half' for reduced set (~250 configs), 'full' for comprehensive set (~500 configs)
"""
# Define parameter ranges
models_2d = [
'resnet18', 'resnet34', 'resnet50',
'mobilenet_v2', 'mobilenet_v3_large', 'mobilenet_v3_small',
'vgg11', 'vgg16', 'vgg19',
'alexnet', 'googlenet',
'densenet121', 'densenet161',
'squeezenet1_0', 'squeezenet1_1',
'shufflenet_v2_x1_0'
"resnet18",
"resnet34",
"resnet50",
"mobilenet_v2",
"mobilenet_v3_large",
"mobilenet_v3_small",
"vgg11",
"vgg16",
"vgg19",
"alexnet",
"googlenet",
"densenet121",
"densenet161",
"squeezenet1_0",
"squeezenet1_1",
"shufflenet_v2_x1_0",
]
if mode == 'small':
if mode == "small":
# Minimal set for quick testing
batch_sizes = [1, 8] # Just two batch sizes
# Very limited input dimensions - only 2 key sizes
@@ -41,12 +50,12 @@ def generate_2d_configs(mode='full'):
]
# Use only first 3 models for minimal testing
models_2d = models_2d[:3] # Only resnet18, resnet34, resnet50
elif mode == 'half':
elif mode == "half":
# Reduced set for faster testing
batch_sizes = [1, 8, 32] # Small, medium, large
# Reduced input dimensions - 5 key sizes
input_dims = [
(64, 64), # Small
(64, 64), # Small
(224, 224), # Standard (most common)
(512, 512), # Large
(224, 320), # Rectangular
@@ -57,18 +66,23 @@ def generate_2d_configs(mode='full'):
batch_sizes = [1, 4, 8, 16, 32]
# More dimensions but skip some redundant ones
input_dims = [
(64, 64), (128, 128), (224, 224), (256, 256), (512, 512), # Square
(224, 320), (320, 224), # Rectangular (reduced from 4)
(64, 64),
(128, 128),
(224, 224),
(256, 256),
(512, 512), # Square
(224, 320),
(320, 224), # Rectangular (reduced from 4)
(227, 227), # AlexNet preferred
(299, 299) # Inception preferred
(299, 299), # Inception preferred
]
precisions = ['fp32'] #, 'fp16', 'bf16']
precisions = ["fp32"] # , 'fp16', 'bf16']
channels = [3] # Most models expect RGB
configs = []
config_id = 1
# Generate all combinations (but limit to reasonable subset)
for model in models_2d:
for batch_size in batch_sizes:
@@ -77,36 +91,37 @@ def generate_2d_configs(mode='full'):
# Skip some combinations to keep dataset manageable
if batch_size > 16 and height > 256:
continue # Skip large batch + large image combinations
if precision != 'fp32' and batch_size < 8:
if precision != "fp32" and batch_size < 8:
continue # Skip mixed precision with tiny batches
config_name = f"{model}_b{batch_size}_{height}x{width}_{precision}"
config = {
'config_name': config_name,
'model': model,
'batch_size': batch_size,
'channels': channels[0],
'height': height,
'width': width,
'precision': precision
"config_name": config_name,
"model": model,
"batch_size": batch_size,
"channels": channels[0],
"height": height,
"width": width,
"precision": precision,
}
configs.append(config)
config_id += 1
return configs
def generate_3d_configs(mode='full'):
def generate_3d_configs(mode="full"):
"""Generate all 3D model configuration combinations
Args:
mode: 'small' for minimal set (~10 configs), 'half' for reduced set (~50 configs), 'full' for comprehensive set (~100 configs)
"""
models_3d = ['r3d_18', 'mc3_18', 'r2plus1d_18']
if mode == 'small':
models_3d = ["r3d_18", "mc3_18", "r2plus1d_18"]
if mode == "small":
# Minimal set for quick testing
batch_sizes = [1, 4] # Just two batch sizes
temporal_sizes = [8] # Only smallest temporal size
@@ -116,7 +131,7 @@ def generate_3d_configs(mode='full'):
]
# Use only first model for minimal testing
models_3d = models_3d[:1] # Only r3d_18
elif mode == 'half':
elif mode == "half":
# Reduced set for faster testing
batch_sizes = [1, 4, 8] # Skip batch_size=2
temporal_sizes = [8, 16] # Skip 32 (most expensive)
@@ -124,7 +139,7 @@ def generate_3d_configs(mode='full'):
input_dims = [
(112, 112), # Small (common for video)
(224, 224), # Standard
(224, 320) # Rectangular
(224, 320), # Rectangular
]
else: # full mode
# More comprehensive but still reasonable
@@ -132,15 +147,18 @@ def generate_3d_configs(mode='full'):
temporal_sizes = [8, 16, 32]
# More dimensions
input_dims = [
(112, 112), (224, 224), (256, 256), # Standard sizes
(224, 320), (320, 224) # Rectangular
(112, 112),
(224, 224),
(256, 256), # Standard sizes
(224, 320),
(320, 224), # Rectangular
]
precisions = ['fp32'] #, 'fp16'] # Skip bf16 for 3D to reduce combinations
precisions = ["fp32"] # , 'fp16'] # Skip bf16 for 3D to reduce combinations
channels = [3]
configs = []
for model in models_3d:
for batch_size in batch_sizes:
for temporal_size in temporal_sizes:
@@ -151,75 +169,97 @@ def generate_3d_configs(mode='full'):
continue
if batch_size > 2 and height > 224:
continue
config_name = f"{model}_b{batch_size}_t{temporal_size}_{height}x{width}_{precision}"
config = {
'config_name': config_name,
'model': model,
'batch_size': batch_size,
'channels': channels[0],
'temporal_size': temporal_size,
'height': height,
'width': width,
'precision': precision
}
"config_name": config_name,
"model": model,
"batch_size": batch_size,
"channels": channels[0],
"temporal_size": temporal_size,
"height": height,
"width": width,
"precision": precision,
}
configs.append(config)
return configs
def save_configs_to_csv(configs, filename, config_type):
"""Save configurations to CSV file"""
if not configs:
print(f"No {config_type} configurations generated")
return
fieldnames = list(configs[0].keys())
with open(filename, 'w', newline='\n', encoding='utf-8') as csvfile:
with open(filename, "w", newline="\n", encoding="utf-8") as csvfile:
csvfile.write(f"# {config_type} Model Configurations\n")
csvfile.write(f"# Generated {len(configs)} configurations\n")
writer = csv.DictWriter(csvfile, fieldnames=fieldnames, lineterminator='\n')
writer = csv.DictWriter(csvfile, fieldnames=fieldnames, lineterminator="\n")
writer.writeheader()
for config in configs:
writer.writerow(config)
print(f"Generated {len(configs)} {config_type} configurations → {filename}")
def main():
parser = argparse.ArgumentParser(description='Generate model configuration combinations')
parser.add_argument('--output-2d', type=str, default='model_configs_2d.csv',
help='Output file for 2D configurations')
parser.add_argument('--output-3d', type=str, default='model_configs_3d.csv',
help='Output file for 3D configurations')
parser.add_argument('--mode', choices=['small', 'half', 'full'], default='full',
help='Configuration mode: small (~60 total), half (~300 total) or full (~600 total) (default: half)')
parser.add_argument('--limit', type=int,
help='Limit number of configurations per type (for testing)')
parser = argparse.ArgumentParser(
description="Generate model configuration combinations"
)
parser.add_argument(
"--output-2d",
type=str,
default="model_configs_2d.csv",
help="Output file for 2D configurations",
)
parser.add_argument(
"--output-3d",
type=str,
default="model_configs_3d.csv",
help="Output file for 3D configurations",
)
parser.add_argument(
"--mode",
choices=["small", "half", "full"],
default="full",
help="Configuration mode: small (~60 total), half (~300 total) or full (~600 total) (default: half)",
)
parser.add_argument(
"--limit",
type=int,
help="Limit number of configurations per type (for testing)",
)
args = parser.parse_args()
print(f"Generating {args.mode} model configurations...")
print("Generating 2D model configurations...")
configs_2d = generate_2d_configs(mode=args.mode)
if args.limit:
configs_2d = configs_2d[:args.limit]
configs_2d = configs_2d[: args.limit]
save_configs_to_csv(configs_2d, args.output_2d, "2D")
print("Generating 3D model configurations...")
configs_3d = generate_3d_configs(mode=args.mode)
if args.limit:
configs_3d = configs_3d[:args.limit]
configs_3d = configs_3d[: args.limit]
save_configs_to_csv(configs_3d, args.output_3d, "3D")
print(f"\nTotal configurations: {len(configs_2d)} 2D + {len(configs_3d)} 3D = {len(configs_2d) + len(configs_3d)}")
print(
f"\nTotal configurations: {len(configs_2d)} 2D + {len(configs_3d)} 3D = {len(configs_2d) + len(configs_3d)}"
)
print("\nTo use these configurations:")
print(" Update generate_test_dataset.sh to read from these CSV files")
if __name__ == "__main__":
main()

View File

@@ -18,301 +18,428 @@ import csv
import re
import os
def parse_miopen_command(command_line):
"""
Parse MIOpen driver command line into parameter dictionary
Example input:
./bin/MIOpenDriver conv -n 4 -c 3 -H 224 -W 224 -k 64 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -m conv -g 1 -F 1 -t 1
Returns dict with parsed parameters or None if parsing fails
"""
if not command_line.strip().startswith('./bin/MIOpenDriver conv'):
if not command_line.strip().startswith("./bin/MIOpenDriver conv"):
return None
# Extract parameters using regex
params = {}
# Parameter mapping: flag -> description
# Support both short (-D) and long (--in_d) parameter formats
param_patterns = {
'n': r'-n\s+(\d+)', # batch size
'c': r'-c\s+(\d+)', # input channels
'k': r'-k\s+(\d+)', # output channels
'H': r'-H\s+(\d+)', # input height
'W': r'-W\s+(\d+)', # input width
'D': r'(?:-D|--in_d)\s+(\d+)', # input depth (3D only) - supports both -D and --in_d
'y': r'-y\s+(\d+)', # kernel height
'x': r'-x\s+(\d+)', # kernel width
'z': r'(?:-z|--fil_d)\s+(\d+)', # kernel depth (3D only) - supports both -z and --fil_d
'u': r'-u\s+(\d+)', # stride height
'v': r'-v\s+(\d+)', # stride width
'w': r'(?:-w|--conv_stride_d)\s+(\d+)', # stride depth (3D only) - supports both -w and --conv_stride_d
'p': r'-p\s+(\d+)', # pad height
'q': r'-q\s+(\d+)', # pad width
's': r'(?:-s|--pad_d)\s+(\d+)', # pad depth (3D only) - supports both -s and --pad_d
'l': r'-l\s+(\d+)', # dilation height
'j': r'-j\s+(\d+)', # dilation width
'r': r'(?:-r|--dilation_d)\s+(\d+)', # dilation depth (3D only) - supports both -r and --dilation_d
'g': r'-g\s+(\d+)', # groups
'F': r'-F\s+(\d+)', # direction (1=fwd, 2=bwd_weight, 4=bwd_data)
"n": r"-n\s+(\d+)", # batch size
"c": r"-c\s+(\d+)", # input channels
"k": r"-k\s+(\d+)", # output channels
"H": r"-H\s+(\d+)", # input height
"W": r"-W\s+(\d+)", # input width
"D": r"(?:-D|--in_d)\s+(\d+)", # input depth (3D only) - supports both -D and --in_d
"y": r"-y\s+(\d+)", # kernel height
"x": r"-x\s+(\d+)", # kernel width
"z": r"(?:-z|--fil_d)\s+(\d+)", # kernel depth (3D only) - supports both -z and --fil_d
"u": r"-u\s+(\d+)", # stride height
"v": r"-v\s+(\d+)", # stride width
"w": r"(?:-w|--conv_stride_d)\s+(\d+)", # stride depth (3D only) - supports both -w and --conv_stride_d
"p": r"-p\s+(\d+)", # pad height
"q": r"-q\s+(\d+)", # pad width
"s": r"(?:-s|--pad_d)\s+(\d+)", # pad depth (3D only) - supports both -s and --pad_d
"l": r"-l\s+(\d+)", # dilation height
"j": r"-j\s+(\d+)", # dilation width
"r": r"(?:-r|--dilation_d)\s+(\d+)", # dilation depth (3D only) - supports both -r and --dilation_d
"g": r"-g\s+(\d+)", # groups
"F": r"-F\s+(\d+)", # direction (1=fwd, 2=bwd_weight, 4=bwd_data)
}
for param, pattern in param_patterns.items():
match = re.search(pattern, command_line)
if match:
params[param] = int(match.group(1))
return params if params else None
def miopen_to_conv_param(miopen_params):
"""
Convert MIOpen parameters to CK ConvParam format
Returns dictionary in CSV format or None if conversion fails
"""
if not miopen_params:
return None
# Determine if 2D or 3D convolution
is_3d = 'D' in miopen_params or 'z' in miopen_params or 'w' in miopen_params or 'r' in miopen_params or 's' in miopen_params
is_3d = (
"D" in miopen_params
or "z" in miopen_params
or "w" in miopen_params
or "r" in miopen_params
or "s" in miopen_params
)
# Extract basic parameters with defaults
ndim = 3 if is_3d else 2
groups = miopen_params.get('g', 1)
batch_size = miopen_params.get('n', 1)
groups = miopen_params.get("g", 1)
batch_size = miopen_params.get("n", 1)
# MIOpen uses total channels (C*G), CK uses channels per group
out_channels_total = miopen_params.get('k', 64)
in_channels_total = miopen_params.get('c', 3)
out_channels_total = miopen_params.get("k", 64)
in_channels_total = miopen_params.get("c", 3)
out_channels = out_channels_total // groups # CK format: channels per group
in_channels = in_channels_total // groups # CK format: channels per group
in_channels = in_channels_total // groups # CK format: channels per group
if is_3d:
# 3D convolution
kernel_d = miopen_params.get('z', 3)
kernel_h = miopen_params.get('y', 3)
kernel_w = miopen_params.get('x', 3)
input_d = miopen_params.get('D', 16)
input_h = miopen_params.get('H', 32)
input_w = miopen_params.get('W', 32)
stride_d = miopen_params.get('w', 1)
stride_h = miopen_params.get('u', 1)
stride_w = miopen_params.get('v', 1)
dilation_d = miopen_params.get('r', 1)
dilation_h = miopen_params.get('l', 1)
dilation_w = miopen_params.get('j', 1)
pad_d = miopen_params.get('s', 0)
pad_h = miopen_params.get('p', 0)
pad_w = miopen_params.get('q', 0)
kernel_d = miopen_params.get("z", 3)
kernel_h = miopen_params.get("y", 3)
kernel_w = miopen_params.get("x", 3)
input_d = miopen_params.get("D", 16)
input_h = miopen_params.get("H", 32)
input_w = miopen_params.get("W", 32)
stride_d = miopen_params.get("w", 1)
stride_h = miopen_params.get("u", 1)
stride_w = miopen_params.get("v", 1)
dilation_d = miopen_params.get("r", 1)
dilation_h = miopen_params.get("l", 1)
dilation_w = miopen_params.get("j", 1)
pad_d = miopen_params.get("s", 0)
pad_h = miopen_params.get("p", 0)
pad_w = miopen_params.get("q", 0)
# Calculate output dimensions
output_d = (input_d + 2 * pad_d - dilation_d * (kernel_d - 1) - 1) // stride_d + 1
output_h = (input_h + 2 * pad_h - dilation_h * (kernel_h - 1) - 1) // stride_h + 1
output_w = (input_w + 2 * pad_w - dilation_w * (kernel_w - 1) - 1) // stride_w + 1
output_d = (
input_d + 2 * pad_d - dilation_d * (kernel_d - 1) - 1
) // stride_d + 1
output_h = (
input_h + 2 * pad_h - dilation_h * (kernel_h - 1) - 1
) // stride_h + 1
output_w = (
input_w + 2 * pad_w - dilation_w * (kernel_w - 1) - 1
) // stride_w + 1
# Skip invalid configurations
if output_d <= 0 or output_h <= 0 or output_w <= 0:
return None
direction = miopen_params.get('F', 1) # 1=fwd, 2=bwd_weight, 4=bwd_data
direction_name = {1: 'fwd', 2: 'bwd_weight', 4: 'bwd_data'}.get(direction, 'fwd')
direction = miopen_params.get("F", 1) # 1=fwd, 2=bwd_weight, 4=bwd_data
direction_name = {1: "fwd", 2: "bwd_weight", 4: "bwd_data"}.get(
direction, "fwd"
)
return {
'NDim': ndim,
'Groups': groups,
'BatchSize': batch_size,
'OutChannels': out_channels,
'InChannels': in_channels,
'KernelD': kernel_d, 'KernelH': kernel_h, 'KernelW': kernel_w,
'InputD': input_d, 'InputH': input_h, 'InputW': input_w,
'OutputD': output_d, 'OutputH': output_h, 'OutputW': output_w,
'StrideD': stride_d, 'StrideH': stride_h, 'StrideW': stride_w,
'DilationD': dilation_d, 'DilationH': dilation_h, 'DilationW': dilation_w,
'LeftPadD': pad_d, 'LeftPadH': pad_h, 'LeftPadW': pad_w,
'RightPadD': pad_d, 'RightPadH': pad_h, 'RightPadW': pad_w,
'TestName': f'MIOpen_3D_{direction_name}'
"NDim": ndim,
"Groups": groups,
"BatchSize": batch_size,
"OutChannels": out_channels,
"InChannels": in_channels,
"KernelD": kernel_d,
"KernelH": kernel_h,
"KernelW": kernel_w,
"InputD": input_d,
"InputH": input_h,
"InputW": input_w,
"OutputD": output_d,
"OutputH": output_h,
"OutputW": output_w,
"StrideD": stride_d,
"StrideH": stride_h,
"StrideW": stride_w,
"DilationD": dilation_d,
"DilationH": dilation_h,
"DilationW": dilation_w,
"LeftPadD": pad_d,
"LeftPadH": pad_h,
"LeftPadW": pad_w,
"RightPadD": pad_d,
"RightPadH": pad_h,
"RightPadW": pad_w,
"TestName": f"MIOpen_3D_{direction_name}",
}
else:
# 2D convolution
kernel_h = miopen_params.get('y', 3)
kernel_w = miopen_params.get('x', 3)
input_h = miopen_params.get('H', 32)
input_w = miopen_params.get('W', 32)
stride_h = miopen_params.get('u', 1)
stride_w = miopen_params.get('v', 1)
dilation_h = miopen_params.get('l', 1)
dilation_w = miopen_params.get('j', 1)
pad_h = miopen_params.get('p', 0)
pad_w = miopen_params.get('q', 0)
kernel_h = miopen_params.get("y", 3)
kernel_w = miopen_params.get("x", 3)
input_h = miopen_params.get("H", 32)
input_w = miopen_params.get("W", 32)
stride_h = miopen_params.get("u", 1)
stride_w = miopen_params.get("v", 1)
dilation_h = miopen_params.get("l", 1)
dilation_w = miopen_params.get("j", 1)
pad_h = miopen_params.get("p", 0)
pad_w = miopen_params.get("q", 0)
# Calculate output dimensions
output_h = (input_h + 2 * pad_h - dilation_h * (kernel_h - 1) - 1) // stride_h + 1
output_w = (input_w + 2 * pad_w - dilation_w * (kernel_w - 1) - 1) // stride_w + 1
output_h = (
input_h + 2 * pad_h - dilation_h * (kernel_h - 1) - 1
) // stride_h + 1
output_w = (
input_w + 2 * pad_w - dilation_w * (kernel_w - 1) - 1
) // stride_w + 1
# Skip invalid configurations
if output_h <= 0 or output_w <= 0:
return None
direction = miopen_params.get('F', 1)
direction_name = {1: 'fwd', 2: 'bwd_weight', 4: 'bwd_data'}.get(direction, 'fwd')
direction = miopen_params.get("F", 1)
direction_name = {1: "fwd", 2: "bwd_weight", 4: "bwd_data"}.get(
direction, "fwd"
)
return {
'NDim': ndim,
'Groups': groups,
'BatchSize': batch_size,
'OutChannels': out_channels,
'InChannels': in_channels,
'KernelH': kernel_h, 'KernelW': kernel_w,
'InputH': input_h, 'InputW': input_w,
'OutputH': output_h, 'OutputW': output_w,
'StrideH': stride_h, 'StrideW': stride_w,
'DilationH': dilation_h, 'DilationW': dilation_w,
'LeftPadH': pad_h, 'LeftPadW': pad_w,
'RightPadH': pad_h, 'RightPadW': pad_w,
'TestName': f'MIOpen_2D_{direction_name}'
"NDim": ndim,
"Groups": groups,
"BatchSize": batch_size,
"OutChannels": out_channels,
"InChannels": in_channels,
"KernelH": kernel_h,
"KernelW": kernel_w,
"InputH": input_h,
"InputW": input_w,
"OutputH": output_h,
"OutputW": output_w,
"StrideH": stride_h,
"StrideW": stride_w,
"DilationH": dilation_h,
"DilationW": dilation_w,
"LeftPadH": pad_h,
"LeftPadW": pad_w,
"RightPadH": pad_h,
"RightPadW": pad_w,
"TestName": f"MIOpen_2D_{direction_name}",
}
def write_csv_cases(test_cases, output_file, ndim):
"""Write test cases to CSV file"""
if not test_cases:
print(f"No {ndim}D test cases to write")
return
print(f"Writing {len(test_cases)} {ndim}D test cases to {output_file}")
# Define CSV headers based on dimension
if ndim == 2:
headers = ['NDim', 'Groups', 'BatchSize', 'OutChannels', 'InChannels',
'KernelH', 'KernelW', 'InputH', 'InputW', 'OutputH', 'OutputW',
'StrideH', 'StrideW', 'DilationH', 'DilationW',
'LeftPadH', 'LeftPadW', 'RightPadH', 'RightPadW', 'TestName']
headers = [
"NDim",
"Groups",
"BatchSize",
"OutChannels",
"InChannels",
"KernelH",
"KernelW",
"InputH",
"InputW",
"OutputH",
"OutputW",
"StrideH",
"StrideW",
"DilationH",
"DilationW",
"LeftPadH",
"LeftPadW",
"RightPadH",
"RightPadW",
"TestName",
]
else: # 3D
headers = ['NDim', 'Groups', 'BatchSize', 'OutChannels', 'InChannels',
'KernelD', 'KernelH', 'KernelW', 'InputD', 'InputH', 'InputW',
'OutputD', 'OutputH', 'OutputW', 'StrideD', 'StrideH', 'StrideW',
'DilationD', 'DilationH', 'DilationW',
'LeftPadD', 'LeftPadH', 'LeftPadW', 'RightPadD', 'RightPadH', 'RightPadW', 'TestName']
with open(output_file, 'w', newline='') as csvfile:
headers = [
"NDim",
"Groups",
"BatchSize",
"OutChannels",
"InChannels",
"KernelD",
"KernelH",
"KernelW",
"InputD",
"InputH",
"InputW",
"OutputD",
"OutputH",
"OutputW",
"StrideD",
"StrideH",
"StrideW",
"DilationD",
"DilationH",
"DilationW",
"LeftPadD",
"LeftPadH",
"LeftPadW",
"RightPadD",
"RightPadH",
"RightPadW",
"TestName",
]
with open(output_file, "w", newline="") as csvfile:
# Write header comment
csvfile.write(f"# {ndim}D Convolution Test Cases from MIOpen Commands\n")
csvfile.write(f"# Generated {len(test_cases)} test cases\n")
writer = csv.DictWriter(csvfile, fieldnames=headers)
writer.writeheader()
for test_case in test_cases:
# Only write fields that exist in headers
filtered_case = {k: v for k, v in test_case.items() if k in headers}
writer.writerow(filtered_case)
def main():
parser = argparse.ArgumentParser(description='Convert MIOpen commands to CSV test cases')
parser.add_argument('--input', type=str, required=True,
help='Input file with MIOpen driver commands')
parser.add_argument('--output', type=str,
help='Output CSV file (for mixed 2D/3D cases)')
parser.add_argument('--output-2d', type=str, default='miopen_conv_2d.csv',
help='Output CSV file for 2D cases')
parser.add_argument('--output-3d', type=str, default='miopen_conv_3d.csv',
help='Output CSV file for 3D cases')
parser.add_argument('--filter-duplicates', action='store_true',
help='Remove duplicate test cases')
parser.add_argument('--model-name', type=str, default='MIOpen',
help='Model name to use in test case names (default: MIOpen)')
parser = argparse.ArgumentParser(
description="Convert MIOpen commands to CSV test cases"
)
parser.add_argument(
"--input",
type=str,
required=True,
help="Input file with MIOpen driver commands",
)
parser.add_argument(
"--output", type=str, help="Output CSV file (for mixed 2D/3D cases)"
)
parser.add_argument(
"--output-2d",
type=str,
default="miopen_conv_2d.csv",
help="Output CSV file for 2D cases",
)
parser.add_argument(
"--output-3d",
type=str,
default="miopen_conv_3d.csv",
help="Output CSV file for 3D cases",
)
parser.add_argument(
"--filter-duplicates", action="store_true", help="Remove duplicate test cases"
)
parser.add_argument(
"--model-name",
type=str,
default="MIOpen",
help="Model name to use in test case names (default: MIOpen)",
)
args = parser.parse_args()
if not os.path.exists(args.input):
print(f"ERROR: Input file not found: {args.input}")
return 1
print(f"Parsing MIOpen commands from {args.input}...")
test_cases_2d = []
test_cases_3d = []
total_lines = 0
parsed_lines = 0
with open(args.input, 'r') as f:
with open(args.input, "r") as f:
for line_num, line in enumerate(f, 1):
total_lines += 1
line = line.strip()
# Skip empty lines and non-MIOpen commands
# Handle both direct commands and logged commands with MIOpen prefix
if not line:
continue
# Extract the actual MIOpenDriver command from logged format
if 'MIOpenDriver conv' in line:
if "MIOpenDriver conv" in line:
# Extract command after finding MIOpenDriver
command_start = line.find('./bin/MIOpenDriver conv')
command_start = line.find("./bin/MIOpenDriver conv")
if command_start != -1:
line = line[command_start:]
else:
# Handle cases where path might be different - create standard format
driver_start = line.find('MIOpenDriver conv')
driver_start = line.find("MIOpenDriver conv")
if driver_start != -1:
line = './bin/' + line[driver_start:]
line = "./bin/" + line[driver_start:]
else:
continue
elif not line.startswith('./bin/MIOpenDriver conv'):
elif not line.startswith("./bin/MIOpenDriver conv"):
continue
try:
# Parse MIOpen command
miopen_params = parse_miopen_command(line)
if not miopen_params:
continue
# Convert to ConvParam format
conv_param = miopen_to_conv_param(miopen_params)
if not conv_param:
continue
# Add model name to test name
conv_param['TestName'] = f"{args.model_name}_{conv_param['NDim']}D_fwd"
conv_param["TestName"] = f"{args.model_name}_{conv_param['NDim']}D_fwd"
# Separate 2D and 3D cases
if conv_param['NDim'] == 2:
if conv_param["NDim"] == 2:
test_cases_2d.append(conv_param)
else:
test_cases_3d.append(conv_param)
parsed_lines += 1
except Exception as e:
print(f"WARNING: Failed to parse line {line_num}: {e}")
continue
print(f"Processed {total_lines} lines, parsed {parsed_lines} commands")
print(f"Found {len(test_cases_2d)} 2D cases, {len(test_cases_3d)} 3D cases")
# Remove duplicates if requested
if args.filter_duplicates:
# Simple duplicate removal based on key parameters
def make_key(case):
if case['NDim'] == 2:
return (case['Groups'], case['BatchSize'], case['OutChannels'], case['InChannels'],
case['KernelH'], case['KernelW'], case['InputH'], case['InputW'],
case['StrideH'], case['StrideW'])
if case["NDim"] == 2:
return (
case["Groups"],
case["BatchSize"],
case["OutChannels"],
case["InChannels"],
case["KernelH"],
case["KernelW"],
case["InputH"],
case["InputW"],
case["StrideH"],
case["StrideW"],
)
else:
return (case['Groups'], case['BatchSize'], case['OutChannels'], case['InChannels'],
case['KernelD'], case['KernelH'], case['KernelW'],
case['InputD'], case['InputH'], case['InputW'],
case['StrideD'], case['StrideH'], case['StrideW'])
return (
case["Groups"],
case["BatchSize"],
case["OutChannels"],
case["InChannels"],
case["KernelD"],
case["KernelH"],
case["KernelW"],
case["InputD"],
case["InputH"],
case["InputW"],
case["StrideD"],
case["StrideH"],
case["StrideW"],
)
seen_2d = set()
unique_2d = []
for case in test_cases_2d:
@@ -320,7 +447,7 @@ def main():
if key not in seen_2d:
seen_2d.add(key)
unique_2d.append(case)
seen_3d = set()
unique_3d = []
for case in test_cases_3d:
@@ -328,11 +455,13 @@ def main():
if key not in seen_3d:
seen_3d.add(key)
unique_3d.append(case)
print(f"After deduplication: {len(unique_2d)} 2D cases, {len(unique_3d)} 3D cases")
print(
f"After deduplication: {len(unique_2d)} 2D cases, {len(unique_3d)} 3D cases"
)
test_cases_2d = unique_2d
test_cases_3d = unique_3d
# Write output files
if args.output:
# Write mixed cases to single file
@@ -340,14 +469,36 @@ def main():
if all_cases:
print(f"Writing {len(all_cases)} total cases to {args.output}")
# Use 2D headers for mixed file, extend as needed
mixed_headers = ['NDim', 'Groups', 'BatchSize', 'OutChannels', 'InChannels',
'KernelH', 'KernelW', 'InputH', 'InputW', 'OutputH', 'OutputW',
'StrideH', 'StrideW', 'DilationH', 'DilationW',
'LeftPadH', 'LeftPadW', 'RightPadH', 'RightPadW', 'TestName']
with open(args.output, 'w', newline='') as csvfile:
csvfile.write(f"# Mixed 2D/3D Convolution Test Cases from MIOpen Commands\n")
writer = csv.DictWriter(csvfile, fieldnames=mixed_headers, extrasaction='ignore')
mixed_headers = [
"NDim",
"Groups",
"BatchSize",
"OutChannels",
"InChannels",
"KernelH",
"KernelW",
"InputH",
"InputW",
"OutputH",
"OutputW",
"StrideH",
"StrideW",
"DilationH",
"DilationW",
"LeftPadH",
"LeftPadW",
"RightPadH",
"RightPadW",
"TestName",
]
with open(args.output, "w", newline="") as csvfile:
csvfile.write(
"# Mixed 2D/3D Convolution Test Cases from MIOpen Commands\n"
)
writer = csv.DictWriter(
csvfile, fieldnames=mixed_headers, extrasaction="ignore"
)
writer.writeheader()
for case in all_cases:
writer.writerow(case)
@@ -355,12 +506,13 @@ def main():
# Write separate files for 2D and 3D
if test_cases_2d:
write_csv_cases(test_cases_2d, args.output_2d, 2)
if test_cases_3d:
write_csv_cases(test_cases_3d, args.output_3d, 3)
print("Conversion completed!")
return 0
if __name__ == "__main__":
exit(main())

View File

@@ -7,13 +7,12 @@ PyTorch Model Runner with MIOpen Command Logging using torchvision models
Usage:
MIOPEN_ENABLE_LOGGING_CMD=1 python3 run_model_with_miopen.py --model resnet18 2> miopen_commands.txt
Available 2D models: alexnet, vgg11, vgg16, resnet18, resnet50, mobilenet_v2, etc.
Available 3D models: r3d_18, mc3_18, r2plus1d_18
"""
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.models.video as video_models
import argparse
@@ -21,94 +20,145 @@ import os
# Define available models
MODELS_2D = [
'alexnet', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn',
'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
'resnext50_32x4d', 'resnext101_32x8d', 'resnext101_64x4d',
'wide_resnet50_2', 'wide_resnet101_2',
'densenet121', 'densenet161', 'densenet169', 'densenet201',
'inception_v3', 'googlenet',
'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0',
'mobilenet_v2', 'mobilenet_v3_large', 'mobilenet_v3_small',
'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3',
'squeezenet1_0', 'squeezenet1_1'
"alexnet",
"vgg11",
"vgg11_bn",
"vgg13",
"vgg13_bn",
"vgg16",
"vgg16_bn",
"vgg19",
"vgg19_bn",
"resnet18",
"resnet34",
"resnet50",
"resnet101",
"resnet152",
"resnext50_32x4d",
"resnext101_32x8d",
"resnext101_64x4d",
"wide_resnet50_2",
"wide_resnet101_2",
"densenet121",
"densenet161",
"densenet169",
"densenet201",
"inception_v3",
"googlenet",
"shufflenet_v2_x0_5",
"shufflenet_v2_x1_0",
"shufflenet_v2_x1_5",
"shufflenet_v2_x2_0",
"mobilenet_v2",
"mobilenet_v3_large",
"mobilenet_v3_small",
"mnasnet0_5",
"mnasnet0_75",
"mnasnet1_0",
"mnasnet1_3",
"squeezenet1_0",
"squeezenet1_1",
]
MODELS_3D = [
'r3d_18', 'mc3_18', 'r2plus1d_18'
]
MODELS_3D = ["r3d_18", "mc3_18", "r2plus1d_18"]
ALL_MODELS = MODELS_2D + MODELS_3D
def main():
parser = argparse.ArgumentParser(description='PyTorch Model Runner with MIOpen Command Logging')
parser = argparse.ArgumentParser(
description="PyTorch Model Runner with MIOpen Command Logging"
)
# Model selection
parser.add_argument('--model', choices=ALL_MODELS, default='resnet18',
help='Model to run')
parser.add_argument(
"--model", choices=ALL_MODELS, default="resnet18", help="Model to run"
)
# Input tensor dimensions
parser.add_argument('--batch-size', type=int, default=4,
help='Batch size')
parser.add_argument('--channels', type=int, default=3,
help='Input channels (e.g., 3 for RGB, 1 for grayscale)')
parser.add_argument('--height', type=int, default=224,
help='Input height')
parser.add_argument('--width', type=int, default=224,
help='Input width')
parser.add_argument('--input-size', type=int,
help='Input size (sets both height and width to same value)')
parser.add_argument('--temporal-size', type=int, default=16,
help='Temporal dimension for 3D models')
parser.add_argument("--batch-size", type=int, default=4, help="Batch size")
parser.add_argument(
"--channels",
type=int,
default=3,
help="Input channels (e.g., 3 for RGB, 1 for grayscale)",
)
parser.add_argument("--height", type=int, default=224, help="Input height")
parser.add_argument("--width", type=int, default=224, help="Input width")
parser.add_argument(
"--input-size",
type=int,
help="Input size (sets both height and width to same value)",
)
parser.add_argument(
"--temporal-size", type=int, default=16, help="Temporal dimension for 3D models"
)
# Device and precision
parser.add_argument('--device', choices=['cuda', 'cpu', 'auto'], default='auto',
help='Device to run on')
parser.add_argument('--precision', choices=['fp32', 'fp16', 'bf16'], default='fp32',
help='Floating point precision')
parser.add_argument(
"--device",
choices=["cuda", "cpu", "auto"],
default="auto",
help="Device to run on",
)
parser.add_argument(
"--precision",
choices=["fp32", "fp16", "bf16"],
default="fp32",
help="Floating point precision",
)
# Output control
parser.add_argument('--quiet', action='store_true',
help='Suppress output except errors')
parser.add_argument('--verbose', action='store_true',
help='Verbose output')
parser.add_argument(
"--quiet", action="store_true", help="Suppress output except errors"
)
parser.add_argument("--verbose", action="store_true", help="Verbose output")
args = parser.parse_args()
# Handle input-size override
if args.input_size:
args.height = args.input_size
args.width = args.input_size
# Check MIOpen logging
if not os.environ.get('MIOPEN_ENABLE_LOGGING_CMD') and not args.quiet:
if not os.environ.get("MIOPEN_ENABLE_LOGGING_CMD") and not args.quiet:
print("WARNING: Set MIOPEN_ENABLE_LOGGING_CMD=1 to capture commands")
# Device selection
if args.device == 'auto':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if args.device == "auto":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device(args.device)
# Check if actually running on GPU
if device.type == 'cpu':
if device.type == "cpu":
import sys
print(f"WARNING: Running on CPU, MIOpen commands will not be generated!", file=sys.stderr)
print(
"WARNING: Running on CPU, MIOpen commands will not be generated!",
file=sys.stderr,
)
print(f"CUDA/ROCm available: {torch.cuda.is_available()}", file=sys.stderr)
if torch.cuda.is_available():
print(f"GPU device count: {torch.cuda.device_count()}", file=sys.stderr)
print(f"GPU name: {torch.cuda.get_device_name(0) if torch.cuda.device_count() > 0 else 'N/A'}", file=sys.stderr)
print(
f"GPU name: {torch.cuda.get_device_name(0) if torch.cuda.device_count() > 0 else 'N/A'}",
file=sys.stderr,
)
# Continue anyway for testing purposes
if not args.quiet:
print(f"Using device: {device}")
# Create model using torchvision
if args.model in MODELS_3D:
# 3D Video models
model = getattr(video_models, args.model)(weights=None)
# 3D input: (batch, channels, temporal, height, width)
input_tensor = torch.randn(args.batch_size, args.channels, args.temporal_size, args.height, args.width)
input_tensor = torch.randn(
args.batch_size, args.channels, args.temporal_size, args.height, args.width
)
if not args.quiet:
print(f"3D model: {args.model}")
print(f"Input shape: {input_tensor.shape} (B, C, T, H, W)")
@@ -116,34 +166,37 @@ def main():
# 2D Image models
model = getattr(models, args.model)(weights=None)
# 2D input: (batch, channels, height, width)
input_tensor = torch.randn(args.batch_size, args.channels, args.height, args.width)
input_tensor = torch.randn(
args.batch_size, args.channels, args.height, args.width
)
if not args.quiet:
print(f"2D model: {args.model}")
print(f"Input shape: {input_tensor.shape} (B, C, H, W)")
# Set precision
if args.precision == 'fp16':
if args.precision == "fp16":
model = model.half()
input_tensor = input_tensor.half()
elif args.precision == 'bf16':
elif args.precision == "bf16":
model = model.bfloat16()
input_tensor = input_tensor.bfloat16()
model = model.to(device)
input_tensor = input_tensor.to(device)
if not args.quiet:
print(f"Running {args.model} model...")
# Run inference
model.eval()
with torch.no_grad():
output = model(input_tensor)
if not args.quiet:
print(f"Output shape: {output.shape}")
if not args.quiet:
print("Done! MIOpen commands logged to stderr")
if __name__ == "__main__":
main()

View File

@@ -170,11 +170,11 @@ warp_tile_supported_combinations = {
[16, 16, 128],
[32, 32, 64],
],
"fp8_bf8_fp16": [
"fp8_bf8_fp16": [
[16, 16, 128],
[32, 32, 64],
],
"bf8_fp8_fp16": [
"bf8_fp8_fp16": [
[16, 16, 128],
[32, 32, 64],
],

View File

@@ -107,32 +107,32 @@ WARP_TILE_SUPPORTED_COMBINATIONS = {
"fp16_fp16_fp16": [
[16, 16, 16],
],
},
},
}
# Supported warp tile combinations for different GPU architectures and data types
WARP_SUPPORTED_COMBINATIONS = {
"gfx90a": [
[1, 4, 1],
[2, 2, 1],
[1, 4, 1],
[2, 2, 1],
[4, 1, 1],
],
"gfx942": [
[1, 4, 1],
[2, 2, 1],
[1, 4, 1],
[2, 2, 1],
[4, 1, 1],
],
"gfx950": [
[1, 4, 1],
[2, 2, 1],
[1, 4, 1],
[2, 2, 1],
[4, 1, 1],
],
"gfx1201": [
[2, 4, 1],
[1, 8, 1],
[8, 1, 1],
[2, 4, 1],
[1, 8, 1],
[8, 1, 1],
[4, 2, 1],
],
],
}
# Unsupported trait combinations
@@ -186,14 +186,14 @@ def is_trait_combination_valid(pipeline: str, epilogue: str, scheduler: str) ->
def validate_warp_configuration(
warp_m: int,
warp_n: int,
warp_m: int,
warp_n: int,
warp_k: int,
gpu_name: str = None,
) -> bool:
"""Validate warp configuration."""
if gpu_name is None:
gpu_name = get_gpu_name_by_id(0)
gpu_name = get_gpu_name_by_id(0)
current_combination = [warp_m, warp_n, warp_k]
@@ -205,11 +205,8 @@ def validate_warp_configuration(
# Check if current combination is in the allowed list
if current_combination not in allowed_combinations:
error_msg = (
f"Invalid warp tile combination: {current_combination} not in allowed list. "
)
return False
return True