diff --git a/example/ck_tile/01_fmha/README.md b/example/ck_tile/01_fmha/README.md index a77d7e6be3..0b526f4e9f 100644 --- a/example/ck_tile/01_fmha/README.md +++ b/example/ck_tile/01_fmha/README.md @@ -17,12 +17,12 @@ The executables reside in `bin` subdirectory of the build directory. This example provides recipes for `tile_example_fmha_fwd`, `tile_example_fmha_bwd`, `tile_example_fmha_fwd_v3`. -> [!NOTE] -> `cmake-ck-dev.sh` is a CMake wrapper. +> [!NOTE] +> `cmake-ck-dev.sh` is a CMake wrapper. > > The first argument is the path to composable_kernel sources. > -> The second argument is the gfx architectures string (e.g. "gfx950" or "gfx90a;gfx942"). +> The second argument is the gfx architectures string (e.g. "gfx950" or "gfx90a;gfx942"). > > The remaining arguments are optional and are passed through to CMake. > E.g. `-G Ninja` specifies ninja as the build system. @@ -61,15 +61,8 @@ args: -d head dim for q, k (default:128) -d_v head dim for v, -1 means equal to d (default:-1) -scale_s scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0) - note when squant=1, this value will be modified by range_q/k - -range_q per-tensor quantization range of q. used if squant=1. (default:16) - -range_k per-tensor quantization range of k. used if squant=1. (default:16) - -range_v per-tensor quantization range of v. used if squant=1. (default:16) - -range_p per-tensor quantization range of p [e^(s-m)]. used if squant=1. (default:1) - -range_o per-tensor quantization range of o (p*v). used if squant=1. (default:16) - -squant if using static quantization fusion or not. auto: fp8 will default use squant, other will not (default:auto) - 0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to P and O. - calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, range_p, range_o + -qscale n or 0, no scaling (default:n) + 1: per-tensor quantization. -iperm permute input (default:1) if true, will be b*h*s*d, else b*s*h*d -operm permute output (default:1) @@ -104,7 +97,7 @@ args: Comma-separated list of length 'b'. If empty, no override ``` Example 1: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case. -Example 2: `./bin/tile_example_fmha_fwd -b=1 -h=8 -s=16384 -d=64 -drop_prefs=1 -drop_seed=10 -drop_offset=1234` will run a fmha case with +Example 2: `./bin/tile_example_fmha_fwd -b=1 -h=8 -s=16384 -d=64 -drop_prefs=1 -drop_seed=10 -drop_offset=1234` will run a fmha case with batch=1, nhead=8, sequence length=16384, hdim=64, drop_seed=0 (in GPU memory), drop_offset=1234 (in GPU memory) fp16 case ## Padding Examples diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 4098eb67c2..312ea26e11 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -63,6 +63,16 @@ def get_mask_check_map(mask: str): return None +QSCALE_MAP = { + "no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE", + "pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR", +} + +QSCALE_CHECK_MAP = { + "no": "quant_scale_enum::no_scale", + "pertensor": "quant_scale_enum::pertensor", +} + BIAS_MAP = { "no": "ck_tile::BlockAttentionBiasEnum::NO_BIAS", "bias": "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS", diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 2acc467410..db817c94ae 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -24,6 +24,8 @@ from codegen.cpp_symbol_map import ( FWD_DTYPE_MAP, BIAS_MAP, get_mask_map, + QSCALE_CHECK_MAP, + QSCALE_MAP, ) from codegen.utils import check_duplicates_and_paddings, if_, indent, update_file @@ -64,7 +66,7 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, false, {F_lse}, {F_dropout}, - {F_squant}, + {F_qscale}, {F_occupancy}, {F_skip}>; @@ -103,7 +105,7 @@ using fmha_kernel_{F_idx} = ck_tile::FmhaFwdKernel; using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; template<> float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) @@ -190,9 +192,9 @@ FMHA_FWD_API_PER_HDIM_CASE = """{F_if}(t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hd }} """ -FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && +FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.skip_min_seqlen_q == {F_skip}) && ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ - using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; return fmha_fwd_(s, a); }} """ @@ -232,7 +234,7 @@ class FmhaFwdApiTrait: bias: str # lse: str # dropout: str - squant: str # + qscale: str # spad: str skpad: str dpad: str @@ -245,7 +247,7 @@ class FmhaFwdApiTrait: def name(self) -> str: return ( f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" - + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}" + + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}" ) @property @@ -341,7 +343,7 @@ class FmhaFwdPipeline: F_bias: str # true/false F_lse: str # F_dropout: str # - F_squant: str # + F_qscale: str # no/pertensor F_mask: str # value from MASK_MAP F_skip: str # true/false F_trload: str # true/false @@ -406,10 +408,10 @@ class FmhaFwdPipeline: else: n += "_nskip" - if self.F_squant == "t": - n += "_squant" + if self.F_qscale != "no": + n += f"_{self.F_qscale}" else: - n += "_nsquant" + n += "_nqscale" if self.F_trload == "t": n += "_trload" @@ -462,7 +464,8 @@ class FmhaFwdApiPool: F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip], F_trload=BOOL_MAP[trait.tr_load], - F_squant=BOOL_MAP[trait.squant], + F_qscale_check=QSCALE_CHECK_MAP[trait.qscale], + F_qscale=QSCALE_MAP[trait.qscale], F_scheck=trait.scheck, F_seqtune=trait.seqtune(max_bm0), F_skcheck=trait.skcheck, @@ -580,7 +583,7 @@ class FmhaFwdKernel: F_bias=BIAS_MAP[self.F_pipeline.F_bias], F_lse=BOOL_MAP[self.F_pipeline.F_lse], F_dropout=BOOL_MAP[self.F_pipeline.F_dropout], - F_squant=BOOL_MAP[self.F_pipeline.F_squant], + F_qscale=QSCALE_MAP[self.F_pipeline.F_qscale], F_skip=BOOL_MAP[self.F_pipeline.F_skip], F_occupancy=self.F_tile.F_occupancy, F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], @@ -623,7 +626,7 @@ class FmhaFwdKernel: bias=self.F_pipeline.F_bias, lse=self.F_pipeline.F_lse, dropout=self.F_pipeline.F_dropout, - squant=self.F_pipeline.F_squant, + qscale=self.F_pipeline.F_qscale, spad=self.F_pipeline.F_spad, skpad=self.F_pipeline.F_skpad, dpad=self.F_pipeline.F_dpad, @@ -695,7 +698,7 @@ class KernelComponentFactoryGfx9: # TODO: how to design this more generic? pipelines = [] if dtype in ["fp32"]: - squant = "f" + qscale = "no" for logits, mask, bias, lse, dropout, skip in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), @@ -704,11 +707,11 @@ class KernelComponentFactoryGfx9: ["t", "f"], ["t", "f"], ): - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "t", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip elif dtype in ["fp16", "bf16"]: - squant = "f" + qscale = "no" for logits, mask, bias, lse, dropout, skip in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), @@ -718,28 +721,31 @@ class KernelComponentFactoryGfx9: ["t", "f"], ): if hdim == 256 and hdim_v == 256: - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip # the below two is used for hdim vectorize load - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip else: if bias == "bias": # TODO: rocm 6.2 compiler problem if using qr_async for bias case - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip else: - pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip if receipt == 1 and bias != "bias": - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip # TODO: cover arbitraty hdim# fmt: skip - elif dtype in ["fp8", "fp8bf16", "fp8fp32"]: + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip # TODO: cover arbitraty hdim# fmt: skip + elif dtype in ["fp8bf16", "fp8fp32"]: # no need lse/dropout kernels - for logits, squant, mask, bias in itertools.product( - ["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys() + for logits, qscale, mask, bias in itertools.product( + ["f"], + ["no", "pertensor"], + get_mask_map(mask_impl).keys(), + ["no"], ): - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip - elif dtype in ["fp8fp16", "bf8"]: + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip + elif dtype in ["fp8", "fp8fp16", "bf8"]: # TODO None else: @@ -756,7 +762,7 @@ class KernelComponentFactoryGfx950(KernelComponentFactoryGfx9): dtype, hdim, hdim_v, receipt, mask_impl ) if dtype in ["fp16", "bf16"]: - squant = "f" + qscale = "no" for logits, mask, bias, lse, dropout, skip in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), @@ -772,8 +778,8 @@ class KernelComponentFactoryGfx950(KernelComponentFactoryGfx9): and dropout == "f" and skip == "f" ): - pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "t")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "t")) # fmt: skip return pipelines @@ -810,7 +816,7 @@ class KernelComponentFactoryGfx12: def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: pipelines = [] if dtype in ["fp16", "bf16"]: - squant = "f" + qscale = "no" for logits, mask, bias, lse, dropout, skip in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), @@ -819,15 +825,15 @@ class KernelComponentFactoryGfx12: ["t", "f"], ["t", "f"], ): - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f")) # fmt: skip elif dtype in ["fp8", "fp8bf16", "fp8fp32"]: # no need lse/dropout kernels - for logits, squant, mask, bias in itertools.product( - ["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys() + for logits, qscale, mask, bias in itertools.product( + ["f"], ["no", "pertensor"], get_mask_map(mask_impl).keys(), ["no"] ): - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip else: assert False return pipelines @@ -932,7 +938,7 @@ def get_fwd_blobs( cond = dtype in ["fp16", "bf16"] cond &= pipeline.F_vlayout == "row" cond &= pipeline.F_bias in ["no", "alibi"] - cond &= pipeline.F_squant == "f" + cond &= pipeline.F_qscale == "no" cond &= pipeline.F_skip == "f" if not cond: continue @@ -941,7 +947,7 @@ def get_fwd_blobs( cond = dtype in ["fp16", "bf16"] cond &= pipeline.F_vlayout == "row" cond &= pipeline.F_bias in ["no", "bias"] - cond &= pipeline.F_squant == "f" + cond &= pipeline.F_qscale == "no" cond &= mode == "batch" cond &= pipeline.F_skip == "f" cond &= pipeline.F_logits == "f" @@ -953,7 +959,7 @@ def get_fwd_blobs( cond &= mode == "batch" cond &= pipeline.F_vlayout == "row" if dtype == "fp8bf16": - cond &= hdim == 128 + cond &= hdim == 128 or hdim == 256 if not cond: continue # Aiter(mha_varlen_fwd) integration @@ -962,7 +968,7 @@ def get_fwd_blobs( cond &= mode == "group" cond &= pipeline.F_vlayout == "row" if dtype == "fp8bf16": - cond &= hdim == 128 + cond &= hdim == 128 or hdim == 256 if not cond: continue # aiter::mha_fwd C++ api integration @@ -970,13 +976,13 @@ def get_fwd_blobs( cond = dtype in ["fp16", "bf16", "fp8bf16"] cond &= pipeline.F_vlayout == "row" if dtype == "fp8bf16": - cond &= hdim == 128 + cond &= hdim == 128 or hdim == 256 if not cond: continue elif receipt == 888: - cond = dtype in ["fp8", "fp8bf16", "fp8fp32"] + cond = dtype in ["fp8bf16", "fp8fp32"] cond &= pipeline.F_vlayout == "row" - cond &= hdim == 128 + cond &= hdim == 128 or hdim == 256 if not cond: continue diff --git a/example/ck_tile/01_fmha/example_fmha_fwd.cpp b/example/ck_tile/01_fmha/example_fmha_fwd.cpp index c27a5ce1ae..3d8aa29131 100644 --- a/example/ck_tile/01_fmha/example_fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/example_fmha_fwd.cpp @@ -45,18 +45,12 @@ auto create_args(int argc, char* argv[]) "must be greater than or equal to s_k") .insert("d", "128", "head dim for q, k") .insert("d_v", "-1", "head dim for v, -1 means equal to d") - .insert("scale_s", - "0", - "scale factor of S. 0 means equal to 1/sqrt(hdim).\n" - "note when squant=1, this value will be modified") + .insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim)") + .insert("qscale", + "n", + "n or 0, no scale\n" + "pt or 1, per-tensor scale\n") .insert("logits_soft_cap", "0", "attention logits soft capping value.") - .insert("squant", - "auto", - "if using static quantization fusion or not. auto: fp8 will default use squant, " - "other will not\n" - "0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to " - "P and O.\n" - "calculate scale_s, scale_p, scale_o auto") .insert("iperm", "1", "permute input\n" @@ -87,7 +81,8 @@ auto create_args(int argc, char* argv[]) "uf", "init method:\n ui or 0 - uniform random int\n ni - normalized random int" "\n uf or 1 - uniform random float\n nf - normalized random float" - "\n tf or 2 - trig float\n") + "\n tf or 2 - trig float" + "\n tf or 3 - uniform random float, min max is the max of the type\n") .insert("seed", "11939", "random seed used for initializing input tensors. 0 for " @@ -152,6 +147,7 @@ auto run(const ck_tile::ArgParser& arg_parser) ck_tile::index_t page_block_size = arg_parser.get_int("page_block_size"); bool use_cache_batch_idx = arg_parser.get_bool("cache_batch_idx"); std::string bias_str = arg_parser.get_str("bias"); + std::string qscale_str = arg_parser.get_str("qscale"); float p_drop = arg_parser.get_float("p_drop"); uint64_t drop_seed = arg_parser.get_uint64("drop_seed"); uint64_t drop_offset = arg_parser.get_uint64("drop_offset"); @@ -162,13 +158,6 @@ auto run(const ck_tile::ArgParser& arg_parser) std::string init_method = arg_parser.get_str("init"); uint32_t seed = arg_parser.get_uint32("seed"); - bool squant = [&]() { - if(arg_parser.get_str("squant") == "auto") - return std::is_same_v; - else - return arg_parser.get_bool("squant"); - }(); - ck_tile::stream_config stream_config{nullptr, true, /* log_level = */ (arg_parser.get_bool("kname") ? 1 : 0), @@ -208,7 +197,7 @@ auto run(const ck_tile::ArgParser& arg_parser) drop_offset, drop_prefs, mask_str, - squant, + qscale_str, is_rotary_interleaved, num_splits, init_method, @@ -239,10 +228,6 @@ int main(int argc, char* argv[]) { return run(arg_parser) == fwd_result::success ? 0 : -2; } - else if(data_type == "fp8") - { - return run(arg_parser) == fwd_result::success ? 0 : -2; - } else if(data_type == "fp8bf16") { return run(arg_parser) == fwd_result::success ? 0 : -2; diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index a952800806..b628fa1d87 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -11,6 +11,7 @@ #include "bias.hpp" #include "mask.hpp" +#include "quant.hpp" #include "rotary.hpp" #include @@ -178,6 +179,9 @@ struct fmha_fwd_args const void* k_ptr; const void* v_ptr; const void* bias_ptr; // bias or alibi_slope pointer + const void* q_descale_ptr; + const void* k_descale_ptr; + const void* v_descale_ptr; void* rand_val_ptr; void* lse_ptr; void* o_ptr; @@ -237,9 +241,6 @@ struct fmha_fwd_args ck_tile::index_t nhead_k; float scale_s; - float scale_p; - float scale_o; - float logits_soft_cap; ck_tile::index_t stride_q; @@ -581,6 +582,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.k_ptr, args.v_ptr, args.bias_ptr, + args.q_descale_ptr, + args.k_descale_ptr, + args.v_descale_ptr, args.rand_val_ptr, args.lse_ptr, args.o_ptr, @@ -593,8 +597,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.nhead_q, args.nhead_q / args.nhead_k, args.scale_s, - args.scale_p, - args.scale_o, args.logits_soft_cap, args.stride_q, args.stride_k, @@ -625,6 +627,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.k_ptr, args.v_ptr, args.bias_ptr, + args.q_descale_ptr, + args.k_descale_ptr, + args.v_descale_ptr, args.rand_val_ptr, args.lse_ptr, args.o_ptr, @@ -635,8 +640,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.nhead_q, args.nhead_q / args.nhead_k, args.scale_s, - args.scale_p, - args.scale_o, args.logits_soft_cap, args.stride_q, args.stride_k, @@ -1125,7 +1128,7 @@ template 1.0f) { std::cerr << "The value of p_drop should be 0~1" << std::endl; @@ -572,6 +574,11 @@ fwd_result fmha_fwd_run(mode_enum mode, hdim_v} : std::array{1, 1, 1, 1, 1}); + // TODO - change the tensor length for different quant scale + ck_tile::HostTensor q_descale_host(get_lengths(i_perm, 1, 1, 1, 1)); + ck_tile::HostTensor k_descale_host(get_lengths(i_perm, 1, 1, 1, 1)); + ck_tile::HostTensor v_descale_host(get_lengths(i_perm, 1, 1, 1, 1)); + // batch mode of lse data layout is [batch, nhead, seqlen_q] // group mode of lse data layout is [nhead, total_seqlen_q] ck_tile::HostTensor lse_host( @@ -592,7 +599,6 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::HostTensor cache_batch_idx_host(use_cache_batch_idx ? std::array{batch} : std::array{1}); - float max_o = 5.0; if(init_method == "ui" || init_method == "0") { ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, next_seed()}(q_host); @@ -640,6 +646,23 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::FillTrigValue{}(vnew_host); ck_tile::FillTrigValue{}(bias_host); } + else if(init_method == "3") + { + float q_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + float k_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + float v_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + float bias_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + + ck_tile::FillUniformDistribution{-q_dtype_max, q_dtype_max, next_seed()}(q_host); + ck_tile::FillUniformDistribution{-k_dtype_max, k_dtype_max, next_seed()}(k_host); + ck_tile::FillUniformDistribution{-k_dtype_max, k_dtype_max, next_seed()}( + knew_host); + ck_tile::FillUniformDistribution{-v_dtype_max, v_dtype_max, next_seed()}(v_host); + ck_tile::FillUniformDistribution{-v_dtype_max, v_dtype_max, next_seed()}( + vnew_host); + ck_tile::FillUniformDistribution{ + -bias_dtype_max, bias_dtype_max, next_seed()}(bias_host); + } if(bias.type == bias_enum::alibi) { auto slopes = ck_tile::get_alibi_slopes(nhead); @@ -658,6 +681,18 @@ fwd_result fmha_fwd_run(mode_enum mode, } } } + if(qscale.type == quant_scale_enum::pertensor) + { + float q_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + float k_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + float v_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + + float qkv_max = 3.f; + q_descale_host(0) = qkv_max / q_dtype_max; + k_descale_host(0) = qkv_max / k_dtype_max; + v_descale_host(0) = qkv_max / v_dtype_max; + } + iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine); iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0, random_engine); @@ -667,6 +702,9 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem vnew_buf(vnew_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem q_descale_buf(q_descale_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_descale_buf(k_descale_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem v_descale_buf(v_descale_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes()); @@ -702,81 +740,15 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem cache_batch_idx_buf(cache_batch_idx_host.get_element_space_size_in_bytes()); - float scale_p = 1.f; - float scale_o = 1.f; - if(squant) - { - float q_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); - float k_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); - float v_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); - float p_dtype_max = v_dtype_max; // assume p and v is the same type - // Q tensor - { - float max_value = ck_tile::type_convert(ck_tile::numeric::min()); - q_host.ForEach([&](auto& self, auto idx) { - float val = ck_tile::type_convert(self(idx)); - if(val > max_value) - max_value = val; - }); - - float scale = q_dtype_max / max_value; - - q_host.ForEach([&](auto& self, auto idx) { - float val = ck_tile::type_convert(self(idx)); - self(idx) = ck_tile::type_convert(val * scale); - }); - scale_s = scale_s / scale; - } - - // K tensor - { - float max_value = ck_tile::type_convert(ck_tile::numeric::min()); - k_host.ForEach([&](auto& self, auto idx) { - float val = ck_tile::type_convert(self(idx)); - if(val > max_value) - max_value = val; - }); - float scale = k_dtype_max / max_value; - k_host.ForEach([&](auto& self, auto idx) { - float val = ck_tile::type_convert(self(idx)); - self(idx) = ck_tile::type_convert(val * scale); - }); - scale_s = scale_s / scale; - } - - // V tensor - { - float max_value = ck_tile::type_convert(ck_tile::numeric::min()); - v_host.ForEach([&](auto& self, auto idx) { - float val = ck_tile::type_convert(self(idx)); - if(val > max_value) - max_value = val; - }); - - float scale = k_dtype_max / max_value; - v_host.ForEach([&](auto& self, auto idx) { - float val = ck_tile::type_convert(self(idx)); - self(idx) = ck_tile::type_convert(val * scale); - }); - - scale_o = (1.0 / p_dtype_max) / scale; - } - - scale_p = p_dtype_max; - - if constexpr(std::is_same_v) - { - float o_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); - scale_o = scale_o * o_dtype_max / max_o; - } - } - q_buf.ToDevice(q_host.data()); k_buf.ToDevice(k_host.data()); v_buf.ToDevice(v_host.data()); knew_buf.ToDevice(knew_host.data()); vnew_buf.ToDevice(vnew_host.data()); bias_buf.ToDevice(bias_host.data()); + q_descale_buf.ToDevice(q_descale_host.data()); + k_descale_buf.ToDevice(k_descale_host.data()); + v_descale_buf.ToDevice(v_descale_host.data()); seqstart_q.ToDevice(seqstart_q_host.data()); // Keep logical starts in seqstart_k; pass padded K via separate pointer seqstart_k.ToDevice(seqstart_k_host.data()); @@ -816,7 +788,7 @@ fwd_result fmha_fwd_run(mode_enum mode, << (seqlen_kpads[0] < 0 ? "" : (std::string("(") + std::to_string(seqlen_kpads[0]) + ")")) << ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias - << ", p_drop:" << p_drop << ", lse:" << lse << ", squant:" << squant + << ", p_drop:" << p_drop << ", lse:" << lse << ", qscale:" << qscale << ", mask:" << mask << ", v:" << (is_v_rowmajor ? "r" : "c"); #if CK_TILE_FMHA_FWD_APPENDKV_API if(0 < rotary_dim) @@ -908,11 +880,11 @@ fwd_result fmha_fwd_run(mode_enum mode, traits.mask_type = mask.type; traits.bias_type = bias.type; traits.has_lse = lse; - traits.do_fp8_static_quant = squant; if constexpr(std::is_same_v>) { traits.has_dropout = (p_drop > 0.0f); + traits.qscale_type = qscale.type; } else if constexpr(std::is_same_v>) @@ -1055,8 +1027,6 @@ fwd_result fmha_fwd_run(mode_enum mode, args.max_seqlen_q = max_seqlen_q; args.scale_s = scale_s; - args.scale_p = scale_p; - args.scale_o = scale_o; args.logits_soft_cap = logits_soft_cap; @@ -1076,6 +1046,10 @@ fwd_result fmha_fwd_run(mode_enum mode, if constexpr(std::is_same_v>) { + args.q_descale_ptr = q_descale_buf.GetDeviceBuffer(); + args.k_descale_ptr = k_descale_buf.GetDeviceBuffer(); + args.v_descale_ptr = v_descale_buf.GetDeviceBuffer(); + args.rand_val_ptr = randval_buf.GetDeviceBuffer(); args.stride_randval = stride_randval; @@ -1351,23 +1325,34 @@ fwd_result fmha_fwd_run(mode_enum mode, lse_buf.FromDevice(lse_host.data()); randval_buf.FromDevice(randval_host.data()); - constexpr bool supports_squant = std::is_same_v || + constexpr bool supports_qscale = std::is_same_v || std::is_same_v || std::is_same_v; + float scale_s_host = scale_s; + float scale_p_host = 1.0f; + float scale_o_host = 1.0f; + + if(qscale.type == quant_scale_enum::pertensor) + { + scale_s_host = scale_s * q_descale_host(0) * k_descale_host(0); + scale_p_host = ck_tile::type_convert(ck_tile::numeric::max()); + scale_o_host = v_descale_host(0) / scale_p_host; + } + auto p_compute_element_func = [&]() { - if constexpr(supports_squant) - return ck_tile::scales{scale_p}; + if constexpr(supports_qscale) + return ck_tile::scales{scale_p_host}; else return ck_tile::identity{}; }(); auto oacc_element_func = [&]() { - if constexpr(std::is_same_v && supports_squant) + if constexpr(std::is_same_v && supports_qscale) return ck_tile::composes(ck_tile::saturates{}, - ck_tile::scales{scale_o}); - else if constexpr(supports_squant) - return ck_tile::scales{scale_o}; + ck_tile::scales{scale_o_host}); + else if constexpr(supports_qscale) + return ck_tile::scales{scale_o_host}; else return ck_tile::identity{}; }(); @@ -1573,7 +1558,7 @@ fwd_result fmha_fwd_run(mode_enum mode, s_host_ref, ck_tile::identity{}, ck_tile::identity{}, - ck_tile::scales(scale_s)); + ck_tile::scales(scale_s_host)); if(0.f < logits_soft_cap) { @@ -1818,7 +1803,8 @@ fwd_result fmha_fwd_run(mode_enum mode, scale_s, p_drop, lse, - squant, + qscale.type == quant_scale_enum::no_scale ? "no_scale" + : "pertensor", bias.type == bias_enum::elementwise_bias ? "elementwise_bias" : (bias.type == bias_enum::alibi ? "alibi" : "no_bias"), diff --git a/example/ck_tile/01_fmha/quant.hpp b/example/ck_tile/01_fmha/quant.hpp new file mode 100644 index 0000000000..35461cc53d --- /dev/null +++ b/example/ck_tile/01_fmha/quant.hpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha.hpp" + +// keep sync with BlockAttentionQuantScaleEnum +enum class quant_scale_enum +{ + no_scale = 0, + pertensor = 1, +}; + +struct quant_scale_info +{ + quant_scale_enum type; + + void serialize(std::ostream& os) const + { + if(type == quant_scale_enum::no_scale) + os << "n"; + else if(type == quant_scale_enum::pertensor) + os << "pt"; + } + + static quant_scale_info decode(std::string str) + { + quant_scale_info info{quant_scale_enum::no_scale}; + if(str == "n" || str == "0") + { + info.type = quant_scale_enum::no_scale; + } + else if(str == "pt" || str == "1") + { + info.type = quant_scale_enum::pertensor; + } + else + { + throw std::invalid_argument("invalid quant scale value: " + str); + } + return info; + } + + friend std::ostream& operator<<(std::ostream& os, const quant_scale_info& qsi) + { + qsi.serialize(os); + return os; + } +}; diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index 02bc5476fa..2330803bea 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -73,52 +73,39 @@ run_fp16_bf16_tests() { for page_block_size in $PAGE_BLOCK_SIZE ; do for cache_batch_idx in $CACHE_BATCH_IDX ; do - # run_exe -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS - run_exe -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16 -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - run_exe -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - run_exe -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + # run_exe -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16 -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS done ; done ; done ; done ; done done ; done ; done ; done ; done } -run_fp8_tests() { - for perm in 0 1 ; do - for bias in "n" "e" "a" ; do - for b in 1 2 ; do - for hdim in 64 128 256 ; do - - $EXE -prec=fp8 -init=0 -b=$b -h=1 -d=$hdim -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS - - done ; done ; done ; done -} - run_fp8bf16_tests() { for perm in 0 1 ; do - for bias in "n" "e" "a" ; do for b in 1 2 ; do for hdim in 64 128 256 ; do - $EXE -prec=fp8bf16 -init=0 -b=$b -h=1 -d=$hdim -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS + $EXE -prec=fp8bf16 -init=3 -b=$b -h=1 -d=$hdim -s=128 -iperm=$perm -operm=$perm -vlayout=r -qscale=1 -kname=$KNAME $COMMON_ARGS - done ; done ; done ; done + done ; done ; done } run_fp8fp32_tests() { for perm in 0 1 ; do - for bias in "n" "e" "a" ; do for b in 1 2 ; do for hdim in 128 ; do - $EXE -prec=fp8fp32 -init=0 -b=$b -h=1 -d=$hdim -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS + $EXE -prec=fp8fp32 -init=3 -b=$b -h=1 -d=$hdim -s=128 -iperm=$perm -operm=$perm -vlayout=r -qscale=1 -kname=$KNAME $COMMON_ARGS - done ; done ; done ; done + done ; done ; done } run_fp16_appendkv_tests() { @@ -133,7 +120,7 @@ run_fp16_appendkv_tests() { run_exe -prec=fp16 -b=3 -h=3 -d=$hdim -s=$s -s_k=$s_k -s_knew=$s_knew -rotary_dim=$rdim -rotary_interleaved=$ri -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -iperm=1 -operm=1 -kname=1 $COMMON_ARGS - done ; done ; done ; done ; done + done ; done ; done ; done ; done done ; done ; done } @@ -249,7 +236,6 @@ set -x run_fp16_bf16_tests run_padding_smoke_tests run_padding_basic_boundary_tests -run_fp8_tests run_fp8bf16_tests run_fp8fp32_tests diff --git a/include/ck_tile/host/host_tensor.hpp b/include/ck_tile/host/host_tensor.hpp index 9b87518161..59510c8b93 100644 --- a/include/ck_tile/host/host_tensor.hpp +++ b/include/ck_tile/host/host_tensor.hpp @@ -598,6 +598,8 @@ struct HostTensor typename Data::size_type size() const { return mData.size(); } + T max() const { return *std::max_element(mData.begin(), mData.end()); } + // return a slice of this tensor // for simplicity we just copy the data and return a new tensor auto slice(std::vector s_begin, std::vector s_end) const diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 6b25c089bd..5b87a821c9 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/fmha/block/block_masking.hpp" #include "ck_tile/ops/fmha/block/block_position_encoding.hpp" diff --git a/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp b/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp new file mode 100644 index 0000000000..4d80443f35 --- /dev/null +++ b/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +namespace ck_tile { + +// This class is used for codegen pattern matching +enum class BlockAttentionQuantScaleEnum +{ + NO_SCALE = 0, + PERTENSOR = 1, +}; + +template +struct BlockAttentionQuantScaleEnumToStr; + +template <> +struct BlockAttentionQuantScaleEnumToStr +{ + static constexpr const char* name = ""; +}; +template <> +struct BlockAttentionQuantScaleEnumToStr +{ + static constexpr const char* name = "pertensor"; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index fba3065842..cd1bfb031b 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp" #include "ck_tile/ops/fmha/block/variants.hpp" #include @@ -36,6 +37,7 @@ struct FmhaFwdKernel using QDataType = ck_tile::remove_cvref_t; using KDataType = ck_tile::remove_cvref_t; using VDataType = ck_tile::remove_cvref_t; + using PDataType = ck_tile::remove_cvref_t; using BiasDataType = ck_tile::remove_cvref_t; using RandValOutputDataType = ck_tile::remove_cvref_t; @@ -54,7 +56,7 @@ struct FmhaFwdKernel static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; - static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; + static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum; static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ; using AttentionVariant = ck_tile::remove_cvref_t; @@ -112,7 +114,8 @@ struct FmhaFwdKernel (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + - (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kUseTrLoad ? "_trload" : "_ntrload"); + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + + (QScaleEnum == BlockAttentionQuantScaleEnum::NO_SCALE ? _SS_("_nqscale") : (_SS_("_") + BlockAttentionQuantScaleEnumToStr::name)) + (kUseTrLoad ? "_trload" : "_ntrload"); #undef _SS_ #undef _TS_ // clang-format on @@ -204,10 +207,11 @@ struct FmhaFwdKernel ck_tile::GenericAttentionMaskEnum mask_type; }; - struct FmhaFwdFp8StaticQuantKargs + struct FmhaFwdCommonQScaleKargs { - float scale_p; - float scale_o; + const void* q_descale_ptr = nullptr; + const void* k_descale_ptr = nullptr; + const void* v_descale_ptr = nullptr; }; struct FmhaFwdCommonLSEKargs @@ -285,7 +289,9 @@ struct FmhaFwdKernel FmhaFwdEmptyKargs<0>>>, std::conditional_t>, std::conditional_t>, - std::conditional_t>, + std::conditional_t>, std::conditional_t>, std::conditional_t> { @@ -309,7 +315,9 @@ struct FmhaFwdKernel FmhaFwdEmptyKargs<0>>>, std::conditional_t>, std::conditional_t>, - std::conditional_t>, + std::conditional_t>, std::conditional_t>, std::conditional_t>, std::conditional_t> @@ -339,6 +347,9 @@ struct FmhaFwdKernel const void* k_ptr, const void* v_ptr, const void* bias_ptr, + const void* q_descale_ptr, + const void* k_descale_ptr, + const void* v_descale_ptr, void* rand_val_ptr, void* lse_ptr, void* o_ptr, @@ -349,8 +360,6 @@ struct FmhaFwdKernel ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, - float scale_p, - float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, @@ -408,7 +417,7 @@ struct FmhaFwdKernel {}, // placeholder for bias {}, // placeholder for mask {}, // placeholder for lse - {}, // placeholder for fp8_static_quant args + {}, // placeholder for qscale {}, // placeholder for dropout {}, // placeholder for logits_soft_cap batch_stride_q, @@ -440,10 +449,11 @@ struct FmhaFwdKernel kargs.nhead_stride_lse = nhead_stride_lse; kargs.batch_stride_lse = batch_stride_lse; } - if constexpr(kDoFp8StaticQuant) + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) { - kargs.scale_p = scale_p; - kargs.scale_o = scale_o; + kargs.q_descale_ptr = q_descale_ptr; + kargs.k_descale_ptr = k_descale_ptr; + kargs.v_descale_ptr = v_descale_ptr; } if constexpr(kHasDropout) { @@ -483,6 +493,9 @@ struct FmhaFwdKernel const void* k_ptr, const void* v_ptr, const void* bias_ptr, + const void* q_descale_ptr, + const void* k_descale_ptr, + const void* v_descale_ptr, void* rand_val_ptr, void* lse_ptr, void* o_ptr, @@ -493,8 +506,6 @@ struct FmhaFwdKernel ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, - float scale_p, - float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, @@ -530,6 +541,9 @@ struct FmhaFwdKernel k_ptr, v_ptr, bias_ptr, + q_descale_ptr, + k_descale_ptr, + v_descale_ptr, rand_val_ptr, lse_ptr, o_ptr, @@ -540,8 +554,6 @@ struct FmhaFwdKernel num_head_q, nhead_ratio_qk, scale_s, - scale_p, - scale_o, logits_soft_cap, stride_q, stride_k, @@ -580,6 +592,9 @@ struct FmhaFwdKernel const void* k_ptr, const void* v_ptr, const void* bias_ptr, + const void* q_descale_ptr, + const void* k_descale_ptr, + const void* v_descale_ptr, void* rand_val_ptr, void* lse_ptr, void* o_ptr, @@ -590,8 +605,6 @@ struct FmhaFwdKernel ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, - float scale_p, - float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, @@ -627,6 +640,9 @@ struct FmhaFwdKernel k_ptr, v_ptr, bias_ptr, + q_descale_ptr, + k_descale_ptr, + v_descale_ptr, rand_val_ptr, lse_ptr, o_ptr, @@ -637,8 +653,6 @@ struct FmhaFwdKernel num_head_q, nhead_ratio_qk, scale_s, - scale_p, - scale_o, logits_soft_cap, stride_q, stride_k, @@ -676,6 +690,9 @@ struct FmhaFwdKernel const void* k_ptr, const void* v_ptr, const void* bias_ptr, + const void* q_descale_ptr, + const void* k_descale_ptr, + const void* v_descale_ptr, void* rand_val_ptr, void* lse_ptr, void* o_ptr, @@ -688,8 +705,6 @@ struct FmhaFwdKernel ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, - float scale_p, - float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, @@ -741,7 +756,7 @@ struct FmhaFwdKernel {}, // placeholder for bias {}, // placeholder for mask {}, // placeholder for lse - {}, // placeholder for fp8_static_quant args + {}, // placeholder for qscale {}, // placeholder for dropout {}, // placeholder for logits_soft_cap {}, // placeholder for min_seqlen_q @@ -772,10 +787,11 @@ struct FmhaFwdKernel kargs.lse_ptr = lse_ptr; kargs.nhead_stride_lse = nhead_stride_lse; } - if constexpr(kDoFp8StaticQuant) + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) { - kargs.scale_p = scale_p; - kargs.scale_o = scale_o; + kargs.q_descale_ptr = q_descale_ptr; + kargs.k_descale_ptr = k_descale_ptr; + kargs.v_descale_ptr = v_descale_ptr; } if constexpr(kHasDropout) { @@ -818,6 +834,9 @@ struct FmhaFwdKernel const void* k_ptr, const void* v_ptr, const void* bias_ptr, + const void* q_descale_ptr, + const void* k_descale_ptr, + const void* v_descale_ptr, void* rand_val_ptr, void* lse_ptr, void* o_ptr, @@ -830,8 +849,6 @@ struct FmhaFwdKernel ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, - float scale_p, - float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, @@ -861,6 +878,9 @@ struct FmhaFwdKernel k_ptr, v_ptr, bias_ptr, + q_descale_ptr, + k_descale_ptr, + v_descale_ptr, rand_val_ptr, lse_ptr, o_ptr, @@ -873,8 +893,6 @@ struct FmhaFwdKernel num_head_q, nhead_ratio_qk, scale_s, - scale_p, - scale_o, logits_soft_cap, stride_q, stride_k, @@ -907,6 +925,9 @@ struct FmhaFwdKernel const void* k_ptr, const void* v_ptr, const void* bias_ptr, + const void* q_descale_ptr, + const void* k_descale_ptr, + const void* v_descale_ptr, void* rand_val_ptr, void* lse_ptr, void* o_ptr, @@ -919,8 +940,6 @@ struct FmhaFwdKernel ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, - float scale_p, - float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, @@ -950,6 +969,9 @@ struct FmhaFwdKernel k_ptr, v_ptr, bias_ptr, + q_descale_ptr, + k_descale_ptr, + v_descale_ptr, rand_val_ptr, lse_ptr, o_ptr, @@ -962,8 +984,6 @@ struct FmhaFwdKernel num_head_q, nhead_ratio_qk, scale_s, - scale_p, - scale_o, logits_soft_cap, stride_q, stride_k, @@ -1527,14 +1547,24 @@ struct FmhaFwdKernel BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; auto o_acc_tile = [&]() { - if constexpr(kDoFp8StaticQuant) + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) { + // TODO - move global load of descale to pipeline + float q_descale = *(reinterpret_cast(kargs.q_descale_ptr)); + float k_descale = *(reinterpret_cast(kargs.k_descale_ptr)); + float v_descale = *(reinterpret_cast(kargs.v_descale_ptr)); + + float scale_s = kargs.scale_s * q_descale * k_descale; + float scale_p = + ck_tile::type_convert(ck_tile::numeric::max()); + float scale_o = v_descale / scale_p; + auto o_acc_element_func = [&]() { if constexpr(std::is_same_v) return ck_tile::composes(ck_tile::saturates{}, - ck_tile::scales{kargs.scale_o}); + ck_tile::scales{scale_o}); else - return ck_tile::scales{kargs.scale_o}; + return ck_tile::scales{scale_o}; }(); return FmhaPipeline{}(q_dram_window, identity{}, // q_element_func @@ -1546,13 +1576,13 @@ struct FmhaFwdKernel identity{}, // bias_element_func randval_dram_window, lse_dram_window, - identity{}, // lse_element_func - identity{}, // s_acc_element_func - scales{kargs.scale_p}, // p_compute_element_func - o_acc_element_func, // o_acc_element_func + identity{}, // lse_element_func + identity{}, // s_acc_element_func + scales{scale_p}, // p_compute_element_func + o_acc_element_func, // o_acc_element_func mask, position_encoding, - kargs.scale_s, + scale_s, variant, variant_params, block_indices, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index cc0851efb3..864d155750 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -60,7 +60,7 @@ struct BlockFmhaPipelineProblem static constexpr auto BiasEnum = Traits::BiasEnum; static constexpr bool kStoreLSE = Traits::kStoreLSE; static constexpr bool kHasDropout = Traits::kHasDropout; - static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; + static constexpr auto QScaleEnum = Traits::QScaleEnum; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; }; diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index 59267fa3b1..1cc41deeb7 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp" #include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp" namespace ck_tile { @@ -18,7 +19,7 @@ template struct TileFmhaTraits @@ -32,7 +33,7 @@ struct TileFmhaTraits static constexpr bool kHasBiasGrad = kHasBiasGrad_; static constexpr bool kStoreLSE = kStoreLSE_; static constexpr bool kHasDropout = kHasDropout_; - static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; + static constexpr auto QScaleEnum = QScaleEnum_; static constexpr index_t kBlockPerCu = kBlockPerCu_; static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; }; diff --git a/include/ck_tile/utility/json_dump.hpp b/include/ck_tile/utility/json_dump.hpp index ece2a4ce1a..ed6373ae66 100644 --- a/include/ck_tile/utility/json_dump.hpp +++ b/include/ck_tile/utility/json_dump.hpp @@ -625,7 +625,7 @@ void dump_fmha_fwd_json_results(const std::string& json_filename, float scale_s, float p_drop, bool lse, - bool squant, + const std::string& qscale, const std::string& bias, const std::string& vlayout, bool pass, @@ -650,7 +650,7 @@ void dump_fmha_fwd_json_results(const std::string& json_filename, ADD_KEY_VALUE("scale_s", scale_s); ADD_KEY_VALUE("p_drop", p_drop); ADD_KEY_VALUE("lse", lse); - ADD_KEY_VALUE("squant", squant); + ADD_KEY_VALUE("qscale", qscale); ADD_KEY_VALUE("bias", bias); ADD_KEY_VALUE("vlayout", vlayout); ADD_KEY_VALUE("verification", pass ? "pass" : "fail"); diff --git a/test/ck_tile/fmha/CMakeLists.txt b/test/ck_tile/fmha/CMakeLists.txt index 6592fe4a9a..52accaf812 100644 --- a/test/ck_tile/fmha/CMakeLists.txt +++ b/test/ck_tile/fmha/CMakeLists.txt @@ -9,10 +9,10 @@ set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances") set(TEST_NAME "test_ck_tile_fmha") function(add_gtest_fwd test_group) - set(V_TYPES "fp16" "bf16" "fp8" "fp32") + set(V_TYPES "fp16" "bf16" "fp8bf16" "fp32") set(CPP_TYPE_fp16 "FmhaFwdFp16") set(CPP_TYPE_bf16 "FmhaFwdBf16") - set(CPP_TYPE_fp8 "FmhaFwdFp8") + set(CPP_TYPE_fp8bf16 "FmhaFwdFp8Bf16") set(CPP_TYPE_fp32 "FmhaFwdFp32") set(all_tests) diff --git a/test/ck_tile/fmha/test_fmha_fwd.cpp b/test/ck_tile/fmha/test_fmha_fwd.cpp index 15382e8072..b81fa88aa2 100644 --- a/test/ck_tile/fmha/test_fmha_fwd.cpp +++ b/test/ck_tile/fmha/test_fmha_fwd.cpp @@ -7,7 +7,7 @@ #include "gtest/gtest.h" #ifndef DataTypeConfig -#define DataTypeConfig FmhaFwdFp16 // or FmhaFwdBf16 / FmhaFwdFp8 / FmhaFwdFp32 +#define DataTypeConfig FmhaFwdFp16 // or FmhaFwdBf16 / FmhaFwdFp8Bf16 / FmhaFwdFp32 #endif using ::testing::Bool; @@ -39,13 +39,14 @@ struct TestConfigs std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}}; static constexpr auto ModeValues = std::array{mode_enum::batch, mode_enum::group}; static constexpr auto IsVRowmajorValues = std::array{true}; - static constexpr bool squant = false; + static constexpr auto qscale_str = "n"; static constexpr bool def_lse = true; static constexpr bool def_is_v_rowmajor = true; static int adjust_seqlen(int seqlen) { return seqlen; } }; + template <> -struct TestConfigs +struct TestConfigs { static constexpr auto HDimValues = std::array{std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}}; @@ -53,13 +54,14 @@ struct TestConfigs static constexpr auto AppendKVHDimValues = std::array{std::tuple{64, -1}, std::tuple{128, -1}}; static constexpr auto ModeValues = std::array{mode_enum::batch, mode_enum::group}; static constexpr auto IsVRowmajorValues = std::array{true}; - static constexpr bool squant = true; + static constexpr auto qscale_str = "pt"; static constexpr bool def_lse = false; static constexpr bool def_is_v_rowmajor = true; // When there are no fp8 instances with padding, pad seqlen to avoid skipping most of the tests: // return ck_tile::integer_least_multiple(seqlen, 128); static int adjust_seqlen(int seqlen) { return seqlen; } }; + template <> struct TestConfigs { @@ -76,7 +78,7 @@ struct TestConfigs static constexpr auto AppendKVHDimValues = std::array, 0>{}; static constexpr auto ModeValues = std::array{mode_enum::batch, mode_enum::group}; static constexpr auto IsVRowmajorValues = std::array{true}; - static constexpr bool squant = false; + static constexpr auto qscale_str = "n"; static constexpr bool def_lse = true; static constexpr bool def_is_v_rowmajor = true; static int adjust_seqlen(int seqlen) { return seqlen; } @@ -87,7 +89,7 @@ static auto SplitKVHDimValues = ValuesIn(TestConfigs::SplitKV static auto AppendKVHDimValues = ValuesIn(TestConfigs::AppendKVHDimValues); static auto ModeValues = ValuesIn(TestConfigs::ModeValues); static auto IsVRowmajorValues = ValuesIn(TestConfigs::IsVRowmajorValues); -constexpr bool squant = TestConfigs::squant; +constexpr static auto qscale_str = TestConfigs::qscale_str; constexpr bool def_lse = TestConfigs::def_lse; constexpr bool def_is_v_rowmajor = TestConfigs::def_is_v_rowmajor; int adjust_seqlen(int seqlen) { return TestConfigs::adjust_seqlen(seqlen); } @@ -203,7 +205,7 @@ TEST_P(AllLong, DataTypeConfig) 1024, // drop_offset false, // drop_prefs mask_str, // mask_str - squant, + qscale_str, true, // is_rotary_interleaved 1, // num_splits COMMON_ARGS); @@ -247,7 +249,7 @@ TEST(TestCkTileFmhaFwd, AppendKvWithBatchEffLensShouldFail) 0, // drop_offset false, // drop_prefs "0", // mask - squant, + qscale_str, true, // is_rotary_interleaved 1, // num_splits init_method, @@ -291,7 +293,7 @@ TEST(TestCkTileFmhaFwd, SplitKvWithGroupPaddingShouldFail) 0, false, "0", - squant, + qscale_str, true, 2, // num_splits (>1 triggers splitkv) init_method, @@ -334,7 +336,7 @@ TEST(TestCkTileFmhaFwd, PagedKvWithGroupPaddingShouldFail) 0, false, "0", - squant, + qscale_str, true, 1, init_method, @@ -403,7 +405,7 @@ TEST_P(HDimPadding, DataTypeConfig) 0, // drop_offset false, // drop_prefs mask_str, // mask_str - squant, + qscale_str, true, // is_rotary_interleaved 1, // num_splits COMMON_ARGS); @@ -463,7 +465,7 @@ TEST_P(ElementwiseBias, DataTypeConfig) 0, // drop_offset false, // drop_prefs mask_str, // mask_str - squant, + qscale_str, true, // is_rotary_interleaved 1, // num_splits COMMON_ARGS); @@ -522,7 +524,7 @@ TEST_P(Alibi, DataTypeConfig) 0, // drop_offset false, // drop_prefs mask_str, // mask_str - squant, + qscale_str, true, // is_rotary_interleaved 1, // num_splits COMMON_ARGS); @@ -583,7 +585,7 @@ TEST_P(Dropout, DataTypeConfig) drop_offset, // drop_offset drop_prefs, // drop_prefs mask_str, // mask_str - squant, + qscale_str, true, // is_rotary_interleaved 1, // num_splits COMMON_ARGS); @@ -648,7 +650,7 @@ TEST_P(PagedKV, DataTypeConfig) 0, // drop_offset false, // drop_prefs mask_str, // mask_str - squant, + qscale_str, true, // is_rotary_interleaved 1, // num_splits COMMON_ARGS); @@ -719,7 +721,7 @@ TEST_P(SplitKV, DataTypeConfig) 0, // drop_offset false, // drop_prefs mask_str, // mask_str - squant, + qscale_str, true, // is_rotary_interleaved num_splits, // num_splits COMMON_ARGS); @@ -796,7 +798,7 @@ TEST_P(AppendKV, DataTypeConfig) 0, // drop_offset false, // drop_prefs mask_str, // mask_str - squant, + qscale_str, false, // is_rotary_interleaved 1, // num_splits COMMON_ARGS); @@ -818,7 +820,7 @@ GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(AppendKVRoPE); INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd, AppendKVRoPE, - Combine(EnableTestIf(!std::is_same_v), + Combine(EnableTestIf(!std::is_same_v), AppendKVHDimValues, Bool(), // layouts of k and v are controlled by i_perm IsVRowmajorValues, // layout of v is controlled by is_v_rowmajor @@ -869,7 +871,7 @@ TEST_P(AppendKVRoPE, DataTypeConfig) 0, // drop_offset false, // drop_prefs mask_str, // mask_str - squant, + qscale_str, is_rotary_interleaved, // is_rotary_interleaved 1, // num_splits COMMON_ARGS); @@ -1105,7 +1107,7 @@ INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd_Padding, PaddingCases, ValuesIn(kPadd TEST_P(PaddingCases, DataTypeConfig) { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { GTEST_SKIP() << "Skip for fp8"; } @@ -1162,7 +1164,7 @@ TEST_P(PaddingCases, DataTypeConfig) 0, // drop_offset false, // drop_prefs mask_str, // mask_str - squant, + qscale_str, true, // is_rotary_interleaved 1, // num_splits COMMON_ARGS);