diff --git a/.github/workflows/therock-ci-linux.yml b/.github/workflows/therock-ci-linux.yml index 0baa503334..cc6178b08c 100644 --- a/.github/workflows/therock-ci-linux.yml +++ b/.github/workflows/therock-ci-linux.yml @@ -54,7 +54,7 @@ jobs: with: repository: "ROCm/TheRock" path: "TheRock" - ref: d76278526218def9fb1b016bc9e421738cb4f8f6 # 2025-12-09 commit + ref: e4d4316c3c20819045722f60fc63928944ebc397 # 2026-01-01 commit - name: Setup ccache run: | @@ -78,8 +78,9 @@ jobs: run: | git config --global --add safe.directory '*' # Remove patches here if they cannot be applied cleanly, and they have not been deleted from TheRock repo - rm -f ./TheRock/patches/amd-mainline/rocm-libraries/0008-Revert-remove-options-no-enumerate-966.patch - git -c user.name="therockbot" -c "user.email=therockbot@amd.com" am --whitespace=nowarn ./TheRock/patches/amd-mainline/rocm-libraries/*.patch + rm ./TheRock/patches/amd-mainline/rocm-libraries/0003-Find-rocm_smi-via-config-files.patch + rm ./TheRock/patches/amd-mainline/rocm-libraries/0007-Remove-Windows-third_party_dlls-copying-code.patch + # git -c user.name="therockbot" -c "user.email=therockbot@amd.com" am --whitespace=nowarn ./TheRock/patches/amd-mainline/rocm-libraries/*.patch - name: Install python deps run: | diff --git a/.github/workflows/therock-test-component.yml b/.github/workflows/therock-test-component.yml index 565d1d3e54..74f3bb0017 100644 --- a/.github/workflows/therock-test-component.yml +++ b/.github/workflows/therock-test-component.yml @@ -51,7 +51,7 @@ jobs: uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: repository: "ROCm/TheRock" - ref: d76278526218def9fb1b016bc9e421738cb4f8f6 # 2025-12-09 commit + ref: e4d4316c3c20819045722f60fc63928944ebc397 # 2026-01-01 commit - name: Run setup test environment workflow uses: './.github/actions/setup_test_environment' diff --git a/.github/workflows/therock-test-packages.yml b/.github/workflows/therock-test-packages.yml index cd255a40b6..e4bd295c95 100644 --- a/.github/workflows/therock-test-packages.yml +++ b/.github/workflows/therock-test-packages.yml @@ -27,7 +27,7 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: repository: "ROCm/TheRock" - ref: d76278526218def9fb1b016bc9e421738cb4f8f6 # 2025-12-09 commit + ref: e4d4316c3c20819045722f60fc63928944ebc397 # 2026-01-01 commit - name: "Configuring CI options" env: diff --git a/CHANGELOG.md b/CHANGELOG.md index 6a9b25b062..3280ad07dc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,12 +5,15 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ## (Unreleased) Composable Kernel 1.3.0 ### Added +* Added preshuffleB support for abquant mode in blockscale GEMM. * Added support for explicit GEMM in CK_TILE grouped convolution forward and backward weight. * Added TF32 convolution support on gfx942 and gfx950 in CK. It could be enabled/disabled via `DTYPES` of "tf32". * Added attention sink support for FMHA FWD, include qr_ks_vs, qr_async and splitkv pipelines. * Added support for microscaling (MX) FP8/FP4 mixed data types to Flatmm pipeline. * Added support for fp8 dynamic tensor-wise quantization of fp8 fmha fwd kernel. * Added FP8 KV cache support for FMHA batch prefill. +* Added support for gfx1153 target. +* Added FMHA batch prefill kernel support for several KV cache layouts, flexible page sizes, and different lookup table configurations. ### Changed diff --git a/Dockerfile.aiter b/Dockerfile.aiter index 94591f9012..020afeccf4 100644 --- a/Dockerfile.aiter +++ b/Dockerfile.aiter @@ -2,7 +2,7 @@ ARG BASE_DOCKER="rocm/pytorch:latest" FROM $BASE_DOCKER ARG AITER_BRANCH="main" ARG CK_AITER_BRANCH="develop" -RUN pip install pandas zmq einops ninja && \ +RUN pip install pandas zmq einops ninja tabulate && \ pip install numpy==1.26.2 && \ sudo mkdir /home/jenkins && \ sudo mkdir /home/jenkins/workspace && \ diff --git a/Jenkinsfile b/Jenkinsfile index cb2f8631c5..7292d9b70c 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1046,7 +1046,7 @@ def run_aiter_tests(Map conf=[:]){ sh "rocminfo" sh "python3 --version" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8.py" - //sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8_blockscale.py" //temporarily disable + sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8_blockscale.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_mha.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_mha_varlen.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe.py" @@ -1469,8 +1469,8 @@ pipeline { environment{ setup_args = "NO_CK_BUILD" execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \ - make -j64 test_grouped_convnd_fwd_large_cases test_grouped_convnd_bwd_data_xdl_large_cases test_grouped_convnd_fwd_bias_clamp_large_cases && \ - ./bin/test_grouped_convnd_fwd_large_cases && ./bin/test_grouped_convnd_bwd_data_xdl_large_cases && ./bin/test_grouped_convnd_fwd_bias_clamp_large_cases""" + make -j64 test_grouped_convnd_fwd_large_cases test_grouped_convnd_bwd_data_large_cases test_grouped_convnd_fwd_bias_clamp_large_cases && \ + ./bin/test_grouped_convnd_fwd_large_cases && ./bin/test_grouped_convnd_bwd_data_large_cases && ./bin/test_grouped_convnd_fwd_bias_clamp_large_cases""" } steps{ buildHipClangJobAndReboot(setup_args:setup_args, build_type: 'Release', execute_cmd: execute_args) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 95e8379769..c4c70009d5 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -36,6 +36,19 @@ DTYPE_BITS = { K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256} +SUPPORTED_PAGE_SIZE = [128, 256, 1024] +SUPPORTED_KV_MEMORY_LAYOUT = ["vectorized", "linear"] +SUPPORTED_KV_LOOKUP_TABLE = ["vllm", "sglang"] +KV_MEMORY_LAYOUT_ENUM_MAP = { + "vectorized": "ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT", + "linear": "ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT", +} +KV_LOOKUP_TABLE_ENUM_MAP = { + "vllm": "ck_tile::BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D", + "sglang": "ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D", +} + + FMHA_BATCH_PREFILL_PIPELINE_MAP = { "qr_async": "ck_tile::BlockFmhaBatchPrefillPipelineQRKSVSAsync", } @@ -59,7 +72,7 @@ using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, {F_vlayout}>; -using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, +using fmha_trait_{F_idx} = ck_tile::TileFmhaBatchPrefillTraits<{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, @@ -69,13 +82,17 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, {F_lse}, {F_dropout}, {F_qscale}, - {F_occupancy}>; + {F_occupancy}, + false, + {F_page_size}, + {F_kv_memory_layout}, + {F_kv_lookup_table}>; using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; using fmha_mask_{F_idx} = {F_mask}; -using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< +using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBatchPrefillPipelineProblem< typename FmhaFwdTypeConfig::QDataType, typename FmhaFwdTypeConfig::KDataType, typename FmhaFwdTypeConfig::VDataType, @@ -92,6 +109,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< fmha_variant_{F_idx}, fmha_mask_{F_idx}, false, + {F_page_size}, fmha_trait_{F_idx}>; using fmha_pipeline_{F_idx} = {F_pipeline}< @@ -105,8 +123,8 @@ using fmha_epilogue_{F_idx} = using fmha_kernel_{F_idx} = ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel; -using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>; +using trait_{F_idx} = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>; #include @@ -184,8 +202,8 @@ FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v """ FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && - ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ - using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>; + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size})) {{ + using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>; return fmha_batch_prefill_(s, a); }} """ @@ -230,12 +248,15 @@ class FmhaFwdApiTrait: dpad: str dvpad: str constraint: CppConstraint + kv_memory_layout: str + kv_lookup_table: str + page_size: int = 1 # page block size @property def name(self) -> str: return ( f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" - + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}" + + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.kv_memory_layout}-{self.kv_lookup_table}-ps{self.page_size}" ) @property @@ -322,6 +343,8 @@ class FmhaFwdPipeline: F_dropout: str # F_qscale: str # no/pertensor F_mask: str # value from MASK_MAP + F_kv_memory_layout: str # + F_kv_lookup_table: str # F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint()) @property @@ -382,6 +405,8 @@ class FmhaFwdPipeline: n += f"_{self.F_qscale}" else: n += "_nqscale" + + n += "_" + self.F_kv_memory_layout + "_" + self.F_kv_lookup_table return n @@ -440,6 +465,13 @@ class FmhaFwdApiPool: F_bk0max=trait.bk0max, F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype], + F_kv_memory_layout=KV_MEMORY_LAYOUT_ENUM_MAP[ + trait.kv_memory_layout + ], + F_kv_lookup_table=KV_LOOKUP_TABLE_ENUM_MAP[ + trait.kv_lookup_table + ], + F_page_size=trait.page_size, ) if_j = "if" if j == 0 else "else if" per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format( @@ -497,6 +529,7 @@ class FmhaFwdKernel: F_tile: FmhaFwdTileSize F_pipeline: FmhaFwdPipeline mask_impl: str + F_page_size: int = 1 # page block size @property def template(self) -> str: @@ -534,17 +567,24 @@ class FmhaFwdKernel: F_dropout=BOOL_MAP[self.F_pipeline.F_dropout], F_qscale=QSCALE_MAP[self.F_pipeline.F_qscale], F_occupancy=self.F_tile.F_occupancy, + F_kv_memory_layout=KV_MEMORY_LAYOUT_ENUM_MAP[ + self.F_pipeline.F_kv_memory_layout + ], + F_kv_lookup_table=KV_LOOKUP_TABLE_ENUM_MAP[ + self.F_pipeline.F_kv_lookup_table + ], F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], F_mode=MODE_MAP[self.F_mode], F_pipeline=FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag], + F_page_size=self.F_page_size, ) @property def name(self) -> str: # TODO: we don't encode idx here return ( - f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_ps{self.F_page_size}_" + self.F_tile.name + "_" + self.F_pipeline.name @@ -578,6 +618,9 @@ class FmhaFwdKernel: dpad=self.F_pipeline.F_dpad, dvpad=self.F_pipeline.F_dvpad, constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint, + kv_memory_layout=self.F_pipeline.F_kv_memory_layout, + kv_lookup_table=self.F_pipeline.F_kv_lookup_table, + page_size=self.F_page_size, ) @@ -604,23 +647,42 @@ class KernelComponentFactory: pipelines = [] if dtype in ["fp16", "bf16"]: qscale = "no" - for logits, mask, bias, lse, dropout in itertools.product( + for ( + logits, + mask, + bias, + lse, + dropout, + kv_memory_layout, + kv_lookup_table, + ) in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], + SUPPORTED_KV_MEMORY_LAYOUT, + SUPPORTED_KV_LOOKUP_TABLE, ): - pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip elif dtype in ["fp8bf16"]: # no need lse/dropout kernels - for logits, qscale, mask, bias in itertools.product( + for ( + logits, + qscale, + mask, + bias, + kv_memory_layout, + kv_lookup_table, + ) in itertools.product( ["t", "f"], ["pertensor"], get_mask_map(mask_impl).keys(), ["no"], + SUPPORTED_KV_MEMORY_LAYOUT, + SUPPORTED_KV_LOOKUP_TABLE, ): - pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip else: assert False return pipelines @@ -672,69 +734,73 @@ def get_fwd_blobs( or pipeline.F_logits == "f" ): continue - k = FmhaFwdKernel( - F_idx=0, - F_hdim=hdim, - F_dtype=dtype, - F_mode=mode, - F_tile=tile, - F_pipeline=pipeline, - mask_impl=mask_impl, - ) - if kernel_filter != "": - if not fnmatch.fnmatch(k.name, kernel_filter): - continue - if optdim_list != [-1]: - if hdim not in optdim_list: - continue - # 2 - Flash attention integration - if receipt in (2, 3): - cond = dtype in ["fp16", "bf16"] - cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_bias in ["no", "alibi"] - cond &= pipeline.F_qscale == "no" - if not cond: - continue - # PyTorch integration - elif receipt == 4: - cond = dtype in ["fp16", "bf16"] - cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_bias in ["no", "bias"] - cond &= pipeline.F_qscale == "no" - if not cond: - continue - # Aiter(mha_fwd) integration - elif receipt == 100: - cond = dtype in ["fp16", "bf16"] - cond &= mode == "batch" - cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_qscale == "no" - if not cond: - continue - # Aiter(mha_batch_prefill) integration - elif receipt == 200: - cond = dtype in ["fp16", "bf16", "fp8bf16"] - cond &= mode == "group" - cond &= pipeline.F_vlayout == "row" - if not cond: - continue - # aiter::mha_batch_prefill C++ api integration - elif receipt == 600: - cond = dtype in ["fp16", "bf16", "fp8bf16"] - cond &= mode == "group" - cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_qscale == "no" - if not cond: - continue - # fp32 only - if receipt == 800 or receipt == 801: - cond = dtype == "fp32" - if not cond: - continue + # Generate kernels for both page_size=16 and page_size=1024 + for page_size in SUPPORTED_PAGE_SIZE: + k = FmhaFwdKernel( + F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl, + F_page_size=page_size, + ) + if kernel_filter != "": + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue + # 2 - Flash attention integration + if receipt in (2, 3): + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_bias in ["no", "alibi"] + cond &= pipeline.F_qscale == "no" + if not cond: + continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_bias in ["no", "bias"] + cond &= pipeline.F_qscale == "no" + if not cond: + continue + # Aiter(mha_fwd) integration + elif receipt == 100: + cond = dtype in ["fp16", "bf16"] + cond &= mode == "batch" + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_qscale == "no" + if not cond: + continue + # Aiter(mha_batch_prefill) integration + elif receipt == 200: + cond = dtype in ["fp16", "bf16", "fp8bf16"] + cond &= mode == "group" + cond &= pipeline.F_vlayout == "row" + if not cond: + continue + # aiter::mha_batch_prefill C++ api integration + elif receipt == 600: + cond = dtype in ["fp16", "bf16", "fp8bf16"] + cond &= mode == "group" + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_qscale == "no" + if not cond: + continue - api_pool.register_traits(k.api_trait()) - gen.append(k) + # fp32 only + if receipt == 800 or receipt == 801: + cond = dtype == "fp32" + if not cond: + continue + + api_pool.register_traits(k.api_trait()) + gen.append(k) return (api_pool, gen) diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index ba55d6d722..3ff4acfc15 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -529,14 +529,25 @@ struct fmha_batch_prefill_args ck_tile::index_t nhead_q; ck_tile::index_t nhead_k; - // SGLang-style page table - int32_t num_total_pages; - void* kv_indptr; - void* kv_page_indices; -#if 0 // we assume page_block_size=1 for now - void* kv_last_page_lens; - ck_tile::index_t page_block_size; -#endif + // KV cache page table fields (kv_lookup_table selects interpretation): + // - SGLANG_PAGE_TABLE_1D: + // kv_indptr: prefix-sum [batch+1] into kv_page_indices + // kv_page_indices: 1D list of physical page ids, length = num_total_pages + // kv_last_page_lens: per-batch last page lengths [batch] + // - VLLM_BLOCK_TABLE_2D: + // kv_page_indices: block_table [batch, max_blocks_per_seq] (2D) + // batch_stride_block_table: row stride for block_table + // seqlen_k_ptr: per-batch seqlen_k [batch] + int32_t num_total_pages; // total physical pages in KV cache (SGLang/vLLM) + ck_tile::index_t page_block_size; // tokens per page (SGLang/vLLM) + ck_tile::BlockAttentionKVCacheMemoryLayoutEnum + kv_memory_layout; // KV memory layout (SGLang/vLLM) + ck_tile::BlockAttentionKVCacheLookupTableEnum kv_lookup_table; // lookup table layout selector + void* kv_indptr; // SGLang: prefix-sum; vLLM: unused + void* kv_page_indices; // SGLang: 1D page list; vLLM: block_table 2D + void* kv_last_page_lens; // SGLang: last page lengths; vLLM: unused + void* seqlen_k_ptr; // vLLM: per-batch seqlen_k; SGLang: unused + ck_tile::index_t batch_stride_block_table; // vLLM: row stride; SGLang: unused float scale_s; float scale_p; @@ -1113,6 +1124,22 @@ template auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) { assert(args.nhead_q % args.nhead_k == 0); + using PageTableKargs = typename FmhaKernel::PageBlockTableKargs; + const PageTableKargs page_table = [&]() { + if constexpr(FmhaKernel::kKVLookupTable == + ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D) + { + return PageTableKargs{reinterpret_cast(args.kv_indptr), + reinterpret_cast(args.kv_page_indices), + reinterpret_cast(args.kv_last_page_lens)}; + } + else + { + return PageTableKargs{reinterpret_cast(args.kv_page_indices), + args.batch_stride_block_table, + reinterpret_cast(args.seqlen_k_ptr)}; + } + }(); auto kargs = [&] { // create group mode kernel arguments if constexpr(FmhaKernel::kIsGroupMode) @@ -1133,12 +1160,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) args.nhead_q, args.nhead_q / args.nhead_k, args.num_total_pages, - args.kv_indptr, - args.kv_page_indices, -#if 0 // we assume page_block_size=1 for now - args.kv_last_page_lens, args.page_block_size, -#endif + page_table, args.scale_s, args.scale_p, args.scale_o, @@ -1184,12 +1207,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) args.nhead_q, args.nhead_q / args.nhead_k, args.num_total_pages, - args.kv_indptr, - args.kv_page_indices, -#if 0 // we assume page_block_size=1 for now - args.kv_last_page_lens, args.page_block_size, -#endif + page_table, args.scale_s, args.scale_p, args.scale_o, @@ -1281,6 +1300,65 @@ struct fmha_fwd_traits_ static constexpr bool kHasSink = kHasSink_; }; +template +struct fmha_fwd_batch_prefill_traits_ : public fmha_fwd_traits_ +{ + static constexpr auto kKVMemoryLayout = kKVMemoryLayout_; + static constexpr auto kKVLookupTable = kKVLookupTable_; + static constexpr ck_tile::index_t kPageBlockSize = kPageBlockSize_; + static_assert(kIsVLayoutRowMajor_, "Batch prefill only supports row-major V layout"); +}; + template float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args); @@ -1527,7 +1605,15 @@ float fmha_fwd_appendkv(fmha_fwd_appendkv_traits, fmha_fwd_appendkv_args, const ck_tile::stream_config&); -using fmha_batch_prefill_traits = fmha_fwd_traits; +struct fmha_batch_prefill_traits : public fmha_fwd_traits +{ + ck_tile::BlockAttentionKVCacheMemoryLayoutEnum kv_memory_layout = + ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT; + ck_tile::BlockAttentionKVCacheLookupTableEnum kv_lookup_table = + ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D; + int page_size = 1; +}; + float fmha_batch_prefill(fmha_batch_prefill_traits, fmha_batch_prefill_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/03_gemm/gemm_basic_invoker.hpp b/example/ck_tile/03_gemm/gemm_basic_invoker.hpp index 77a9fe4271..df8351602b 100644 --- a/example/ck_tile/03_gemm/gemm_basic_invoker.hpp +++ b/example/ck_tile/03_gemm/gemm_basic_invoker.hpp @@ -69,107 +69,88 @@ struct BasicInvoker using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + M_Warp, + N_Warp, + M_Warp_Tile, + N_Warp_Tile, + K_Warp_Tile, + CodegenPipelineProblem::TransposeC>>; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - AccDataType, - CDataType, - ck_tile::tuple<>, - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - M_Warp, - N_Warp, - M_Warp_Tile, - N_Warp_Tile, - K_Warp_Tile, - CodegenPipelineProblem::TransposeC, - memory_operation>>; + // ToDo: Will add the codegen part to test different pipeline policies in GEMM. + // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); - // ToDo: Will add the codegen part to test different pipeline policies in GEMM. - // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << CodegenGemmShape::GetName() << '\n' + << "problem: " << CodegenPipelineProblem::GetName() << '\n' + << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << CodegenGemmShape::GetName() << '\n' - << "problem: " << CodegenPipelineProblem::GetName() << '\n' - << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; - } + // Declare rotating_mem_ptr here so it stays in scope until it is needed + std::unique_ptr> rotating_mem_ptr; + std::function preprocess; - // Declare rotating_mem_ptr here so it stays in scope until it is needed - std::unique_ptr> rotating_mem_ptr; - std::function preprocess; - - auto clear_gemm_output = [&]() { - if(args.k_batch > 1) - hipGetErrorString(hipMemsetAsync( - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); - }; - - if(s.flush_cache_) - { - std::cout << "Flushing cache..." << std::endl; - - ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( - args.M, args.K, args.stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); - - auto size_a_buffer = a_m.get_element_space_size_in_bytes(); - auto size_b_buffer = b_n.get_element_space_size_in_bytes(); - - rotating_mem_ptr = - std::make_unique>( - kargs.as_ptr[0], - kargs.bs_ptr[0], - s.rotating_count_, - size_a_buffer, - size_b_buffer); - rotating_mem_ptr->Print(); - - preprocess = [&]() { - ck_tile::flush_icache(); - rotating_mem_ptr->Next(); - clear_gemm_output(); - }; - } - else - { - preprocess = clear_gemm_output; - } - - return ck_tile::launch_kernel_time_mask( - s, - preprocess, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + auto clear_gemm_output = [&]() { + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); }; - if(args.k_batch == 1) + if(s.flush_cache_) { - return Run(MemoryOpSet{}); + std::cout << "Flushing cache..." << std::endl; + + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes(); + auto size_b_buffer = b_n.get_element_space_size_in_bytes(); + + rotating_mem_ptr = std::make_unique>( + kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem_ptr->Print(); + + preprocess = [&]() { + ck_tile::flush_icache(); + rotating_mem_ptr->Next(); + clear_gemm_output(); + }; } else { - return Run(MemoryOpAtomicAdd{}); + preprocess = clear_gemm_output; } + + return ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } }; diff --git a/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp b/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp index c312a53c2a..d2460193d8 100644 --- a/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp +++ b/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp @@ -72,160 +72,144 @@ struct SplitKTwoStageInvoker using GemmPipeline = typename PipelineTypeTraits< GemmConfig::Pipeline>::template GemmPipeline; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using GemmKernel = ck_tile::GemmKernel; - using GemmKernel = ck_tile::GemmKernel; + ck_tile::DeviceMem ws_m_n_dev_buf(args.M * args.N * sizeof(WorkspaceType)); + ck_tile::GemmHostArgs ws_args = ck_tile::GemmHostArgs(args); + auto c_ptr = ws_args.c_ptr; + ws_args.c_ptr = ws_m_n_dev_buf.GetDeviceBuffer(); + auto gemm_kargs = GemmKernel::MakeKernelArgs(ws_args); - ck_tile::DeviceMem ws_m_n_dev_buf(args.M * args.N * sizeof(WorkspaceType)); - ck_tile::GemmHostArgs ws_args = ck_tile::GemmHostArgs(args); - auto c_ptr = ws_args.c_ptr; - ws_args.c_ptr = ws_m_n_dev_buf.GetDeviceBuffer(); - auto gemm_kargs = GemmKernel::MakeKernelArgs(ws_args); + const dim3 grids = Persistent ? GemmKernel::MaxOccupancyGridSize(s) + : GemmKernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = GemmKernel::BlockSize(); - const dim3 grids = Persistent ? GemmKernel::MaxOccupancyGridSize(s) - : GemmKernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = GemmKernel::BlockSize(); + if(!GemmKernel::IsSupportedArgument(gemm_kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } - if(!GemmKernel::IsSupportedArgument(gemm_kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } + using XElementwiseOperation = ck_tile::element_wise::UnaryConvert; + using BlockTile = ck_tile::sequence<2048>; + using BlockWarps = ck_tile::sequence<8>; + using WarpTile = ck_tile::sequence<64>; - using XElementwiseOperation = ck_tile::element_wise::UnaryConvert; - using BlockTile = ck_tile::sequence<2048>; - using BlockWarps = ck_tile::sequence<8>; - using WarpTile = ck_tile::sequence<64>; + using ElementwiseShape = + ck_tile::ElementWiseShape; + using Problem = ck_tile::ElementWisePipelineProblem; + using ElementwiseKernel = + ck_tile::ElementWiseKernel; - using ElementwiseShape = - ck_tile::ElementWiseShape; - using Problem = ck_tile::ElementWisePipelineProblem; - using ElementwiseKernel = - ck_tile::ElementWiseKernel; + ck_tile::index_t total_elements = 1; + std::vector shape = {args.M, args.N}; - ck_tile::index_t total_elements = 1; - std::vector shape = {args.M, args.N}; + for(auto d : shape) + total_elements *= d; - for(auto d : shape) - total_elements *= d; + const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = 1; - const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = 1; + constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{}); + ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block; - constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{}); - ck_tile::index_t kGridSize = - (total_elements + elements_per_block - 1) / elements_per_block; + auto input_tensors = ck_tile::make_tuple(static_cast(ws_args.c_ptr)); + auto input_size = ck_tile::make_tuple(args.M, args.N); - auto input_tensors = ck_tile::make_tuple(static_cast(ws_args.c_ptr)); - auto input_size = ck_tile::make_tuple(args.M, args.N); + // Check if the kernel configuration is supported + if(!ElementwiseKernel::IsSupportedArgument(input_size)) + { + throw std::runtime_error( + "Wrong! Elementwise arguments not supported! Skipping gemm!\n"); + } - // Check if the kernel configuration is supported - if(!ElementwiseKernel::IsSupportedArgument(input_size)) - { - throw std::runtime_error( - "Wrong! Elementwise arguments not supported! Skipping gemm!\n"); - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; - } + // Declare rotating_mem_ptr here so it stays in scope until it is needed + std::unique_ptr> rotating_mem_ptr; + std::function preprocess; - // Declare rotating_mem_ptr here so it stays in scope until it is needed - std::unique_ptr> rotating_mem_ptr; - std::function preprocess; - - auto clear_gemm_output = [&]() { - if(args.k_batch > 1) - hipGetErrorString(hipMemsetAsync( - ws_args.c_ptr, 0, args.M * args.N * sizeof(WorkspaceType), s.stream_id_)); - }; - - if(s.flush_cache_) - { - std::cout << "Flushing cache..." << std::endl; - - ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( - args.M, args.K, args.stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); - - auto size_a_buffer = a_m.get_element_space_size_in_bytes(); - auto size_b_buffer = b_n.get_element_space_size_in_bytes(); - - rotating_mem_ptr = - std::make_unique>( - gemm_kargs.as_ptr[0], - gemm_kargs.bs_ptr[0], - s.rotating_count_, - size_a_buffer, - size_b_buffer); - rotating_mem_ptr->Print(); - - preprocess = [&]() { - ck_tile::flush_icache(); - rotating_mem_ptr->Next(); - clear_gemm_output(); - }; - } - else - { - preprocess = clear_gemm_output; - } - - return ck_tile::launch_kernel_time_mask( - s, - preprocess, - ck_tile::make_kernel( - GemmKernel{}, grids, blocks, 0, gemm_kargs), - ck_tile::make_kernel(ElementwiseKernel{}, - kGridSize, - kBlockSize, - 0, - input_size, - ck_tile::make_tuple(args.N, 1), // Input Stride - ck_tile::make_tuple(args.N, 1), // Output Stride - input_tensors, - static_cast(c_ptr))); + auto clear_gemm_output = [&]() { + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + ws_args.c_ptr, 0, args.M * args.N * sizeof(WorkspaceType), s.stream_id_)); }; - if(args.k_batch == 1) + if(s.flush_cache_) { - return Run(MemoryOpSet{}); + std::cout << "Flushing cache..." << std::endl; + + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes(); + auto size_b_buffer = b_n.get_element_space_size_in_bytes(); + + rotating_mem_ptr = std::make_unique>( + gemm_kargs.as_ptr[0], + gemm_kargs.bs_ptr[0], + s.rotating_count_, + size_a_buffer, + size_b_buffer); + rotating_mem_ptr->Print(); + + preprocess = [&]() { + ck_tile::flush_icache(); + rotating_mem_ptr->Next(); + clear_gemm_output(); + }; } else { - return Run(MemoryOpAtomicAdd{}); + preprocess = clear_gemm_output; } + + return ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel( + GemmKernel{}, grids, blocks, 0, gemm_kargs), + ck_tile::make_kernel(ElementwiseKernel{}, + kGridSize, + kBlockSize, + 0, + input_size, + ck_tile::make_tuple(args.N, 1), // Input Stride + ck_tile::make_tuple(args.N, 1), // Output Stride + input_tensors, + static_cast(c_ptr))); } }; diff --git a/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp index c06dc457c9..64305b85cf 100644 --- a/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp +++ b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp @@ -160,110 +160,101 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config& args.stride_E); constexpr auto scheduler = GemmConfig::Scheduler; - const auto Run = [&]() { - // use SET operation since each K-split writes to separate memory - constexpr auto memory_operation = ck_tile::memory_operation_enum::set; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< + UniversalGemmProblem>; - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; + using GemmEpilogue = + ck_tile::CShuffleEpilogue>; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(base_args); - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(base_args); + dim3 grids; + if constexpr(Persistent) + { + grids = Kernel::MaxOccupancyGridSize(s); + } + else + { + grids = Kernel::GridSize(args.M, args.N, args.k_batch); + } + const dim3 blocks = Kernel::BlockSize(); - dim3 grids; - if constexpr(Persistent) - { - grids = Kernel::MaxOccupancyGridSize(s); - } - else - { - grids = Kernel::GridSize(args.M, args.N, args.k_batch); - } - const dim3 blocks = Kernel::BlockSize(); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } + if(s.log_level_ > 0) + { + std::cout << "Stage 1 - Launching GEMM kernel: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } - if(s.log_level_ > 0) - { - std::cout << "Stage 1 - Launching GEMM kernel: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; - } + if(s.flush_cache_) + { + std::cout << "Flushing cache..." << std::endl; - if(s.flush_cache_) - { - std::cout << "Flushing cache..." << std::endl; + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); - ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( - args.M, args.K, args.stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + auto size_a_buffer = a_m.get_element_space_size_in_bytes(); + auto size_b_buffer = b_n.get_element_space_size_in_bytes(); - auto size_a_buffer = a_m.get_element_space_size_in_bytes(); - auto size_b_buffer = b_n.get_element_space_size_in_bytes(); + ck_tile::RotatingMemWrapper rotating_mem( + kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem.Print(); - ck_tile::RotatingMemWrapper rotating_mem( - kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer); - rotating_mem.Print(); - - auto run_flush_cache = [&]() { - // flush icache - ck_tile::flush_icache(); - // rotating mem - rotating_mem.Next(); - // clear c mem - if(args.k_batch > 1) - hipGetErrorString(hipMemsetAsync( - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); - }; - return ck_tile::launch_kernel_time_mask( - s, - run_flush_cache, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - } - else - { - return ck_tile::launch_kernel( - s, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - } - }; - - return Run(); + auto run_flush_cache = [&]() { + // flush icache + ck_tile::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + }; + return ck_tile::launch_kernel_time_mask( + s, + run_flush_cache, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + } + else + { + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + } } /** diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index f79494a478..8eff0e7469 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -460,12 +460,6 @@ inline auto create_args() return arg_parser; } -// Type aliases for memory operation integral constants -using MemoryOpSet = - std::integral_constant; -using MemoryOpAtomicAdd = std::integral_constant; - // host API template ::template GemmPipeline; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); - dim3 grids; - if constexpr(Persistent) - { - grids = Kernel::MaxOccupancyGridSize(s); - } - else - { - grids = Kernel::GridSize(args.M, args.N, args.k_batch); - } - dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << ", kBlockPerCu: {" << GemmConfig::kBlockPerCu << "}" - << std::endl; - } - float ave_time = 0.f; - if(s.flush_cache_) - { - std::cout << "Flushing cache..." << std::endl; - - ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( - args.M, args.K, args.stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); - - auto size_a_buffer = a_m.get_element_space_size_in_bytes(); - auto size_b_buffer = b_n.get_element_space_size_in_bytes(); - - ck_tile::RotatingMemWrapper rotating_mem(kargs.as_ptr[0], - kargs.bs_ptr[0], - s.rotating_count_, - size_a_buffer, - size_b_buffer); - rotating_mem.Print(); - - auto run_flush_cache = [&]() { - // flush icache - ck_tile::flush_icache(); - // rotating mem - rotating_mem.Next(); - // clear c mem - if(args.k_batch > 1) - hipGetErrorString(hipMemsetAsync( - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); - }; - ave_time = - ck_tile::launch_kernel_time_mask(s, - run_flush_cache, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); - } - else - { - ave_time = ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); - } - return ave_time; - }; - - if(args.k_batch == 1) + dim3 grids; + if constexpr(Persistent) { - return Run(ck_tile::integral_constant{}); + grids = Kernel::MaxOccupancyGridSize(s); } else { - throw std::runtime_error("split-k is not supported yet!"); + grids = Kernel::GridSize(args.M, args.N, args.k_batch); } + dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << ", kBlockPerCu: {" << GemmConfig::kBlockPerCu << "}" << std::endl; + } + float ave_time = 0.f; + if(s.flush_cache_) + { + std::cout << "Flushing cache..." << std::endl; + + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes(); + auto size_b_buffer = b_n.get_element_space_size_in_bytes(); + + ck_tile::RotatingMemWrapper rotating_mem( + kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck_tile::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + }; + ave_time = ck_tile::launch_kernel_time_mask( + s, + run_flush_cache, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + } + else + { + ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + } + return ave_time; } }; diff --git a/example/ck_tile/03_gemm/universal_gemm_invoker.hpp b/example/ck_tile/03_gemm/universal_gemm_invoker.hpp index 4a83a2c4ab..fb89e6b4cc 100644 --- a/example/ck_tile/03_gemm/universal_gemm_invoker.hpp +++ b/example/ck_tile/03_gemm/universal_gemm_invoker.hpp @@ -60,112 +60,94 @@ struct UniversalInvoker using GemmPipeline = typename PipelineTypeTraits< GemmConfig::Pipeline>::template GemmPipeline; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using GemmEpilogue = ck_tile::CShuffleEpilogue>; + using Kernel = ck_tile::GemmKernel; - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); + auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Persistent ? Kernel::MaxOccupancyGridSize(s) - : Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Persistent ? Kernel::MaxOccupancyGridSize(s) + : Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } - // Declare rotating_mem_ptr here so it stays in scope until it is needed - std::unique_ptr> rotating_mem_ptr; - std::function preprocess; + // Declare rotating_mem_ptr here so it stays in scope until it is needed + std::unique_ptr> rotating_mem_ptr; + std::function preprocess; - auto clear_gemm_output = [&]() { - if(args.k_batch > 1) - hipGetErrorString(hipMemsetAsync( - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); - }; - - if(s.flush_cache_) - { - std::cout << "Flushing cache..." << std::endl; - - ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( - args.M, args.K, args.stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); - - auto size_a_buffer = a_m.get_element_space_size_in_bytes(); - auto size_b_buffer = b_n.get_element_space_size_in_bytes(); - - rotating_mem_ptr = - std::make_unique>( - kargs.as_ptr[0], - kargs.bs_ptr[0], - s.rotating_count_, - size_a_buffer, - size_b_buffer); - rotating_mem_ptr->Print(); - - preprocess = [&]() { - ck_tile::flush_icache(); - rotating_mem_ptr->Next(); - clear_gemm_output(); - }; - } - else - { - preprocess = clear_gemm_output; - } - - return ck_tile::launch_kernel_time_mask( - s, - preprocess, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + auto clear_gemm_output = [&]() { + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); }; - if(args.k_batch == 1) + if(s.flush_cache_) { - return Run(MemoryOpSet{}); + std::cout << "Flushing cache..." << std::endl; + + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes(); + auto size_b_buffer = b_n.get_element_space_size_in_bytes(); + + rotating_mem_ptr = std::make_unique>( + kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem_ptr->Print(); + + preprocess = [&]() { + ck_tile::flush_icache(); + rotating_mem_ptr->Next(); + clear_gemm_output(); + }; } else { - return Run(MemoryOpAtomicAdd{}); + preprocess = clear_gemm_output; } + + return ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } }; diff --git a/example/ck_tile/13_moe_sorting/moe_sorting.cpp b/example/ck_tile/13_moe_sorting/moe_sorting.cpp index d9cb54cf74..a98faf5840 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting.cpp @@ -334,13 +334,13 @@ bool test_moe_sorting(ck_tile::ArgParser args) if(moe_buf_bytes > 0) { #if MOE_SORTING_FMOE_2D_BUF - printf("moe_buf:%lu(%d,%d), ", + printf("moe_buf:%" PRIu64 "(%d,%d), ", static_cast(moe_buf_bytes), moe_buf_interm_dim, moe_buf_elem_bytes); #else - printf("moe_buf:%lu, ", static_cast(moe_buf_bytes)); + printf("moe_buf:%" PRIu64 ", ", static_cast(moe_buf_bytes)); #endif } diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.cpp b/example/ck_tile/16_batched_gemm/batched_gemm.cpp index c7e37bc8a7..b68c30351d 100644 --- a/example/ck_tile/16_batched_gemm/batched_gemm.cpp +++ b/example/ck_tile/16_batched_gemm/batched_gemm.cpp @@ -78,63 +78,48 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< UniversalGemmProblem>; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::BatchedGemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); + using Kernel = ck_tile::BatchedGemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); - const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); + const dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; - } - - return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - }; - - if(args.k_batch == 1) + if(!Kernel::IsSupportedArgument(kargs)) { - return Run(ck_tile::integral_constant{}); + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); } - else + + if(s.log_level_ > 0) { - return Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; } + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } #include "run_batched_gemm_example.inc" diff --git a/example/ck_tile/17_grouped_gemm/CMakeLists.txt b/example/ck_tile/17_grouped_gemm/CMakeLists.txt index 9b51af22fe..0f0a0d8ba7 100644 --- a/example/ck_tile/17_grouped_gemm/CMakeLists.txt +++ b/example/ck_tile/17_grouped_gemm/CMakeLists.txt @@ -14,7 +14,7 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95") quant_grouped_gemm_bf8_rowcol.cpp quant_grouped_gemm_bf8_tensor.cpp ) - + add_executable(tile_example_abquant_grouped_gemm abquant_grouped_gemm.cpp) add_executable(tile_example_grouped_gemm_preshuffle grouped_gemm_preshuffle.cpp) add_executable(tile_example_grouped_gemm_multi_d grouped_gemm_multi_d.cpp) set(EXAMPLE_GEMM_COMPILE_OPTIONS) @@ -25,4 +25,5 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95") target_compile_options(tile_example_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(tile_example_grouped_gemm_multi_d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(tile_example_quant_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(tile_example_abquant_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() diff --git a/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp new file mode 100644 index 0000000000..84da1e26da --- /dev/null +++ b/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp @@ -0,0 +1,278 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" +#include "ck_tile/ops/gemm_quant.hpp" +#include "ck_tile/host.hpp" +#include "abquant_grouped_gemm.hpp" + +// Non-persistent grouped gemm for ABQuant +template +float grouped_gemm_abquant(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr) +{ + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = + GemmQuantConfig::template BaseGemmPipeline; + + const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile; + const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile; + + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + + using QuantGemmProblem = ck_tile::GemmABQuantPipelineProblem; + + using GemmPipeline = + GemmQuantConfig::template GemmPipeline; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + QuantGemmProblem::TransposeC>>; + + using Kernel = ck_tile::QuantGroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Kernel arguments not supported!"); + } + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + + return ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); + }; + + return ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); +} + +// Persistent grouped gemm tileloop for ABQuant +template +float grouped_gemm_tileloop(const ck_tile::stream_config& s, + const ck_tile::index_t num_groups, + void* kargs_ptr) +{ + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + + using QuantGemmProblem = ck_tile::GemmABQuantPipelineProblem; + + using GemmPipeline = GemmQuantConfig::template GemmPipeline; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + QuantGemmProblem::TransposeC>>; + using Kernel = ck_tile::QuantGroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + + return ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); +} + +#include "run_grouped_gemm_abquant_example.inc" + +int main(int argc, char* argv[]) +{ + int result1 = run_abquant_grouped_gemm_example(argc, argv); + return result1; +} diff --git a/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.hpp new file mode 100644 index 0000000000..da8bd5514c --- /dev/null +++ b/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.hpp @@ -0,0 +1,171 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/utility/json_dump.hpp" + +template +struct GemmTypeConfig; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template +struct GemmConfigBase +{ + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool PermuteA = false; + static constexpr bool PermuteB = false; + + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; + + static constexpr int kBlockPerCu = 1; + static constexpr ck_tile::index_t TileParitionerGroupNum = 8; + static constexpr ck_tile::index_t TileParitionerM01 = 4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr bool DoubleSmemBuffer = false; + static constexpr bool PreshuffleB = false; + static constexpr bool Persistent = Persistent_; +}; + +template +struct GemmConfigComputeV3_2 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); +}; + +template +struct GemmQuantConfig; + +// ABQuant specialization for GemmQuantConfig +template <> +struct GemmQuantConfig +{ + template + using GemmConfig = GemmConfigComputeV3_2; + + template + using GemmPipeline = ck_tile::ABQuantGemmPipelineAgBgCrCompV3; + + template + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs; + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("Ms", "", "M dimensions - empty by default.") + .insert("Ns", "", "N dimensions - empty by default.") + .insert("Ks", "", "K dimensions - empty by default.") + .insert( + "stride_As", + "", + "Tensor A strides - it is empty by default.") // stride_As/stride_Bs/stride_Cs/stride_AQs/stride_BQs + // can be set to zero if + // Ms/Ns/Ks is not empty + .insert("stride_Bs", "", "Tensor B strides - it is empty by default.") + .insert("stride_Cs", "", "Tensor C strides - it is empty by default.") + .insert("stride_AQs", "", "Tensor AQ strides - it is empty by default.") + .insert("stride_BQs", "", "Tensor BQ strides - it is empty by default.") + .insert("a_layout", "R", "A tensor data layout - Row by default.") + .insert("b_layout", "C", "B tensor data layout - Row by default.") + .insert("c_layout", "R", "C tensor data layout - Row by default.") + .insert("validate", "1", "0. No validation, 1. Validation on CPU.") + .insert("prec", "fp8", "data type. fp16/bf16/fp8/bf8") + .insert("warmup", "10", "number of iterations before benchmark the kernel.") + .insert("repeat", "100", "number of iterations to benchmark the kernel.") + .insert("group_count", "8", "group count.") + .insert("kbatch", "1", "kbatch for SplitK") + .insert("init", "0", "0. Random, 2. One(s) (Constant)") + .insert("persistent", "0", "Kernel persistency. 0: non-persistent. 1: persistent.") + .insert("bquant_group_size", "1x1x128", "BQuant group size. 1x1x128 (default) or 1x128x128") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format") + .insert("jsonfile", "abquant_grouped_gemm.json", "json file name to dump results"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +inline std::size_t get_workspace_size(const std::vector& gemm_descs) +{ + return gemm_descs.size() * sizeof(ck_tile::QuantGemmTransKernelArg); +} + +// Forward declaration of the non-persistent version +template +float grouped_gemm_abquant(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr); + +// Forward declaration of the tileloop version for persistent kernels +template +float grouped_gemm_tileloop(const ck_tile::stream_config& s, + const ck_tile::index_t num_groups, + void* kargs_ptr); diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index 3ff3f2f10e..a24e4bc8ab 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -62,71 +62,55 @@ float grouped_gemm(const std::vector& gemm_descs, using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< UniversalGemmProblem>; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Kernel arguments not supported!"); - } - - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(gemm_descs); - - HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" - << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" - << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - return ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - gemm_descs.size())); - }; - - if(gemm_descs[0].k_batch == 1) + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) { - return Run(ck_tile::integral_constant{}); + throw std::runtime_error("Kernel arguments not supported!"); } - else + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) { - return Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } + + return ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); } template float grouped_gemm_tileloop(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, - void* kargs_ptr, - bool splitk) + void* kargs_ptr) { using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -161,74 +144,55 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, BLayout, CLayout>; - float ave_time{0}; + constexpr auto scheduler = GemmConfig::Scheduler; - const auto Run = [&](const auto memory_operation_) { - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + // We create the GEMM pipeline without specifying hotloop or tailnumber. + // These are automatically run inside the kernel based on the given input data. + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - // We create the GEMM pipeline without specifying hotloop or tailnumber. - // These are automatically run inside the kernel based on the given input data. - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< + UniversalGemmProblem>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + UniversalGemmProblem::TransposeC>>; + using Kernel = ck_tile::GroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - AccDataType, - CDataType, - ck_tile::tuple<>, - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - UniversalGemmProblem::TransposeC, - memory_operation>>; - using Kernel = ck_tile::GroupedGemmKernel; - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::MaxOccupancyGridSize(s); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" - << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" - << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - return ave_time = ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - num_groups)); - }; - - if(!splitk) + if(s.log_level_ > 0) { - return ave_time = Run(ck_tile::integral_constant{}); - } - else - { - return ave_time = - Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } + + return ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); } #include "run_grouped_gemm_example.inc" diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index 67b411c1f0..462f11e405 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -328,5 +328,4 @@ template float grouped_gemm_tileloop(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, - void* kargs_ptr, - bool splitk = false); + void* kargs_ptr); diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp index 060dd311b5..e5aefad8d1 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp @@ -61,72 +61,56 @@ float grouped_gemm_multi_d(const std::vector& gemm_d using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< UniversalGemmProblem>; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Kernel arguments not supported!"); - } - - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(gemm_descs); - - HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: { " - << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" - << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - return ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - gemm_descs.size())); - }; - - if(gemm_descs[0].k_batch == 1) + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) { - return Run(ck_tile::integral_constant{}); + throw std::runtime_error("Kernel arguments not supported!"); } - else + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) { - return Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: { " + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } + + return ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); } template float grouped_gemm_multi_d_tileloop(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, - void* kargs_ptr, - bool splitk) + void* kargs_ptr) { using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -163,76 +146,55 @@ float grouped_gemm_multi_d_tileloop(const ck_tile::stream_config& s, BLayout, ELayout>; - float ave_time{0}; + constexpr auto scheduler = GemmConfig::Scheduler; - const auto Run = [&](const auto memory_operation_) { - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + // We create the GEMM pipeline without specifying hotloop or tailnumber. + // These are automatically run inside the kernel based on the given input data. + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - // We create the GEMM pipeline without specifying hotloop or tailnumber. - // These are automatically run inside the kernel based on the given input data. - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< + UniversalGemmProblem>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::MaxOccupancyGridSize(s); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" - << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" - << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - num_groups)); - - return ave_time; - }; - if(!splitk) + if(s.log_level_ > 0) { - Run(ck_tile::integral_constant{}); - } - else - { - Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } - return ave_time; + return ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); } #include "run_grouped_gemm_multi_d_example.inc" diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp index 4a5be996c0..b4c10900d6 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp @@ -65,70 +65,54 @@ float grouped_gemm(const std::vector& gemm_descs, using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< UniversalGemmProblem>; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Kernel arguments not supported!"); - } - - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(gemm_descs); - - HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" - << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" - << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - return ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - gemm_descs.size())); - }; - - if(gemm_descs[0].k_batch == 1) + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) { - return Run(ck_tile::integral_constant{}); + throw std::runtime_error("Kernel arguments not supported!"); } - else + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) { - return Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } + + return ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); } template float grouped_gemm_tileloop(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, - void* kargs_ptr, - bool splitk) + void* kargs_ptr) { using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -167,75 +150,53 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, GemmConfig::NumWaveGroups, GemmConfig::Preshuffle>; - float ave_time{0}; + constexpr auto scheduler = GemmConfig::Scheduler; - const auto Run = [&](const auto memory_operation_) { - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< + UniversalGemmProblem>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, // DsDataType (empty for no D tensors) + AccDataType, + CDataType, + ck_tile::tuple<>, // DsLayout (empty for no D tensors) + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + UniversalGemmProblem::TransposeC>>; + using Kernel = ck_tile::GroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue, // DsDataType (empty for no D tensors) - AccDataType, - CDataType, - ck_tile::tuple<>, // DsLayout (empty for no D tensors) - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - UniversalGemmProblem::TransposeC, - memory_operation>>; - using Kernel = ck_tile::GroupedGemmKernel; - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::MaxOccupancyGridSize(s); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" - << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" - << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - num_groups)); - - return ave_time; - }; - - if(splitk) + if(s.log_level_ > 0) { - Run(ck_tile::integral_constant{}); - } - else - { - Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } - return ave_time; + return ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); } #include "run_grouped_gemm_example.inc" diff --git a/example/ck_tile/17_grouped_gemm/quant_invoke_grouped_gemm_kernel.hpp b/example/ck_tile/17_grouped_gemm/quant_invoke_grouped_gemm_kernel.hpp index 16352722e1..ea71abb213 100644 --- a/example/ck_tile/17_grouped_gemm/quant_invoke_grouped_gemm_kernel.hpp +++ b/example/ck_tile/17_grouped_gemm/quant_invoke_grouped_gemm_kernel.hpp @@ -72,10 +72,9 @@ float grouped_gemm(const std::vector& gemm_descs, float ave_time{0}; const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = ck_tile::memory_operation_enum::set; + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped || QuantMode == ck_tile::QuantType::BQuantGrouped; @@ -137,8 +136,7 @@ float grouped_gemm(const std::vector& gemm_descs, GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile, - QuantGemmProblem::TransposeC, - memory_operation>>; + QuantGemmProblem::TransposeC>>; using Kernel = ck_tile::QuantGroupedGemmKernel; - float ave_time{0}; + constexpr auto scheduler = GemmConfig::Scheduler; - const auto Run = [&](const auto memory_operation_) { - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped || + QuantMode == ck_tile::QuantType::BQuantGrouped; - constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped || - QuantMode == ck_tile::QuantType::BQuantGrouped; + using QuantGemmProblem = std::conditional_t< + UseGroupedQuant, + std::conditional_t, + ck_tile::GemmBQuantPipelineProblem>, + ck_tile::GemmRowColTensorQuantPipelineProblem>; - using QuantGemmProblem = std::conditional_t< - UseGroupedQuant, - std::conditional_t, - ck_tile::GemmBQuantPipelineProblem>, - ck_tile::GemmRowColTensorQuantPipelineProblem>; + using GemmPipeline = GemmQuantConfig::template GemmPipeline; - using GemmPipeline = - GemmQuantConfig::template GemmPipeline; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + QuantGemmProblem::TransposeC>>; + using Kernel = ck_tile::QuantGroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - AccDataType, - CDataType, - ck_tile::tuple<>, - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - QuantGemmProblem::TransposeC, - memory_operation>>; - using Kernel = ck_tile::QuantGroupedGemmKernel; - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::MaxOccupancyGridSize(s); + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" - << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" - << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - return ave_time = ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - num_groups)); - }; - - return ave_time = Run(ck_tile::integral_constant{}); + return ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); } diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_abquant_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_abquant_example.inc new file mode 100644 index 0000000000..bc5167439d --- /dev/null +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_abquant_example.inc @@ -0,0 +1,604 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +template +float invoke_abquant_gemm(int n_warmup, + int n_repeat, + int group_count, + const std::vector& args) +{ + // Workspace memory allocated to hold the gemm descriptions. + ck_tile::DeviceMem gemm_workspace; + gemm_workspace.Realloc(get_workspace_size(args)); + + float ave_time = 0; + + if constexpr(!GemmConfig::Persistent) + { + ave_time = grouped_gemm_abquant( + args, + ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}, + gemm_workspace.GetDeviceBuffer()); + } + else + { + // NOTE: With the persistent TileLoop kernel, we do not necessarily need to have + // the gemm problems known on the host. Instead, we can just pass the pointer + // to the kernel and let the workgroups figure out which tiles to work on. + // This is useful when the gemm problems are generated dynamically. + // In this example however, we generate the `kargs` using the known gemm_descs, + // and copy the gemm descriptions to the device memory. + // The contents of the memory pointed to by `kargs_ptr` pointer could be + // written by e.g. another kernel from earlier stage. + std::vector kargs; + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); + if(args[0].k_batch != 1) + { + throw std::runtime_error("Split-K not supported yet for persistent kernel"); + } + + for(const auto& arg : args) + { + kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr, + arg.b_ptr, + arg.aq_ptr, + arg.bq_ptr, + arg.e_ptr, + arg.M, + arg.N, + arg.K, + arg.QK_A, + arg.QK_B, + arg.stride_A, + arg.stride_B, + arg.stride_E, + arg.stride_AQ, + arg.stride_BQ, + arg.k_batch}); + } + const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}; + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg), + hipMemcpyHostToDevice, + stream.stream_id_)); + ave_time = grouped_gemm_tileloop(stream, group_count, kargs_ptr); + } + + return ave_time; +} + +template +int run_abquant_grouped_gemm_example_with_layouts( + int argc, + char* argv[], + const ALayout a_layout = ALayout{}, + const AQLayout aq_layout = AQLayout{}, + const BLayout b_layout = BLayout{}, + const BQLayout bq_layout = BQLayout{}, + [[maybe_unused]] const CLayout c_layout = CLayout{}) +{ + + auto [result, arg_parser] = create_args(argc, argv); + + auto valid_input_data = [&](int group_count, const auto&... args) { + return group_count != 0 && ((args.size() == static_cast(group_count)) && ...); + }; + + const int group_count = arg_parser.get_int("group_count"); + const int repeat = arg_parser.get_int("repeat"); + const int warmup = arg_parser.get_int("warmup"); + const int kbatch = arg_parser.get_int("kbatch"); + const int init_method = arg_parser.get_int("init"); + bool validate = arg_parser.get_bool("validate"); + + if(kbatch > 1 && validate && warmup + repeat > 1) + { + std::cout << "WARNING: Data validation enabled with SplitK and more than" + << "1 warmup/repeat. Disabling validation." << std::endl; + validate = false; + } + + std::vector Ms = arg_parser.get_int_vec("Ms"); + std::vector Ns = arg_parser.get_int_vec("Ns"); + std::vector Ks = arg_parser.get_int_vec("Ks"); + std::vector AQs; // dimension of AQ tensor is calculated from A tensor + std::vector BQs; // dimension of BQ tensor is calculated from B tensor + std::vector stride_As = arg_parser.get_int_vec("stride_As"); + std::vector stride_Bs = arg_parser.get_int_vec("stride_Bs"); + std::vector stride_Cs = arg_parser.get_int_vec("stride_Cs"); + std::vector stride_AQs = arg_parser.get_int_vec("stride_AQs"); + std::vector stride_BQs = arg_parser.get_int_vec("stride_BQs"); + + ck_tile::index_t AQK, BQK; + + if(!valid_input_data( + group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs)) + { + std::cout << "Please check the input data. Default values will be used." << std::endl; + + // Clear existing (invalid) data before adding defaults + Ms.clear(); + Ns.clear(); + Ks.clear(); + stride_As.clear(); + stride_Bs.clear(); + stride_Cs.clear(); + stride_AQs.clear(); + stride_BQs.clear(); + + for(int i = 0; i < group_count; i++) + { + + Ms.push_back(256 + 256 * i); + Ns.push_back(256 + 512 * i); + Ks.push_back(512 + 128 * i); + + // Let get_default_stride calculate based on layout + stride_As.push_back(0); + stride_Bs.push_back(0); + stride_Cs.push_back(0); + stride_AQs.push_back(0); + stride_BQs.push_back(0); + } + } + + std::vector> a_m_k_tensors; + std::vector> b_k_n_tensors; + std::vector> c_m_n_tensors; + std::vector> aq_tensors; + std::vector> bq_tensors; + + a_m_k_tensors.reserve(group_count); + b_k_n_tensors.reserve(group_count); + c_m_n_tensors.reserve(group_count); + aq_tensors.reserve(group_count); + bq_tensors.reserve(group_count); + + std::vector> a_m_k_dev_buf; + std::vector> b_k_n_dev_buf; + std::vector> c_m_n_dev_buf; + std::vector> aq_dev_buf; + std::vector> bq_dev_buf; + + a_m_k_dev_buf.reserve(group_count); + b_k_n_dev_buf.reserve(group_count); + c_m_n_dev_buf.reserve(group_count); + aq_dev_buf.reserve(group_count); + bq_dev_buf.reserve(group_count); + + std::vector gemm_descs; + gemm_descs.reserve(group_count); + + for(int i = 0; i < group_count; ++i) + { + + const ck_tile::index_t M = Ms[i]; + const ck_tile::index_t N = Ns[i]; + const ck_tile::index_t K = Ks[i]; + + // For ABQuantGrouped, both A and B need quantization + static_assert(QuantMode == ck_tile::QuantType::ABQuantGrouped, + "This file only supports ABQuantGrouped mode"); + + AQK = K / AQuantGroupSize::kK; // Group quantization: AQK = K / AQuantGroupSize + BQK = K / BQuantGroupSize::kK; // Group quantization: BQK = K / BQuantGroupSize + if(K % AQuantGroupSize::kK != 0) + { + throw std::runtime_error( + "K must be divisible by AQuantGroupSize::kK for ABQuantGrouped mode"); + } + if(K % BQuantGroupSize::kK != 0) + { + throw std::runtime_error( + "K must be divisible by BQuantGroupSize::kK for ABQuantGrouped mode"); + } + + stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(a_layout)); + stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout)); + stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(CLayout{})); + stride_AQs[i] = ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(aq_layout)); + stride_BQs[i] = ck_tile::get_default_stride(BQK, N, stride_BQs[i], is_row_major(bq_layout)); + + a_m_k_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout)))); + b_k_n_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], is_row_major(b_layout)))); + c_m_n_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], is_row_major(CLayout{})))); + aq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, AQK, stride_AQs[i], is_row_major(aq_layout)))); + bq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(BQK, N, stride_BQs[i], is_row_major(bq_layout)))); + + std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc + << " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc + << " aq: " << aq_tensors[i].mDesc << " bq: " << bq_tensors[i].mDesc << std::endl; + + if(init_method == 2) + { + ck_tile::FillUniformDistribution{1.f, 1.f}(a_m_k_tensors[i]); + ck_tile::FillUniformDistribution{1.f, 1.f}(b_k_n_tensors[i]); + ck_tile::FillUniformDistribution{1.f, 1.f}(aq_tensors[i]); + ck_tile::FillUniformDistribution{1.f, 1.f}(bq_tensors[i]); + } + else + { + ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(aq_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(bq_tensors[i]); + } + + a_m_k_dev_buf.push_back(std::make_unique( + a_m_k_tensors[i].get_element_space_size_in_bytes())); + b_k_n_dev_buf.push_back(std::make_unique( + b_k_n_tensors[i].get_element_space_size_in_bytes())); + c_m_n_dev_buf.push_back(std::make_unique( + c_m_n_tensors[i].get_element_space_size_in_bytes())); + aq_dev_buf.push_back( + std::make_unique(aq_tensors[i].get_element_space_size_in_bytes())); + bq_dev_buf.push_back( + std::make_unique(bq_tensors[i].get_element_space_size_in_bytes())); + + a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data()); + b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data()); + aq_dev_buf[i]->ToDevice(aq_tensors[i].data()); + bq_dev_buf[i]->ToDevice(bq_tensors[i].data()); + c_m_n_dev_buf[i]->SetZero(); + c_m_n_tensors[i].SetZero(); + + const void* p_a = a_m_k_dev_buf[i]->GetDeviceBuffer(); + const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer(); + void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer(); + const void* p_aq = aq_dev_buf[i]->GetDeviceBuffer(); + const void* p_bq = bq_dev_buf[i]->GetDeviceBuffer(); + + gemm_descs.push_back({p_a, + p_b, + p_c, + p_aq, + p_bq, + kbatch, + M, + N, + K, + AQK, + BQK, + stride_As[i], + stride_Bs[i], + stride_Cs[i], + stride_AQs[i], + stride_BQs[i]}); + } + + float ave_time = invoke_abquant_gemm(warmup, repeat, group_count, gemm_descs); + + std::string op_name = "ABQuant Grouped Gemm (" + ck_tile::quant_type_to_string(QuantMode) + ")"; + + std::size_t flop = 0, num_btype = 0; + for(int j = 0; j < group_count; ++j) + { + flop += std::size_t(2) * gemm_descs[j].M * gemm_descs[j].N * gemm_descs[j].K; + + num_btype += sizeof(ADataType) * gemm_descs[j].M * gemm_descs[j].K + + sizeof(BDataType) * gemm_descs[j].K * gemm_descs[j].N + + sizeof(CDataType) * gemm_descs[j].M * gemm_descs[j].N; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + for(int i = 0; i < group_count; i++) + { + c_m_n_dev_buf[i]->FromDevice(c_m_n_tensors[i].data()); + } + + bool pass{true}; + if(validate) + { + for(int i = 0; i < group_count; ++i) + { + ck_tile::HostTensor c_m_n_host_ref(ck_tile::host_tensor_descriptor( + Ms[i], Ns[i], stride_Cs[i], is_row_major(CLayout{}))); + c_m_n_host_ref.SetZero(); + + // Reference implementation for ABQuantGrouped + ck_tile::reference_gemm_abquant( + a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], bq_tensors[i], c_m_n_host_ref); + + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = + calculate_rtol_atol( + Ks[i], kbatch, max_accumulated_value); + pass &= + ck_tile::check_err(c_m_n_tensors[i], + c_m_n_host_ref, + "Error: Incorrect results! in group [" + std::to_string(i) + "]", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + std::cout << "gemm[" << i + << "] Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + } + std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; + } + + if(arg_parser.get_int("json") == 1) + { + dump_grouped_gemm_json_results(arg_parser.get_str("jsonfile"), + op_name, + group_count, + pass, + ave_time, + tflops, + gb_per_sec); + } + + return pass; +} + +template +int run_abquant_grouped_gemm_example_prec_type_with_bquant( + std::string a_layout, std::string b_layout, std::string c_layout, int argc, char* argv[]) +{ + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + using Types = GemmTypeConfig; + // Specific type aliases for easy access + using ADataType = typename Types::ADataType; + using BDataType = typename Types::BDataType; + using AccDataType = typename Types::AccDataType; + using CDataType = typename Types::CDataType; + using AQDataType = typename Types::AccDataType; + using BQDataType = typename Types::AccDataType; + using AQuantGroupSize = ck_tile::QuantGroupShape>; + + constexpr auto QuantMode = ck_tile::QuantType::ABQuantGrouped; + + if(a_layout == "R" && b_layout == "C" && c_layout == "R") + { + return run_abquant_grouped_gemm_example_with_layouts( + argc, argv, Row{}, Row{}, Col{}, Col{}, Row{}); + } + else if(a_layout == "R" && b_layout == "R" && c_layout == "R") + { + return run_abquant_grouped_gemm_example_with_layouts( + argc, argv, Row{}, Row{}, Row{}, Col{}, Row{}); + } + else if(a_layout == "C" && b_layout == "R" && c_layout == "R") + { + return run_abquant_grouped_gemm_example_with_layouts( + argc, argv, Col{}, Row{}, Row{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); + } +} + +template +int run_abquant_grouped_gemm_example_prec_type(std::string a_layout, + std::string b_layout, + std::string c_layout, + std::string bquant_group_size, + int argc, + char* argv[]) +{ + if(bquant_group_size == "1x1x128") + { + using BQuantGroupSize = ck_tile::QuantGroupShape>; + return run_abquant_grouped_gemm_example_prec_type_with_bquant( + a_layout, b_layout, c_layout, argc, argv); + } + else if(bquant_group_size == "1x128x128") + { + using BQuantGroupSize = ck_tile::QuantGroupShape>; + return run_abquant_grouped_gemm_example_prec_type_with_bquant( + a_layout, b_layout, c_layout, argc, argv); + } + else + { + throw std::runtime_error("Unsupported BQuantGroupSize! Use 1x1x128 or 1x128x128."); + } +} + +template +int run_abquant_gemm_example_persistency(std::string a_layout, + std::string b_layout, + std::string c_layout, + bool persistent, + std::string bquant_group_size, + int argc, + char* argv[]) +{ + if(persistent) + { + using GemmConfig = typename GemmQuantConfig< + ck_tile::QuantType::ABQuantGrouped>::template GemmConfig; + return run_abquant_grouped_gemm_example_prec_type( + a_layout, b_layout, c_layout, bquant_group_size, argc, argv); + } + else + { + using GemmConfig = typename GemmQuantConfig< + ck_tile::QuantType::ABQuantGrouped>::template GemmConfig; + return run_abquant_grouped_gemm_example_prec_type( + a_layout, b_layout, c_layout, bquant_group_size, argc, argv); + } +} + +int run_abquant_grouped_gemm_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + { + return -1; + } + + const std::string a_layout = arg_parser.get_str("a_layout"); + const std::string b_layout = arg_parser.get_str("b_layout"); + const std::string c_layout = arg_parser.get_str("c_layout"); + const std::string data_type = arg_parser.get_str("prec"); + bool persistent = arg_parser.get_bool("persistent"); + const std::string bquant_group_size = arg_parser.get_str("bquant_group_size"); + + if(data_type == "fp8") + { + return run_abquant_gemm_example_persistency( + a_layout, b_layout, c_layout, persistent, bquant_group_size, argc, argv); + } + else if(data_type == "bf8") + { + return run_abquant_gemm_example_persistency( + a_layout, b_layout, c_layout, persistent, bquant_group_size, argc, argv); + } + else + { + throw std::runtime_error("Unsupported data type configuration."); + } +} diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc index 390a54644b..7a01b1dcea 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc @@ -79,8 +79,7 @@ float invoke_gemm(int n_warmup, // earlier stage. std::vector> kargs; - void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); - const bool splitk = args[0].k_batch > 1; + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); for(const auto& arg : args) { kargs.emplace_back(ck_tile::UniversalGemmKernelArgs<>{{arg.a_ptr}, @@ -109,7 +108,7 @@ float invoke_gemm(int n_warmup, ADataType, BDataType, AccDataType, - CDataType>(stream, group_count, kargs_ptr, splitk); + CDataType>(stream, group_count, kargs_ptr); } return ave_time; diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc index ac6ea99db3..4f2bebdf17 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc @@ -95,8 +95,7 @@ float invoke_gemm(int n_warmup, else { std::vector> kargs; - void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); - const bool splitk = args[0].k_batch > 1; + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); for(const auto& arg : args) { kargs.emplace_back(ck_tile::UniversalGemmKernelArgs<1, 1, NumDTensor>{{arg.a_ptr}, @@ -119,18 +118,17 @@ float invoke_gemm(int n_warmup, kargs.size() * sizeof(ck_tile::GemmTransKernelArg), hipMemcpyHostToDevice, stream.stream_id_)); - ave_time = - grouped_gemm_multi_d_tileloop(stream, group_count, kargs_ptr, splitk); + ave_time = grouped_gemm_multi_d_tileloop(stream, group_count, kargs_ptr); } return ave_time; } diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index cd241a2be0..af46884a90 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -170,13 +170,10 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = FlatmmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = FlatmmConfig::Scheduler; using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem& args, FlatmmConfig::N_Warp_Tile, FlatmmConfig::K_Warp_Tile, CodegenPipelineProblem::TransposeC, - memory_operation, FlatmmConfig::NumWaveGroups, false, 1, @@ -282,23 +278,7 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - else - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); return ave_time; } diff --git a/example/ck_tile/18_flatmm/grouped_flatmm.cpp b/example/ck_tile/18_flatmm/grouped_flatmm.cpp index da85c95dae..780a21ba14 100644 --- a/example/ck_tile/18_flatmm/grouped_flatmm.cpp +++ b/example/ck_tile/18_flatmm/grouped_flatmm.cpp @@ -113,13 +113,10 @@ float grouped_flatmm(const KernelArguments& args, const ck_tile::stream_config& const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = FlatmmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = FlatmmConfig::Scheduler; using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem>; // ToDo: Will add the codegen part to test different pipeline policies in GEMM. @@ -216,23 +212,7 @@ float grouped_flatmm(const KernelArguments& args, const ck_tile::stream_config& return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - else - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); return ave_time; } diff --git a/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp b/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp index fe7fe4c5d1..708e8a683e 100644 --- a/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp +++ b/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp @@ -113,13 +113,10 @@ float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = FlatmmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = FlatmmConfig::Scheduler; using CodegenPipelineProblem = std::conditional_t{}); - } - else - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); return ave_time; } diff --git a/example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.cpp b/example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.cpp index 2b6dbace36..f9f8c0cec7 100644 --- a/example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.cpp +++ b/example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.cpp @@ -89,13 +89,10 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = FlatmmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = FlatmmConfig::Scheduler; constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern @@ -128,7 +125,6 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& FlatmmConfig::N_Warp_Tile, FlatmmConfig::K_Warp_Tile, CodegenPipelineProblem::TransposeC, - memory_operation, FlatmmConfig::NumWaveGroups, false, // FixedVectorSize 1, // VectorSizeC @@ -201,23 +197,7 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - else - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); return ave_time; } diff --git a/example/ck_tile/18_flatmm/moe_flatmm.cpp b/example/ck_tile/18_flatmm/moe_flatmm.cpp index 96b9ae29a4..4cca953066 100644 --- a/example/ck_tile/18_flatmm/moe_flatmm.cpp +++ b/example/ck_tile/18_flatmm/moe_flatmm.cpp @@ -144,15 +144,11 @@ float moe_gemm(const ck_tile::MoeFlatmmHostArgs& args, const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = FlatmmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = FlatmmConfig::Scheduler; using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem& args, FlatmmConfig::N_Warp_Tile, FlatmmConfig::K_Warp_Tile, CodegenPipelineProblem::TransposeC, - memory_operation, FlatmmConfig::NumWaveGroups, false, 1, @@ -261,37 +256,20 @@ float moe_gemm(const ck_tile::MoeFlatmmHostArgs& args, args.NumTokens * args.TopK * outputN * sizeof(CDataType), s.stream_id_)); }; - ave_time = ck_tile::launch_kernel_time_mask( + return ck_tile::launch_kernel_time_mask( s, run_flush_cache, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } else { - ave_time = ck_tile::launch_kernel( + return ck_tile::launch_kernel( s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } - return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - else - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + float ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); return ave_time; } diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp index f177ef04ca..01128f8fe8 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp @@ -61,8 +61,7 @@ float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, "mixed_prec_flatmm requires ADataType is a wider type than BDataType"); constexpr auto scheduler = FlatmmConfig::Scheduler; - constexpr auto memory_operation = - Splitk ? ck_tile::memory_operation_enum::atomic_add : ck_tile::memory_operation_enum::set; + ck_tile::ignore = Splitk; constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern @@ -98,7 +97,6 @@ float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, FlatmmConfig::N_Warp_Tile, FlatmmConfig::K_Warp_Tile, MXPipelineProblem::TransposeC, - memory_operation, FlatmmConfig::NumWaveGroups, false, // FixedVectorSize 1, // VectorSizeC diff --git a/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp index 9e2bc3e3fb..1c56295f9f 100644 --- a/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp +++ b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp @@ -81,87 +81,45 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config& using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< UniversalGemmProblem>; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - // Epilogue selection: set to true for chainer-based, false for standard - // CShuffleEpilogue - constexpr bool UseChainerEpilogue = true; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using GemmEpilogue = std::conditional_t< - UseChainerEpilogue, - // Chainer-based epilogue - ck_tile::EpilogueChainer, - ck_tile::DefaultScheduleTag>>, - // Standard CShuffleEpilogue - ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>>; + using Kernel = ck_tile::GemmKernelMultiD; + auto kargs = Kernel::MakeKernelArgs(args); - using Kernel = ck_tile::GemmKernelMultiD; - auto kargs = Kernel::MakeKernelArgs(args); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y - << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y - << ", " << blocks.z << "}" << std::endl; - } - - return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - }; - - if(args.k_batch == 1) + if(!Kernel::IsSupportedArgument(kargs)) { - return Run(ck_tile::integral_constant{}); + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); } - else + + if(s.log_level_ > 0) { - return Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y + << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " + << blocks.z << "}" << std::endl; } + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } #include "run_gemm_multi_d_fp16_example.inc" diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp index d2663b033c..ca8573d6d2 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp @@ -59,94 +59,80 @@ struct GroupedConvolutionBackwardDataInvoker ConvConfig::NumWaveGroups>; constexpr auto scheduler = ConvConfig::Scheduler; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + OutDataType, + WeiDataType, + AccDataType, + GemmShape, + GemmUniversalTraits, + scheduler, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + InDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< - OutDataType, - WeiDataType, - AccDataType, - GemmShape, - GemmUniversalTraits, - scheduler, - ck_tile::element_wise::PassThrough, - ck_tile::element_wise::PassThrough, - InDataType, - GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, - GroupedConvTraitsType::VectorSizeA, - GroupedConvTraitsType::VectorSizeB>; + using GemmPipeline = typename PipelineTypeTraits< + ConvConfig::Pipeline>::template GemmPipeline; - using GemmPipeline = typename PipelineTypeTraits< - ConvConfig::Pipeline>::template GemmPipeline; + using ConvEpilogue = ck_tile::CShuffleEpilogue>; - using ConvEpilogue = ck_tile::CShuffleEpilogue>; + using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel; + auto kargs = Kernel::MakeKernelArgs(args); - using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel; - auto kargs = Kernel::MakeKernelArgs(args); + const dim3 grids = Kernel::GridSize(args); + const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(args); - const dim3 blocks = Kernel::BlockSize(); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); + } - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << '\n' + << "Vector size A: " << GemmPipeline::GetVectorSizeA() + << ", Vector size B: " << GemmPipeline::GetVectorSizeB() + << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << '\n' - << "Vector size A: " << GemmPipeline::GetVectorSizeA() - << ", Vector size B: " << GemmPipeline::GetVectorSizeB() - << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; - } - - auto preprocess = [&]() { - ck_tile::hip_check_error(hipMemsetAsync( - kargs.in_ptr, 0, args.template GetInputByte(), s.stream_id_)); - }; - - return ck_tile::launch_kernel_time_mask( - s, - preprocess, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + auto preprocess = [&]() { + ck_tile::hip_check_error(hipMemsetAsync( + kargs.in_ptr, 0, args.template GetInputByte(), s.stream_id_)); }; - if(args.k_batch == 1) - { - return Run(MemoryOpSet{}); - } - else - { - return Run(MemoryOpAtomicAdd{}); - } + return ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp index afe43cd1c0..90874e6018 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp @@ -59,104 +59,85 @@ struct GroupedConvolutionBackwardWeightInvoker ConvConfig::NumWaveGroups>; constexpr auto scheduler = ConvConfig::Scheduler; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + OutDataType, + InDataType, + AccDataType, + GemmShape, + GemmUniversalTraits, + scheduler, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + WeiDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< - OutDataType, - InDataType, - AccDataType, - GemmShape, - GemmUniversalTraits, - scheduler, - ck_tile::element_wise::PassThrough, - ck_tile::element_wise::PassThrough, - WeiDataType, - GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, - GroupedConvTraitsType::VectorSizeA, - GroupedConvTraitsType::VectorSizeB>; + using GemmPipeline = typename PipelineTypeTraits< + ConvConfig::Pipeline>::template GemmPipeline; - using GemmPipeline = typename PipelineTypeTraits< - ConvConfig::Pipeline>::template GemmPipeline; + using ConvEpilogue = ck_tile::CShuffleEpilogue>; - using ConvEpilogue = ck_tile::CShuffleEpilogue>; + using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel; + auto kargs = Kernel::MakeKernelArgs(args); - using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel; - const auto kargs = Kernel::MakeKernelArgs(args); + const dim3 grids = Kernel::GridSize(args); + const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(kargs); - const dim3 blocks = Kernel::BlockSize(); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); + } - if(!Kernel::IsSupportedArgument(kargs)) + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << '\n' + << "Vector size A: " << GemmPipeline::GetVectorSizeA() + << ", Vector size B: " << GemmPipeline::GetVectorSizeB() + << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; + } + + auto preprocess = [&]() { + if(args.k_batch > 1) { - throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); + ck_tile::hip_check_error(hipMemsetAsync( + kargs.wei_ptr, 0, args.template GetWeightByte(), s.stream_id_)); } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << '\n' - << "Vector size A: " << GemmPipeline::GetVectorSizeA() - << ", Vector size B: " << GemmPipeline::GetVectorSizeB() - << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; - } - - auto preprocess = [&]() { - if(kargs.k_batch > 1) - { - ck_tile::hip_check_error( - hipMemsetAsync(kargs.wei_ptr, - 0, - args.template GetWeightByte(), - s.stream_id_)); - } - }; - - const auto ave_time = ck_tile::launch_kernel_time_mask( - s, - preprocess, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - - const auto split_k = kargs.k_batch; - - return InvokerResult{ave_time, split_k}; }; - if(args.k_batch == 1) - { - return Run(MemoryOpSet{}); - } - else - { - return Run(MemoryOpAtomicAdd{}); - } + float ave_time = ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return InvokerResult{ave_time, args.k_batch}; } }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp index ad5e8ae70f..c4d618a0bf 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp @@ -65,163 +65,143 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker constexpr auto scheduler = ConvConfig::Scheduler; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + OutDataType, + InDataType, + AccDataType, + GemmShape, + GemmUniversalTraits, + scheduler, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + WeiDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< - OutDataType, - InDataType, - AccDataType, - GemmShape, - GemmUniversalTraits, - scheduler, - ck_tile::element_wise::PassThrough, - ck_tile::element_wise::PassThrough, - WeiDataType, - GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, - GroupedConvTraitsType::VectorSizeA, - GroupedConvTraitsType::VectorSizeB>; + using GemmPipeline = typename PipelineTypeTraits< + ConvConfig::Pipeline>::template GemmPipeline; - using GemmPipeline = typename PipelineTypeTraits< - ConvConfig::Pipeline>::template GemmPipeline; + using ConvEpilogue = ck_tile::CShuffleEpilogue>; - using ConvEpilogue = ck_tile::CShuffleEpilogue>; + using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel; - using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel; + const ck_tile::index_t spatial_lengths_accum = + std::accumulate(args.filter_spatial_lengths_.begin(), + args.filter_spatial_lengths_.end(), + 1, + std::multiplies()); + ck_tile::DeviceMem ws_m_n_dev_buf(args.G_ * args.K_ * args.C_ * spatial_lengths_accum * + sizeof(WorkspaceDataType)); + ck_tile::GroupedConvBwdWeightHostArgs ws_args = ck_tile::GroupedConvBwdWeightHostArgs(args); + auto c_ptr = ws_args.wei_ptr; + ws_args.wei_ptr = ws_m_n_dev_buf.GetDeviceBuffer(); - const ck_tile::index_t spatial_lengths_accum = - std::accumulate(args.filter_spatial_lengths_.begin(), - args.filter_spatial_lengths_.end(), - 1, - std::multiplies()); - ck_tile::DeviceMem ws_m_n_dev_buf(args.G_ * args.K_ * args.C_ * spatial_lengths_accum * - sizeof(WorkspaceDataType)); - ck_tile::GroupedConvBwdWeightHostArgs ws_args = - ck_tile::GroupedConvBwdWeightHostArgs(args); - auto c_ptr = ws_args.wei_ptr; - ws_args.wei_ptr = ws_m_n_dev_buf.GetDeviceBuffer(); - const auto kargs = Kernel::MakeKernelArgs(ws_args); + const auto kargs = Kernel::MakeKernelArgs(ws_args); + const dim3 grids = Kernel::GridSize(kargs); + const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(kargs); - const dim3 blocks = Kernel::BlockSize(); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); + } - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); - } + using XElementwiseOperation = ck_tile::element_wise::UnaryConvert; + using BlockTile = ck_tile::sequence<2048>; + using BlockWarps = ck_tile::sequence<8>; + using WarpTile = ck_tile::sequence<64>; - using XElementwiseOperation = ck_tile::element_wise::UnaryConvert; - using BlockTile = ck_tile::sequence<2048>; - using BlockWarps = ck_tile::sequence<8>; - using WarpTile = ck_tile::sequence<64>; + using ElementwiseShape = + ck_tile::ElementWiseShape; + using Problem = ck_tile::ElementWisePipelineProblem; + using ElementwiseKernel = + ck_tile::ElementWiseKernel; - using ElementwiseShape = - ck_tile::ElementWiseShape; - using Problem = ck_tile::ElementWisePipelineProblem; - using ElementwiseKernel = - ck_tile::ElementWiseKernel; + ck_tile::index_t total_elements = 1; + std::vector shape = { + static_cast(args.G_ * args.K_), + static_cast(args.C_ * spatial_lengths_accum)}; - ck_tile::index_t total_elements = 1; - std::vector shape = { - static_cast(args.G_ * args.K_), - static_cast(args.C_ * spatial_lengths_accum)}; + for(auto d : shape) + total_elements *= d; - for(auto d : shape) - total_elements *= d; + const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize(); - const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize(); + constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{}); + ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block; - constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{}); - ck_tile::index_t kGridSize = - (total_elements + elements_per_block - 1) / elements_per_block; + auto input_tensors = ck_tile::make_tuple(static_cast(ws_args.wei_ptr)); + auto input_size = ck_tile::make_tuple(shape[0], shape[1]); - auto input_tensors = - ck_tile::make_tuple(static_cast(ws_args.wei_ptr)); - auto input_size = ck_tile::make_tuple(shape[0], shape[1]); + // Check if the kernel configuration is supported + if(!ElementwiseKernel::IsSupportedArgument(input_size)) + { + throw std::runtime_error( + "Wrong! Elementwise arguments not supported! Skipping gemm!\n"); + } - // Check if the kernel configuration is supported - if(!ElementwiseKernel::IsSupportedArgument(input_size)) - { - throw std::runtime_error( - "Wrong! Elementwise arguments not supported! Skipping gemm!\n"); - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << '\n' + << "Vector size A: " << GemmPipeline::GetVectorSizeA() + << ", Vector size B: " << GemmPipeline::GetVectorSizeB() + << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << '\n' - << "Vector size A: " << GemmPipeline::GetVectorSizeA() - << ", Vector size B: " << GemmPipeline::GetVectorSizeB() - << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; - } - - auto preprocess = [&]() { - if(kargs.k_batch > 1) - ck_tile::hip_check_error( - hipMemsetAsync(ws_args.wei_ptr, - 0, - shape[0] * shape[1] * sizeof(WorkspaceDataType), - s.stream_id_)); - }; - - const auto ave_time = ck_tile::launch_kernel_time_mask( - s, - preprocess, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs), - ck_tile::make_kernel( - ElementwiseKernel{}, - kGridSize, - kBlockSize, - 0, - input_size, - ck_tile::make_tuple(shape[1], 1), // Input Stride - ck_tile::make_tuple(shape[1], 1), // Output Stride - input_tensors, - static_cast(c_ptr))); - - const auto split_k = kargs.k_batch; - - return InvokerResult{ave_time, split_k}; + auto preprocess = [&]() { + if(args.k_batch > 1) + ck_tile::hip_check_error( + hipMemsetAsync(ws_args.wei_ptr, + 0, + shape[0] * shape[1] * sizeof(WorkspaceDataType), + s.stream_id_)); }; - if(args.k_batch == 1) - { - return Run(MemoryOpSet{}); - } - else - { - return Run(MemoryOpAtomicAdd{}); - } + float ave_time = ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs), + ck_tile::make_kernel( + ElementwiseKernel{}, + kGridSize, + kBlockSize, + 0, + input_size, + ck_tile::make_tuple(shape[1], 1), // Input Stride + ck_tile::make_tuple(shape[1], 1), // Output Stride + input_tensors, + static_cast(c_ptr))); + return InvokerResult{ave_time, kargs.k_batch}; } }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp index 82541bb593..c94466aeb2 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp @@ -70,91 +70,74 @@ struct GroupedConvolutionForwardInvoker // ===================================================================== // Regular Convolution: Simple, no split-image // ===================================================================== - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< - InDataType, - WeiDataType, - AccDataType, - GemmShape, - GemmUniversalTraits, - scheduler, - ck_tile::element_wise::PassThrough, - ck_tile::element_wise::PassThrough, - OutDataType, - GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, - GroupedConvTraitsType::VectorSizeA, - GroupedConvTraitsType::VectorSizeB>; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + InDataType, + WeiDataType, + AccDataType, + GemmShape, + GemmUniversalTraits, + scheduler, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + OutDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; - using GemmPipeline = typename PipelineTypeTraits< - ConvConfig::Pipeline>::template GemmPipeline; + using GemmPipeline = typename PipelineTypeTraits< + ConvConfig::Pipeline>::template GemmPipeline; - using ConvEpilogue = ck_tile::CShuffleEpilogue>; + using ConvEpilogue = ck_tile::CShuffleEpilogue>; - using Kernel = ck_tile::GroupedConvolutionForwardKernel; - auto kargs = Kernel::MakeKernelArgs(args); + using Kernel = ck_tile::GroupedConvolutionForwardKernel; + auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(kargs); - const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(kargs); + const dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << '\n' - << "Vector size A: " << GemmPipeline::GetVectorSizeA() - << ", Vector size B: " << GemmPipeline::GetVectorSizeB() - << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; - } - - return ck_tile::launch_kernel( - s, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - }; - - // ===================================================================== - // Split-K dispatch - // ===================================================================== - if(args.k_batch == 1) + if(!Kernel::IsSupportedArgument(kargs)) { - return Run(MemoryOpSet{}); + throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); } - else + + if(s.log_level_ > 0) { - return Run(MemoryOpAtomicAdd{}); + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << '\n' + << "Vector size A: " << GemmPipeline::GetVectorSizeA() + << ", Vector size B: " << GemmPipeline::GetVectorSizeB() + << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; } + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp index 4261385a84..5dec340668 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp @@ -213,8 +213,7 @@ struct GroupedConvolutionForwardInvoker // ===================================================================== // Kernel launch lambda: Uses EnableSplitImage based on layout support // ===================================================================== - const auto Run = [&](const auto memory_operation_, const auto enable_split_image_) { - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto enable_split_image_) { constexpr bool EnableSplitImage = enable_split_image_.value; using GroupedConvTraitsType = std::conditional_t>; @@ -332,17 +330,11 @@ struct GroupedConvolutionForwardInvoker // ===================================================================== if(use_split_image) { - if(args.k_batch == 1) - return Run(MemoryOpSet{}, ck_tile::bool_constant{}); - else - return Run(MemoryOpAtomicAdd{}, ck_tile::bool_constant{}); + return Run(ck_tile::bool_constant{}); } else { - if(args.k_batch == 1) - return Run(MemoryOpSet{}, ck_tile::bool_constant{}); - else - return Run(MemoryOpAtomicAdd{}, ck_tile::bool_constant{}); + return Run(ck_tile::bool_constant{}); } } }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp index 63dd54dcae..a78a880815 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp @@ -13,11 +13,6 @@ #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #include "conv_configs.hpp" -using MemoryOpSet = - std::integral_constant; -using MemoryOpAtomicAdd = std::integral_constant; - template auto calculate_rtol_atol(const ck_tile::index_t GemmK, const ck_tile::index_t kbatch, diff --git a/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp b/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp index acb9126d65..9202bf9d98 100644 --- a/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp +++ b/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp @@ -85,60 +85,44 @@ auto gemm_multi_abd(const gemm_multi_abd_kargs& args, const ck_tile::stream_conf using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< UniversalGemmProblem>; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GemmKernelMultiABD; + auto kargs = Kernel::MakeKernelArgs(args); - using Kernel = ck_tile::GemmKernelMultiABD; - auto kargs = Kernel::MakeKernelArgs(args); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y - << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y - << ", " << blocks.z << "}" << std::endl; - } - - return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - }; - - if(args.k_batch == 1) + if(!Kernel::IsSupportedArgument(kargs)) { - return Run(ck_tile::integral_constant{}); + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); } - else + + if(s.log_level_ > 0) { - return Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y + << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " + << blocks.z << "}" << std::endl; } + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } #include "run_gemm_multi_abd_fp16_example.inc" diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp index 4a90c07e05..155f19881e 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp @@ -69,4 +69,64 @@ void abquant_quantgrouped_instance_factory( BQuantGroupSize, ck_tile::QuantType::ABQuantGrouped>(arg_parser); }; + lut[hash_multiple_strings({"fp8", + "abquant", + "preshuffleb", + "non-preshufflequant", + "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "abquant", + "preshuffleb", + "non-preshufflequant", + "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "abquant", + "preshuffleb", + "non-preshufflequant", + "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "abquant", + "preshuffleb", + "non-preshufflequant", + "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; } diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp index e0e0a64416..62ca34b057 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant.cpp @@ -9,36 +9,194 @@ using GemmConfig = GemmConfigPreshuffleBQuantPrefill; void bquant_quantgrouped_preshufflequant_instance_factory( std::unordered_map>& lut) { - using QuantGroupSize = ck_tile::QuantGroupShape>; lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; return run_gemm_example_prec_type, TypeConfig, QuantGroupSize, ck_tile::QuantType::BQuantGrouped>(arg_parser); }; + + lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "bquant", + "non-preshuffleb", + "preshufflequant", + "1x16x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "bquant", + "non-preshuffleb", + "preshufflequant", + "1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "bquant", + "non-preshuffleb", + "preshufflequant", + "1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; return run_gemm_example_prec_type, TypeConfig, QuantGroupSize, ck_tile::QuantType::BQuantGrouped>(arg_parser); }; + lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "bquant", + "non-preshuffleb", + "preshufflequant", + "1x16x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "bquant", + "non-preshuffleb", + "preshufflequant", + "1x32x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"bf8", + "bquant", + "non-preshuffleb", + "preshufflequant", + "1x64x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; lut[hash_multiple_strings( {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"fp8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; return run_gemm_example_prec_type, TypeConfig, QuantGroupSize, @@ -47,10 +205,63 @@ void bquant_quantgrouped_preshufflequant_instance_factory( lut[hash_multiple_strings( {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using TypeConfig = decltype(GemmQuantTypeConfig{}); + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x8x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x16x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x32x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; + return run_gemm_example_prec_type, + TypeConfig, + QuantGroupSize, + ck_tile::QuantType::BQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings( + {"bf8i4", "bquant", "non-preshuffleb", "preshufflequant", "1x64x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using TypeConfig = decltype(GemmQuantTypeConfig{}); + using QuantGroupSize = ck_tile::QuantGroupShape>; return run_gemm_example_prec_type, TypeConfig, QuantGroupSize, diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 47a22cdcba..607c53d9af 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -74,9 +74,10 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str std::conditional_t< QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::PreshuffleQuant == true, ck_tile::BaseGemmPipelineAgBgCrCompV3, - std::conditional_t, - ck_tile::BaseGemmPipelineAgBgCrCompV3>>>; + std::conditional_t< + QuantMode == ck_tile::QuantType::AQuantGrouped, + ck_tile::BaseGemmPipelineAgBgCrMem, + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2>>>; const ck_tile::index_t K_split = (args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile; @@ -145,26 +146,33 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str GemmConfig::Scheduler, has_hot_loop_v, tail_number_v>>>>; + using AQuantPipeline = + std::conditional_t, + ck_tile::AQuantGemmPipelineAgBgCrMem>; + + using BQuantPipeline = std::conditional_t< + GemmConfig::PreshuffleB, + ck_tile::WPQuantBPipelineAgBgCrV2, + std::conditional_t< + std::is_same_v, + ck_tile::MxFp4GemmPipelineAgBgCrCompV3, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>>; + + using ABQuantPipeline = + std::conditional_t, + ck_tile::ABQuantGemmPipelineAgBgCrCompV3>; using GemmPipeline = std::conditional_t< QuantMode == ck_tile::QuantType::RowColQuant || QuantMode == ck_tile::QuantType::TensorQuant, ck_tile::GemmPipelineAgBgCrCompV3, - std::conditional_t< - QuantMode == ck_tile::QuantType::AQuantGrouped, - std::conditional_t, - ck_tile::AQuantGemmPipelineAgBgCrMem>, - std::conditional_t< - QuantMode == ck_tile::QuantType::ABQuantGrouped, - ck_tile::ABQuantGemmPipelineAgBgCrCompV3, - std::conditional_t< - GemmConfig::PreshuffleB == true, - ck_tile::WPQuantBPipelineAgBgCrV2, - std::conditional_t< - std::is_same_v, - ck_tile::MxFp4GemmPipelineAgBgCrCompV3, - ck_tile::BQuantGemmPipelineAgBgCrCompV3>>>>>; + std::conditional_t>>; constexpr bool TiledPermuteN = (BQuantGroupSize::kN > 1) ? false : GemmConfig::TiledMMAPermuteN; @@ -173,77 +181,30 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str printf( "TiledPermuteN: %d (QuantGroupSize::kN=%d)\n", TiledPermuteN, BQuantGroupSize::kN); } - - // Epilogue selection: use chainer for RowCol/Tensor quant, standard for others - // Toggle to switch between chainer-based and standard CShuffleEpilogue - constexpr bool UseChainerEpilogue = true; - - // Define the schedule tag based on quant mode - using ScheduleTag = - std::conditional_t>; - - using GemmEpilogue = std::conditional_t< - UseChainerEpilogue && (QuantMode == ck_tile::QuantType::RowColQuant || - QuantMode == ck_tile::QuantType::TensorQuant), - // Chainer-based epilogue for RowCol/Tensor quant modes - ck_tile::EpilogueChainer, - typename TypeConfig::ADataType, - typename TypeConfig::BDataType>, - ck_tile::tuple<>, - typename TypeConfig::AccDataType, - typename TypeConfig::CDataType, - ck_tile::tuple<>, - CLayout, - CDEElementWise, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - transpose_c, - ck_tile::memory_operation_enum::set, - 1, - false, - 1, - TiledPermuteN>, - ScheduleTag>>, - // Standard CShuffleEpilogue for other modes - ck_tile::CShuffleEpilogue, typename TypeConfig::ADataType, - std::conditional_t< - std::is_same_v, - typename TypeConfig::ADataType, - typename TypeConfig::BDataType>, - ck_tile::tuple<>, - typename TypeConfig::AccDataType, - typename TypeConfig::CDataType, - ck_tile::tuple<>, - CLayout, - CDEElementWise, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - transpose_c, - ck_tile::memory_operation_enum::set, - 1, - false, - 1, - TiledPermuteN>>>; - + typename TypeConfig::BDataType>, + ck_tile::tuple<>, + typename TypeConfig::AccDataType, + typename TypeConfig::CDataType, + ck_tile::tuple<>, + CLayout, + CDEElementWise, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + transpose_c, + 1, + false, + 1, + TiledPermuteN>>; using Kernel = ck_tile::QuantGemmKernel; @@ -579,7 +540,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, QuantMode == ck_tile::QuantType::RowColQuant) { bq_tensor_ptr = std::make_unique>( - ck_tile::host_tensor_descriptor(BQK, N, stride_BQ, is_row_major(bq_layout))); + ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, is_row_major(bq_layout))); } else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) { @@ -955,8 +916,7 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser) using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - if((QuantMode == ck_tile::QuantType::ABQuantGrouped || - QuantMode == ck_tile::QuantType::AQuantGrouped || + if((QuantMode == ck_tile::QuantType::AQuantGrouped || QuantMode == ck_tile::QuantType::RowColQuant || std::is_same_v) && GemmConfig::PreshuffleB) @@ -985,7 +945,7 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser) if constexpr((QuantMode == ck_tile::QuantType::AQuantGrouped || QuantMode == ck_tile::QuantType::ABQuantGrouped) && - !GemmConfig::PreshuffleQuant) + !GemmConfig::PreshuffleQuant && !GemmConfig::PreshuffleB) { if(a_layout == "R" && b_layout == "R") { diff --git a/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp b/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp index d3ee9fe9c6..828c861349 100644 --- a/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp +++ b/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp @@ -48,112 +48,87 @@ std::tuple gemm(const ck_tile::StreamKHostArgs& args, GemmConfiguration::NUM_WAVE_GROUPS, GemmConfiguration::PRESHUFFLE>; - const auto runKernel = [&](const auto memory_operation) -> std::tuple { - // We create the GEMM pipeline without specifying has_hot_loop or tail_num. - // This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K - // while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K - // Kernel's RunGemm function. This is a similar pattern used by grouped GEMM. - using UniversalGemmProblem = - ck_tile::UniversalGemmPipelineProblem; + // We create the GEMM pipeline without specifying has_hot_loop or tail_num. + // This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K + // while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K + // Kernel's RunGemm function. This is a similar pattern used by grouped GEMM. + using UniversalGemmProblem = + ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::StreamKKernel; + using Kernel = ck_tile::StreamKKernel; - auto kernel_args = Kernel::MakeKernelArgs(args); - const auto workspace_size = Kernel::GetWorkSpaceSize(kernel_args); - ck_tile::DeviceMem workspace_data(workspace_size); + auto kernel_args = Kernel::MakeKernelArgs(args); + const auto workspace_size = Kernel::GetWorkSpaceSize(kernel_args); + ck_tile::DeviceMem workspace_data(workspace_size); + workspace_data.SetZero(); + kernel_args.workspace_ptr = workspace_data.GetDeviceBuffer(); + + dim3 grids = Kernel::GridSize(kernel_args.tile_partitioner); + dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kernel_args)) + { + // Clear the output C tensor results after each repetition of the kernel + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream_config.stream_id_)); + } + + if(stream_config.log_level_ > 0) + { + // Reset sk flags to zero before each repetition of the kernel workspace_data.SetZero(); - kernel_args.workspace_ptr = workspace_data.GetDeviceBuffer(); + } - dim3 grids = Kernel::GridSize(kernel_args.tile_partitioner); - dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kernel_args)) + auto reset_data_buffers = [&]() { + if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic) { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + // Clear the output C tensor results after each repetition of the kernel + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream_config.stream_id_)); } - - if(stream_config.log_level_ > 0) + else if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction) { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; + // Reset sk flags to zero before each repetition of the kernel + workspace_data.SetZero(); } - - auto reset_data_buffers = [&]() { - if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic) - { - // Clear the output C tensor results after each repetition of the kernel - hipGetErrorString(hipMemsetAsync( - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream_config.stream_id_)); - } - else if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction) - { - // Reset sk flags to zero before each repetition of the kernel - workspace_data.SetZero(); - } - }; - - std::function preprocess = reset_data_buffers; - - float average_time = - ck_tile::launch_kernel_time_mask(stream_config, - preprocess, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kernel_args)); - - ck_tile::index_t num_wgs_per_tile = - kernel_args.tile_partitioner.estimate_num_wgs_per_tile(); - return std::tuple{average_time, num_wgs_per_tile}; }; - if constexpr(ck_tile::StreamKReductionStrategy::Atomic == ReductionStrategy) - { - return runKernel(ck_tile::integral_constant{}); - } - else // We are using ck_tile::StreamKReductionStrategy::Reduction - { - return runKernel(ck_tile::integral_constant{}); - } + std::function preprocess = reset_data_buffers; + + float average_time = + ck_tile::launch_kernel_time_mask(stream_config, + preprocess, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kernel_args)); + + ck_tile::index_t num_wgs_per_tile = kernel_args.tile_partitioner.estimate_num_wgs_per_tile(); + return std::tuple{average_time, num_wgs_per_tile}; } #include "run_gemm_example.inc" diff --git a/example/ck_tile/41_batched_contraction/batched_contraction.cpp b/example/ck_tile/41_batched_contraction/batched_contraction.cpp index f9f13c6e85..1e159a5615 100644 --- a/example/ck_tile/41_batched_contraction/batched_contraction.cpp +++ b/example/ck_tile/41_batched_contraction/batched_contraction.cpp @@ -92,67 +92,59 @@ float batched_contraction_impl(const ck_tile::BatchedContractionHostArgs; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using GemmPipeline = GEMM_PIPELINE; - using GemmPipeline = GEMM_PIPELINE; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using Kernel = + ck_tile::BatchedContractionKernel; + auto kargs = Kernel::MakeKernelArgs(args); - using Kernel = - ck_tile::BatchedContractionKernel; - auto kargs = Kernel::MakeKernelArgs(args); + const dim3 grids = Kernel::GridSize(kargs); + const dim3 blocks = Kernel::GetBlockSize(); - const dim3 grids = Kernel::GridSize(kargs); - const dim3 blocks = Kernel::GetBlockSize(); + if(!Kernel::IsSupportedArguments(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping contraction!\n"); + } - if(!Kernel::IsSupportedArguments(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping contraction!\n"); - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetKernelName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << GemmPipelineProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetKernelName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << GemmPipelineProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; - } + auto kernel = ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs); - auto kernel = ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs); - - return ck_tile::launch_kernel(s, kernel); - }; - - return Run(); + return ck_tile::launch_kernel(s, kernel); } #define HANDLE_CASE(G, M, N, K) \ diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp index 817432081b..32161a234a 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp @@ -51,13 +51,13 @@ struct ConvBwdWeightWmmaFactory static_assert(InputVectorTransferLimits, "Invalid A block transfer config"); static_assert(InputVectorTransferLimits, "Invalid B block transfer config"); static_assert(OutputVectorTransferLimits, "Invalid C block transfer config"); - static_assert(AccessOrderLimits4D, + static_assert(AccessOrderLimits3D, "Invalid A thread cluster access order"); - static_assert(AccessOrderLimits4D, + static_assert(AccessOrderLimits3D, "Invalid B thread cluster access order"); - static_assert(AccessOrderLimits4D, + static_assert(AccessOrderLimits3D, "Invalid A source access order"); - static_assert(AccessOrderLimits4D, + static_assert(AccessOrderLimits3D, "Invalid B source access order"); // The forward convolution kernel class instance. diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index 319293cff1..e235db4bb0 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -112,7 +112,7 @@ constexpr auto make_conv_instance() return typename ReferenceFactory::Instance{}; } // CK Tile supports common factory for each direction - if constexpr(TileAlgorithm) + else if constexpr(TileAlgorithm) { return typename ConvTileFactory::Instance{}; } diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp index cce95cb3f1..6ce508b47d 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp @@ -116,7 +116,6 @@ struct ConvTileFactory BLOCK_GEMM.warp_tile.k, GroupedConvTraitsType::FixedGemmParams::TransposeC, // TODO:: This template parameter will be moved inside the kernel - ck_tile::memory_operation_enum::set, BLOCK_GEMM.num_wave_groups, GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, SCALAR_PER_VECTOR.c>>; diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp index 1cecb8d43b..aa938aa544 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp @@ -47,6 +47,11 @@ struct DataTypeToCK { using type = ck::f8_t; }; +template <> +struct DataTypeToCK +{ + using type = uint8_t; +}; struct CK_empty_tuple { diff --git a/experimental/builder/include/ck_tile/builder/factory/reference_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/reference_factory.hpp index 0246c805c2..0748725c96 100644 --- a/experimental/builder/include/ck_tile/builder/factory/reference_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/reference_factory.hpp @@ -125,9 +125,9 @@ struct ReferenceFactory // Direct Run method (simpler interface, direction-agnostic) template - static void Run(InPtrType input, - WeiPtrType weight, - OutPtrType output, + static void Run(InPtrType* input, + WeiPtrType* weight, + OutPtrType* output, int G, int N, int K, @@ -142,9 +142,9 @@ struct ReferenceFactory if constexpr(ConvDirectionIsForward) { ck_tile::naive_grouped_conv_fwd( - input, - weight, - output, + static_cast(input), + static_cast(weight), + static_cast(output), G, N, K, @@ -160,9 +160,9 @@ struct ReferenceFactory { ck_tile:: naive_grouped_conv_bwd_data( - input, - weight, - output, + static_cast(input), + static_cast(weight), + static_cast(output), G, N, K, @@ -179,19 +179,20 @@ struct ReferenceFactory ck_tile::naive_grouped_conv_bwd_weight(input, - weight, - output, - G, - N, - K, - C, - input_spatial, - filter_spatial, - output_spatial, - strides, - dilations, - left_pads); + OutDataType>( + static_cast(input), + static_cast(weight), + static_cast(output), + G, + N, + K, + C, + input_spatial, + filter_spatial, + output_spatial, + strides, + dilations, + left_pads); } } diff --git a/experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp b/experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp index 0264264372..3240033c55 100644 --- a/experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp @@ -7,11 +7,14 @@ #include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" #include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" #include "ck_tile/builder/testing/testing.hpp" -#include "ck_tile/builder/testing/extent.hpp" +#include "ck_tile/builder/testing/filter_extent.hpp" #include "ck_tile/builder/testing/tensor_buffer.hpp" #include "ck_tile/builder/testing/tensor_initialization.hpp" +#include "ck_tile/builder/testing/tensor_descriptor.hpp" +#include "ck_tile/builder/testing/validation.hpp" #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + /// This file implements common functionality for invoking/testing grouped /// forward convolutions created through the CK Builder API. The main item /// of it is the ConvArgs structure - which contains a complete description @@ -37,12 +40,12 @@ namespace ck_tile::builder::test { template struct ConvTensorLengths { - size_t batch_size = 1; // N - size_t groups = 1; // G - size_t input_channels = 1; // C - size_t output_channels = 1; // K - Extent image = {}; // W, H, D - Extent filter = {}; // X, Y, Z + size_t batch_size = 1; // N + size_t groups = 1; // G + size_t input_channels = 1; // C + size_t output_channels = 1; // K + FilterExtent image = {}; // W, H, D + FilterExtent filter = {}; // X, Y, Z }; /// @brief `Args` specialization for forward convolution. @@ -59,6 +62,14 @@ struct Args constexpr static auto WEIGHT_TYPE = SIGNATURE.data_type; constexpr static auto OUTPUT_TYPE = SIGNATURE.data_type; + constexpr static int INPUT_RANK = 3 + SPATIAL_DIM; + constexpr static int WEIGHT_RANK = 3 + SPATIAL_DIM; + constexpr static int OUTPUT_RANK = 3 + SPATIAL_DIM; + + using InputDescriptor = TensorDescriptor; + using WeightDescriptor = TensorDescriptor; + using OutputDescriptor = TensorDescriptor; + // TODO: We shouldn't need to call into an internal namespace here. using Ops = factory::internal::ElementwiseOps; @@ -72,10 +83,10 @@ struct Args // implementation (based on ConvParam in old CK / CK Tile) does not // support strides at all. - Extent filter_strides; - Extent filter_dilation; - Extent input_left_pad; - Extent input_right_pad; + FilterExtent filter_strides; + FilterExtent filter_dilation; + FilterExtent input_left_pad; + FilterExtent input_right_pad; Ops::AElementwiseOp a_elementwise_op; Ops::BElementwiseOp b_elementwise_op; @@ -84,7 +95,7 @@ struct Args /// This function returns the `TensorDescriptor` corresponding to /// the input-tensor of the convolution problem. This can then /// be used to, for example, allocate memory. - TensorDescriptor make_input_descriptor() const + InputDescriptor make_input_descriptor() const { // TODO: We're using old CK functionality to compute the right // values here, mainly because CK tile does not support the @@ -95,31 +106,37 @@ struct Args const auto param = to_ck_conv_param(); const auto desc = ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed< typename Layouts::ALayout>(param); - return TensorDescriptor(desc.GetLengths(), desc.GetStrides()); + using Extent = typename InputDescriptor::Extent; + return InputDescriptor(Extent::from_vector(desc.GetLengths()), + Extent::from_vector(desc.GetStrides())); } /// This function returns the `TensorDescriptor` corresponding to /// the weight-tensor of the convolution problem. This can then /// be used to, for example, allocate memory. - TensorDescriptor make_weight_descriptor() const + WeightDescriptor make_weight_descriptor() const { // See note in implementation of `make_input_descriptor`. const auto param = to_ck_conv_param(); const auto desc = ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed< typename Layouts::BLayout>(param); - return TensorDescriptor(desc.GetLengths(), desc.GetStrides()); + using Extent = typename WeightDescriptor::Extent; + return WeightDescriptor(Extent::from_vector(desc.GetLengths()), + Extent::from_vector(desc.GetStrides())); } /// This function returns the `TensorDescriptor` corresponding to /// the output-tensor of the convolution problem. This can then /// be used to, for example, allocate memory. - TensorDescriptor make_output_descriptor() const + OutputDescriptor make_output_descriptor() const { // See note in implementation of `make_input_descriptor`. const auto param = to_ck_conv_param(); const auto desc = ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed< typename Layouts::ELayout>(param); - return TensorDescriptor(desc.GetLengths(), desc.GetStrides()); + using Extent = typename OutputDescriptor::Extent; + return OutputDescriptor(Extent::from_vector(desc.GetLengths()), + Extent::from_vector(desc.GetStrides())); } /// Convert the Args structure into a CK conv_param structure. This @@ -244,12 +261,11 @@ UniqueInputs alloc_inputs(const Args& args) /// /// @see alloc_inputs() template - requires ValidConvSignature && ConvDirectionIsForward && - ValidUniqueInputs -void init_inputs(const Args& args, UniqueInputs& inputs) + requires ValidConvSignature && ConvDirectionIsForward +void init_inputs(const Args& args, Inputs inputs) { - init_tensor_buffer_uniform_fp(inputs.input_buf, args.make_input_descriptor(), -2.0f, 2.0f); - init_tensor_buffer_uniform_fp(inputs.weight_buf, args.make_weight_descriptor(), -2.0f, 2.0f); + init_tensor_buffer_uniform_fp(inputs.input, args.make_input_descriptor(), -2.0f, 2.0f); + init_tensor_buffer_uniform_fp(inputs.weight, args.make_weight_descriptor(), -2.0f, 2.0f); } /// @brief `alloc_outputs()` specialization for forward convolution. @@ -267,4 +283,19 @@ UniqueOutputs alloc_outputs(const Args& args) }; } +/// @brief `validate()` specialization for forward convolution. +/// +/// @tparam SIGNATURE Forward convolution signature. +/// +/// @see validate() +template + requires ValidConvSignature && ConvDirectionIsForward +ValidationReport +validate(const Args& args, Outputs actual, Outputs expected) +{ + ValidationReport report; + report.check("output", args.make_output_descriptor(), actual.output, expected.output); + return report; +} + } // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck.hpp b/experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck.hpp index cc5c613d95..499e0ef3de 100644 --- a/experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck.hpp @@ -3,10 +3,10 @@ #pragma once -#include -#include - #include "ck_tile/builder/testing/conv_fwd.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include +#include /// This file contains the implementation details for invoking/testing /// grouped convolution operations in old CK. The main item is the @@ -15,6 +15,63 @@ namespace ck_tile::builder::test { +namespace detail { + +/// @brief Concept for checking whether this is the reference convolution +/// implementation. +/// +/// This is the same as `::ck_tile::builder::test::CkConvInstance`, except +/// with some utility aliases. For that reason, its moved to this detail +/// namespace. +template > +concept CkConvInstance = requires(Conv& conv, + // TODO: This should be changed depending on IsMultiA etc. + // Currently that is not yet supported elsewhere anyway. + const void* p_a, + const void* p_b, + void* p_e, + std::array lengths, + std::array strides, + std::array filter, + Ops::AElementwiseOp elementwise_a, + Ops::BElementwiseOp elementwise_b, + Ops::CDEElementwiseOp elementwise_cde) { + { + conv.MakeArgument(p_a, + p_b, + // TODO: Support multiple D outputs. + {}, + p_e, + // A lengths/strides + lengths, + strides, + // B lengths/strides + lengths, + strides, + // TODO: Ds lengths/strides + {}, + {}, + // E lengths/strides + lengths, + strides, + // strides/dilations/pads + filter, + filter, + filter, + filter, + // element-wise operations. + elementwise_a, + elementwise_b, + elementwise_cde) + }; +}; + +} // namespace detail + /// @brief Concept for checking whether a convolution is invoked like old CK. /// /// This concept is used to tell whether a convolution implementation is @@ -24,13 +81,8 @@ namespace ck_tile::builder::test { /// /// - SIGNATURE is the operation signature. /// - Conv is a convolution instance created by the CK Builder API. -template -concept IsCkConvInstance = - // TODO: This should be implemented by converting the signature into the - // type parameters for DeviceGroupedConvFwdMultipleABD. For now, just leave - // it empty. Improve when needed, you get the point. Also we should probably - // move this to the ck conv factory helper. - true; +template +concept CkConvInstance = detail::CkConvInstance; /// @brief `run()` specialization for forward convolution and old CK. /// @@ -39,10 +91,9 @@ concept IsCkConvInstance = /// operation. This should be caught and reported by the testing framework. /// /// @see run() -template - requires ValidConvSignature && ConvDirectionIsForward && - IsCkConvInstance -void run(Conv& conv, +template + requires ValidConvSignature && ConvDirectionIsForward +void run(CkConvInstance auto& conv, const Args& args, const Inputs& inputs, const Outputs& outputs) diff --git a/experimental/builder/include/ck_tile/builder/testing/conv_fwd_reference.hpp b/experimental/builder/include/ck_tile/builder/testing/conv_fwd_reference.hpp new file mode 100644 index 0000000000..85493e32eb --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/testing/conv_fwd_reference.hpp @@ -0,0 +1,114 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/testing/conv_fwd.hpp" +#include +#include + +/// This file contains the implementation details for invoking/testing +/// grouped convolution operations using the reference implementation. +/// The main item is the `run()` function, which is the primary way to +/// invoke the reference execution mechanism. +/// The implementation of this file mostly looks like `conv_fwd_ck.hpp`, +/// but its made specific to the reference implementation, which is +/// invoked in a slightly different way. + +namespace ck_tile::builder::test { + +/// @brief Concept for checking whether this is the reference convolution +/// implementation. +/// +/// This concept is used to tell whether a convolution implementation is +/// likely to be the reference implementation - that is, whether we should +/// invoke it like the reference kernel. This is mainly used with `run()` to +/// differentiate which implementation that should be invoked. +/// +/// - SIGNATURE is the operation signature. +/// - Conv is a convolution instance created by the CK Builder API. +template +concept RefConvInstance = requires(Conv& conv, + const void* input, + const void* weight, + void* output, + int G, + int N, + int K, + int C, + std::vector dims) { + { + conv.Run(input, + weight, + output, + G, + N, + K, + C, + dims, // input_spatial + dims, // filter_spatial + dims, // output_spatial + dims, // strides + dims, // dilations + dims // left_pads + ) + }; +}; + +/// @brief `run()` specialization for forward convolution and the reference +/// implementation. +/// +/// @tparam SIGNATURE Forward convolution signature. +/// @throws std::runtime_error if the arguments weren't actually valid for the +/// operation. This should be caught and reported by the testing framework. +/// +/// @see run() +template + requires ValidConvSignature && + // TODO: Maybe we can unify this implementation for bwd/weight too? + // for now, just concern outselves with reference and see when the + // rest of the bwd/weight plumbing is there. + ConvDirectionIsForward +void run(RefConvInstance auto& conv, + const Args& args, + const Inputs& inputs, + const Outputs& outputs) +{ + // We don't want to compute the output dims manually, just get + // them via the existing infrastructure + const auto param = args.to_ck_conv_param(); + + // TODO: The reference convolution is currently missing a few features. + // Just throw for now, but regard these as TODO items that should be resolved + // eventually. + + // Right pads are not supported right now for some reason. + for(auto right_pad : param.input_right_pads_) + { + if(right_pad != 0) + throw std::runtime_error("TODO: Support right pad in reference conv"); + } + + if(!args.make_input_descriptor().is_packed()) + throw std::runtime_error("TODO: Support non-packed input tensor in reference conv"); + if(!args.make_weight_descriptor().is_packed()) + throw std::runtime_error("TODO: Support non-packed weight tensor in reference conv"); + if(!args.make_output_descriptor().is_packed()) + throw std::runtime_error("TODO: Support non-packed output tensor in reference conv"); + + conv.Run(inputs.input, + inputs.weight, + outputs.output, + param.G_, + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, + param.output_spatial_lengths_, + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_); +} + +} // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/error.hpp b/experimental/builder/include/ck_tile/builder/testing/error.hpp new file mode 100644 index 0000000000..242f2a8e51 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/testing/error.hpp @@ -0,0 +1,150 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include + +/// This file defines some utilities for dealing with HIP errors. In the CK-Builder +/// testing code, we'd like to just turn them into exceptions: This cleans up testing +/// code as we don't need to think about returning error codes, but its still much +/// cleaner than just creating a hard crash and thereby possibly interrupting other +/// units in the same test. The testing framework can catch these exceptions where +/// necessary. +/// +/// While the exceptions defined in this file are in principle suitable for general +/// usage, HIP functions which return HIP error codes (`hipError_t`) should be +/// checked using the `check_hip` function. + +namespace ck_tile::builder::test { + +/// @brief Generic HIP exception. +/// +/// This is a derivation of `std::runtime_error` which represents a HIP error code. +/// +/// @see std::runtime_error +/// @see hipError_t +struct HipError : std::runtime_error +{ + /// @brief Utility for formatting HIP error messages + /// + /// Returns a human-readable description of a HIP error. Given a description of the + /// activity that the user tried to perform, this function appends the HIP-specific + /// information such as the stringified version of the error code, and the error + /// code itself (for reference). + /// + /// @param user_msg User-given message about the activity at time of error. + /// @param code The status to report. + /// @param src The location where this error was discovered. + static std::string + format_error(std::string_view user_msg, hipError_t code, std::source_location src) + { + std::stringstream msg; + msg << user_msg << ": " << hipGetErrorString(code) << " (" << code << ")"; + if(src.function_name()) + msg << " in function '" << src.function_name(); + msg << "' at " << src.file_name() << ":" << src.line() << ":" << src.column(); + return msg.str(); + } + + /// @brief Construct a generic HIP error. + /// + /// @param msg User-given message about the activity at time of error. + /// @param code The status to report. + /// @param src The location where this error was discovered. Defaults to the caller's + /// location. + HipError(std::string_view msg, + hipError_t code, + std::source_location src = std::source_location::current()) + : std::runtime_error(format_error(msg, code, src)), code_(code) + { + } + + /// @brief Retrieve the inner error code. + /// + /// This function returns the status code that was encountered while checking an + /// operation for errors. + hipError_t code() const { return code_; } + + private: + hipError_t code_; +}; + +/// @brief HIP out of memory error. +/// +/// This a derivation of `HipError` which is specialized for Out-of-memory errors. This +/// makes it easier to attach additional context, and to match on these errors while +/// using `catch` blocks. +/// +/// @see HipError +struct OutOfDeviceMemoryError : HipError +{ + /// @brief Construct an out-of-device-memory error. + /// + /// @param msg User-given message about the activity at time of error. + /// @param src The location where this error was discovered. Defaults to the caller's + /// location. + OutOfDeviceMemoryError(std::string_view msg = "failed to allocate device memory", + std::source_location src = std::source_location::current()) + : HipError(msg, hipErrorOutOfMemory, src) + { + } +}; + +/// @brief Check HIP status for errors. +/// +/// This function checks a HIP status code (obtained from a HIP function call) for any +/// errors. If the status `code` is not `hipSuccess`, this function throws an instance of +/// `HipError`. The exact type thats thrown depends on the status. If `code` represents +/// an out-of-memory error `hipErrorOutOfMemory`, then `OutOfDeviceMemoryError` will be +/// thrown instead. +/// +/// @param msg User-given message about the activity at possible time of error. +/// @param code The HIP status code to examine. +/// @param src The location where this status was set. Defaults to the caller's location. +/// +/// @throws HipError if `code` is not `hipSuccess`. +/// +/// @see HipError +/// @see OutOfDeviceMemoryError +inline void check_hip(std::string_view msg, + hipError_t code, + std::source_location src = std::source_location::current()) +{ + // -Wswitch-enum throws a warning if this code is changed into a switch, even with + // the `default` label... + + if(code == hipSuccess) + // When you beat the error allegations + return; + else if(code == hipErrorOutOfMemory) + throw OutOfDeviceMemoryError(msg, src); + else + throw HipError(msg, code, src); +} + +/// @brief Check HIP status for errors. +/// +/// This function is similar to `check_hip(std::string_view, hipError_t)`, except that a +/// default message is given. +/// +/// @param code The HIP status code to examine. +/// @param src The location where this status was set. Defaults to the caller's location. +/// +/// @throws HipError if `code` is not `hipSuccess`. +/// +/// @see HipError +/// @see OutOfDeviceMemoryError +/// @see check_hip(std::string_view, hipError_t) +inline void check_hip(hipError_t code, std::source_location src = std::source_location::current()) +{ + check_hip(code == hipErrorOutOfMemory ? "failed to allocate device memory" + : "HIP runtime error", + code, + src); +} + +} // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/extent.hpp b/experimental/builder/include/ck_tile/builder/testing/filter_extent.hpp similarity index 50% rename from experimental/builder/include/ck_tile/builder/testing/extent.hpp rename to experimental/builder/include/ck_tile/builder/testing/filter_extent.hpp index a2d9b3ff4c..3587ac406f 100644 --- a/experimental/builder/include/ck_tile/builder/testing/extent.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/filter_extent.hpp @@ -5,28 +5,29 @@ namespace ck_tile::builder::test { -/// This structure describes a 1-, 2-, or 3-D extent. Its used to -/// communicate 1-, 2- or 3-D sizes and strides of tensors. -/// Depending on the dimension, the structure will have the `width`, -/// `height`, and `depth` fields available. +/// This structure describes a 1-, 2-, or 3-D extent for convolution +/// filters. Its used to communicate 1-, 2- or 3-D sizes and strides +/// of tensors, specifically for convolution filters. Depending on the +/// dimension, the structure will have the `width`, `height`, and +/// `depth` fields available. template -struct Extent; +struct FilterExtent; template <> -struct Extent<1> +struct FilterExtent<1> { size_t width = 1; }; template <> -struct Extent<2> +struct FilterExtent<2> { size_t width = 1; size_t height = 1; }; template <> -struct Extent<3> +struct FilterExtent<3> { size_t width = 1; size_t height = 1; diff --git a/experimental/builder/include/ck_tile/builder/testing/tensor_buffer.hpp b/experimental/builder/include/ck_tile/builder/testing/tensor_buffer.hpp index 42f85f8017..6043ba2103 100644 --- a/experimental/builder/include/ck_tile/builder/testing/tensor_buffer.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/tensor_buffer.hpp @@ -3,19 +3,15 @@ #pragma once +#include "ck_tile/builder/testing/error.hpp" +#include #include #include -#include -#include -#include -#include -#include "ck_tile/builder/conv_signature_concepts.hpp" -#include "ck_tile/builder/testing/type_traits.hpp" -#include "ck_tile/host/host_tensor.hpp" +#include -/// This file deals with tensor memory allocation: Both the act of allocating -/// and (automatically) deallocating memory, as well as utilities for managing -/// the layout of tensor data in memory. +/// This file deals with tensor memory management and allocation. The main +/// item is the `DeviceBuffer`: An owned piece of device memory, which is +/// automatically freed when it goes out of scope. namespace ck_tile::builder::test { @@ -39,31 +35,6 @@ struct DeviceMemoryDeleter } }; -/// @brief HIP out of memory error -/// -/// This is a derivation of `std::runtime_error` specialized for HIP -/// out-of-memory errors. -/// -/// @see std::runtime_error -struct OutOfDeviceMemoryError : std::runtime_error -{ - /// @brief Utility for formatting out-of-memory error messages - /// - /// Returns a human-readable description of a HIP out-of-memory error. - /// - /// @param status The status to report - static std::string format_error(hipError_t status) - { - return std::string("failed to allocate hip memory: ") + hipGetErrorString(status) + " (" + - std::to_string(status) + ")"; - } - - /// @brief Construct an out-of-memory error using `status` as message. - /// - /// @param status A HIP error status that was encountered while allocating memory. - OutOfDeviceMemoryError(hipError_t status) : std::runtime_error(format_error(status)) {} -}; - /// @brief Automatically managed GPU memory. /// /// The `DeviceBuffer` is an automatically managed pointer for GPU memory. When @@ -96,117 +67,18 @@ inline DeviceBuffer alloc_buffer(size_t size) std::byte* d_buf = nullptr; if(const auto status = hipMalloc(&d_buf, size); status != hipSuccess) { - throw OutOfDeviceMemoryError(status); + // Add some additional context + + size_t free, total; + check_hip("failed to get HIP memory info", hipMemGetInfo(&free, &total)); + + std::stringstream ss; + ss << "failed to allocate device memory (tried to allocate " << size << " bytes with only " + << free << " available)"; + + throw OutOfDeviceMemoryError(ss.str()); } return DeviceBuffer(d_buf); } -/// @brief Type managing tensor data layout in memory. -/// -/// This structure describes a tensor in memory. It does not actually hold any -/// reference to memory, it just describes how the memory should be laid out if it -/// were. -/// -/// @note This type is very much like ck_tile::HostTensorDescriptor, except that it -/// also includes the data type of the elements of htis tensor. This is mainly to -/// make the descriptor a _complete_ description of a tensor rather than just the -/// dimensions in strides, which helps in reducing clutter in uses of this type. -/// -/// @note All strides are still in _elements_. -/// -/// @tparam DT The conceptual data type of the tensor elements. This need not be the -/// type that the data is actually stored as in memory. -template -struct TensorDescriptor -{ - // For now, the implementation of this type is based on - // `ck_tile::HostTensorDescriptor`, so that we can prototype without - // reimplementing the `HostTensorDescriptor` for the 3rd time. You can regard - // the use of `ck_tile::HostTensorDescriptor` here as an implementation detail. - - /// The conceptual data type of the tensor elements. This need not be the type - /// that the data is actually stored as in memory. - constexpr static DataType data_type = DT; - - /// @brief Create a tensor descriptor from lengths and strides. - /// - /// @param lengths A sequence of tensor lengths, the conceptial dimensions of - /// the tensor in elements. - /// @param strides A sequence of in-memory strides of the tensor, measured in - /// elements. Each element of `strides`` corresponds to one at the same index - /// in `lengths`, the amount of elements to skip in memory to find the next - /// element along that axis. - TensorDescriptor(std::span lengths, std::span strides) - : inner_descriptor_(lengths, strides) - { - // TODO: Validation of strides? For now we just delegate the details of the - // construction to the CK Tile HostTensorDescriptor. - } - - /// Query the conceptual dimensions of the tensor. - /// - /// @returns A span of tensor dimensions, one for every axis. Note that the order - /// does *not* correspond with memory layout, query the in-memory strides for - /// that. - /// - /// @see get_strides() - std::span get_lengths() const { return inner_descriptor_.get_lengths(); } - - /// Query the in-memory strides of the tensor. - /// - /// @returns A span of tensor dimensions, one for every axis. Each element - /// corresponds directly with the stride in elements at the same index in the - /// tensor dimensions. - /// - /// @see get_lengths() - std::span get_strides() const { return inner_descriptor_.get_strides(); } - - /// @brief Compute total tensor size in elements. - /// - /// This function returns the total size of the memory backing a tensor with - /// this descriptor in *elements*, including required extra size for strides. - /// - /// @see get_element_space_size_in_bytes() - size_t get_element_space_size() const { return inner_descriptor_.get_element_space_size(); } - - /// @brief Compute total tensor size in bytes. - /// - /// This function is like `get_element_space_size()`, except that the returned - /// value is measured in *bytes* rather than *elements*. Use this function for - /// figuring out how much memory needs to be allocated for a particular tensor. - /// - /// @see get_element_space_size() - size_t get_element_space_size_in_bytes() const - { - // For now, the backing type is the naive C++-type that represents the data - // type. When we are going to support packed types such as i4 and fp6, this - // is going to become more complicated. - return get_element_space_size() * data_type_sizeof(DT); - } - - private: - ck_tile::HostTensorDescriptor inner_descriptor_; -}; - -/// @brief Allocate automatically managed GPU memory corresponding to a tensor descriptor. -/// -/// This function is similar to `alloc_buffer()`, except that the required size is -/// derived automatically from a tensor descriptor. The returned buffer is valid for -/// tensors with that layout. Strides are also taken into account when computing the -/// required size. -/// -/// @tparam DT The conceptual datatype of the elements of the tensor. -/// @param descriptor A descriptor of the memory layout of the tensor to allocate. -/// @throws OutOfDeviceMemoryError if memory allocation failed. -/// -/// @see TensorDescriptor -/// @see DeviceBuffer -/// @see OutOfDeviceMemoryError -/// @see hipMalloc() -template -DeviceBuffer alloc_tensor_buffer(const TensorDescriptor
& descriptor) -{ - return alloc_buffer(descriptor.get_element_space_size_in_bytes()); -} - } // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/tensor_descriptor.hpp b/experimental/builder/include/ck_tile/builder/testing/tensor_descriptor.hpp new file mode 100644 index 0000000000..15fe4d89db --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/testing/tensor_descriptor.hpp @@ -0,0 +1,474 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/testing/type_traits.hpp" +#include "ck_tile/builder/testing/tensor_buffer.hpp" +#include "ck_tile/host/host_tensor.hpp" + +/// This file deals with tensor memory layout. The `TensorDescriptor` is the +/// main item, which is a type that describes (but not manages!) the layout +/// of tensor memory. There are also some related utilities. + +namespace ck_tile::builder::test { + +/// @brief Tensor dimensions type +/// +/// An Extent describes size in tensor space, usually either the tensor lengths +/// (conceptual size) or the tensor strides (memory layout). This type is mainly +/// used by the `TensorDescriptor`. This type is based on `std::array` +/// and supports all relevant operations on that. +/// +/// @note In practical terms, this type is not just an alias of `std::array` for +/// two reasons: First, writing a separate type allows us to write a custom +/// CTAD deduction guideline. This allows users to write `Extent{1, 2, 3}` and +/// get an instance of the correct type, whereas `std::array{1, 2, 3}` yields an +/// instance of `std::array`. This, in turn, allows inferring the rank +/// from the instance (useful in combination with `make_descriptor`), as it alows +/// us to write `function(Extent{1, 2, 3})`. Note that `function({1, 2, 3})` is +/// not valid before C++26 because `{1, 2, 3}` is an initializer list (even if +/// `function` accepts an instance of `Extent`), which does not have a known size +/// at compile time. Second, creating a separate struct for the `Extent` allows +/// additional (static) member functions. +/// +/// @tparam RANK The rank (number of spatial dimensions) of the tensor that this +/// extent describes a size of. +/// +/// @see TensorDescriptor +/// @see make_descriptor +template +struct Extent : std::array +{ + using Base = std::array; + // Note: Default constructor inherited from std::array. + + /// @brief Construct an extent from an `std::vector`. + /// + /// This function can be used to turn an `std::vector` into an `Extent`. + /// Because this code is mainly intended for testing, the vector's size is + /// checked. If its not equal to `RANK`, an exception is thrown. + /// + /// @throws std::runtime_error if the size of `extent` is not equal to `RANK`. + static Extent from_vector(const std::vector& extent) + { + if(extent.size() != RANK) + { + std::stringstream msg; + msg << "invalid rank! expected: " << RANK << ", got: " << extent.size(); + throw std::runtime_error(msg.str()); + } + + Extent result; + std::copy_n(extent.begin(), RANK, result.begin()); + return result; + } + + // Note: std::array doesn't like generating indexing code when the RANK + // is zero. Looks like there is a missing __device__ overload in ROCm 7.1 + // at least. Its not terribly important, but just override the default + // operator[] to fix it. + + /// @brief Array indexing operator + /// + /// `std::array` has issues with this operator when RANK=0, this version + /// fixes that. + /// + /// @param i The index to index the array with. + /// + /// @see std::array::operator[] + __device__ __host__ size_t operator[](size_t i) const + { + if constexpr(RANK > 0) + { + return Base::operator[](i); + } + else + { + __builtin_unreachable(); + } + } + + /// @brief Array indexing operator + /// + /// `std::array` has issues with this operator when RANK=0, this version + /// fixes that. + /// + /// @param i The index to index the array with. + /// + /// @see std::array::operator[] + __device__ __host__ size_t& operator[](size_t i) + { + if constexpr(RANK > 0) + { + return Base::operator[](i); + } + else + { + __builtin_unreachable(); + } + } +}; + +// This is a deduction guideline necessary to resolve `Extent{1, 2, 3}` to the +// correct type. This definition is practically the same as that of `std::array`. +template +Extent(T...) -> Extent; + +/// @brief Concept for automatically deriving tensor memory layout. +/// +/// A `TensorStridesGenerator` is a type which can be used to automatically +/// derive the strides (memory layout) of a tensor, given the tensor lengths. +/// This is mainly used to avoid manually computing strides. +/// +/// Implementors of this concept are required to implement `operator()`, +/// which accepts an instance of `Extent` (the tensor lengths) and +/// yields another instance of `Extent` (the tensor strides). Note +/// that the returned strides are expected to be "pre-scanned", meaning +/// that the offset in memory of a tensor can be computed as +/// `dot(index * strides)` (where `*` is element-wise multiplication). +/// +/// @see TensorDescriptor +/// @see PackedRightLayout +/// @see PackedLeftLayout +template +concept TensorStridesGenerator = requires(const G& generator, const Extent& lengths) { + { generator(lengths) } -> std::convertible_to>; +}; + +/// @brief Layout generator where right-most dimension has stride 1 and +/// all dimensions are packed. +/// +/// This structure implements a `TensorStridesGenerator` which generates +/// a memory layout which has the right-most dimension equal to 1, and +/// all other strides increase right-to-left as a products of the extent. +/// This corresponds with a row-major layout. +/// +/// @see TensorStridesGenerator +/// @see TensorDescriptor +struct PackedRightLayout +{ + /// @brief Stride generation implementation. + /// + /// This is the main function which implements the stride generation + /// + /// @tparam RANK The rank of the tensor. + /// + /// @param lengths The lengths of the tensor. + /// + /// @returns The tensor's memory layout according to the definition + /// of `PackedRightLayout`. + /// + /// @see TensorStridesGenerator + template + Extent operator()(const Extent& lengths) const + { + Extent strides = {}; + size_t numel = 1; + + for(size_t i = RANK; i > 0; --i) + { + strides[i - 1] = numel; + numel *= lengths[i - 1]; + } + + return strides; + } +}; +static_assert(TensorStridesGenerator, + "PackedRightLayout should be a TensorStridesGenerator!"); + +/// @brief Layout generator where left-most dimension has stride 1 and +/// all dimensions are packed. +/// +/// This structure implements a `TensorStridesGenerator` which generates +/// a memory layout which has the left-most dimension equal to 1, and +/// all other strides increase left-to-right as a products of the extent. +/// This corresponds with a column-major layout. +/// +/// @see TensorStridesGenerator +/// @see TensorDescriptor +struct PackedLeftLayout +{ + /// @brief Stride generation implementation. + /// + /// This is the main function which implements the stride generation + /// + /// @tparam RANK The rank of the tensor. + /// + /// @param lengths The lengths of the tensor. + /// + /// @returns The tensor's memory layout according to the definition + /// of `PackedLeftLayout`. + /// + /// @see TensorStridesGenerator + template + Extent operator()(const Extent& lengths) const + { + Extent strides = {}; + size_t numel = 1; + + for(size_t i = 0; i < RANK; ++i) + { + strides[i] = numel; + numel *= lengths[i]; + } + + return strides; + } +}; +static_assert(TensorStridesGenerator, + "PackedLeftLayout should be a TensorStridesGenerator!"); + +/// @brief Type managing tensor data layout in memory. +/// +/// This structure describes a tensor in memory. It does not actually hold any +/// reference to memory, it just describes how the memory should be laid out if it +/// were. +/// +/// @note This type is very much like ck_tile::HostTensorDescriptor, except that it +/// also includes the data type of the elements of htis tensor. This is mainly to +/// make the descriptor a _complete_ description of a tensor rather than just the +/// dimensions in strides, which helps in reducing clutter in uses of this type. +/// +/// @note All strides are still in _elements_. +/// +/// @tparam DT The conceptual data type of the tensor elements. This need not be the +/// type that the data is actually stored as in memory. +/// @tparam RANK The tensor "rank": the number of conceptial spatial dimensions that +/// the tensor covers. +template +struct TensorDescriptor +{ + // For now, the implementation of this type is based on + // `ck_tile::HostTensorDescriptor`, so that we can prototype without + // reimplementing the `HostTensorDescriptor` for the 3rd time. You can regard + // the use of `ck_tile::HostTensorDescriptor` here as an implementation detail. + + /// @brief Tensor extent alias + /// + /// This alias represents a std::array which holds tensor dimensions. There is one + /// item for each dimension in the tensor, and each item corresponds with the + /// value for that dimension. + using Extent = ::ck_tile::builder::test::Extent; + + /// The conceptual data type of the tensor elements. This need not be the type + /// that the data is actually stored as in memory. + constexpr static DataType data_type = DT; + + /// The tensor "rank": the number of conceptial spatial dimensions that the + /// tensor covers. + constexpr static size_t rank = RANK; + + /// @brief Create a tensor descriptor from lengths and strides. + /// + /// @param lengths A sequence of tensor lengths, the conceptial dimensions of + /// the tensor in elements. + /// @param strides A sequence of in-memory strides of the tensor, measured in + /// elements. Each element of `strides`` corresponds to one at the same index + /// in `lengths`, the amount of elements to skip in memory to find the next + /// element along that axis. + TensorDescriptor(const Extent& lengths, const Extent& strides) + : inner_descriptor_(lengths, strides) + { + // TODO: Validation of strides? For now we just delegate the details of the + // construction to the CK Tile HostTensorDescriptor. + } + + /// @brief Create a tensor descriptor with lengths and automatic layout. + /// + /// This function initializes a tensor descriptor using lengths, and by deriving + /// the memory layout from the layout generator `Generator`. The tensor will be + /// initialized with the strides yielded from `Generator`. + /// + /// @tparam Generator The generator type to generate the strides with. For example, + /// `PackedRightLayout` or `PackedLeftLayout`. + /// + /// @param lengths A sequence of tensor lengths, the conceptial dimensions of + /// the tensor in elements. + /// @param gen An instance of `Generator` to generate the strides with. + /// + /// @see TensorStridesGenerator + /// @see PackedLeftLayout + /// @see PackedRightLayout + template + requires TensorStridesGenerator + TensorDescriptor(const Extent& lengths, const Generator& gen) + : TensorDescriptor(lengths, gen(lengths)) + { + } + + /// Query the conceptual dimensions of the tensor. + /// + /// @returns A span of tensor dimensions, one for every axis. Note that the order + /// does *not* correspond with memory layout, query the in-memory strides for that. + /// + /// @see get_strides() + Extent get_lengths() const + { + // TODO: This is ugly for now. We should ditch the HostTensorDescriptor, and + // after that this can just be `return lengths_;` (and make it const Extent&). + Extent result; + std::copy_n(inner_descriptor_.get_lengths().begin(), RANK, result.begin()); + return result; + } + + /// Query the in-memory strides of the tensor. + /// + /// @returns A span of tensor dimensions, one for every axis. Each element + /// corresponds directly with the stride in elements at the same index in the + /// tensor dimensions. + /// + /// @see get_lengths() + Extent get_strides() const + { + // TODO: This is ugly for now. We should ditch the HostTensorDescriptor, and + // after that this can just be `return strides_;` (and make it const Extent&). + Extent result; + std::copy_n(inner_descriptor_.get_strides().begin(), RANK, result.begin()); + return result; + } + + /// @brief Compute conceptual tensor size in elements. + /// + /// This function returns the size of the tensor in elements. This function only + /// takes the lengths into account, not the strides. In order to allocate memory + /// for the tensor, use `get_element_space_size()`. + /// + /// @see get_lengths + /// @see get_element_space_size + size_t get_element_size() const { return inner_descriptor_.get_element_size(); } + + /// @brief Compute total tensor space size in elements. + /// + /// This function returns the total size of the memory backing a tensor with + /// this descriptor in *elements*, including required extra size for strides. + /// + /// @see get_element_space_size_in_bytes() + size_t get_element_space_size() const { return inner_descriptor_.get_element_space_size(); } + + /// @brief Compute total tensor size in bytes. + /// + /// This function is like `get_element_space_size()`, except that the returned + /// value is measured in *bytes* rather than *elements*. Use this function for + /// figuring out how much memory needs to be allocated for a particular tensor. + /// + /// @see get_element_space_size() + size_t get_element_space_size_in_bytes() const + { + // For now, the backing type is the naive C++-type that represents the data + // type. When we are going to support packed types such as i4 and fp6, this + // is going to become more complicated. + return get_element_space_size() * data_type_sizeof(DT); + } + + /// @brief Check if a tensor is packed in memory. + /// + /// This function checks whether the tensor memory is "packed", that is, whether + /// all elements are continuous in memory with no gaps. + bool is_packed() const + { + // First sort by stride, then check if they match the scan of the + // sizes. + const auto& lengths = inner_descriptor_.get_lengths(); + const auto& strides = inner_descriptor_.get_strides(); + + std::array indices; + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), [&](auto i, auto j) { + return strides[i] < strides[j]; + }); + + size_t x = 1; + for(size_t i = 0; i < RANK; ++i) + { + if(strides[indices[i]] != x) + return false; + + x *= lengths[indices[i]]; + } + + return true; + } + + /// @brief Get a tensor descriptor for the space backing a tensor. + /// + /// This function returns a tensor descriptor which represents the buffer space + /// required to a tensor with this descriptor. This is mainly useful to process + /// buffers with functions which normally operate over tensor descriptors. The + /// resulting tensor descriptor describes a 1D tensor with the same number of + /// elements as in the space. + /// + /// @see get_element_space_size() + TensorDescriptor get_space_descriptor() const + { + ck_tile::builder::test::Extent<1> lengths = {this->get_element_space_size()}; + ck_tile::builder::test::Extent<1> strides = {1}; + return TensorDescriptor(lengths, strides); + } + + private: + ck_tile::HostTensorDescriptor inner_descriptor_; +}; + +/// @brief Tensor descriptor construction helper. +/// +/// This function can be used to create a tensor descriptor. It accepts the same +/// parameters as the constructor of `TensorDescriptor`, that is, a sequence of +/// lengths and a sequence of strides (or a generator to generate the strides). +/// The main use of this function is that it allows automatic inference of the `RANK` +/// parameter. C++ constructors do not allow partial specification of type parameters, +/// and so its impossible to write `TensorDescriptor
x(Extent{1, 2, 3}, ...)` +/// and have the `RANK` be automatically inferred. Functions do allow this though, +/// so this function can be used to write `make_descriptor(Extent{1, 2, 3}, ...)` +/// +/// @tparam DT The conceptual data type of the tensor elements. This need not be the +/// type that the data is actually stored as in memory. +/// @tparam RANK The tensor "rank": the number of conceptial spatial dimensions that +/// the tensor covers. +/// +/// @param lengths A sequence of tensor lengths, the conceptial dimensions of +/// the tensor in elements. +/// @param strides A sequence of in-memory strides of the tensor, or a generator +/// to generate those strides from the tensor lengths. +/// +/// @see TensorDescriptor +template +TensorDescriptor make_descriptor(const Extent& lengths, const auto& strides) +{ + return TensorDescriptor(lengths, strides); +} + +/// @brief Allocate automatically managed GPU memory corresponding to a tensor descriptor. +/// +/// This function is similar to `alloc_buffer()`, except that the required size is +/// derived automatically from a tensor descriptor. The returned buffer is valid for +/// tensors with that layout. Strides are also taken into account when computing the +/// required size. +/// +/// @tparam DT The conceptual datatype of the elements of the tensor. +/// @tparam RANK The conceptual rank (number of dimensions) of the tensor. +/// +/// @param descriptor A descriptor of the memory layout of the tensor to allocate. +/// +/// @throws OutOfDeviceMemoryError if memory allocation failed. +/// +/// @see TensorDescriptor +/// @see DeviceBuffer +/// @see OutOfDeviceMemoryError +/// @see hipMalloc() +template +DeviceBuffer alloc_tensor_buffer(const TensorDescriptor& descriptor) +{ + return alloc_buffer(descriptor.get_element_space_size_in_bytes()); +} + +} // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/tensor_foreach.hpp b/experimental/builder/include/ck_tile/builder/testing/tensor_foreach.hpp new file mode 100644 index 0000000000..f078a1ac82 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/testing/tensor_foreach.hpp @@ -0,0 +1,258 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/testing/tensor_descriptor.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include +#include +#include + +/// This file implements a generic GPU tensor "foreach" function. This +/// functionality turned out useful in separate parts of the testing +/// system, hence its implemented in a separate file. This version is +/// not particularly efficient (but it should at least be readable), +/// but it should be easy to replace the implementation in the future, +/// should that be needed. + +namespace ck_tile::builder::test { + +/// @brief Concept for constraining tensor iteration functors. +/// +/// This concept checks that a functor has the correct signature for +/// use with the `tensor_foreach` function. +template +concept ForeachFunctor = requires(const F& f, const Extent& index) { + { f(index) } -> std::same_as; +}; + +namespace detail { + +/// @brief Default foreach kernel block size +/// +/// This value is the default number of threads in each block when +/// executing the foreach kernel. This value is mostly arbitrary, +/// 256 is usually a good default for AMD GPUs. +/// +/// @see tensor_foreach +constexpr int DEVICE_FOREACH_BLOCK_SIZE = 256; + +/// @brief Tensor iteration kernel +/// +/// This kernel implements the actual iteration logic, and is intended +/// to be used solely by `tensor_foreach` to iterate & invoke the +/// actual callback. +/// +/// @tparam BLOCK_SIZE The number of threads in each block on the GPU. +/// @tparam RANK The rank (number of spatial dimensions) of the tensor to +/// iterate. +/// @tparam F The type of the callback to invoke. This function must be +/// compatible with execution as a __device__ function. +/// +/// @param numel The total number of elements in the tensor. +/// @param shape_scan A right-exclusive scan of the shape of the tensor. +/// @param f The callback to invoke for each index of the tensor. This +/// functor must be eligible for running on the GPU. +template + requires ForeachFunctor +__global__ __launch_bounds__(BLOCK_SIZE) // + void foreach_kernel(const size_t numel, Extent shape_scan, F f) +{ + const auto gid = blockIdx.x * BLOCK_SIZE + threadIdx.x; + for(size_t flat_idx = gid; flat_idx < numel; flat_idx += gridDim.x * BLOCK_SIZE) + { + // Compute the current index. + Extent index = {}; + + size_t idx = flat_idx; + for(size_t i = 0; i < RANK; ++i) + { + const auto scanned_dim = shape_scan[i]; + index[i] = idx / scanned_dim; + idx %= scanned_dim; + } + + // Then invoke the callback with the index. + f(index); + } +} + +/// @brief A utility to get a C++ type for a CKB type +/// +/// Right now this is just an alias of an internal CKB helper, +/// but this should probably be moved elsewhere. +template +using cpp_type_t = typename builder::factory::internal::DataTypeToCK
::type; + +} // namespace detail + +/// @brief Calculate tensor memory offset given index and strides. +/// +/// This function returns the offset in memory in a tensor, given a particular +/// multi-dimensional index and a particular set of strides. Each value in the +/// index corresponds one-to-one with a value in the strides, which are the +/// index and stride at that dimension in the tensor. These strides must be +/// pre-scanned, meaning that each index is the absolute stride of elements +/// along that axis. In essence, this means that you should pass the output of +/// `TensorDescriptor::get_strides()` into this function. +/// +/// @pre The index must be inside the tensor space. +/// +/// @tparam RANK The rank (number of spatial dimensions) of the tensor. +/// +/// @param index A multi-dimensional index inside the tensor space. +/// @param strides A set of strides, one for each dimension. +/// +/// @see TensorDescriptor +template +__host__ __device__ size_t calculate_offset(const Extent& index, const Extent& strides) +{ + size_t offset = 0; +#pragma unroll + for(size_t i = 0; i < RANK; ++i) + { + offset += index[i] * strides[i]; + } + return offset; +} + +/// @brief Invoke a callback on the GPU for every index in a tensor. +/// +/// This function invokes a callback functor on the GPU, for each index in +/// a tensor. This function _only_ takes care of iterating over all indices +/// in a tensor of a particular shape; this function does not handle or know +/// about actual tensor data. +/// +/// @note This function is currently implemented relatively naively: The +/// iteration order is always row-wise, implemented as a persistent kernel. +/// The main objective of this function is to be used with the CK-Builder +/// testing system, and so readability and correctness should be preferred +/// over performance. If this is ever a source of performance problems, +/// feel free to replace the implementation with something better. +/// +/// @tparam RANK The rank (number of spatial dimensions) of the tensor. +/// +/// @param shape The shape of the tensor to iterate over. +/// @param f The callback to invoke for each index of the tensor. This +/// functor must be eligible for running on the GPU. +/// +/// @see ForeachFunctor +/// @see detail::foreach_kernel +template +void tensor_foreach(const Extent& shape, ForeachFunctor auto f) +{ + constexpr int block_size = detail::DEVICE_FOREACH_BLOCK_SIZE; + const auto kernel = detail::foreach_kernel; + + int occupancy; + check_hip(hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, block_size, 0)); + + int device; + check_hip(hipGetDevice(&device)); + + int multiprocessors; + check_hip( + hipDeviceGetAttribute(&multiprocessors, hipDeviceAttributeMultiprocessorCount, device)); + + // Pre-scan the shape to help indexing in the kernel. + // Note: the order is not that important, so long as the iteration + // order in the kernel is from large-to-small. Right layout is the + // easiest solution for that. + + Extent shape_scan; + size_t numel = 1; + for(int i = RANK; i > 0; --i) + { + shape_scan[i - 1] = numel; + numel *= shape[i - 1]; + } + + // Reset any errors from previous launches. + (void)hipGetLastError(); + + kernel<<>>(numel, shape_scan, f); + check_hip(hipGetLastError()); +} + +/// @brief Concept for tensor initializing functors. +/// +/// This concept checks that a functor has the correct signature for +/// use with the `fill_tensor` function. +template +concept FillTensorFunctor = requires(const F& f, const Extent& index) { + { f(index) } -> std::convertible_to>; +}; + +/// @brief Utility for initializing tensors. +/// +/// This function is a utility helper for initializing tensors. It accepts a +/// tensor descriptor, buffer, and a callback. The callback is invoked for every +/// coordinate (which is passed to the callback), and the tensor is initialized +/// with resulting value. +/// +/// @tparam DT The tensor element datatype +/// @tparam RANK The rank (number of spatial dimensions) of the tensor. +/// +/// @param desc The descriptor of the tensor to initialize. +/// @param buffer The memory of the tensor to initialize. +/// @param f A functor used to get the value at a particular coordinate. +/// +/// @see FillTensorFunctor +template +void fill_tensor(const TensorDescriptor& desc, + void* buffer, + FillTensorFunctor auto f) +{ + const auto strides = desc.get_strides(); + tensor_foreach(desc.get_lengths(), [buffer, f, strides](const auto& index) { + using T = detail::cpp_type_t
; + auto* ptr = static_cast(buffer); + const auto offset = calculate_offset(index, strides); + + ptr[offset] = f(index); + }); +} + +/// @brief Concept for tensor buffer initializing functors. +/// +/// This concept checks that a functor has the correct signature for +/// use with the `fill_tensor_buffer` function. +template +concept FillTensorBufferFunctor = requires(const F& f, size_t index) { + { f(index) } -> std::convertible_to>; +}; + +/// @brief Utility for initializing tensor buffers. +/// +/// This function is a utility for initializing memory backing a tensor buffer. In +/// contrast to `fill_tensor`, this function first extracts the backing space of +/// the tensor, and then invokes the callback for each (flat) index. This function +/// is particular useful for initializing out-of-bounds indices with a known with a +/// known value. +/// +/// @tparam DT The tensor element datatype +/// @tparam RANK The rank (number of spatial dimensions) of the tensor. +/// +/// @param desc The descriptor of the tensor to initialize. +/// @param buffer The memory of the tensor to initialize. +/// @param f A functor used to get the value at a particular index. +/// +/// @see FillTensorBufferFunctor +template +void fill_tensor_buffer(const TensorDescriptor& desc, + void* buffer, + FillTensorBufferFunctor
auto f) +{ + fill_tensor(desc.get_space_descriptor(), buffer, [f](auto index) { return f(index[0]); }); +} + +template +void clear_tensor_buffer(const TensorDescriptor& desc, + void* buffer, + detail::cpp_type_t
value = detail::cpp_type_t
{0}) +{ + fill_tensor_buffer(desc, buffer, [value]([[maybe_unused]] size_t i) { return value; }); +} + +} // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/tensor_initialization.hpp b/experimental/builder/include/ck_tile/builder/testing/tensor_initialization.hpp index 15cb43f369..2976e6c14b 100644 --- a/experimental/builder/include/ck_tile/builder/testing/tensor_initialization.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/tensor_initialization.hpp @@ -19,15 +19,30 @@ namespace ck_tile::builder::test { -template -void init_tensor_buffer_uniform_int(const DeviceBuffer& buf, - const TensorDescriptor
& descriptor, - int min_val, - int max_val) +/// @brief Initialize tensor data with a uniform int distribution +/// +/// This function initializes a tensor's device memory with random integer data, +/// drawn from a uniform distribution. The initialization is done directly on the +/// GPU. Note that the entire buffer is filled with the specified distribution +/// regardless of whether the layout is packed. +/// +/// @tparam DT The data type of the tensor memory to initialize +/// @tparam RANK The rank (number of spatial dimensions) of the tensor. +/// +/// @param buf The device memory to initialize +/// @param descriptor A tensor descriptor describing the precise layout of the +/// tensor memory. +/// @param min_value The minimum value of the distribution (inclusive). +/// @param max_value The maximum value of the distribution (exclusive). +template +void init_tensor_buffer_uniform_int(void* buf, + const TensorDescriptor& descriptor, + int min_value, + int max_value) { size_t size = descriptor.get_element_space_size_in_bytes(); - if(max_val - min_val <= 1) + if(max_value - min_value <= 1) { throw std::runtime_error("Error while filling device tensor with random integer data: max " "value must be at least 2 greater than min value, otherwise " @@ -38,19 +53,34 @@ void init_tensor_buffer_uniform_int(const DeviceBuffer& buf, // we might be asked to generate int values on fp data types that don't have the required // precision - if(static_cast(max_val - 1) == static_cast(min_val)) + if(static_cast(max_value - 1) == static_cast(min_value)) { throw std::runtime_error("Error while filling device tensor with random integer data: " "insufficient precision in specified range"); } size_t packed_size = ck::packed_size_v; fill_tensor_uniform_rand_int_values<<<256, 256>>>( - static_cast(buf.get()), min_val, max_val, (size * packed_size) / sizeof(ck_type)); + static_cast(buf), min_value, max_value, (size * packed_size) / sizeof(ck_type)); } -template -void init_tensor_buffer_uniform_fp(const DeviceBuffer& buf, - const TensorDescriptor
& descriptor, +/// @brief Initialize tensor data with a uniform float distribution +/// +/// This function initializes a tensor's device memory with random floating data, +/// drawn from a uniform distribution. The initialization is done directly on the +/// GPU. Note that the entire buffer is filled with the specified distribution +/// regardless of whether the layout is packed. +/// +/// @tparam DT The data type of the tensor memory to initialize +/// @tparam RANK The rank (number of spatial dimensions) of the tensor. +/// +/// @param buf The device memory to initialize +/// @param descriptor A tensor descriptor describing the precise layout of the +/// tensor memory. +/// @param min_value The minimum value of the distribution (inclusive). +/// @param max_value The maximum value of the distribution (exclusive). +template +void init_tensor_buffer_uniform_fp(void* buf, + const TensorDescriptor& descriptor, float min_value, float max_value) { @@ -59,15 +89,30 @@ void init_tensor_buffer_uniform_fp(const DeviceBuffer& buf, using ck_type = factory::internal::DataTypeToCK
::type; size_t packed_size = ck::packed_size_v; - fill_tensor_uniform_rand_fp_values<<<256, 256>>>(reinterpret_cast(buf.get()), + fill_tensor_uniform_rand_fp_values<<<256, 256>>>(reinterpret_cast(buf), min_value, max_value, (size * packed_size) / sizeof(ck_type)); } -template -void init_tensor_buffer_normal_fp(const DeviceBuffer& buf, - const TensorDescriptor
& descriptor, +/// @brief Initialize tensor data with a normal float distribution +/// +/// This function initializes a tensor's device memory with random floating data, +/// drawn from a normal distribution. The initialization is done directly on the +/// GPU. Note that the entire buffer is filled with the specified distribution +/// regardless of whether the layout is packed. +/// +/// @tparam DT The data type of the tensor memory to initialize +/// @tparam RANK The rank (number of spatial dimensions) of the tensor. +/// +/// @param buf The device memory to initialize +/// @param descriptor A tensor descriptor describing the precise layout of the +/// tensor memory. +/// @param sigma The standard deviation of the distribution. +/// @param mean The mean of the distribution. +template +void init_tensor_buffer_normal_fp(void* buf, + const TensorDescriptor& descriptor, float sigma, float mean) { @@ -76,7 +121,7 @@ void init_tensor_buffer_normal_fp(const DeviceBuffer& buf, using ck_type = factory::internal::DataTypeToCK
::type; size_t packed_size = ck::packed_size_v; fill_tensor_norm_rand_fp_values<<<256, 256>>>( - static_cast(buf.get()), sigma, mean, (size * packed_size) / sizeof(ck_type)); + static_cast(buf), sigma, mean, (size * packed_size) / sizeof(ck_type)); } } // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/testing.hpp b/experimental/builder/include/ck_tile/builder/testing/testing.hpp index a0dfa27409..609c93cacf 100644 --- a/experimental/builder/include/ck_tile/builder/testing/testing.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/testing.hpp @@ -5,6 +5,8 @@ #include +#include "ck_tile/builder/testing/validation.hpp" + /// This file is the main header for the CK-Builder testing system. A high-level /// description of this testing system is documented in /// `ck_tile/builder/testing/README.md`. This file deals mainly deals with the @@ -78,7 +80,7 @@ namespace ck_tile::builder::test { /// that this structure is an aggregrate so that it can be initialized using C++20 /// designated initializers to keep the tests readable. /// -/// @tparam SIGNATURE the signature to specialize the structure for. +/// @tparam SIGNATURE The signature to specialize the structure for. template struct Args; @@ -98,7 +100,7 @@ struct Args; /// structure is an aggregrate so that it can be initialized using C++20 /// designated initializers to keep the tests readable. /// -/// @tparam SIGNATURE the signature to specialize the structure for. +/// @tparam SIGNATURE The signature to specialize the structure for. template struct Inputs; @@ -118,7 +120,7 @@ struct Inputs; /// structure is an aggregrate so that it can be initialized using C++20 /// designated initializers to keep the tests readable. /// -/// @tparam SIGNATURE the signature to specialize the structure for. +/// @tparam SIGNATURE The signature to specialize the structure for. template struct Outputs; @@ -133,7 +135,7 @@ struct Outputs; /// @note The easiest way to implement this type is to use the `DeviceBuffer` /// type to allocate individual device buffers for each input tensor. /// -/// @tparam SIGNATURE the signature to specialize the structure for. +/// @tparam SIGNATURE The signature to specialize the structure for. /// /// @see alloc_inputs() /// @see ValidUniqueInputs @@ -152,7 +154,7 @@ struct UniqueInputs; /// @note The easiest way to implement this type is to use the `DeviceBuffer` /// type to allocate individual device buffers for each output tensor. /// -/// @tparam SIGNATURE the signature to specialize the structure for. +/// @tparam SIGNATURE The signature to specialize the structure for. /// /// @see alloc_outputs() /// @see ValidUniqueOutputs @@ -195,7 +197,9 @@ concept ValidUniqueOutputs = requires(UniqueOutputs& inputs) { /// amount of memory required and then allocate it on the device, for example /// using `alloc_buffer` or `alloc_tensor_buffer`. /// -/// @tparam SIGNATURE the signature to specialize the structure for. +/// @tparam SIGNATURE The signature to specialize the structure for. +/// +/// @param args The run-time arguments of the operation. /// /// @see Inputs /// @see UniqueInputs @@ -208,16 +212,21 @@ UniqueInputs alloc_inputs(const Args& args); /// @brief Allocate inputs corresponding to a signature. /// /// The `init_inputs()` function is used to initialize pseudo-random data -/// to the tensors specified in the Inputs structure. +/// to the tensors specified in the Inputs structure. Implementors should +/// fill each of the tensors in `inputs` with appropriate random data. /// /// @tparam SIGNATURE the signature to specialize the structure for. /// +/// @param args The run-time arguments of the operation. +/// @param inputs The operation inputs to initialize with random data. +/// +/// @note This function is explicitly deleted to generate compile errors +/// for missing implementations. +/// /// @see Inputs -/// @see UniqueInputs /// @see tensor_initialization template - requires ValidUniqueInputs -void init_inputs(const Args& args, UniqueInputs& inputs); +void init_inputs(const Args& args, Inputs inputs) = delete; /// @brief Allocate outputs corresponding to a signature. /// @@ -226,7 +235,12 @@ void init_inputs(const Args& args, UniqueInputs& inputs); /// amount of memory required and then allocate it on the device, for example /// using `alloc_buffer` or `alloc_tensor_buffer`. /// -/// @tparam SIGNATURE the signature to specialize the structure for. +/// @tparam SIGNATURE The signature to specialize the structure for. +/// +/// @param args The run-time arguments of the operation. +/// +/// @note This function is explicitly deleted to generate compile errors +/// for missing implementations. /// /// @see Outputs /// @see UniqueOutputs @@ -234,7 +248,34 @@ void init_inputs(const Args& args, UniqueInputs& inputs); /// @see alloc_tensor_buffer() template requires ValidUniqueOutputs -UniqueInputs alloc_outputs(const Args& args); +UniqueInputs alloc_outputs(const Args& args) = delete; + +/// @brief Compare device operation outputs. +/// +/// This function implements the main comparison functionality, used to compare +/// the output of one implementation for a particular `SIGNATURE` with that of +/// another. Usually, the `expected` output should be computed by a reference +/// implementation. +/// +/// The implementation of this function generates a "report", which includes +/// detailed information about which tensors are different, how many elements +/// were incorrect, and where (a subset of) those elements are located within +/// the tensor. See `ValidationReport` for more information about the report. +/// +/// @tparam SIGNATURE The signature to specialize the structure for. +/// +/// @param args The run-time arguments of the operation. +/// @param actual The actual results, the results of the operation to-be-tested. +/// @param expected The expected results, the results of the reference implementation. +/// +/// @note This function is explicitly deleted to generate compile errors +/// for missing implementations. +/// +/// @see ValidationReport +template +ValidationReport validate(const Args& args, + Outputs actual, + Outputs expected) = delete; /// @brief Invoke a device operation created by CK Builder. /// @@ -257,7 +298,7 @@ UniqueInputs alloc_outputs(const Args& args); /// @post The tensors in `outputs` are overwritten with the outputs of the device /// operation. /// -/// @tparam SIGNATURE the signature to specialize this function for +/// @tparam SIGNATURE The signature to specialize this function for /// @tparam Operation the kernel of the operation to invoke. This type should be /// one that is created using the Builder API. /// @param operation An instance of the operation to invoke. @@ -265,10 +306,13 @@ UniqueInputs alloc_outputs(const Args& args); /// @param inputs The input tensor data. Will not be modified by this function. /// @param outputs The output tensor data. The contents will be overwritten by /// this function. +/// +/// @note This function is explicitly deleted to generate compile errors +/// for missing implementations. template void run(Operation& operation, const Args& args, const Inputs& inputs, - const Outputs& outputs); + const Outputs& outputs) = delete; } // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/validation.hpp b/experimental/builder/include/ck_tile/builder/testing/validation.hpp new file mode 100644 index 0000000000..267bf8d2ac --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/testing/validation.hpp @@ -0,0 +1,205 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/testing/error.hpp" +#include "ck_tile/builder/testing/tensor_buffer.hpp" +#include "ck_tile/builder/testing/tensor_foreach.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/utility/type_convert.hpp" +#include +#include +#include +#include +#include + +/// This file implements functionality related to "validation", ie, functionality +/// to compare tensors. The functionality in this file should be testing-framework +/// agnostic, and it should NOT generate any error messages by itself. Instead, +/// all relevant information should be stored in the `ValidationReport` structure. +/// This structure should then be used to generate error messages, explainations, +/// etc, by the actual testing framework that the user has chosen. + +namespace ck_tile::builder::test { + +/// @brief Information about how a set of comparisons failed or succeeded. +/// +/// This structure represents a "report" generated by comparing sets of tensors. +/// Its intended to be used as the result of `ckt::validate()`, where `check()` +/// is invoked for each of the output tensors of a particular device operation. +/// The test should be considered successful if _all_ of those checks passes, +/// which can inspected by asserting that `get_errors().size()` is 0. +struct ValidationReport +{ + /// @brief Information related to a single tensor comparison. + /// + /// This structure holds the information about the result of comparing + /// two particular tensors. + struct Case + { + /// The name of the tensor that was compared here, stored here for convenience + /// so that reporting any errors is easier. + std::string tensor_name; + + /// The number of elements which were different between the two compared tensors. + uint64_t wrong_elements; + + /// The total number of elements in each tensor. + uint64_t total_elements; + + /// The number of elements which were bitwise 0. + uint64_t zero_elements; + + /// @brief Check whether both the output and reference tensor were both all zeros. + /// + /// If both tensors are all zero, it indicates either an incorrect testing setup + /// or an issue with the testing framework. For that reason we also consider that + /// a failure. + bool is_all_zero() const { return zero_elements == total_elements; } + + /// @brief Return whether the check associated to this case was successful. + /// + /// This function returns whether the check associated to this case was successful, + /// which is directly derived from checking whether the number of incorrect elements + /// was 0 AND whether the tensor was not all zero. + bool is_ok() const { return wrong_elements == 0 && !is_all_zero(); } + }; + + /// @brief Get comparison cases which were incorrect. + /// + /// This function returns a vector of comparison cases that did not succeed, ie, for + /// which `Case::is_ok` return false. In order to check whether validation passed, it + /// is sufficient to assert that this function returns no cases. + std::vector get_errors() const + { + std::vector errors; + std::copy_if(reports_.begin(), + reports_.end(), + std::back_inserter(errors), + [](const auto& report) { return !report.is_ok(); }); + return errors; + } + + /// @brief Compare two tensors and record the results in the report. + /// + /// This is the main function used to compare two tensors. The results of this + /// comparison, including any supplemental information, is recorded into the report. + /// + /// @returns `false` if the comparison failed. If so, the details can be found via + /// `get_errors()`. + /// + /// @tparam DT The data type of the tensors to check. + /// @tparam RANK The rank (number of spatial dimensions) of the tensor to check. + /// + /// @param tensor_name The name of the tensors to check. This should be a value by which + /// whoever is debugging the associated test later can easily find out which of the + /// outputs of a device operation was incorrect. + /// @param descriptor The descriptor (memory layout) of the tensor. + /// @param actual The device buffer with the values of the tensor to-be-tested, ie, the + /// results of the device operation. + /// @param expected The device buffer with the values of the reference tensor. These are + /// treated as a "golden standard", and should usually be generated by a reference + /// implementation. + /// @param rtol The relative acceptable tolerance between two values. + /// @param atol The absolute acceptable tolerance between two values. + template + bool check(std::string_view tensor_name, + const TensorDescriptor& descriptor, + const void* actual, + const void* expected, + double rtol = 1e-3, + double atol = 1e-3); + + private: + std::vector reports_; +}; + +template +bool ValidationReport::check(std::string_view tensor_name, + const TensorDescriptor& descriptor, + const void* actual_data, + const void* expected_data, + double rtol, + double atol) +{ + const auto strides = descriptor.get_strides(); + + // During development and CI, only the kernels that were changed would fail, and so we can + // assume that the average case does not have errors. Therefore, split out testing into a + // quick test which just counts the incorrect elements, and a more in-depth test that also + // returns the indices of the incorrect items. + + // Initial pass: count errors + + // Allocate and reset counter + auto d_counters = alloc_buffer(sizeof(uint64_t) * 2); + check_hip(hipMemset(d_counters.get(), 0, sizeof(uint64_t) * 2)); + + auto d_error_count = &reinterpret_cast(d_counters.get())[0]; + auto d_zero_count = &reinterpret_cast(d_counters.get())[1]; + + tensor_foreach(descriptor.get_lengths(), [=](auto index) { + using CKType = typename factory::internal::DataTypeToCK
::type; + + const auto* actual = static_cast(actual_data); + const auto* expected = static_cast(expected_data); + + static_assert(!std::is_same_v, + "TODO implement compare_kernel() for double"); + + const auto offset = calculate_offset(index, strides); + + const auto a = actual[offset]; + const auto b = expected[offset]; + + const auto o = static_cast(type_convert(a)); + const auto r = static_cast(type_convert(b)); + const auto err = std::abs(o - r); + + if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) + { + // We expect the number of errors to be very low, so just use an atomic + // for now. + atomicAdd(d_error_count, 1); + } + + // Now compare the numbers as bitwise too. + // Update the counter if they're both zero. + using Bytes = std::array; + bool all_zero = true; + for(auto x : std::bit_cast(a)) + { + if(x != std::byte{0}) + all_zero = false; + } + for(auto x : std::bit_cast(b)) + { + if(x != std::byte{0}) + all_zero = false; + } + if(all_zero) + { + atomicAdd(d_zero_count, 1); + } + }); + + uint64_t error_count = 0; + check_hip(hipMemcpy(&error_count, d_error_count, sizeof(uint64_t), hipMemcpyDeviceToHost)); + uint64_t zero_count = 0; + check_hip(hipMemcpy(&zero_count, d_zero_count, sizeof(uint64_t), hipMemcpyDeviceToHost)); + + // TODO: Gather detailed coordinates. + + reports_.push_back(Case{ + .tensor_name = std::string(tensor_name), + .wrong_elements = error_count, + .total_elements = descriptor.get_element_size(), + .zero_elements = zero_count, + }); + + return reports_.back().is_ok(); +} + +} // namespace ck_tile::builder::test diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 667490151f..71b56182c5 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -81,33 +81,36 @@ add_ck_builder_test(test_ckb_conv_builder test_instance_traits_util.cpp unit_device_buffer.cpp unit_tensor_descriptor.cpp + unit_tensor_foreach.cpp + unit_error.cpp + unit_validation.cpp unit_conv_elementwise_op.cpp unit_conv_tensor_layout.cpp unit_conv_tensor_type.cpp unit_conv_thread_block.cpp unit_conv_tuning_params.cpp) - - # Tests the inline diff utility used for comparing strings in tests assertions - add_ck_builder_test(test_ckb_inline_diff test_inline_diff.cpp) - # GPU reference validation tests (in validation/ folder) - # 1. Reference kernel execution and InstanceTraits - add_ck_builder_test(test_ckb_reference_execution - validation/test_reference_execution.cpp - validation/test_reference_instance_traits.cpp) - target_link_libraries(test_ckb_reference_execution PRIVATE utility) - - # Note: Optimized kernel validation tests will be added after merging dev branch - # with kernel Run() implementation from colleague's work +# Tests the inline diff utility used for comparing strings in tests assertions +add_ck_builder_test(test_ckb_inline_diff test_inline_diff.cpp) + +# GPU reference validation tests (in validation/ folder) +# 1. Reference kernel execution and InstanceTraits +add_ck_builder_test(test_ckb_reference_execution + validation/test_reference_execution.cpp + validation/test_reference_instance_traits.cpp) +target_link_libraries(test_ckb_reference_execution PRIVATE utility) + +# Note: Optimized kernel validation tests will be added after merging dev branch +# with kernel Run() implementation from colleague's work + +# Tests convolution trait selection and configuration +add_ck_builder_test(test_ckb_conv_traits + conv/ck/test_conv_traits.cpp) + +# Tests convolution problem description and parameter handling +add_ck_builder_test(test_ckb_conv_description + test_conv_description.cpp) - # Tests convolution trait selection and configuration - add_ck_builder_test(test_ckb_conv_traits - conv/ck/test_conv_traits.cpp) - - # Tests convolution problem description and parameter handling - add_ck_builder_test(test_ckb_conv_description - test_conv_description.cpp) - ################################################################################ # REGRESSION TESTS - Integration Tests (With Kernel Compilation) ################################################################################ diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp index d5051a50c8..a559d3ee47 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_wmma_cshuffle.cpp @@ -22,7 +22,7 @@ constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 3, constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle{} .with_thread_block(cku::ThreadBlock_64_32x32x32) .with_gemm_config(cku::GemmParams_Wmma_16x16_2x1_per_wave) - .with_transfer(cku::BwdTransfer_4x64x1) + .with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3) .with_bwd_specialization(ckb::ConvSpecialization::DEFAULT) .with_prefetch_config(1, ckb::PipelineScheduler::DEFAULT) .with_gemm_pipeline(ckb::PipelineVersion::V1); diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp index 152409396e..628394e3ca 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp @@ -5,12 +5,16 @@ #include "utils/ckb_conv_test_utils.hpp" #include "utils/conv_algorithm_type_utils.hpp" #include "ck_tile/builder/testing/conv_fwd_ck.hpp" +#include "ck_tile/builder/testing/conv_fwd_reference.hpp" #include "ck_tile/host/device_prop.hpp" +#include "testing_utils.hpp" namespace ckb = ck_tile::builder; namespace ckt = ck_tile::builder::test; namespace cku = ck_tile::builder::test_utils; +using ck_tile::test::MatchesReference; + constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 2, .direction = ckb::ConvDirection::FORWARD, @@ -31,6 +35,8 @@ constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xd using Builder = ckb::ConvBuilder; using Instance = Builder::Instance; +using Reference = ckb::ConvBuilder::Instance; + TEST(Fwd2DFp16_CShufV3_GNHWC, Create) { const auto expected_transfer_parameters = to_string(ALGORITHM); @@ -78,11 +84,17 @@ TEST(Fwd2DFp16_CShufV3_GNHWC, EndToEnd) .cde_elementwise_op = {}, }; - auto inputs = alloc_inputs(args); - auto outputs = alloc_outputs(args); + auto inputs = ckt::alloc_inputs(args); + auto outputs = ckt::alloc_outputs(args); + auto reference = ckt::alloc_outputs(args); - init_inputs(args, inputs); + ckt::init_inputs(args, inputs.get()); auto conv = Instance{}; ckt::run(conv, args, inputs.get(), outputs.get()); + + auto ref_conv = Reference{}; + ckt::run(ref_conv, args, inputs.get(), reference.get()); + + EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get())); } diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp index b79fdf513a..89baf9b51b 100644 --- a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp @@ -40,7 +40,6 @@ TEST(BwdDataConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D "Default", "Intrawave", "CShuffleEpilogue", - "set", "pipeline_AgBgCrCompV3", "DoubleSmemBuffer_0", "NumWaveGroups_1", diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp index a5801b0e85..292d852b91 100644 --- a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp @@ -40,7 +40,6 @@ TEST(BwdWeightConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_ "Default", "Intrawave", "CShuffleEpilogue", - "set", "pipeline_AgBgCrCompV3", "DoubleSmemBuffer_0", "NumWaveGroups_1", diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp index 9a8a4ce753..2c35fb5076 100644 --- a/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp @@ -39,7 +39,6 @@ TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP1 "Default", "Intrawave", "CShuffleEpilogue", - "set", "pipeline_AgBgCrCompV3", "DoubleSmemBuffer_0", "NumWaveGroups_1", diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index 0c665c8321..c8cf809d3a 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -610,6 +610,32 @@ using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle = ConvSpecializationBwdWeight_, MultipleDSpecialization_>; +using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3 = + ConvAlgorithmTemplate, + ConvSpecializationBwdWeight_, + BlockGemm_, + TransposeParams_>; + +using ConvAlgorithm_DeviceGroupedConvBwdWeight_TwoStage_Wmma_CShuffle_V3 = + ConvAlgorithmTemplate, + ConvSpecializationBwdWeight_, + BlockGemm_, + TransposeParams_, + GemmBatchOptions_, + TwoStageSpecialization_>; + +using ConvAlgorithm_DeviceGroupedConvBwdWeight_Wmma_CShuffle = + ConvAlgorithmTemplate, + ConvSpecializationBwdWeight_, + GridGemm_, + Prefetch_>; + using ConvAlgorithm_DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3 = ConvAlgorithmTemplate>; diff --git a/experimental/builder/test/test_bwd_weight_instance_traits.cpp b/experimental/builder/test/test_bwd_weight_instance_traits.cpp index c7c4e370e2..dbb3a0a8fc 100644 --- a/experimental/builder/test/test_bwd_weight_instance_traits.cpp +++ b/experimental/builder/test/test_bwd_weight_instance_traits.cpp @@ -184,7 +184,6 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat) 16 /*N_Warp_Tile*/, 16 /*K_Warp_Tile*/, GroupedConvTraitsType::FixedGemmParams::TransposeC, - ck_tile::memory_operation_enum::set /*memory_operation*/, 1 /*kNumWaveGroups*/, GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, GroupedConvTraitsType::VectorSizeC>>; diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp index 498de9a42f..9e8008ccf0 100644 --- a/experimental/builder/test/test_conv_description.cpp +++ b/experimental/builder/test/test_conv_description.cpp @@ -161,8 +161,9 @@ struct DefaultAlgorithm ckb::ConvSpecialization fwd_specialization = ckb::ConvSpecialization::DEFAULT; ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::Default; - ckb::test::BlockGemm block_gemm{.pipeline_version = ckb::PipelineVersion::V4, - .scheduler = ckb::PipelineScheduler::INTRAWAVE}; + ckb::test::BlockGemmPipeline block_gemm_pipeline{.pipeline_version = ckb::PipelineVersion::V4, + .scheduler = + ckb::PipelineScheduler::INTRAWAVE}; }; static_assert(ckb::ConvAlgorithmDescriptor); diff --git a/experimental/builder/test/test_fwd_instance_traits.cpp b/experimental/builder/test/test_fwd_instance_traits.cpp index 6dd2a4eada..ad0a2cadc6 100644 --- a/experimental/builder/test/test_fwd_instance_traits.cpp +++ b/experimental/builder/test/test_fwd_instance_traits.cpp @@ -795,7 +795,6 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat) 16 /*N_Warp_Tile*/, 16 /*K_Warp_Tile*/, GroupedConvTraitsType::FixedGemmParams::TransposeC, - ck_tile::memory_operation_enum::set /*memory_operation*/, 1 /*kNumWaveGroups*/, GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, GroupedConvTraitsType::VectorSizeC>>; diff --git a/experimental/builder/test/test_inline_diff.cpp b/experimental/builder/test/test_inline_diff.cpp index 8d3a90c95f..6a7a7ac8f7 100644 --- a/experimental/builder/test/test_inline_diff.cpp +++ b/experimental/builder/test/test_inline_diff.cpp @@ -5,8 +5,7 @@ #include "testing_utils.hpp" -namespace ck_tile::builder { -namespace { +using ck_tile::test::inlineDiff; TEST(InlineDiff, simpleColorDiff) { @@ -16,8 +15,8 @@ TEST(InlineDiff, simpleColorDiff) // some easy tests // you can veryfy the ungodly strings are meaningful by running echo -e "" - EXPECT_THAT(test::inlineDiff(str1, str2, true), "hello"); - EXPECT_THAT(test::inlineDiff(str1, str3, true), + EXPECT_THAT(inlineDiff(str1, str2, true), "hello"); + EXPECT_THAT(inlineDiff(str1, str3, true), "[\x1B[36mwor\x1B[0m|\x1B[35mhel\x1B[0m]l[\x1B[36md\x1B[0m|\x1B[35mo\x1B[0m]"); } @@ -28,8 +27,8 @@ TEST(InlineDiff, noColorDiff) std::string str3{"world"}; // some easy tests without color - EXPECT_THAT(test::inlineDiff(str1, str2, false), "hello"); - EXPECT_THAT(test::inlineDiff(str1, str3, false), "[wor|hel]l[d|o]"); + EXPECT_THAT(inlineDiff(str1, str2, false), "hello"); + EXPECT_THAT(inlineDiff(str1, str3, false), "[wor|hel]l[d|o]"); } TEST(InlineDiff, complexColorDiff) @@ -42,11 +41,8 @@ TEST(InlineDiff, complexColorDiff) "this part has degeahc, this part has, this part added, this part has ana extra letter"}; EXPECT_THAT( - test::inlineDiff(str5, str4, true), + inlineDiff(str5, str4, true), "this part has [\x1B[36mchanged\x1B[0m|\x1B[35mdegeahc\x1B[0m], this part has[\x1B[36m " "been left out\x1B[0m|\x1B[35m\x1B[0m], this part[\x1B[36m\x1B[0m|\x1B[35m added\x1B[0m], " "this part has an[\x1B[36m\x1B[0m|\x1B[35ma\x1B[0m] extra letter"); }; - -} // namespace -} // namespace ck_tile::builder diff --git a/experimental/builder/test/testing_utils.hpp b/experimental/builder/test/testing_utils.hpp index 7a03851ac4..b84d53b6df 100644 --- a/experimental/builder/test/testing_utils.hpp +++ b/experimental/builder/test/testing_utils.hpp @@ -2,6 +2,7 @@ // SPDX-License-Identifier: MIT #include +#include "ck_tile/builder/testing/testing.hpp" #include #include #include @@ -21,6 +22,16 @@ /// dedicated function to override to provide printing support. std::ostream& operator<<(std::ostream& os, hipError_t status); +namespace ck_tile::builder::test { + +template +std::ostream& operator<<(std::ostream& os, [[maybe_unused]] Outputs outputs) +{ + return os << ""; +} + +} // namespace ck_tile::builder::test + namespace ck_tile::test { static bool isTerminalOutput() { return isatty(fileno(stdout)) || isatty(fileno(stderr)); } @@ -150,4 +161,47 @@ struct HipStatusMatcher : public ::testing::MatcherInterface /// @param error The error to expect. ::testing::Matcher HipError(hipError_t error); +template +struct ReferenceOutputMatcher + : public ::testing::MatcherInterface> +{ + ReferenceOutputMatcher(const builder::test::Args& args, + builder::test::Outputs expected) + : args_(&args), expected_(expected) + { + } + + bool MatchAndExplain(builder::test::Outputs actual, + [[maybe_unused]] ::testing::MatchResultListener* listener) const override + { + const auto report = ck_tile::builder::test::validate(*args_, actual, expected_); + const auto errors = report.get_errors(); + + if(listener->IsInterested() && !errors.empty()) + { + *listener << errors.size() << " tensors failed to validate"; + } + + return errors.empty(); + } + + void DescribeTo(std::ostream* os) const override { *os << ""; } + + void DescribeNegationTo(std::ostream* os) const override + { + *os << "isn't equal to "; + } + + const builder::test::Args* args_; + builder::test::Outputs expected_; +}; + +template +::testing::Matcher> +MatchesReference(const builder::test::Args& args, + builder::test::Outputs expected) +{ + return ::testing::MakeMatcher(new ReferenceOutputMatcher(args, expected)); +} + } // namespace ck_tile::test diff --git a/experimental/builder/test/unit_conv_tensor_type.cpp b/experimental/builder/test/unit_conv_tensor_type.cpp index 7ffd446966..b385210cea 100644 --- a/experimental/builder/test/unit_conv_tensor_type.cpp +++ b/experimental/builder/test/unit_conv_tensor_type.cpp @@ -11,40 +11,27 @@ namespace { namespace ckb = ck_tile::builder; using ck_tile::builder::factory::internal::DataTypeToCK; -TEST(ConvTensorType, AssignsTypesForFP16) -{ - using CKType = DataTypeToCK::type; - EXPECT_TRUE((std::is_same_v)); -} +template +constexpr auto check_same = std::is_same_v::type, T>; -TEST(ConvTensorType, AssignsTypesForBF16) +TEST(ConvTensorType, Exhaustive) { - using CKType = DataTypeToCK::type; - EXPECT_TRUE((std::is_same_v)); -} + using enum ckb::DataType; -TEST(ConvTensorType, AssignsTypesForFP32) -{ - using CKType = DataTypeToCK::type; - EXPECT_TRUE((std::is_same_v)); -} - -TEST(ConvTensorType, AssignsTypesForINT32) -{ - using CKType = DataTypeToCK::type; - EXPECT_TRUE((std::is_same_v)); -} - -TEST(ConvTensorType, AssignsTypesForI8) -{ - using CKType = DataTypeToCK::type; - EXPECT_TRUE((std::is_same_v)); -} - -TEST(ConvTensorType, AssignsTypesForFP8) -{ - using CKType = DataTypeToCK::type; - EXPECT_TRUE((std::is_same_v)); + const auto type = FP32; + // This switch ensures that we get a warning (error with -Werror) if + // a variant is missing. + switch(type) + { + case UNDEFINED_DATA_TYPE: break; + case FP32: EXPECT_TRUE((check_same)); break; + case FP16: EXPECT_TRUE((check_same)); break; + case BF16: EXPECT_TRUE((check_same)); break; + case INT32: EXPECT_TRUE((check_same)); break; + case FP8: EXPECT_TRUE((check_same)); break; + case I8: EXPECT_TRUE((check_same)); break; + case U8: EXPECT_TRUE((check_same)); break; + } } } // namespace diff --git a/experimental/builder/test/unit_conv_tuning_params.cpp b/experimental/builder/test/unit_conv_tuning_params.cpp index ee1388a77f..9005742930 100644 --- a/experimental/builder/test/unit_conv_tuning_params.cpp +++ b/experimental/builder/test/unit_conv_tuning_params.cpp @@ -19,7 +19,7 @@ TEST(ConvTuningParams, AssignsBlockGemmParams) { ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V3; ckb::PipelineScheduler scheduler = ckb::PipelineScheduler::INTRAWAVE; - } block_gemm; + } block_gemm_pipeline; } kAlgorithm; constexpr auto block_gemm = SetBlockGemm(); @@ -42,10 +42,7 @@ TEST(ConvTuningParams, AssignsGridwiseGemmPipelineVersion) { constexpr struct Algorithm { - struct GridwiseGemm - { - ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V4; - } gridwise_gemm; + ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V4; } kAlgorithm; constexpr auto pipeline_version = SetGridwiseGemmPipelineVersion(); diff --git a/experimental/builder/test/unit_device_buffer.cpp b/experimental/builder/test/unit_device_buffer.cpp index 75408acc16..c7180395b7 100644 --- a/experimental/builder/test/unit_device_buffer.cpp +++ b/experimental/builder/test/unit_device_buffer.cpp @@ -2,10 +2,11 @@ // SPDX-License-Identifier: MIT #include "ck_tile/builder/testing/tensor_buffer.hpp" +#include "ck_tile/builder/testing/tensor_descriptor.hpp" #include "testing_utils.hpp" #include #include -#include +#include namespace ckb = ck_tile::builder; namespace ckt = ck_tile::builder::test; @@ -54,6 +55,11 @@ TEST(DeviceBuffer, AutoFree) // Trying to use a pointer after freeing should return en error in HIP. EXPECT_THAT(hipMemset(ptr, 0xFF, size), HipError(hipErrorInvalidValue)); + + // Reset internal HIP error state. + // Otherwise, the error may leak into other tests, triggering anything that + // checks the output of hipGetLastError(); + (void)hipGetLastError(); } TEST(DeviceBuffer, ThrowsOnOom) @@ -62,13 +68,16 @@ TEST(DeviceBuffer, ThrowsOnOom) auto check = [] { auto buffer = ckt::alloc_buffer(size); }; EXPECT_THAT(check, Throws()); + + // Reset internal HIP error state. + // Otherwise, the error may leak into other tests, triggering anything that + // checks the output of hipGetLastError(); + (void)hipGetLastError(); } TEST(DeviceBuffer, AllocTensorBuffer) { - std::vector lengths = {128, 128, 128}; - std::vector strides = {128 * 128, 128, 1}; - ckt::TensorDescriptor descriptor(lengths, strides); + ckt::TensorDescriptor descriptor({128, 128, 128}, {128 * 128, 128, 1}); auto buffer = ckt::alloc_tensor_buffer(descriptor); diff --git a/experimental/builder/test/unit_error.cpp b/experimental/builder/test/unit_error.cpp new file mode 100644 index 0000000000..201780cc6a --- /dev/null +++ b/experimental/builder/test/unit_error.cpp @@ -0,0 +1,46 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/builder/testing/error.hpp" +#include "ck_tile/builder/testing/tensor_buffer.hpp" +#include "testing_utils.hpp" +#include +#include + +namespace ckt = ck_tile::builder::test; + +using ::testing::AllOf; +using ::testing::HasSubstr; +using ::testing::Throws; +using ::testing::ThrowsMessage; + +[[noreturn]] void throw_error() { throw ckt::HipError("test error", hipErrorInvalidValue); } + +TEST(HipError, SourceInfo) +{ + EXPECT_THAT(throw_error, + ThrowsMessage(AllOf( + // The error message should include... + // ...the user message + HasSubstr("test error"), + // ...the HIP message + HasSubstr("invalid argument"), + // ...the HIP status code, + HasSubstr("(1)"), + // ...the filename + HasSubstr("experimental/builder/test/unit_error.cpp"), + // ...the function name + HasSubstr("throw_error") + // Note: Don't include the row/column so that we can move + // stuff around in this file. + ))); +} + +TEST(CheckHip, BasicUsage) +{ + EXPECT_THAT([] { ckt::check_hip(hipSuccess); }, Not(Throws())); + EXPECT_THAT([] { ckt::check_hip(hipErrorNotMapped); }, Throws()); + EXPECT_THAT([] { ckt::check_hip(hipErrorOutOfMemory); }, Throws()); + EXPECT_THAT([] { ckt::check_hip("test message", hipErrorAlreadyMapped); }, + ThrowsMessage(HasSubstr("test message"))); +} diff --git a/experimental/builder/test/unit_tensor_descriptor.cpp b/experimental/builder/test/unit_tensor_descriptor.cpp index 07abfe44bd..672ebbd88a 100644 --- a/experimental/builder/test/unit_tensor_descriptor.cpp +++ b/experimental/builder/test/unit_tensor_descriptor.cpp @@ -1,25 +1,28 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include "ck_tile/builder/testing/tensor_buffer.hpp" +#include "ck_tile/builder/testing/tensor_descriptor.hpp" #include "testing_utils.hpp" #include #include +#include #include namespace ckb = ck_tile::builder; namespace ckt = ck_tile::builder::test; using ::testing::ElementsAreArray; -using ::testing::Ge; +using ::testing::Eq; +using ::testing::Throws; TEST(TensorDescriptor, Basic) { - constexpr auto dt = ckb::DataType::FP16; - std::vector lengths = {123, 456, 789}; - std::vector strides = {456 * 789, 789, 1}; + constexpr auto dt = ckb::DataType::FP16; + constexpr size_t rank = 3; + ckt::Extent lengths = {123, 456, 789}; + ckt::Extent strides = {456 * 789, 789, 1}; - ckt::TensorDescriptor
descriptor(lengths, strides); + ckt::TensorDescriptor descriptor(lengths, strides); EXPECT_THAT(descriptor.get_lengths(), ElementsAreArray(lengths)); EXPECT_THAT(descriptor.get_strides(), ElementsAreArray(strides)); @@ -27,21 +30,162 @@ TEST(TensorDescriptor, Basic) TEST(TensorDescriptor, ComputeSize) { - constexpr auto dt = ckb::DataType::FP32; - std::vector lengths = {305, 130, 924}; - std::vector strides = {1000 * 1000, 1, 1000}; + constexpr auto dt = ckb::DataType::FP32; + constexpr size_t rank = 3; + ckt::Extent lengths = {305, 130, 924}; + ckt::Extent strides = {1001 * 1000, 1, 1000}; - ckt::TensorDescriptor
descriptor(lengths, strides); + ckt::TensorDescriptor descriptor(lengths, strides); - // Compute the location of the last item in memory, then add one - // to get the minimum size. - size_t expected_size = 1; + // Compute the location of the last item in memory, + // then add one to get the minimum size. + size_t expected_size = 1; + size_t expected_numel = 1; for(size_t i = 0; i < lengths.size(); ++i) { expected_size += (lengths[i] - 1) * strides[i]; + expected_numel *= lengths[i]; } - EXPECT_THAT(descriptor.get_element_space_size(), Ge(expected_size)); + EXPECT_THAT(descriptor.get_element_size(), Eq(expected_numel)); + EXPECT_THAT(descriptor.get_element_space_size(), Eq(expected_size)); EXPECT_THAT(descriptor.get_element_space_size_in_bytes(), - Ge(expected_size * ckt::data_type_sizeof(dt))); + Eq(expected_size * ckt::data_type_sizeof(dt))); +} + +TEST(TensorDescriptor, PackedRightLayout) +{ + const ckt::Extent lengths = {5125, 623, 1177, 1534}; + const auto strides = ckt::PackedRightLayout{}(lengths); + + EXPECT_THAT(strides, ElementsAreArray({623 * 1177 * 1534, 1177 * 1534, 1534, 1})); +} + +TEST(TensorDescriptor, PackedLeftLayout) +{ + const ckt::Extent lengths = {4, 15, 925, 662, 1462}; + const auto strides = ckt::PackedLeftLayout{}(lengths); + + EXPECT_THAT(strides, ElementsAreArray({1, 4, 4 * 15, 4 * 15 * 925, 4 * 15 * 925 * 662})); +} + +TEST(TensorDescriptor, MakeDescriptor) +{ + { + const ckt::Extent lengths = {10, 11, 12, 13, 14}; + + // Note: automatic inference of RANK. + const auto desc = + ckt::make_descriptor(lengths, ckt::PackedRightLayout{}); + + EXPECT_THAT(desc.get_lengths(), ElementsAreArray(lengths)); + EXPECT_THAT(desc.get_strides(), + ElementsAreArray({11 * 12 * 13 * 14, 12 * 13 * 14, 13 * 14, 14, 1})); + } + + { + const ckt::Extent lengths = {4, 3, 2}; + const ckt::Extent strides = {60, 1, 7}; + + // Note: automatic inference of RANK. + const auto desc = ckt::make_descriptor(lengths, strides); + + EXPECT_THAT(desc.get_lengths(), ElementsAreArray(lengths)); + EXPECT_THAT(desc.get_strides(), ElementsAreArray(strides)); + } +} + +TEST(TensorDescriptor, GetSpaceDescriptor) +{ + { + const auto desc = ckt::make_descriptor(ckt::Extent{4, 4, 4}, + ckt::PackedLeftLayout{}); + const auto space = desc.get_space_descriptor(); + + const auto expected = 4 * 4 * 4; + + EXPECT_THAT(decltype(space)::data_type, Eq(ckb::DataType::FP32)); + EXPECT_THAT(decltype(space)::rank, Eq(1)); + + EXPECT_THAT(decltype(space)::data_type, Eq(ckb::DataType::FP32)); + EXPECT_THAT(decltype(space)::rank, Eq(1)); + EXPECT_THAT(space.get_lengths(), ElementsAreArray({expected})); + EXPECT_THAT(space.get_strides(), ElementsAreArray({1})); + EXPECT_THAT(space.get_element_size(), Eq(expected)); + EXPECT_THAT(space.get_element_space_size(), Eq(expected)); + } + + { + const ckt::Extent lengths = {6, 3, 4}; + const ckt::Extent strides = {102, 1, 2002}; + const auto desc = ckt::make_descriptor(lengths, strides); + const auto space = desc.get_space_descriptor(); + + // Compute the location of the last item in memory, + // then add one to get the minimum size. + size_t expected_size = 1; + for(size_t i = 0; i < lengths.size(); ++i) + { + expected_size += (lengths[i] - 1) * strides[i]; + } + + EXPECT_THAT(decltype(space)::data_type, Eq(ckb::DataType::FP32)); + EXPECT_THAT(decltype(space)::rank, Eq(1)); + EXPECT_THAT(space.get_lengths(), ElementsAreArray({expected_size})); + EXPECT_THAT(space.get_strides(), ElementsAreArray({1})); + EXPECT_THAT(space.get_element_size(), Eq(expected_size)); + EXPECT_THAT(space.get_element_space_size(), Eq(expected_size)); + } +} + +TEST(TensorDescriptor, EmptyExtent) +{ + // A rank-0 tensor points to a single element + const auto desc = ckt::make_descriptor(ckt::Extent{}, ckt::Extent{}); + EXPECT_THAT(decltype(desc)::rank, Eq(0)); + EXPECT_THAT(desc.get_lengths().size(), Eq(0)); + EXPECT_THAT(desc.get_strides().size(), Eq(0)); + EXPECT_THAT(desc.get_element_size(), Eq(1)); + EXPECT_THAT(desc.get_element_space_size(), Eq(1)); + EXPECT_THAT(desc.get_element_space_size_in_bytes(), Eq(2)); + + // We expect a rank-1 tensor with the one dimension being 1. + const auto space = desc.get_space_descriptor(); + + const auto expected = 1; + + EXPECT_THAT(decltype(space)::rank, Eq(1)); + EXPECT_THAT(space.get_lengths(), ElementsAreArray({expected})); + EXPECT_THAT(space.get_strides(), ElementsAreArray({1})); + EXPECT_THAT(space.get_element_size(), Eq(expected)); + EXPECT_THAT(space.get_element_space_size(), Eq(expected)); + EXPECT_THAT(space.get_element_space_size_in_bytes(), Eq(2)); +} + +TEST(TensorDescriptor, ExtentFromVector) +{ + EXPECT_THAT(ckt::Extent<4>::from_vector(std::vector{1, 2, 3, 4}), + ElementsAreArray({1, 2, 3, 4})); + + EXPECT_THAT([] { return ckt::Extent<5>::from_vector(std::vector{1, 2}); }, + Throws()); +} + +TEST(TensorDescriptor, IsPacked) +{ + constexpr auto dt = ckb::DataType::INT32; // Irrelevant for this test + EXPECT_TRUE( + ckt::make_descriptor
(ckt::Extent{101, 43, 25, 662, 654}, ckt::PackedLeftLayout{}) + .is_packed()); + EXPECT_TRUE( + ckt::make_descriptor
(ckt::Extent{5334, 235, 1563, 256, 23}, ckt::PackedRightLayout{}) + .is_packed()); + EXPECT_TRUE(ckt::make_descriptor
(ckt::Extent{}, ckt::Extent{}).is_packed()); + EXPECT_TRUE( + ckt::make_descriptor
(ckt::Extent{461, 345, 5, 93}, ckt::Extent{160425, 5, 1, 1725}) + .is_packed()); + EXPECT_FALSE( + ckt::make_descriptor
(ckt::Extent{10, 11, 12}, ckt::Extent{1, 100, 1100}).is_packed()); + EXPECT_FALSE( + ckt::make_descriptor
(ckt::Extent{30, 20, 10}, ckt::Extent{1, 1, 1}).is_packed()); } diff --git a/experimental/builder/test/unit_tensor_foreach.cpp b/experimental/builder/test/unit_tensor_foreach.cpp new file mode 100644 index 0000000000..de635bc09b --- /dev/null +++ b/experimental/builder/test/unit_tensor_foreach.cpp @@ -0,0 +1,205 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/builder/testing/tensor_descriptor.hpp" +#include "ck_tile/builder/testing/tensor_buffer.hpp" +#include "ck_tile/builder/testing/tensor_foreach.hpp" +#include "testing_utils.hpp" +#include +#include +#include +#include + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; + +using ::testing::Each; +using ::testing::Eq; + +TEST(TensorForeach, CalculateOffset) +{ + EXPECT_THAT(ckt::calculate_offset(ckt::Extent{1, 2, 3}, ckt::Extent{100, 10, 1}), Eq(123)); + EXPECT_THAT(ckt::calculate_offset(ckt::Extent{523, 266, 263}, ckt::Extent{1, 545, 10532}), + Eq(2915409)); + EXPECT_THAT(ckt::calculate_offset(ckt::Extent{}, ckt::Extent{}), Eq(0)); + // Note: >4 GB overflow test + EXPECT_THAT(ckt::calculate_offset(ckt::Extent{8, 2, 5, 7, 0, 4, 1, 3, 6, 9}, + ckt::Extent{1'000, + 1'000'000, + 10'000'000, + 1'000'000'000, + 1, + 10'000, + 100, + 10, + 100'000'000, + 100'000}), + Eq(size_t{7'652'948'130})); +} + +TEST(TensorForeach, VisitsCorrectCount) +{ + // tensor_foreach should visit every index exactly once. + // This test checks that the count is at least correct. + + const ckt::Extent shape = {10, 20, 30}; + + auto d_count = ckt::alloc_buffer(sizeof(uint64_t)); + ckt::check_hip(hipMemset(d_count.get(), 0, sizeof(uint64_t))); + + ckt::tensor_foreach(shape, [count = d_count.get()]([[maybe_unused]] const auto& index) { + atomicAdd(reinterpret_cast(count), 1); + }); + + uint64_t actual; + ckt::check_hip(hipMemcpy(&actual, d_count.get(), sizeof(uint64_t), hipMemcpyDeviceToHost)); + + const auto expected = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + + EXPECT_THAT(actual, Eq(expected)); +} + +TEST(TensorForeach, VisitsEveryIndex) +{ + const ckt::Extent shape = {5, 6, 7, 8, 9, 10, 11}; + const auto total = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + + // We know this is correct due to testing in unit_tensor_descriptor.cpp + const auto stride = ckt::PackedRightLayout{}(shape); + + auto d_output = ckt::alloc_buffer(sizeof(uint32_t) * total); + ckt::check_hip(hipMemset(d_output.get(), 0, sizeof(uint32_t) * total)); + + ckt::tensor_foreach(shape, [output = d_output.get(), stride](const auto& index) { + // We know this is correct due to the CalculateOffset test. + auto offset = ckt::calculate_offset(index, stride); + + // Use atomic add so that we can check that every index is visited exactly once. + atomicAdd(&reinterpret_cast(output)[offset], 1); + }); + + std::vector actual(total); + ckt::check_hip( + hipMemcpy(actual.data(), d_output.get(), sizeof(uint32_t) * total, hipMemcpyDeviceToHost)); + + EXPECT_THAT(actual, Each(Eq(1))); +} + +TEST(TensorForeach, FillTensorBuffer) +{ + auto desc = ckt::make_descriptor(ckt::Extent{31, 54, 13}, + ckt::PackedRightLayout{}); + + auto buffer = ckt::alloc_tensor_buffer(desc); + + ckt::fill_tensor_buffer(desc, buffer.get(), [](size_t i) { return static_cast(i); }); + + std::vector h_buffer(desc.get_element_space_size()); + ckt::check_hip(hipMemcpy( + h_buffer.data(), buffer.get(), h_buffer.size() * sizeof(uint32_t), hipMemcpyDeviceToHost)); + + for(size_t i = 0; i < h_buffer.size(); ++i) + { + EXPECT_THAT(h_buffer[i], Eq(static_cast(i))); + } +} + +TEST(TensorForeach, FillTensor) +{ + // FillTensor with non-packed indices should not write out-of-bounds. + const ckt::Extent shape = {4, 23, 35}; + const ckt::Extent pad = {12, 53, 100}; + auto desc = ckt::make_descriptor(shape, ckt::PackedRightLayout{}(pad)); + const auto strides = desc.get_strides(); + + auto size = desc.get_element_space_size(); + auto buffer = ckt::alloc_tensor_buffer(desc); + + ckt::fill_tensor_buffer(desc, buffer.get(), []([[maybe_unused]] size_t i) { return 123; }); + + ckt::fill_tensor(desc, buffer.get(), []([[maybe_unused]] const auto& index) { return 1; }); + + auto d_error = ckt::alloc_buffer(sizeof(uint32_t) * size); + ckt::check_hip(hipMemset(d_error.get(), 0, sizeof(uint32_t))); + + ckt::tensor_foreach( + // Iterate over the entire padding so that we can check out-of-bounds elements + pad, + [shape, pad, strides, size, error = d_error.get(), tensor = buffer.get()]( + const auto& index) { + const auto offset = ckt::calculate_offset(index, strides); + const auto value = reinterpret_cast(tensor)[offset]; + + // Note: The space of the descriptor will not actually be (12, 53, 100) but + // more like (4, 53, 100), as the outer stride is irrelevant. So we have to + // perform an extra bounds check here. + if(offset < size) + { + // Check if the coordinate is within the shape bounds. + bool in_bounds = true; + for(size_t i = 0; i < shape.size(); ++i) + { + if(index[i] >= shape[i]) + { + in_bounds = false; + } + } + + // In-bounds elements are 1, out-of-bounds is 123. + if(in_bounds && value != 1) + { + atomicAdd(reinterpret_cast(error), 1); + } + else if(!in_bounds && value != 123) + { + atomicAdd(reinterpret_cast(error), 1); + } + } + }); + + uint32_t error_count = 0; + ckt::check_hip(hipMemcpy(&error_count, d_error.get(), sizeof(uint32_t), hipMemcpyDeviceToHost)); + + EXPECT_THAT(error_count, Eq(0)); +} + +TEST(TensorForeach, ClearTensorZeros) +{ + const ckt::Extent shape = {5, 4, 5, 4, 5, 4, 5, 6}; + const ckt::Extent pad = {6, 6, 6, 6, 6, 6, 6, 6}; + + const auto desc = + ckt::make_descriptor(shape, ckt::PackedRightLayout{}(pad)); + + auto buffer = ckt::alloc_tensor_buffer(desc); + ckt::clear_tensor_buffer(desc, buffer.get()); + + // Check that all values are zeroed. + auto d_count = ckt::alloc_buffer(sizeof(uint64_t)); + ckt::check_hip(hipMemset(d_count.get(), 0, sizeof(uint64_t))); + + { + const auto size = desc.get_element_space_size(); + const auto strides = desc.get_strides(); + auto* count = d_count.get(); + const auto* tensor = reinterpret_cast(buffer.get()); + // Note: iterate over the entire pad, so that we can check out-of-bounds elements. + ckt::tensor_foreach(pad, + [count, tensor, strides, size]([[maybe_unused]] const auto& index) { + const auto offset = ckt::calculate_offset(index, strides); + + // Note: The space of the descriptor will not actually be (6, 6, + // ...) but more like (5, 6, ...), as the outer stride is + // irrelevant. So we have to perform an extra bounds check here. + if(offset < size && tensor[offset] != 0) + { + atomicAdd(reinterpret_cast(count), 1); + } + }); + } + + uint64_t actual; + ckt::check_hip(hipMemcpy(&actual, d_count.get(), sizeof(uint64_t), hipMemcpyDeviceToHost)); + + EXPECT_THAT(actual, Eq(0)); +} diff --git a/experimental/builder/test/unit_validation.cpp b/experimental/builder/test/unit_validation.cpp new file mode 100644 index 0000000000..5f6b620d6b --- /dev/null +++ b/experimental/builder/test/unit_validation.cpp @@ -0,0 +1,298 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/builder/testing/error.hpp" +#include "ck_tile/builder/testing/tensor_buffer.hpp" +#include "ck_tile/builder/testing/tensor_descriptor.hpp" +#include "ck_tile/builder/testing/validation.hpp" +#include "ck_tile/builder/testing/tensor_foreach.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include "ck_tile/builder/testing/testing.hpp" +#include "testing_utils.hpp" +#include +#include +#include +#include + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; + +using testing::ElementsAreArray; +using testing::Eq; +using testing::StrEq; + +using ck_tile::test::MatchesReference; +using ck_tile::test::StringEqWithDiff; + +// Googletest cannot have both type AND value parameterized tests. +// For now just act lazy and use value template parameters. +template +struct Param +{ + constexpr static auto data_type = DT; + constexpr static auto shape = SHAPE; + constexpr static auto strides = STRIDES; + + constexpr static auto rank = shape.size(); + + static ckt::TensorDescriptor get_descriptor() + { + return ckt::make_descriptor(shape, strides); + } +}; + +template +struct ValidationReportTests : public ::testing::Test +{ +}; + +using Types = ::testing::Types< + Param, + Param, + Param, + Param>; + +TYPED_TEST_SUITE(ValidationReportTests, Types); + +TYPED_TEST(ValidationReportTests, SingleCorrect) +{ + const auto desc = TypeParam::get_descriptor(); + + auto a = ckt::alloc_tensor_buffer(desc); + auto b = ckt::alloc_tensor_buffer(desc); + + ckt::clear_tensor_buffer(desc, a.get()); + ckt::clear_tensor_buffer(desc, b.get()); + + // Generate a sort-of-random looking sequence + auto generator = [strides = desc.get_strides()](const auto& index) { + const auto flat_index = ckt::calculate_offset(index, strides); + return static_cast((flat_index + 1) * 10'000'019 % 768'351); + }; + + ckt::fill_tensor(desc, a.get(), generator); + ckt::fill_tensor(desc, b.get(), generator); + + ckt::ValidationReport report; + report.check("correct", desc, b.get(), a.get()); + + EXPECT_THAT(report.get_errors().size(), Eq(0)); +} + +TYPED_TEST(ValidationReportTests, SingleIncorrect) +{ + const auto desc = TypeParam::get_descriptor(); + const auto packed_strides = ckt::PackedRightLayout{}(desc.get_lengths()); + + auto a = ckt::alloc_tensor_buffer(desc); + auto b = ckt::alloc_tensor_buffer(desc); + + ckt::clear_tensor_buffer(desc, a.get()); + ckt::clear_tensor_buffer(desc, b.get()); + + ckt::fill_tensor(desc, a.get(), []([[maybe_unused]] const auto& i) { return 123; }); + ckt::fill_tensor(desc, b.get(), [packed_strides](const auto& index) { + const auto flat_index = ckt::calculate_offset(index, packed_strides); + return flat_index == 0 ? 0 : flat_index == 12345 ? 456 : flat_index == 999999 ? 1 : 123; + }); + + ckt::ValidationReport report; + report.check("incorrect", desc, b.get(), a.get()); + + const auto errors = report.get_errors(); + + const auto flat_size = desc.get_element_size(); + const auto expected_errors = flat_size >= 999999 ? 3 : flat_size >= 12345 ? 2 : 1; + + ASSERT_THAT(errors.size(), Eq(1)); + EXPECT_THAT(errors[0].tensor_name, StrEq("incorrect")); + EXPECT_THAT(errors[0].wrong_elements, Eq(expected_errors)); + EXPECT_THAT(errors[0].total_elements, Eq(desc.get_element_size())); +} + +TYPED_TEST(ValidationReportTests, ZeroIsIncorrect) +{ + const auto desc = TypeParam::get_descriptor(); + + auto a = ckt::alloc_tensor_buffer(desc); + auto b = ckt::alloc_tensor_buffer(desc); + + ckt::clear_tensor_buffer(desc, a.get()); + ckt::clear_tensor_buffer(desc, b.get()); + + ckt::ValidationReport report; + report.check("zero_is_incorrect", desc, b.get(), a.get()); + + const auto errors = report.get_errors(); + ASSERT_THAT(errors.size(), Eq(1)); + EXPECT_THAT(errors[0].tensor_name, StrEq("zero_is_incorrect")); + EXPECT_THAT(errors[0].wrong_elements, Eq(0)); + EXPECT_THAT(errors[0].total_elements, Eq(desc.get_element_size())); + EXPECT_THAT(errors[0].zero_elements, Eq(desc.get_element_size())); +} + +TEST(ValidationReportTests, MultipleSomeIncorrect) +{ + ckt::ValidationReport report; + + { + auto desc = ckt::make_descriptor({'R', 'O', 'C', 'm'}, + ckt::PackedLeftLayout{}); + + auto a = ckt::alloc_tensor_buffer(desc); + auto b = ckt::alloc_tensor_buffer(desc); + + ckt::fill_tensor_buffer( + desc, a.get(), [](size_t i) { return ck::type_convert(i % 100); }); + ckt::fill_tensor_buffer( + desc, b.get(), [](size_t i) { return ck::type_convert(i % 101); }); + + report.check("incorrect 1", desc, b.get(), a.get()); + } + + { + auto desc = + ckt::make_descriptor({'H', 'I', 'P'}, ckt::PackedRightLayout{}); + + auto a = ckt::alloc_tensor_buffer(desc); + auto b = ckt::alloc_tensor_buffer(desc); + + ckt::fill_tensor_buffer(desc, a.get(), [](size_t i) { return "ROCm"[i % 4]; }); + ckt::fill_tensor_buffer(desc, b.get(), [](size_t i) { + switch(i % 4) + { + case 0: return 'R'; + case 1: return 'O'; + case 2: return 'C'; + case 3: return 'm'; + default: return 'x'; + } + }); + + report.check("correct", desc, b.get(), a.get()); + } + + { + auto desc = ckt::make_descriptor({'G', 'P', 'U'}, + ckt::PackedRightLayout{}); + + auto a = ckt::alloc_tensor_buffer(desc); + auto b = ckt::alloc_tensor_buffer(desc); + + ckt::fill_tensor_buffer(desc, a.get(), []([[maybe_unused]] size_t i) { return 1; }); + ckt::fill_tensor_buffer(desc, b.get(), []([[maybe_unused]] size_t i) { return 555; }); + + report.check("incorrect 2", desc, b.get(), a.get()); + } + + const auto errors = report.get_errors(); + + ASSERT_THAT(errors.size(), Eq(2)); + EXPECT_THAT(errors[0].tensor_name, StrEq("incorrect 1")); + EXPECT_THAT(errors[0].wrong_elements, Eq(46840334)); + EXPECT_THAT(errors[1].tensor_name, StrEq("incorrect 2")); + EXPECT_THAT(errors[1].wrong_elements, Eq(482800)); +} + +// MatchesReference operates on the types defined in testing.hpp, so just +// quickly define a bunch of dummy values for that. + +struct DummySignature +{ +}; + +constexpr DummySignature DUMMY_SIGNATURE = {}; + +namespace ck_tile::builder::test { +template <> +struct Args +{ + auto make_a_descriptor() const + { + return make_descriptor(Extent{5, 5, 5, 5}, PackedRightLayout{}); + } + + auto make_b_descriptor() const + { + return make_descriptor(Extent{100000}, PackedLeftLayout{}); + } +}; + +template <> +struct Outputs +{ + void* a; + void* b; +}; + +template <> +ValidationReport validate(const Args& args, + Outputs actual, + Outputs expected) +{ + ValidationReport report; + report.check("a", args.make_a_descriptor(), actual.a, expected.a); + report.check("b", args.make_b_descriptor(), actual.b, expected.b); + return report; +} + +} // namespace ck_tile::builder::test + +TEST(MatchesReference, Correct) +{ + const ckt::Args args; + + const auto a_desc = args.make_a_descriptor(); + const auto b_desc = args.make_b_descriptor(); + + auto a_actual = ckt::alloc_tensor_buffer(a_desc); + auto b_actual = ckt::alloc_tensor_buffer(b_desc); + ckt::clear_tensor_buffer(a_desc, a_actual.get(), 1); + ckt::clear_tensor_buffer(b_desc, b_actual.get(), 2); + const auto actual = ckt::Outputs{ + .a = a_actual.get(), + .b = b_actual.get(), + }; + + auto a_expected = ckt::alloc_tensor_buffer(a_desc); + auto b_expected = ckt::alloc_tensor_buffer(b_desc); + ckt::clear_tensor_buffer(a_desc, a_expected.get(), 1); + ckt::clear_tensor_buffer(b_desc, b_expected.get(), 2); + const auto expected = ckt::Outputs{ + .a = a_expected.get(), + .b = b_expected.get(), + }; + + EXPECT_THAT(actual, MatchesReference(args, expected)); +} + +TEST(MatchesReference, Incorrect) +{ + const ckt::Args args; + + const auto a_desc = args.make_a_descriptor(); + const auto b_desc = args.make_b_descriptor(); + + auto a_actual = ckt::alloc_tensor_buffer(a_desc); + auto b_actual = ckt::alloc_tensor_buffer(b_desc); + ckt::clear_tensor_buffer(a_desc, a_actual.get(), 1); + ckt::clear_tensor_buffer(b_desc, b_actual.get(), 2); + const auto actual = ckt::Outputs{ + .a = a_actual.get(), + .b = b_actual.get(), + }; + + auto a_expected = ckt::alloc_tensor_buffer(a_desc); + auto b_expected = ckt::alloc_tensor_buffer(b_desc); + ckt::clear_tensor_buffer(a_desc, a_expected.get(), 2); + ckt::clear_tensor_buffer(b_desc, b_expected.get(), 2); + const auto expected = ckt::Outputs{ + .a = a_expected.get(), + .b = b_expected.get(), + }; + + testing::StringMatchResultListener listener; + EXPECT_TRUE(!ExplainMatchResult(MatchesReference(args, expected), actual, &listener)); + + EXPECT_THAT(listener.str(), StringEqWithDiff("1 tensors failed to validate")); +} diff --git a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp index ff47817227..bad29a65c0 100644 --- a/experimental/builder/test/utils/conv_algorithm_type_utils.hpp +++ b/experimental/builder/test/utils/conv_algorithm_type_utils.hpp @@ -399,7 +399,7 @@ inline std::string to_string(t)) << "," << to_string(static_cast(t)) - << "," << to_string(static_cast>(t)); + << "," << to_string(static_cast>(t)); return oss.str(); } diff --git a/include/ck/library/utility/device_tensor_generator.hpp b/include/ck/library/utility/device_tensor_generator.hpp index 4da38bf399..60bc3110d4 100644 --- a/include/ck/library/utility/device_tensor_generator.hpp +++ b/include/ck/library/utility/device_tensor_generator.hpp @@ -7,7 +7,6 @@ #include "ck/utility/common_header.hpp" #include "ck/library/utility/device_tensor_generator.hpp" #include "ck/utility/data_type.hpp" -#include // use xorshift for now since it is simple. Should be suitable enough, but feel free to switch in // the future @@ -107,6 +106,7 @@ template __global__ void fill_tensor_norm_rand_fp_values(T* p, float sigma, float mean, uint64_t buffer_element_size) { + static constexpr float PI = 3.141592653f; // initial values ran_state_u32 s = ran_init(); float norm[2]; @@ -115,12 +115,11 @@ fill_tensor_norm_rand_fp_values(T* p, float sigma, float mean, uint64_t buffer_e { if(j % (2 / ck::packed_size_v) == 0) { - float u1 = ran_gen_round_u32(s) * (1.0f / 4294967296.0f); - float u2 = ran_gen_round_u32(s) * (1.0f / 4294967296.0f); - norm[0] = - sigma * std::sqrt(-2.0f * ck::math::log(u1)) * std::cos(2.0f * M_PI * u2) + mean; - norm[1] = - sigma * std::sqrt(-2.0f * ck::math::log(u1)) * std::sin(2.0f * M_PI * u2) + mean; + float u1 = ran_gen_round_u32(s) * (1.0f / 4294967296.0f); + float u2 = ran_gen_round_u32(s) * (1.0f / 4294967296.0f); + float scale = sigma * ck::math::sqrt(-2.0f * ck::math::log(u1)); + norm[0] = scale * ck::math::cos(2.0f * PI * u2) + mean; + norm[1] = scale * ck::math::sin(2.0f * PI * u2) + mean; } if constexpr(ck::is_same_v) diff --git a/include/ck/utility/amd_wmma.hpp b/include/ck/utility/amd_wmma.hpp index 35389bda37..057687985d 100644 --- a/include/ck/utility/amd_wmma.hpp +++ b/include/ck/utility/amd_wmma.hpp @@ -10,7 +10,8 @@ namespace ck { #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \ - defined(__gfx1103__) || defined(__gfx11_generic__) + defined(__gfx1103__) || defined(__gfx1150__) || defined(__gfx1151__) || \ + defined(__gfx1152__) || defined(__gfx1153__) || defined(__gfx11_generic__) #define __gfx11__ #endif diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index 562b246ac3..9f9770df1b 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -2376,12 +2376,23 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, return amd_buffer_load_impl( src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); #else - thread_buffer tmp = - amd_buffer_load_impl(src_wave_buffer_resource, src_thread_addr_offset, 0); if constexpr(oob_conditional_check) - return src_thread_element_valid ? tmp : thread_buffer{numeric::zero()}; + { + if(src_thread_element_valid) + { + return amd_buffer_load_impl( + src_wave_buffer_resource, src_thread_addr_offset, 0); + } + else + { + return thread_buffer{numeric::zero()}; + } + } else - return tmp; + { + return amd_buffer_load_impl( + src_wave_buffer_resource, src_thread_addr_offset, 0); + } #endif } diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index a162195390..97e962f5a3 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -87,6 +87,7 @@ enum struct amdgcn_target_id GFX1150 = 0x1150, GFX1151 = 0x1151, GFX1152 = 0x1152, + GFX1153 = 0x1153, GFX11_GENERIC = 0x11FF, GFX1200 = 0x1200, GFX1201 = 0x1201, @@ -282,6 +283,7 @@ constexpr auto get_compiler_target() MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1150, GFX1150); MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1151, GFX1151); MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1152, GFX1152); + MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1153, GFX1153); MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX11_GENERIC, GFX11_GENERIC); MAP_COMPILER_STATE_TO_GFX12_TARGET(CK_TILE_ARCH_GFX1200, GFX1200); MAP_COMPILER_STATE_TO_GFX12_TARGET(CK_TILE_ARCH_GFX1201, GFX1201); @@ -348,6 +350,7 @@ CK_TILE_HOST auto hip_device_prop_gcn_arch_name_to_amdgcn_target_id(char const* MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1150", GFX1150); MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1151", GFX1151); MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1152", GFX1152); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1153", GFX1153); MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx11_generic", GFX11_GENERIC); MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1200", GFX1200); MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_TARGET_ID("gfx1201", GFX1201); @@ -603,6 +606,7 @@ CK_TILE_HOST_DEVICE constexpr auto get_compiler_target() MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1150, GFX1150); MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1151, GFX1151); MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1152, GFX1152); + MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX1153, GFX1153); MAP_COMPILER_STATE_TO_GFX11_TARGET(CK_TILE_ARCH_GFX11_GENERIC, GFX11_GENERIC); MAP_COMPILER_STATE_TO_GFX12_TARGET(CK_TILE_ARCH_GFX1200, GFX1200); MAP_COMPILER_STATE_TO_GFX12_TARGET(CK_TILE_ARCH_GFX1201, GFX1201); @@ -683,6 +687,7 @@ CK_TILE_HOST auto hip_device_prop_gcn_arch_name_to_amdgcn_target(char const* tes MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1150", GFX1150); MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1151", GFX1151); MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1152", GFX1152); + MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx1153", GFX1153); MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX11_TARGET("gfx11_generic", GFX11_GENERIC); MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX12_TARGET("gfx1200", GFX1200); MAP_HIP_DEVICE_PROP_GCN_ARCH_NAME_STRING_TO_GFX12_TARGET("gfx1201", GFX1201); @@ -1119,8 +1124,14 @@ CK_TILE_DEVICE static constexpr auto get_device_arch() { // FIXME(0): on all devices except gfx11 it returns gfx12_t // FIXME(1): during the host compilation pass it returns gfx12_t -#if defined(__gfx11__) +#if defined(__gfx103__) + return gfx103_t{}; +#elif defined(__gfx11__) return gfx11_t{}; +#elif defined(__gfx950__) + return gfx950_t{}; +#elif defined(__gfx9__) + return gfx9_t{}; #else return gfx12_t{}; #endif @@ -1141,26 +1152,10 @@ CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx950_t) { return 64; } CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx_invalid_t) { return 0; } -CK_TILE_DEVICE static constexpr auto arch_tag_dispatch() -{ -#if defined(__gfx103__) - return gfx103_t{}; -#elif defined(__gfx11__) - return gfx11_t{}; -#elif defined(__gfx12__) - return gfx12_t{}; -#elif defined(__gfx950__) - return gfx950_t{}; -#elif defined(__gfx9__) - return gfx9_t{}; -#else - return gfx_invalid_t{}; -#endif -} } // namespace detail CK_TILE_DEVICE static constexpr auto get_n_lds_banks() { - return detail::get_n_lds_banks(detail::arch_tag_dispatch()); + return detail::get_n_lds_banks(get_device_arch()); } enum LLVMSchedGroupMask : int32_t diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index 7830749efb..fed9209bad 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -315,6 +315,7 @@ namespace ck_tile::core { * @var CK_TILE_ARCH_GFX1102 Indicates if the compiler target architecture is GFX1102. * @var CK_TILE_ARCH_GFX1151 Indicates if the compiler target architecture is GFX1151. * @var CK_TILE_ARCH_GFX1152 Indicates if the compiler target architecture is GFX1152. + * @var CK_TILE_ARCH_GFX1153 Indicates if the compiler target architecture is GFX1153. * @var CK_TILE_ARCH_GFX11_GENERIC Indicates if the compiler target architecture is GFX11 generic. * @var CK_TILE_ARCH_GFX1200 Indicates if the compiler target architecture is GFX1200. * @var CK_TILE_ARCH_GFX1201 Indicates if the compiler target architecture is GFX1201. @@ -468,6 +469,12 @@ struct amdgcn_compiler_target_state static constexpr bool CK_TILE_ARCH_GFX1152 = false; #endif // __gfx1152__ +#if defined(__gfx1153__) + static constexpr bool CK_TILE_ARCH_GFX1153 = true; +#else + static constexpr bool CK_TILE_ARCH_GFX1153 = false; +#endif // __gfx1153__ + #if defined(__gfx11_generic__) static constexpr bool CK_TILE_ARCH_GFX11_GENERIC = true; #else @@ -538,6 +545,7 @@ CK_TILE_HOST_DEVICE static constexpr uint32_t count_values_of(T search, Ts... se amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1150, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1151, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1152, \ + amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1153, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX11_GENERIC, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1200, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1201, \ diff --git a/include/ck_tile/core/tensor/transpose_tile.hpp b/include/ck_tile/core/tensor/transpose_tile.hpp index e5a0664ec9..50927c5ca4 100644 --- a/include/ck_tile/core/tensor/transpose_tile.hpp +++ b/include/ck_tile/core/tensor/transpose_tile.hpp @@ -34,46 +34,23 @@ CK_TILE_DEVICE void transpose_tile2d_impl_in_thread(OutTensor& out_tensor, constexpr auto y_in_desc = InTensor::get_tile_distribution().get_ys_to_d_descriptor(); constexpr auto y_out_desc = OutTensor::get_tile_distribution().get_ys_to_d_descriptor(); - // y_dim_out_to_in - // For swapped Hs tile case I need only get_rh_minor_to_y - // since rh_major are already swapped due to swapped Hs. - constexpr auto get_rh_minor_to_y = [](auto dstr_tensor) { - using DstrEncode = typename decltype(dstr_tensor.get_tile_distribution())::DstrEncode; - - map rh_minor_to_y_; - - static_for<0, DstrEncode::NDimY, 1>{}([&](auto i) { - constexpr index_t rh_minor = DstrEncode::ys_to_rhs_minor_[i]; - - rh_minor_to_y_(rh_minor) = i; - }); - - return rh_minor_to_y_; - }; - // In swapped Hs case -> tile // we have same rh_major, but reversed rh_minor! - constexpr auto rh_minor_to_y_in = get_rh_minor_to_y(InTensor{}); - constexpr auto rh_minor_to_y_out = get_rh_minor_to_y(OutTensor{}); + constexpr index_t NDimY = InTensor::get_tile_distribution().get_num_of_dimension_y(); - // Is this really needed?? Should we have simple reverse here?? constexpr auto y_dim_out_to_in = [&] { map y_dim_out_to_in_; - for(const auto& [rh_minor, y_out] : rh_minor_to_y_out) - { - y_dim_out_to_in_(y_out) = rh_minor_to_y_in[rh_minor]; - } + static_for<0, NDimY, 1>{}([&](auto i) { y_dim_out_to_in_(i) = NDimY - 1 - i; }); return y_dim_out_to_in_; }(); - constexpr index_t NDimY = InTensor::get_tile_distribution().get_num_of_dimension_y(); constexpr auto y_lengths = to_sequence(y_in_desc.get_lengths()); // input and output vector dim in the order of input Y dims constexpr index_t y_dim_vec_in = NDimY - 1; - constexpr index_t y_dim_vec_out = y_dim_out_to_in[NDimY - 1]; + constexpr index_t y_dim_vec_out = 0; // vector lengths constexpr index_t vec_length_in = y_lengths[y_dim_vec_in]; diff --git a/include/ck_tile/host/fill.hpp b/include/ck_tile/host/fill.hpp index 4bbf8cbf3f..bddc0ae2d2 100644 --- a/include/ck_tile/host/fill.hpp +++ b/include/ck_tile/host/fill.hpp @@ -55,9 +55,10 @@ struct FillUniformDistribution const auto total_bytes = total * sizeof(T_iter); // max 80 threads; at least 2MB per thread - const size_t available_cpu_cores = get_available_cpu_cores(); - const size_t num_thread = - min(80UL, available_cpu_cores, integer_divide_ceil(total_bytes, 0x200000UL)); + const size_t available_cpu_cores = get_available_cpu_cores(); + constexpr uint64_t MAX_THREAD_COUNT = 80; + const size_t num_thread = min( + MAX_THREAD_COUNT, available_cpu_cores, integer_divide_ceil(total_bytes, 0x200000UL)); constexpr size_t BLOCK_BYTES = 64; constexpr size_t BLOCK_SIZE = BLOCK_BYTES / sizeof(T_iter); const size_t num_blocks = integer_divide_ceil(total_bytes, BLOCK_BYTES); diff --git a/include/ck_tile/host/reference/reference_grouped_conv_bwd_data.hpp b/include/ck_tile/host/reference/reference_grouped_conv_bwd_data.hpp index e141d842dd..95ab1258d6 100644 --- a/include/ck_tile/host/reference/reference_grouped_conv_bwd_data.hpp +++ b/include/ck_tile/host/reference/reference_grouped_conv_bwd_data.hpp @@ -3,6 +3,7 @@ #pragma once +#include #include #include @@ -28,7 +29,7 @@ CK_TILE_HOST void reference_grouped_conv_bwd_data(HostTensor& input, output.get_num_of_dimension() == NDimSpatial + 3)) { - printf("%lu %lu %lu", + printf("%" PRIu64 " %" PRIu64 " %" PRIu64, input.get_num_of_dimension(), weight.get_num_of_dimension(), output.get_num_of_dimension()); diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 53bfa6041d..97f936fde9 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -30,7 +30,6 @@ template struct CShuffleEpilogueProblem { - using AsDataType = remove_cvref_t; - using BsDataType = remove_cvref_t; - using AccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using DsDataType = remove_cvref_t; - using DsLayout = remove_cvref_t; - using ELayout = remove_cvref_t; - using CDElementwise = remove_cvref_t; - static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size(); - static constexpr index_t kMPerBlock = kM_; - static constexpr index_t kNPerBlock = kN_; - static constexpr index_t MWave = MWave_; - static constexpr index_t NWave = NWave_; - static constexpr index_t MPerXdl = MPerXdl_; - static constexpr index_t NPerXdl = NPerXdl_; - static constexpr index_t KPerXdl = KPerXdl_; - static constexpr index_t isCTransposed = isCTransposed_; - static constexpr memory_operation_enum MemoryOperation = MemoryOperation_; - static constexpr bool FixedVectorSize = FixedVectorSize_; - static constexpr index_t VectorSizeC = VectorSizeC_; - static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_; - static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_; - static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_; - static constexpr index_t kNumWaveGroups = kNumWaveGroups_; - static constexpr index_t NumDTensor = DsDataType::size(); + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using DsDataType = remove_cvref_t; + using DsLayout = remove_cvref_t; + using ELayout = remove_cvref_t; + using CDElementwise = remove_cvref_t; + static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size(); + static constexpr index_t kMPerBlock = kM_; + static constexpr index_t kNPerBlock = kN_; + static constexpr index_t MWave = MWave_; + static constexpr index_t NWave = NWave_; + static constexpr index_t MPerXdl = MPerXdl_; + static constexpr index_t NPerXdl = NPerXdl_; + static constexpr index_t KPerXdl = KPerXdl_; + static constexpr index_t isCTransposed = isCTransposed_; + static constexpr bool FixedVectorSize = FixedVectorSize_; + static constexpr index_t VectorSizeC = VectorSizeC_; + static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_; + static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_; + static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_; + static constexpr index_t kNumWaveGroups = kNumWaveGroups_; + static constexpr index_t NumDTensor = DsDataType::size(); static_assert(NumDTensor == DsLayout::size(), "The size of DsDataType and DsLayout should be the same"); @@ -105,28 +103,27 @@ struct CShuffleEpilogue ADataType, BDataType>; - using ELayout = remove_cvref_t; - using CDElementwise = remove_cvref_t; - static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation; - static constexpr index_t kBlockSize = Problem::kBlockSize; - static constexpr index_t kMPerBlock = Problem::kMPerBlock; - static constexpr index_t kNPerBlock = Problem::kNPerBlock; - static constexpr index_t MWave = Problem::MWave; - static constexpr index_t NWave = Problem::NWave; - static constexpr index_t MPerXdl = Problem::MPerXdl; - static constexpr index_t NPerXdl = Problem::NPerXdl; - static constexpr index_t KPerXdl = Problem::KPerXdl; - static constexpr index_t isCTransposed = Problem::isCTransposed; - static constexpr bool FixedVectorSize = Problem::FixedVectorSize; - static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN; - static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp; - static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; - static constexpr index_t VectorSizeC = Problem::VectorSizeC; - static constexpr index_t MPerIteration = MPerXdl * MWave; - static constexpr index_t NPerIteration = NPerXdl * NWave; - static constexpr index_t NumDTensor = Problem::NumDTensor; - static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave); - static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave); + using ELayout = remove_cvref_t; + using CDElementwise = remove_cvref_t; + static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t kMPerBlock = Problem::kMPerBlock; + static constexpr index_t kNPerBlock = Problem::kNPerBlock; + static constexpr index_t MWave = Problem::MWave; + static constexpr index_t NWave = Problem::NWave; + static constexpr index_t MPerXdl = Problem::MPerXdl; + static constexpr index_t NPerXdl = Problem::NPerXdl; + static constexpr index_t KPerXdl = Problem::KPerXdl; + static constexpr index_t isCTransposed = Problem::isCTransposed; + static constexpr bool FixedVectorSize = Problem::FixedVectorSize; + static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN; + static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp; + static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; + static constexpr index_t VectorSizeC = Problem::VectorSizeC; + static constexpr index_t MPerIteration = MPerXdl * MWave; + static constexpr index_t NPerIteration = NPerXdl * NWave; + static constexpr index_t NumDTensor = Problem::NumDTensor; + static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave); + static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave); CDElementwise elfunc_; @@ -142,8 +139,7 @@ struct CShuffleEpilogue concat('x', MWave, NWave), concat('x', MPerXdl, NPerXdl, KPerXdl), VectorSizeC, - isCTransposed ? "CTransposed" : "CNotTransposed", - mem_op_string()); + isCTransposed ? "CTransposed" : "CNotTransposed"); // clang-format on } @@ -337,14 +333,30 @@ struct CShuffleEpilogue { constexpr int RakedXDLN_PerWarp = NumNXdlPerWavePerShuffle / BlockedXDLN_PerWarp; // BlockedLayout - return tile_distribution_encoding< - sequence<>, - tuple, - sequence>, - tuple>, - tuple>, - sequence<1, 2, 2>, - sequence<0, 0, 2>>{}; + // this branch is for original a16w4 + if constexpr(is_any_of::value || + is_any_of::value) + { + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 2, 2>, + sequence<0, 0, 2>>{}; + } + else + { + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 2, 2>, + sequence<0, 0, 1>>{}; + } } }(); constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding( @@ -355,7 +367,8 @@ struct CShuffleEpilogue CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { - return MPerIterationShuffle * NPerIterationShuffle * sizeof(ODataType); + constexpr auto lds_block_desc = MakeLdsBlockDescriptor(); + return lds_block_desc.get_element_space_size() * sizeof(ODataType); } template @@ -445,7 +458,8 @@ struct CShuffleEpilogue CK_TILE_DEVICE void store_to_dram(OutDramWindow& out_dram_window, const COutTensor& c_out_tensor) { - if constexpr(MemoryOperation == memory_operation_enum::set) + if constexpr(decltype(out_dram_window.get_bottom_tensor_view())::DstInMemOp == + memory_operation_enum::set) { store_tile(out_dram_window, c_out_tensor); } @@ -617,7 +631,8 @@ struct CShuffleEpilogue }); // store/update - if constexpr(MemoryOperation == memory_operation_enum::set) + if constexpr(decltype(out_dram_window.get_bottom_tensor_view())::DstInMemOp == + memory_operation_enum::set) { store_tile(out_dram_window, c_out_tensor); } diff --git a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp index cc2303582e..aafe7b9f58 100644 --- a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp @@ -15,17 +15,15 @@ template + bool UseRawStore_ = true> struct Default2DEpilogueProblem { - using AccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - static constexpr bool kPadM = kPadM_; - static constexpr bool kPadN = kPadN_; - static constexpr bool UseRawStore = UseRawStore_; - static constexpr memory_operation_enum MemoryOperation = MemoryOperation_; - static constexpr index_t NumDTensor = 0; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + static constexpr bool kPadM = kPadM_; + static constexpr bool kPadN = kPadN_; + static constexpr bool UseRawStore = UseRawStore_; + static constexpr index_t NumDTensor = 0; }; template -struct DefaultGemm2DEpilogueProblem : public Default2DEpilogueProblem + bool UseRawStore_ = true> +struct DefaultGemm2DEpilogueProblem + : public Default2DEpilogueProblem { using AsDataType = remove_cvref_t; using BsDataType = remove_cvref_t; @@ -81,7 +74,6 @@ struct Default2DEpilogue static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadN = Problem::kPadN; static constexpr bool UseRawStore = Problem::UseRawStore; - static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation; CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; } @@ -102,7 +94,10 @@ struct Default2DEpilogue // TODO: this is ugly if constexpr(UseRawStore && (kPadM || kPadN)) { - if constexpr(MemoryOperation == memory_operation_enum::set) + // FIXME? + // if constexpr(decltype(o_dram_window_tmp.get_bottom_tensor_view())::DstInMemOp == + // memory_operation_enum::set) + if constexpr(true) { if constexpr(is_partition_index) { @@ -123,7 +118,10 @@ struct Default2DEpilogue } else { - if constexpr(MemoryOperation == memory_operation_enum::set) + // FIXME? + // if constexpr(decltype(o_dram_window_tmp.get_bottom_tensor_view())::DstInMemOp == + // memory_operation_enum::set) + if constexpr(true) { if constexpr(is_partition_index) { diff --git a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp index 9a33801c8f..42dab68e91 100644 --- a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp @@ -558,21 +558,19 @@ struct FlatmmKernel return DTesnorIsValid; } - template - CK_TILE_DEVICE static auto - MakeGemmTensorViews(const ADataType* a_ptr, - const BDataType* b_flat_ptr, - const std::array& ds_ptr, - EDataType* e_ptr, - const KernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset) + template + CK_TILE_DEVICE static auto MakeABlockWindow(const ADataType* a_ptr, + const KernelArgs& kargs, + const index_t k_size, + const index_t block_idx_m) { + // Step 1: Create tensor view const auto& a_tensor_view = [&]() { if constexpr(std::is_same_v) { return make_naive_tensor_view( a_ptr, - make_tuple(kargs.M, splitk_batch_offset.splitted_k), + make_tuple(kargs.M, k_size), make_tuple(kargs.stride_A, 1), number{}, number<1>{}); @@ -581,25 +579,81 @@ struct FlatmmKernel { return make_naive_tensor_view( a_ptr, - make_tuple(splitk_batch_offset.splitted_k, kargs.M), + make_tuple(k_size, kargs.M), make_tuple(kargs.stride_A, 1), number{}, number<1>{}); } }(); - index_t kFlatK = - FlatmmPipeline::flatKPerWarp * (kargs.K / BlockGemmShape::WarpTile::at(I2)); - index_t kFlatN = kargs.N * kargs.K / kFlatK; - const auto& b_flat_tensor_view = [&]() { - return make_naive_tensor_view( - b_flat_ptr, - make_tuple(kFlatN, kFlatK), - make_tuple(kFlatK, 1), - number{}, - number<1>{}); + // Step 2: Create padded view + const auto& a_pad_view = [&]() { + if constexpr(std::is_same_v) + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } }(); + // Step 3: Create tile window + if constexpr(std::is_same_v) + { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {block_idx_m, 0}); + } + else + { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {0, block_idx_m}); + } + } + + template + CK_TILE_DEVICE static auto MakeBFlatBlockWindow(const BDataType* b_flat_ptr, + const KernelArgs& kargs, + const index_t block_idx_n) + { + // Step 1: Create tensor view + index_t kFlatK = + FlatmmPipeline::flatKPerWarp * (kargs.K / BlockGemmShape::WarpTile::at(I2)); + index_t kFlatN = kargs.N * kargs.K / kFlatK; + + const auto& b_flat_tensor_view = make_naive_tensor_view( + b_flat_ptr, + make_tuple(kFlatN, kFlatK), + make_tuple(kFlatK, 1), + number{}, + number<1>{}); + + // Step 2: No padding needed for b_flat + // Step 3: Create tile window + return make_tile_window( + b_flat_tensor_view, + make_tuple(number{}, + number{}), + {static_cast(block_idx_n / BlockGemmShape::WarpTile::at(I1)), 0}); + } + + template + CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array& ds_ptr, + const KernelArgs& kargs, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Step 1: Create tensor views const auto& ds_tensor_view = generate_tuple( [&](auto i) { using DiLayout = remove_cvref_t>; @@ -625,7 +679,56 @@ struct FlatmmKernel }, number{}); - // TODO: enable vector write for C in ColMajor + // Step 2: Create padded views + const auto& ds_pad_view = generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return pad_tensor_view(ds_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(ds_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + }, + number{}); + + // Step 3: Create tile windows + return generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {block_idx_m, block_idx_n}); + } + else + { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {block_idx_n, block_idx_m}); + } + }, + number{}); + } + + template + CK_TILE_DEVICE static auto MakeEBlockWindow(EDataType* e_ptr, + const KernelArgs& kargs, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Step 1: Create tensor view const auto& e_tensor_view = [&]() { if constexpr(std::is_same_v) { @@ -647,98 +750,8 @@ struct FlatmmKernel } }(); - constexpr int ScaleGranularityM = decltype(kargs.scale_m_ptr)::GranularityMN; - constexpr int ScaleGranularityN = decltype(kargs.scale_n_ptr)::GranularityMN; - - constexpr int ScaleGranularityKA = decltype(kargs.scale_m_ptr)::GranularityK; - constexpr int ScaleGranularityKB = decltype(kargs.scale_n_ptr)::GranularityK; - - auto scale_stride_m = ScaleGranularityM == 0 ? 0 // per-tensor scale - : 1; // per-token scale - auto scale_stride_n = ScaleGranularityN == 0 ? 0 // per-tensor scale - : 1; // per-channel scale - - static_assert(ScaleGranularityM == 0 || ScaleGranularityM == 1 || ScaleGranularityM == -1, - "only support per-tensor or per-row scaling"); - static_assert(ScaleGranularityN == 0 || ScaleGranularityN == 1 || ScaleGranularityN == -1, - "only support per-tensor or per-column scaling"); - - const auto scale_m_view = make_naive_tensor_view( - kargs.scale_m_ptr.ptr, - make_tuple(kargs.M / ScaleGranularityM, - ScaleGranularityKA == 0 - ? 1 - : splitk_batch_offset.splitted_k / - (ScaleGranularityKA != 0 ? ScaleGranularityKA : 1)), - make_tuple(scale_stride_m, 0), - number < ScaleGranularityM == 1 ? FlatmmPipeline::GetVectorSizeA() : 1 > {}, - number<1>{}); - const auto scale_n_view = make_naive_tensor_view( - kargs.scale_n_ptr.ptr, - make_tuple(ScaleGranularityKB == 0 - ? 1 - : (splitk_batch_offset.splitted_k / - (ScaleGranularityKB != 0 ? ScaleGranularityKB : 1)), - kargs.N / ScaleGranularityN), - make_tuple(0, scale_stride_n), - number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {}, - number<1>{}); - - return make_tuple(a_tensor_view, - b_flat_tensor_view, - ds_tensor_view, - e_tensor_view, - scale_m_view, - scale_n_view); - } - - template - CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) - { - const auto& a_pad_view = [&]() { - const auto& a_tensor_view = views.at(I0); - if constexpr(std::is_same_v) - { - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - }(); - - const auto& b_flat_tensor_view = views.at(I1); - - const auto& ds_pad_view = generate_tuple( - [&](auto i) { - const auto& d_tensor_view = views.at(I2); - using DiLayout = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return pad_tensor_view(d_tensor_view[i], - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(d_tensor_view[i], - make_tuple(number{}, - number{}), - sequence{}); - } - }, - number{}); - - // TODO vector write in for C in ColMajor + // Step 2: Create padded view const auto& e_pad_view = [&]() { - const auto& e_tensor_view = views.at(I3); if constexpr(std::is_same_v) { return pad_tensor_view(e_tensor_view, @@ -755,93 +768,72 @@ struct FlatmmKernel } }(); - return make_tuple(a_pad_view, - b_flat_tensor_view, - ds_pad_view, - e_pad_view, - views.at(number<4>{}), - views.at(number<5>{})); - } - - template - CK_TILE_DEVICE static auto - MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) - { - const auto& a_pad_view = views.at(I0); - const auto& b_flat_pad_view = views.at(I1); - const auto& ds_pad_view = views.at(I2); - const auto& e_pad_view = views.at(I3); - - const auto& a_block_window = [&]() { - if constexpr(std::is_same_v) - { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {i_m, 0}); - } - else - { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {0, i_m}); - } - }(); - - const auto& b_flat_block_window = - make_tile_window(b_flat_pad_view, - make_tuple(number{}, - number{}), - {static_cast(i_n / BlockGemmShape::WarpTile::at(I1)), 0}); - - const auto ds_block_window = generate_tuple( - [&](auto i) { - using DiLayout = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return make_tile_window(ds_pad_view[i], - make_tuple(number{}, - number{}), - {i_m, i_n}); - } - else - { - return make_tile_window(ds_pad_view[i], - make_tuple(number{}, - number{}), - {i_n, i_m}); - } - }, - number{}); - - auto e_block_window = make_tile_window( + // Step 3: Create tile window + return make_tile_window( e_pad_view, make_tuple(number{}, number{}), - {i_m, i_n}); + {block_idx_m, block_idx_n}); + } - constexpr int ScaleGranularityKA = 0; // decltype(kargs.scale_m_ptr)::GranularityK; - constexpr int ScaleGranularityKB = 0; // decltype(kargs.scale_n_ptr)::GranularityK; + template + CK_TILE_DEVICE static auto MakeScaleMWindow(const KernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_m) + { + constexpr int ScaleGranularityM = decltype(kargs.scale_m_ptr)::GranularityMN; + constexpr int ScaleGranularityKA = decltype(kargs.scale_m_ptr)::GranularityK; - auto scale_m_window = make_tile_window(views.at(number<4>{}), - make_tuple(number{}, - number < ScaleGranularityKA == 0 - ? TilePartitioner::NPerBlock - : TilePartitioner::KPerBlock > {}), - {i_m, 0}); - auto scale_n_window = make_tile_window(views.at(number<5>{}), - make_tuple(number < ScaleGranularityKB == 0 - ? TilePartitioner::MPerBlock - : TilePartitioner::KPerBlock > {}, - number{}), - {0, i_n}); + auto scale_stride_m = ScaleGranularityM == 0 ? 0 // per-tensor scale + : 1; // per-token scale - return make_tuple(a_block_window, - b_flat_block_window, - ds_block_window, - e_block_window, - scale_m_window, - scale_n_window); + // Step 1: Create tensor view + const auto scale_m_view = make_naive_tensor_view( + kargs.scale_m_ptr.ptr, + make_tuple(kargs.M / ScaleGranularityM, + ScaleGranularityKA == 0 + ? 1 + : (splitk_batch_offset.splitted_k / ScaleGranularityKA)), + make_tuple(scale_stride_m, 0), + number < ScaleGranularityM == 1 ? FlatmmPipeline::GetVectorSizeA() : 1 > {}, + number<1>{}); + + // Step 2: Create tile window + return make_tile_window(scale_m_view, + make_tuple(number{}, + number < ScaleGranularityKA == 0 + ? TilePartitioner::NPerBlock + : TilePartitioner::KPerBlock > {}), + {block_idx_m, 0}); + } + + template + CK_TILE_DEVICE static auto MakeScaleNWindow(const KernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_n) + { + constexpr int ScaleGranularityN = decltype(kargs.scale_n_ptr)::GranularityMN; + constexpr int ScaleGranularityKB = decltype(kargs.scale_n_ptr)::GranularityK; + + auto scale_stride_n = ScaleGranularityN == 0 ? 0 // per-tensor scale + : 1; // per-channel scale + + // Step 1: Create tensor view + const auto scale_n_view = make_naive_tensor_view( + kargs.scale_n_ptr.ptr, + make_tuple( + ScaleGranularityKB == 0 ? 1 : (splitk_batch_offset.splitted_k / ScaleGranularityKB), + kargs.N / ScaleGranularityN), + make_tuple(0, scale_stride_n), + number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {}, + number<1>{}); + + // Step 2: Create tile window + return make_tile_window(scale_n_view, + make_tuple(number < ScaleGranularityKB == 0 + ? TilePartitioner::MPerBlock + : TilePartitioner::KPerBlock > {}, + number{}), + {0, block_idx_n}); } template @@ -857,45 +849,74 @@ struct FlatmmKernel const index_t block_idx_m, const index_t block_idx_n) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews( - a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset); - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + // Create block windows using specialized methods + const auto& a_block_window = + MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); + const auto& b_flat_block_window = MakeBFlatBlockWindow(b_flat_ptr, kargs, block_idx_n); + const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n); + const auto& scale_m_window = MakeScaleMWindow(kargs, splitk_batch_offset, block_idx_m); + const auto& scale_n_window = MakeScaleNWindow(kargs, splitk_batch_offset, block_idx_n); const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k); // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_flat_block_window = gemm_tile_windows.at(I1); - const auto& d_block_window = gemm_tile_windows.at(I2); - const auto& c_block_tile = FlatmmPipeline{}.template operator()( + const auto& c_block_tile = FlatmmPipeline{}.template operator()( a_block_window, b_flat_block_window, num_loop, smem_ptr_ping, smem_ptr_pong); - auto scale_m_window = gemm_tile_windows.at(number<4>{}); - auto scale_n_window = gemm_tile_windows.at(number<5>{}); - - // Run Epilogue Pipeline + // Run Epilogue Pipeline with k_batch dispatching if constexpr(ScaleM::GranularityMN != -1 || ScaleN::GranularityMN != -1) { - auto& c_block_window = gemm_tile_windows.at(I3); - EpiloguePipeline{}.template - operator()( - c_block_window, - c_block_tile, - d_block_window, - smem_ptr_ping, - scale_m_window, - scale_n_window); + if(kargs.k_batch == 1) + { + auto e_block_window = MakeEBlockWindow( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{} + .template operator()(e_block_window, + c_block_tile, + ds_block_window, + smem_ptr_ping, + scale_m_window, + scale_n_window); + } + else + { + auto e_block_window = MakeEBlockWindow( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{} + .template operator()(e_block_window, + c_block_tile, + ds_block_window, + smem_ptr_ping, + scale_m_window, + scale_n_window); + } } else if(UseDefaultScheduler || (get_warp_id() == 0)) { - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I3); - EpiloguePipeline{}.template - operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_ping); + if(kargs.k_batch == 1) + { + auto e_block_window = MakeEBlockWindow( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{} + .template operator()( + e_block_window, c_block_tile, ds_block_window, smem_ptr_ping); + } + else + { + auto e_block_window = MakeEBlockWindow( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{} + .template operator()( + e_block_window, c_block_tile, ds_block_window, smem_ptr_ping); + } } } @@ -924,8 +945,7 @@ struct FlatmmKernel __shared__ char smem_ptr_ping[GetSmemPingSize()]; __shared__ char smem_ptr_pong[GetSmemPongSize()]; - if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && - EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && is_any_of::value)) { constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1); diff --git a/include/ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp index 05d50666a5..61001522b0 100644 --- a/include/ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp @@ -100,21 +100,19 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel - CK_TILE_DEVICE static auto - MakeGemmTensorViews(const ADataType* a_ptr, - const BDataType* b_flat_ptr, - const std::array& ds_ptr, - EDataType* e_ptr, - const KernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset) + template + CK_TILE_DEVICE static auto MakeABlockWindow(const ADataType* a_ptr, + const KernelArgs& kargs, + const index_t k_size, + const index_t block_idx_m) { + // Step 1: Create tensor view const auto& a_tensor_view = [&]() { if constexpr(std::is_same_v) { return make_naive_tensor_view( a_ptr, - make_tuple(kargs.M, splitk_batch_offset.splitted_k), + make_tuple(kargs.M, k_size), make_tuple(kargs.stride_A, 1), number{}, number<1>{}); @@ -123,25 +121,80 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel( a_ptr, - make_tuple(splitk_batch_offset.splitted_k, kargs.M), + make_tuple(k_size, kargs.M), make_tuple(kargs.stride_A, 1), number{}, number<1>{}); } }(); + // Step 2: Create padded view + const auto& a_pad_view = [&]() { + if constexpr(std::is_same_v) + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + // Step 3: Create tile window + if constexpr(std::is_same_v) + { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {block_idx_m, 0}); + } + else + { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {0, block_idx_m}); + } + } + + template + CK_TILE_DEVICE static auto MakeBFlatBlockWindow(const BDataType* b_flat_ptr, + const KernelArgs& kargs, + const index_t block_idx_n) + { + // Step 1: Create tensor view index_t kFlatK = kargs.K * BlockGemmShape::WarpTile::at(I1); index_t kFlatN = kargs.N * kargs.K / kFlatK; - const auto& b_flat_tensor_view = [&]() { - return make_naive_tensor_view( - b_flat_ptr, - make_tuple(kFlatN, kFlatK), - make_tuple(kFlatK, 1), - number{}, - number<1>{}); - }(); + const auto& b_flat_tensor_view = make_naive_tensor_view( + b_flat_ptr, + make_tuple(kFlatN, kFlatK), + make_tuple(kFlatK, 1), + number{}, + number<1>{}); + // Step 2: No padding needed for b_flat + // Step 3: Create tile window + return make_tile_window( + b_flat_tensor_view, + make_tuple(number{}, + number{}), + {static_cast(block_idx_n / BlockGemmShape::WarpTile::at(I1)), 0}); + } + + template + CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array& ds_ptr, + const KernelArgs& kargs, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Step 1: Create tensor views const auto& ds_tensor_view = generate_tuple( [&](auto i) { using DiLayout = remove_cvref_t>; @@ -167,7 +220,56 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel{}); - // TODO: enable vector write for C in ColMajor + // Step 2: Create padded views + const auto& ds_pad_view = generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return pad_tensor_view(ds_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(ds_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + }, + number{}); + + // Step 3: Create tile windows + return generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {block_idx_m, block_idx_n}); + } + else + { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {block_idx_n, block_idx_m}); + } + }, + number{}); + } + + template + CK_TILE_DEVICE static auto MakeEBlockWindow(EDataType* e_ptr, + const KernelArgs& kargs, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Step 1: Create tensor view const auto& e_tensor_view = [&]() { if constexpr(std::is_same_v) { @@ -189,70 +291,8 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel( - reinterpret_cast(scale_n.ptr), - make_tuple(FlatScaleN, FlatScaleK), - make_tuple(FlatScaleK, 1), - number<8>{}, - number<1>{}); - - return make_tuple( - a_tensor_view, b_flat_tensor_view, ds_tensor_view, e_tensor_view, scale_b_flat_view); - } - - template - CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) - { - const auto& a_pad_view = [&]() { - const auto& a_tensor_view = views.at(I0); - if constexpr(std::is_same_v) - { - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - }(); - - const auto& b_flat_tensor_view = views.at(I1); - - const auto& ds_pad_view = generate_tuple( - [&](auto i) { - const auto& d_tensor_view = views.at(I2); - using DiLayout = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return pad_tensor_view(d_tensor_view[i], - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(d_tensor_view[i], - make_tuple(number{}, - number{}), - sequence{}); - } - }, - number{}); - - // TODO vector write in for C in ColMajor + // Step 2: Create padded view const auto& e_pad_view = [&]() { - const auto& e_tensor_view = views.at(I3); if constexpr(std::is_same_v) { return pad_tensor_view(e_tensor_view, @@ -269,77 +309,37 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel - CK_TILE_DEVICE static auto - MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) - { - const auto& a_pad_view = views.at(I0); - const auto& b_flat_pad_view = views.at(I1); - const auto& ds_pad_view = views.at(I2); - const auto& e_pad_view = views.at(I3); - - const auto& a_block_window = [&]() { - if constexpr(std::is_same_v) - { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {i_m, 0}); - } - else - { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {0, i_m}); - } - }(); - - const auto& b_flat_block_window = - make_tile_window(b_flat_pad_view, - make_tuple(number{}, - number{}), - {static_cast(i_n / BlockGemmShape::WarpTile::at(I1)), 0}); - - const auto ds_block_window = generate_tuple( - [&](auto i) { - using DiLayout = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return make_tile_window(ds_pad_view[i], - make_tuple(number{}, - number{}), - {i_m, i_n}); - } - else - { - return make_tile_window(ds_pad_view[i], - make_tuple(number{}, - number{}), - {i_n, i_m}); - } - }, - number{}); - - auto e_block_window = make_tile_window( + // Step 3: Create tile window + return make_tile_window( e_pad_view, make_tuple(number{}, number{}), - {i_m, i_n}); + {block_idx_m, block_idx_n}); + } - auto scale_block_window = - make_tile_window(views.at(I4), - make_tuple(number{}, - number{}), - {i_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0}); + template + CK_TILE_DEVICE static auto MakeScaleBBlockWindow(const KernelArgs& kargs, + const index_t block_idx_n) + { + auto scale_n = kargs.scale_n_ptr; - return make_tuple(a_block_window, - b_flat_block_window, - ds_block_window, - e_block_window, - scale_block_window); + // Step 1: Create tensor view + index_t FlatScaleK = + (kargs.K / decltype(scale_n)::GranularityK) * N_Pack * BlockGemmShape::WarpTile::at(I1); + index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1); + + const auto scale_b_flat_view = make_naive_tensor_view( + reinterpret_cast(scale_n.ptr), + make_tuple(FlatScaleN, FlatScaleK), + make_tuple(FlatScaleK, 1), + number<8>{}, + number<1>{}); + + // Step 2: Create tile window + return make_tile_window( + scale_b_flat_view, + make_tuple(number{}, + number{}), + {block_idx_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0}); } template @@ -355,21 +355,15 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel( - a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset); - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + // Create block windows using specialized methods + const auto& a_block_window = + MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); + const auto& b_flat_block_window = MakeBFlatBlockWindow(b_flat_ptr, kargs, block_idx_n); + const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n); + const auto& scale_block_window = MakeScaleBBlockWindow(kargs, block_idx_n); const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k); - // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_flat_block_window = gemm_tile_windows.at(I1); - const auto& d_block_window = gemm_tile_windows.at(I2); - const auto& scale_block_window = gemm_tile_windows.at(I4); - static_assert(ScaleM::GranularityK == ScaleN::GranularityK // have the same granK || ScaleM::GranularityMN == -1 // or ScaleA is disable || ScaleN::GranularityMN == -1, // or ScaleB is disable @@ -378,6 +372,7 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(e_block_window, + c_block_tile, + ds_block_window, + smem_ptr_ping, + kargs.scale_m_ptr + block_idx_m, + kargs.scale_n_ptr + block_idx_n); + } + else + { + auto e_block_window = MakeEBlockWindow( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(e_block_window, + c_block_tile, + ds_block_window, + smem_ptr_ping, + kargs.scale_m_ptr + block_idx_m, + kargs.scale_n_ptr + block_idx_n); + } } else if(UseDefaultScheduler || (get_warp_id() == 0)) { - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I3); - EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping); + if(kargs.k_batch == 1) + { + auto e_block_window = MakeEBlockWindow( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_window, smem_ptr_ping); + } + else + { + auto e_block_window = MakeEBlockWindow( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_window, smem_ptr_ping); + } } } @@ -434,8 +453,7 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel::value)) { constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1); diff --git a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp index b47ec4a829..604089b7c4 100644 --- a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp @@ -1476,7 +1476,8 @@ struct MoeFlatmmKernel c_scatter_valids[mIter]); if constexpr(!IsInputGemm || - EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add) + decltype(c_block_window.get_bottom_tensor_view())::DstInMemOp == + memory_operation_enum::atomic_add) c_scatter_tile_window.update(c_out_tensor); else c_scatter_tile_window.store(c_out_tensor); diff --git a/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp index 799f8f26a9..a58d71c790 100644 --- a/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp @@ -113,32 +113,50 @@ struct MXFlatmmKernel : FlatmmKernel - CK_TILE_DEVICE static auto - MakeGemmTensorViews(const ADataType* a_ptr, - const BDataType* b_flat_ptr, - const std::array& ds_ptr, - EDataType* e_ptr, - const KernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset) + template + CK_TILE_DEVICE static auto MakeABlockWindow(const ADataType* a_ptr, + const KernelArgs& kargs, + const index_t k_size, + const index_t block_idx_m) { + // Step 1: Create tensor view const auto& a_tensor_view = [&]() { static_assert(std::is_same_v, "A tensor for mx must be RowMajor"); return make_naive_tensor_view( a_ptr, - make_tuple(kargs.M, splitk_batch_offset.splitted_k), + make_tuple(kargs.M, k_size), make_tuple(kargs.stride_A, 1), number{}, number<1>{}); }(); + // Step 2: Create padded view + const auto& a_pad_view = pad_tensor_view( + a_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + + // Step 3: Create tile window + return make_tile_window( + a_pad_view, + make_tuple(number{}, number{}), + {block_idx_m, 0}); + } + + template + CK_TILE_DEVICE static auto MakeBFlatBlockWindow(const BDataType* b_flat_ptr, + const KernelArgs& kargs, + const index_t block_idx_n) + { + // Step 1: Create tensor view with special flat layout constexpr index_t kKPerBlock = MXFlatmmPipeline::kKPerBlock; constexpr index_t kNWarpTile = BlockGemmShape::WarpTile::at(I1); constexpr index_t flatKPerBlock = kKPerBlock * kNWarpTile; const index_t kFlatKBlocks = kargs.K / kKPerBlock; const index_t kFlatN = kargs.N / kNWarpTile; - const auto& b_flat_tensor_view = [&]() { + + const auto& b_flat_tensor_view = [&]() { static_assert(flatKPerBlock % MXFlatmmPipeline::GetVectorSizeB() == 0, "wrong! vector size for B tensor"); auto&& naive_desc = make_naive_tensor_descriptor_packed( @@ -153,6 +171,22 @@ struct MXFlatmmKernel : FlatmmKernel(b_flat_ptr, desc); }(); + // Step 2: No padding for flat B + // Step 3: Create tile window + return make_tile_window( + b_flat_tensor_view, + make_tuple(number{}, + number{}), + {static_cast(block_idx_n / BlockGemmShape::WarpTile::at(I1)), 0}); + } + + template + CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array& ds_ptr, + const KernelArgs& kargs, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Step 1: Create tensor views const auto& ds_tensor_view = generate_tuple( [&](auto i) { using DiLayout = remove_cvref_t>; @@ -178,7 +212,56 @@ struct MXFlatmmKernel : FlatmmKernel{}); - // TODO: enable vector write for C in ColMajor + // Step 2: Create padded views + const auto& ds_pad_view = generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return pad_tensor_view(ds_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(ds_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + }, + number{}); + + // Step 3: Create tile windows + return generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {block_idx_m, block_idx_n}); + } + else + { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {block_idx_n, block_idx_m}); + } + }, + number{}); + } + + template + CK_TILE_DEVICE static auto MakeEBlockWindow(EDataType* e_ptr, + const KernelArgs& kargs, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Step 1: Create tensor view const auto& e_tensor_view = [&]() { if constexpr(std::is_same_v) { @@ -200,92 +283,8 @@ struct MXFlatmmKernel : FlatmmKernel{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return make_tensor_view( - reinterpret_cast(scale_a.ptr), scale_a_desc); - }(); - - // B scale tensor view - const auto& scale_b_tensor_view = [&]() { - const auto scale_b_navie_desc = make_naive_tensor_descriptor_packed( - make_tuple(scale_packs_n, scale_packs_k, KThreadPerXdl, NThreadPerXdl)); - const auto scale_b_desc = transform_tensor_descriptor( - scale_b_navie_desc, - make_tuple(make_merge_transform(make_tuple(scale_packs_n, NThreadPerXdl)), - make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))), - make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return make_tensor_view( - reinterpret_cast(scale_b.ptr), scale_b_desc); - }(); - - return make_tuple(a_tensor_view, - b_flat_tensor_view, - ds_tensor_view, - e_tensor_view, - scale_a_tensor_view, - scale_b_tensor_view); - } - - template - CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) - { - const auto& a_pad_view = [&]() { - const auto& a_tensor_view = views.at(I0); - static_assert(std::is_same_v, - "A tensor for mx must be RowMajor"); - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - }(); - - const auto& b_flat_tensor_view = views.at(I1); - - const auto& ds_pad_view = generate_tuple( - [&](auto i) { - const auto& d_tensor_view = views.at(I2); - using DiLayout = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return pad_tensor_view(d_tensor_view[i], - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(d_tensor_view[i], - make_tuple(number{}, - number{}), - sequence{}); - } - }, - number{}); - - // TODO vector write in for C in ColMajor + // Step 2: Create padded view const auto& e_pad_view = [&]() { - const auto& e_tensor_view = views.at(I3); if constexpr(std::is_same_v) { return pad_tensor_view(e_tensor_view, @@ -302,79 +301,71 @@ struct MXFlatmmKernel : FlatmmKernel - CK_TILE_DEVICE static auto - MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) - { - const auto& a_pad_view = views.at(I0); - const auto& b_flat_pad_view = views.at(I1); - const auto& ds_pad_view = views.at(I2); - const auto& e_pad_view = views.at(I3); - - const auto& a_block_window = [&]() { - static_assert(std::is_same_v, - "A tensor for mx must be RowMajor"); - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {i_m, 0}); - }(); - - const auto& b_flat_block_window = - make_tile_window(b_flat_pad_view, - make_tuple(number{}, - number{}), - {static_cast(i_n / BlockGemmShape::WarpTile::at(I1)), 0}); - - const auto ds_block_window = generate_tuple( - [&](auto i) { - using DiLayout = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return make_tile_window(ds_pad_view[i], - make_tuple(number{}, - number{}), - {i_m, i_n}); - } - else - { - return make_tile_window(ds_pad_view[i], - make_tuple(number{}, - number{}), - {i_n, i_m}); - } - }, - number{}); - - auto e_block_window = make_tile_window( + // Step 3: Create tile window + return make_tile_window( e_pad_view, make_tuple(number{}, number{}), - {i_m, i_n}); + {block_idx_m, block_idx_n}); + } + template + CK_TILE_DEVICE static auto MakeScaleABlockWindow(const KernelArgs& kargs, + const index_t block_idx_m) + { static constexpr int BlockScaleSize = 32; - auto scale_a_block_window = make_tile_window( - views.at(I4), + const auto&& scale_packs_m = integer_divide_ceil(kargs.M, (MXdlPack * MThreadPerXdl)); + const auto&& scale_packs_k = kargs.K / BlockScaleSize / (KXdlPack * KThreadPerXdl); + + // Step 1: Create tensor view + const auto scale_a_naive_desc = make_naive_tensor_descriptor_packed( + make_tuple(scale_packs_m, scale_packs_k, KThreadPerXdl, MThreadPerXdl)); + const auto scale_a_desc = transform_tensor_descriptor( + scale_a_naive_desc, + make_tuple(make_merge_transform(make_tuple(scale_packs_m, MThreadPerXdl)), + make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))), + make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + const auto& scale_a_tensor_view = make_tensor_view( + reinterpret_cast(kargs.scale_m_ptr.ptr), scale_a_desc); + + // Step 2: Create tile window + return make_tile_window( + scale_a_tensor_view, make_tuple(number{}, number{}), - {i_m / MXdlPack, 0}); + {block_idx_m / MXdlPack, 0}); + } - auto scale_b_block_window = make_tile_window( - views.at(I5), + template + CK_TILE_DEVICE static auto MakeScaleBBlockWindow(const KernelArgs& kargs, + const index_t block_idx_n) + { + static constexpr int BlockScaleSize = 32; + + const auto&& scale_packs_n = integer_divide_ceil(kargs.N, (NXdlPack * NThreadPerXdl)); + const auto&& scale_packs_k = kargs.K / BlockScaleSize / (KXdlPack * KThreadPerXdl); + + // Step 1: Create tensor view + const auto scale_b_naive_desc = make_naive_tensor_descriptor_packed( + make_tuple(scale_packs_n, scale_packs_k, KThreadPerXdl, NThreadPerXdl)); + const auto scale_b_desc = transform_tensor_descriptor( + scale_b_naive_desc, + make_tuple(make_merge_transform(make_tuple(scale_packs_n, NThreadPerXdl)), + make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))), + make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + const auto& scale_b_tensor_view = make_tensor_view( + reinterpret_cast(kargs.scale_n_ptr.ptr), scale_b_desc); + + // Step 2: Create tile window + return make_tile_window( + scale_b_tensor_view, make_tuple(number{}, number{}), - {i_n / NXdlPack, 0}); - - return make_tuple(a_block_window, - b_flat_block_window, - ds_block_window, - e_block_window, - scale_a_block_window, - scale_b_block_window); + {block_idx_n / NXdlPack, 0}); } template @@ -390,22 +381,16 @@ struct MXFlatmmKernel : FlatmmKernel( - a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset); - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + // Create block windows using specialized methods + const auto& a_block_window = + MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); + const auto& b_flat_block_window = MakeBFlatBlockWindow(b_flat_ptr, kargs, block_idx_n); + const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n); + const auto& scale_a_block_window = MakeScaleABlockWindow(kargs, block_idx_m); + const auto& scale_b_block_window = MakeScaleBBlockWindow(kargs, block_idx_n); const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k); - // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_flat_block_window = gemm_tile_windows.at(I1); - const auto& d_block_window = gemm_tile_windows.at(I2); - const auto& scale_a_block_window = gemm_tile_windows.at(I4); - const auto& scale_b_block_window = gemm_tile_windows.at(I5); - static_assert(ScaleM::GranularityK == ScaleN::GranularityK // have the same granK || ScaleM::GranularityMN == -1 // or ScaleA is disable || ScaleN::GranularityMN == -1, // or ScaleB is disable @@ -422,22 +407,46 @@ struct MXFlatmmKernel : FlatmmKernel( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(e_block_window, + c_block_tile, + ds_block_window, + smem_ptr_ping, + kargs.scale_m_ptr + block_idx_m, + kargs.scale_n_ptr + block_idx_n); + } + else + { + auto e_block_window = MakeEBlockWindow( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(e_block_window, + c_block_tile, + ds_block_window, + smem_ptr_ping, + kargs.scale_m_ptr + block_idx_m, + kargs.scale_n_ptr + block_idx_n); + } } else if(UseDefaultScheduler || (get_warp_id() == 0)) { - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I3); - EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping); + if(kargs.k_batch == 1) + { + auto e_block_window = MakeEBlockWindow( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_window, smem_ptr_ping); + } + else + { + auto e_block_window = MakeEBlockWindow( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_window, smem_ptr_ping); + } } } @@ -466,27 +475,17 @@ struct MXFlatmmKernel : FlatmmKernel::value)) - { - constexpr auto scheduler_type = (MXFlatmmPipeline::NumWaveGroups == 1); - RunFlatmm(a_ptr, - b_flat_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr_ping, - smem_ptr_pong, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - else - { - static_assert(false, - "Unimplemented: atomic_add with odd vector size for fp16/bf16"); - } + constexpr auto scheduler_type = (MXFlatmmPipeline::NumWaveGroups == 1); + RunFlatmm(a_ptr, + b_flat_ptr, + kargs.ds_ptr, + e_ptr, + smem_ptr_ping, + smem_ptr_pong, + kargs, + splitk_batch_offset, + i_m, + i_n); partition_idx += gridDim.x; } while(UsePersistentKernel && partition_idx < total_work_tile_cnt); } diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 20714397c9..eb4aa16d05 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -3,6 +3,7 @@ #pragma once #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/fmha/block/block_masking.hpp" diff --git a/include/ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp b/include/ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp new file mode 100644 index 0000000000..c79e639469 --- /dev/null +++ b/include/ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp @@ -0,0 +1,32 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +namespace ck_tile { + +// KV cache memory layout selector. +// +// Layout summary (kVectorSize = 16 / sizeof(KDataType)): +// - VECTORIZED_LAYOUT (swizzled): +// K: [NumBlocks, NumHeads, HeadDim/kVectorSize, PageSize, kVectorSize] +// V: [NumBlocks, NumHeads, PageSize/kVectorSize, HeadDim, kVectorSize] +// - LINEAR_LAYOUT: +// K: [NumBlocks, PageSize, NumHeads, HeadDim] +// V: [NumBlocks, PageSize, NumHeads, HeadDim] +enum class BlockAttentionKVCacheMemoryLayoutEnum +{ + VECTORIZED_LAYOUT = 0, + LINEAR_LAYOUT = 1, +}; + +// KV cache lookup table layout selector. +// - VLLM_BLOCK_TABLE_2D: block_table[batch, max_blocks_per_seq] +// - SGLANG_PAGE_TABLE_1D: kv_page_indices[kv_indptr[b] ... kv_indptr[b+1]) +enum class BlockAttentionKVCacheLookupTableEnum +{ + VLLM_BLOCK_TABLE_2D = 0, + SGLANG_PAGE_TABLE_1D = 1, +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index 73b6a329d1..9afd097eed 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp" #include "ck_tile/ops/fmha/block/variants.hpp" @@ -56,12 +57,15 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum; + static constexpr auto kKVMemoryLayout = FmhaPipeline::Problem::kKVMemoryLayout; + static constexpr auto kKVLookupTable = FmhaPipeline::Problem::kKVLookupTable; + static constexpr index_t kPageBlockSize = FmhaPipeline::kPageBlockSize; + static constexpr index_t kVectorSize = FmhaPipeline::kVectorSize; using AttentionVariant = ck_tile::remove_cvref_t; using FmhaMask = ck_tile::remove_cvref_t; static constexpr bool kHasMask = FmhaMask::IsMasking; static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy; - template // to avoid duplicated base class prblem, introduce an template // arg struct FmhaFwdEmptyKargs @@ -71,6 +75,26 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel // kargs use aggregate initializer, so no constructor will provided // use inheritance to minimize karg size // user need to use MakeKargs() function to create kargs. + struct SglangPageTableKargs + { + const int32_t* kv_indptr; + const int32_t* kv_page_indices; + const int32_t* kv_last_page_lens; + }; + + struct VllmPageTableKargs + { + const int32_t* block_table_ptr; + ck_tile::index_t batch_stride_block_table; + const int32_t* seqlen_k_ptr; + }; + + using PageBlockTableKargs = + std::conditional_t; + struct FmhaFwdCommonKargs { const void* q_ptr; @@ -89,14 +113,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel ck_tile::index_t nhead_ratio_qk; int32_t num_total_pages; - const int32_t* kv_indptr; - const int32_t* kv_page_indices; -#if 0 // we assume page_block_size=1 for now - const int32_t* kv_last_page_lens; ck_tile::index_t page_block_size; -#else - static constexpr ck_tile::index_t page_block_size = 1; -#endif + PageBlockTableKargs page_table; float scale_s; @@ -295,12 +313,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, int32_t num_total_pages, - const void* kv_indptr, - const void* kv_page_indices, -#if 0 // we assume page_block_size=1 for now - const void* kv_last_page_lens, ck_tile::index_t page_block_size, -#endif + const PageBlockTableKargs& page_table, float scale_s, [[maybe_unused]] float scale_p, [[maybe_unused]] float scale_o, @@ -345,12 +359,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel num_head_q, nhead_ratio_qk, num_total_pages, - reinterpret_cast(kv_indptr), - reinterpret_cast(kv_page_indices), -#if 0 // we assume page_block_size=1 for now - reinterpret_cast(kv_last_page_lens), page_block_size, -#endif + page_table, #if CK_TILE_FMHA_FWD_FAST_EXP2 static_cast(scale_s * ck_tile::log2e_v<>), #else @@ -453,12 +463,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, int32_t num_total_pages, - const void* kv_indptr, - const void* kv_page_indices, -#if 0 // we assume page_block_size=1 for now - const void* kv_last_page_lens, ck_tile::index_t page_block_size, -#endif + const PageBlockTableKargs& page_table, float scale_s, [[maybe_unused]] float scale_p, [[maybe_unused]] float scale_o, @@ -498,12 +504,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel num_head_q, nhead_ratio_qk, num_total_pages, - reinterpret_cast(kv_indptr), - reinterpret_cast(kv_page_indices), -#if 0 // we assume page_block_size=1 for now - reinterpret_cast(kv_last_page_lens), page_block_size, -#endif + page_table, #if CK_TILE_FMHA_FWD_FAST_EXP2 static_cast(scale_s * ck_tile::log2e_v<>), #else @@ -700,10 +702,46 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel long_index_t batch_offset_lse = 0; long_index_t batch_offset_o = 0; - const int32_t num_page_blocks = kargs.kv_indptr[i_batch + 1] - kargs.kv_indptr[i_batch]; -#if 0 // we assume page_block_size=1 for now - const int32_t last_page_len = kargs.kv_last_page_lens[i_batch]; -#endif + const index_t seqlen_k = [&]() { + if constexpr(kKVLookupTable == + BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D) + { + const int32_t page_start = kargs.page_table.kv_indptr[i_batch]; + const int32_t page_end = kargs.page_table.kv_indptr[i_batch + 1]; + const int32_t num_page_blocks = page_end - page_start; + const int32_t last_page_len = [&]() { + if constexpr(kPageBlockSize == 1) + return static_cast(kPageBlockSize); + else + return kargs.page_table.kv_last_page_lens[i_batch]; + }(); + return num_page_blocks > 0 + ? static_cast((num_page_blocks - 1) * kargs.page_block_size + + last_page_len) + : 0; + } + else // BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D + { + if(kargs.page_table.seqlen_k_ptr != nullptr) + return static_cast(kargs.page_table.seqlen_k_ptr[i_batch]); + else + return kargs.seqlen_k; + } + }(); + const int32_t* page_idx = [&]() { + if constexpr(kKVLookupTable == + BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D) + { + return kargs.page_table.kv_page_indices + kargs.page_table.kv_indptr[i_batch]; + } + else // BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D + { + return kargs.page_table.block_table_ptr + + static_cast(i_batch) * + kargs.page_table.batch_stride_block_table; + } + }(); + if constexpr(kIsGroupMode) { // get starting offset for each batch @@ -711,8 +749,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel batch_offset_q = query_start * kargs.stride_q; - kargs.kv_page_indices += kargs.kv_indptr[i_batch]; - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { batch_offset_bias = query_start * kargs.stride_bias; @@ -737,18 +773,12 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel return; } -#if 0 // we assume page_block_size=1 for now - kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + last_page_len; -#else - kargs.seqlen_k = num_page_blocks; -#endif + kargs.seqlen_k = seqlen_k; } else { batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; - kargs.kv_page_indices += kargs.kv_indptr[i_batch]; - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; @@ -764,11 +794,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel } batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; -#if 0 // we assume page_block_size=1 for now - kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + last_page_len; -#else - kargs.seqlen_k = num_page_blocks; -#endif + kargs.seqlen_k = seqlen_k; } // for simplicity, batch stride we just modify the pointer @@ -809,60 +835,137 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel } }(); const auto k_dram = [&]() { - const auto k_dram_naive = make_naive_tensor_view( - k_ptr, - make_tuple(kargs.num_total_pages * kargs.page_block_size, kargs.hdim_q), - make_tuple(kargs.stride_k, 1), - number{}, - number<1>{}); - - constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true; - return pad_tensor_view( - k_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - }(); - const auto v_dram = [&]() { - if constexpr(std::is_same_v) + if constexpr(kKVMemoryLayout == + BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) { - const auto v_dram_naive = make_naive_tensor_view( - v_ptr, - make_tuple(kargs.num_total_pages * kargs.page_block_size, kargs.hdim_v), - make_tuple(kargs.stride_v, 1), - number{}, + // Vectorized K Layout: [NumPages, D/kVectorSize, S, kVectorSize] + // Logical View for Pipeline: (TotalSeqK, D) + + // Define the naive physical view with 4D shape: (NumPages, HeadDim/kVectorSize, + // PageBlockSize, kVectorSize) + // Strides: (BatchStride, PageBlockSize*kVectorSize, kVectorSize, 1) + const auto k_dram_naive = make_naive_tensor_view( + k_ptr, + make_tuple(kargs.num_total_pages, + kargs.hdim_q / kVectorSize, + kargs.page_block_size, + kVectorSize), + make_tuple( + kargs.batch_stride_k, kargs.page_block_size * kVectorSize, kVectorSize, 1), + number{}, number<1>{}); - const auto v_dram_transposed = transform_tensor_view( - v_dram_naive, - make_tuple( - make_pass_through_transform(kargs.hdim_v), - make_pass_through_transform(kargs.num_total_pages * kargs.page_block_size)), - make_tuple(sequence<1>{}, sequence<0>{}), + // Merge to (TotalSeqK, D) in a single transform: + // physical (Page, D/vec, S, vec) -> logical (TotalSeqK, D) + auto k_dram_2d = transform_tensor_view( + k_dram_naive, + make_tuple(make_merge_transform(make_tuple(kargs.num_total_pages, + kargs.page_block_size)), // TotalSeqK + make_merge_transform( + make_tuple(static_cast(kargs.hdim_q / kVectorSize), + static_cast(kVectorSize)))), // D + make_tuple(sequence<0, 2>{}, sequence<1, 3>{}), make_tuple(sequence<0>{}, sequence<1>{})); constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true; return pad_tensor_view( - v_dram_transposed, + k_dram_2d, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + // Linear K Layout: [NumPages, PageSize, NumHeads, HeadDim] + // Logical View for Pipeline: (TotalSeqK, D) + const auto k_dram_naive = make_naive_tensor_view( + k_ptr, + make_tuple(kargs.num_total_pages, kargs.page_block_size, kargs.hdim_q), + make_tuple(kargs.batch_stride_k, kargs.stride_k, 1), + number{}, + number<1>{}); + + // Merge to (TotalSeqK, D) in a single transform: + // physical (Page, S, D) -> logical (TotalSeqK, D) + auto k_dram_2d = transform_tensor_view( + k_dram_naive, + make_tuple(make_merge_transform( + make_tuple(kargs.num_total_pages, kargs.page_block_size)), + make_pass_through_transform(kargs.hdim_q)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true; + return pad_tensor_view( + k_dram_2d, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + const auto v_dram = [&]() { + if constexpr(kKVMemoryLayout == + BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) + { + // Vectorized V Layout: [NumPages, S/kVectorSize, D, kVectorSize] + // Logical View for Pipeline: (D, TotalSeqK) - Transposed for GEMM + + // Define the naive physical view with 4D shape: (NumPages, + // PageBlockSize/kVectorSize, HeadDim, kVectorSize) + // Strides: (BatchStride, HeadDim*kVectorSize, kVectorSize, 1) + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.num_total_pages, + kargs.page_block_size / kVectorSize, + kargs.hdim_v, + kVectorSize), + make_tuple(kargs.batch_stride_v, kargs.hdim_v * kVectorSize, kVectorSize, 1), + number{}, + number<1>{}); + + // Merge to (D, TotalSeqK) in a single transform: + // physical (Page, S/vec, D, vec) -> logical (D, TotalSeqK) + auto v_dram_final = transform_tensor_view( + v_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_v), // D + make_merge_transform(make_tuple(kargs.num_total_pages, + kargs.page_block_size / kVectorSize, + kVectorSize))), // TotalSeqK + make_tuple(sequence<2>{}, sequence<0, 1, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true; + return pad_tensor_view( + v_dram_final, make_tuple(number{}, number{}), sequence{}); } else { + // Linear V Layout: [NumPages, PageSize, NumHeads, HeadDim] + // Logical View for Pipeline: (D, TotalSeqK) const auto v_dram_naive = make_naive_tensor_view( v_ptr, - make_tuple(kargs.hdim_v, kargs.num_total_pages * kargs.page_block_size), - make_tuple(kargs.stride_v, 1), + make_tuple(kargs.num_total_pages, kargs.page_block_size, kargs.hdim_v), + make_tuple(kargs.batch_stride_v, kargs.stride_v, 1), number{}, number<1>{}); - constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false; - return pad_tensor_view( + // Merge to (D, TotalSeqK) in a single transform: + // physical (Page, S, D) -> logical (D, TotalSeqK) + auto v_dram_final = transform_tensor_view( v_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_v), + make_merge_transform( + make_tuple(kargs.num_total_pages, kargs.page_block_size))), + make_tuple(sequence<2>{}, sequence<0, 1>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true; + return pad_tensor_view( + v_dram_final, make_tuple(number{}, number{}), - sequence{}); + sequence{}); } }(); - auto q_dram_window = make_tile_window( q_dram, [&]() { @@ -1070,6 +1173,15 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; + const index_t stride_k_for_pipeline = + kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT + ? kVectorSize + : kargs.stride_k; + const index_t stride_v_for_pipeline = + kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT + ? kargs.hdim_v + : kargs.stride_v; + auto o_acc_tile = [&] { if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR) { @@ -1108,9 +1220,11 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel variant_params, block_indices, smem_ptr, - kargs.kv_page_indices, - kargs.stride_k, - kargs.stride_v, + page_idx, + stride_k_for_pipeline, + stride_v_for_pipeline, + kargs.batch_stride_k, + kargs.batch_stride_v, dropout); } else @@ -1128,9 +1242,11 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel variant_params, block_indices, smem_ptr, - kargs.kv_page_indices, - kargs.stride_k, - kargs.stride_v, + page_idx, + stride_k_for_pipeline, + stride_v_for_pipeline, + kargs.batch_stride_k, + kargs.batch_stride_v, dropout); } }(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 2102fe768f..0b47441995 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -6,12 +6,82 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/fmha/block/variants.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" namespace ck_tile { +template +CK_TILE_HOST_DEVICE void kv_offset_array_transform(const index_t* page_vec, + const index_t& stride_kv, + const index_t& page_stride_kv, + const CoordVecType& coord_vec, + OffsetVecType& kv_offset_vec, + index_t global_seq_offset = 0) +{ + const index_t& thread_coord_start = coord_vec[kCoordAxis]; + constexpr index_t kInPageOffsetMask = (1 << kLog2PageSize) - 1; + if constexpr(kIsKcache) + { + // for k offsets + static_for<0, kLoopCount, 1>{}([&](auto k0) { + const index_t global_token_idx = + global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; + const index_t page_id = global_token_idx >> kLog2PageSize; + const index_t page_offset = global_token_idx & kInPageOffsetMask; + kv_offset_vec[k0] = static_cast(page_vec[page_id]) * page_stride_kv + + static_cast(page_offset) * stride_kv; + }); + } + else + { + // for v offsets + const index_t lane0_start = __builtin_amdgcn_readfirstlane(thread_coord_start); + const index_t lane0_page_id = + (global_seq_offset + lane0_start + kLoopStart) >> kLog2PageSize; + + const long_index_t page_loc = + static_cast(page_vec[lane0_page_id]) * page_stride_kv; + + static_for<0, kLoopCount, 1>{}([&](auto k0) { + const index_t page_offset = + (global_seq_offset + thread_coord_start + kLoopStart + k0.value) & + kInPageOffsetMask; + + if constexpr(kKVMemoryLayout == + BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) + { + // Vectorized layout offset + // Layout: [BlockSize/kVectorSize, HeadDim, kVectorSize] + // Offset(s) = (s / kVectorSize) * (HeadDim * kVectorSize) + (s % kVectorSize) + const index_t s = page_offset; + const index_t D = stride_kv; + + const long_index_t s_offset = + static_cast((s / kVectorSize) * (D * kVectorSize)) + + (s % kVectorSize); + + kv_offset_vec[k0] = page_loc + s_offset; + } + else // BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT + { + kv_offset_vec[k0] = page_loc + static_cast(page_offset) * stride_kv; + } + }); + } +} // a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future) template {}; - static constexpr auto I1 = number<1>{}; - static constexpr auto I2 = number<2>{}; - static constexpr auto I3 = number<3>{}; + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; + static constexpr index_t kPageBlockSize = Problem::kPageBlockSize; + static constexpr index_t kLog2PageSize = Problem::kLog2PageSize; + static constexpr index_t kVectorSize = Problem::kVectorSize; + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto I2 = number<2>{}; + static constexpr auto I3 = number<3>{}; static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + static_assert(kPageBlockSize % kN0 == 0, + "V offset assumes each tile stays within a page; kPageBlockSize must be " + "divisible by kN0."); static constexpr bool kIsGroupMode = Problem::kIsGroupMode; // TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x) @@ -68,6 +144,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kHasDropout = Problem::kHasDropout; + static constexpr auto kKVMemoryLayout = Problem::kKVMemoryLayout; static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || @@ -196,6 +273,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync const index_t* page_idx, const index_t stride_k, const index_t stride_v, + const index_t page_stride_k, + const index_t page_stride_v, DropoutType& dropout) const { static_assert( @@ -325,9 +404,20 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync using KDstrEncode = typename decltype(k_dist)::DstrEncode; constexpr index_t NRepeat = KDstrEncode::hs_lengthss_[I0][I0]; statically_indexed_array k_offsets; - static_for<0, NRepeat, 1>{}([&](auto n0) { - k_offsets[n0] = page_idx[k_coord[0] + kN0 / NRepeat * n0.value] * stride_k; - }); + index_t current_seq_k = seqlen_k_start; + kv_offset_array_transform, + decltype(k_coord), + 0, + kPageBlockSize, + kLog2PageSize, + 0, + NRepeat, + kN0 / NRepeat, + kKVMemoryLayout, + true, + kVectorSize>( + page_idx, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k); + auto k_dram_window = make_tile_scatter_gather(k_dram_block_window.get_bottom_tensor_view(), k_dram_block_window.get_window_lengths(), k_dram_block_window.get_window_origin(), @@ -360,10 +450,18 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync using VDstrEncode = typename decltype(v_dist)::DstrEncode; constexpr index_t V_KRepeat = VDstrEncode::hs_lengthss_[I1][I3]; statically_indexed_array v_offsets; - (void)stride_k; - static_for<0, V_KRepeat, 1>{}([&](auto k0) { - v_offsets[k0] = page_idx[v_coord[VPageIndexDim] + k0.value] * stride_v; - }); + kv_offset_array_transform, + decltype(v_coord), + VPageIndexDim, + kPageBlockSize, + kLog2PageSize, + 0, + V_KRepeat, + 1, + kKVMemoryLayout, + false, + kVectorSize>( + page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); auto v_dram_window = make_tile_scatter_gather(v_dram_block_window_tmp.get_bottom_tensor_view(), @@ -425,13 +523,6 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync async_load_fence(); __builtin_amdgcn_s_barrier(); - const auto bias_tile = load_tile(bias_dram_window); // load bias tile - auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); - static_for<0, V_KRepeat, 1>{}([&](auto k0) { - v_offsets[k0] = page_idx[kK1 + v_coord[VPageIndexDim] + k0.value] * stride_v; - }); - v_dram_window.update_page_idx(v_offsets); - __builtin_amdgcn_sched_barrier(0); { // tail gemm_0( @@ -444,49 +535,67 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync } __builtin_amdgcn_sched_barrier(1); - // STAGE 2, scale_s, add bias, mask, softmax - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); - tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); - tile_elementwise_inout( - [&](auto& x, const auto& y) { -#if !CK_TILE_FMHA_FWD_FAST_EXP2 - x += type_convert(bias_element_func(y)); -#else - x += log2e_v * - type_convert(bias_element_func(y)); -#endif - }, - s_acc, - bias_tile); - } - else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - const auto k_origin = k_dram_block_window.get_window_origin(); - constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); - sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { - const auto tile_idx = get_x_indices_from_distributed_indices( - s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); + kv_offset_array_transform, + decltype(v_coord), + VPageIndexDim, + kPageBlockSize, + kLog2PageSize, + kK1, + V_KRepeat, + 1, + kKVMemoryLayout, + false, + kVectorSize>( + page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); + v_dram_window.update_page_idx(v_offsets); - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - constexpr auto i_j_idx = make_tuple(idx0, idx1); + const auto p = [&]() { + const auto bias_tile = load_tile(bias_dram_window); // load bias tile - s_acc(i_j_idx) *= scale_s; - position_encoding.update(s_acc(i_j_idx), row, col); - }); - }); - } - else - { - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); - if constexpr(kHasLogitsSoftCap) + // STAGE 2, scale_s, add bias, mask, softmax + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - auto apply_logits_transform = - [&variant, &variant_params, &block_indices](auto& x) { + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + x += type_convert(bias_element_func(y)); +#else + x += log2e_v * + type_convert(bias_element_func(y)); +#endif + }, + s_acc, + bias_tile); + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + s_acc(i_j_idx) *= scale_s; + position_encoding.update(s_acc(i_j_idx), row, col); + }); + }); + } + else + { + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + if constexpr(kHasLogitsSoftCap) + { + auto apply_logits_transform = [&variant, &variant_params, &block_indices]( + auto& x) { x = variant.LogitsTransform(variant_params, variant.QueryTransform(variant_params, x), block_indices.batch_idx, @@ -494,216 +603,229 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync block_indices.kv_head_idx); }; #if !CK_TILE_FMHA_FWD_FAST_EXP2 - for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i) - { - apply_logits_transform(s_acc.thread_buf_[i]); - } + for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i) + { + apply_logits_transform(s_acc.thread_buf_[i]); + } #else - for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i) - { + for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i) + { #if(defined(__gfx90a__) || defined(__gfx94__)) && \ (CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN && \ CK_TILE_ATTENTION_USE_SOFTSIGN_ASM) - // Avoid data hazard if v_mfma is followed by inline asm consumer - // instructions. In this case, compiler won't add s_nop for us - if(i == s_acc.thread_buf_.size() / 2) - { - __builtin_amdgcn_sched_barrier(0); + // Avoid data hazard if v_mfma is followed by inline asm consumer + // instructions. In this case, compiler won't add s_nop for us + if(i == s_acc.thread_buf_.size() / 2) + { + __builtin_amdgcn_sched_barrier(0); + } +#endif + apply_logits_transform(s_acc.thread_buf_[i]); } #endif - apply_logits_transform(s_acc.thread_buf_[i]); - } -#endif - } - else - { -#if !CK_TILE_FMHA_FWD_FAST_EXP2 - tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); -#endif - } - } - move_tile_window(bias_dram_window, {0, kN0}); - if constexpr(kPadSeqLenK || FmhaMask::IsMasking) - { - const auto k_origin = k_dram_block_window.get_window_origin(); - bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), - k_origin.at(number<0>{}), - number{}, - number{}); - - if(need_perpixel_check) - { - set_tile_if( - s_acc, -numeric::infinity(), [&](auto tile_idx) { - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return !variant.LogitsMask(variant_params, - block_indices.batch_idx, - row, - col, - block_indices.qo_head_idx, - block_indices.kv_head_idx); - }); - } - } - - const auto s = cast_tile(s_acc); // S{j} - auto m_local = block_tile_reduce( - s, - sequence<1>{}, - f_max, - -numeric::infinity()); // m_local = rowmax(S{j}) - block_tile_reduce_sync(m_local, f_max, bool_constant{}); - - const auto m_old = m; // m{j-1} - tile_elementwise_inout( - [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} - - auto p_compute = make_static_distributed_tensor( - s.get_tile_distribution()); // Pcompute{j} - - __builtin_amdgcn_sched_barrier(0x7F); - // store & prefetch next v, after the max reduction - if constexpr(std::is_same_v) - { - auto v_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledVRegBlockDescriptor()); - shuffle_tile(v_shuffle_tmp, v_buf); - - auto v_lds_window_tmp = - get_slice_tile(v_lds_window, - sequence<(LdsSeq.at(number{})) * kN1, 0>{}, - sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); - - store_tile( - v_lds_window_tmp, - tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch - } - else - { - auto v_lds_window_tmp = - get_slice_tile(v_lds_window, - sequence<(LdsSeq.at(number{})) * kN1, 0>{}, - sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); - store_tile(v_lds_window_tmp, - tile_elementwise_in(v_element_func, v_buf)); // store the prefetch - } - - if constexpr(k1_loops > 1) - { - move_tile_window( - v_dram_window, - {0, kK1}); // will have scratch if move this right after load_tile(v_dram)... - v_buf = load_tile( - v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf - static_for<0, V_KRepeat, 1>{}([&](auto k0) { - v_offsets[k0] = - page_idx[kK1 * 2 + v_coord[VPageIndexDim] + k0.value] * stride_v; - }); - v_dram_window.update_page_idx(v_offsets); - } - __builtin_amdgcn_sched_barrier(0); - - static const auto get_validated_m = [](SMPLComputeDataType raw_m) { - /// NOTICE: bias might be materialized mask including -inf values, need - /// consideration. alibi does not have this problem - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - FmhaMask::IsMasking) - { - return raw_m == -numeric::infinity() - ? type_convert(0.f) - : raw_m; - } - else - { - return raw_m; - } - }; - - constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); - sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - auto row_max = scale_s * get_validated_m(m[i_idx]); -#endif - sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); } else { - if constexpr(kHasLogitsSoftCap) +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); +#endif + } + } + move_tile_window(bias_dram_window, {0, kN0}); + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}), + number{}, + number{}); + + if(need_perpixel_check) + { + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = + q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = + k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return !variant.LogitsMask(variant_params, + block_indices.batch_idx, + row, + col, + block_indices.qo_head_idx, + block_indices.kv_head_idx); + }); + } + } + + const auto s = cast_tile(s_acc); // S{j} + auto m_local = block_tile_reduce( + s, + sequence<1>{}, + f_max, + -numeric::infinity()); // m_local = rowmax(S{j}) + block_tile_reduce_sync(m_local, f_max, bool_constant{}); + + const auto m_old = m; // m{j-1} + tile_elementwise_inout([](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, + m, + m_old, + m_local); // m{j} + + auto p_compute = make_static_distributed_tensor( + s.get_tile_distribution()); // Pcompute{j} + + __builtin_amdgcn_sched_barrier(0x7F); + // store & prefetch next v, after the max reduction + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_buf); + + auto v_lds_window_tmp = + get_slice_tile(v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + + store_tile( + v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch + } + else + { + auto v_lds_window_tmp = + get_slice_tile(v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_buf)); // store the prefetch + } + + if constexpr(k1_loops > 1) + { + move_tile_window( + v_dram_window, + {0, + kK1}); // will have scratch if move this right after load_tile(v_dram)... + v_buf = load_tile( + v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf + kv_offset_array_transform, + decltype(v_coord), + VPageIndexDim, + kPageBlockSize, + kLog2PageSize, + 2 * kK1, + V_KRepeat, + 1, + kKVMemoryLayout, + false, + kVectorSize>( + page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); + v_dram_window.update_page_idx(v_offsets); + } + __builtin_amdgcn_sched_barrier(0); + + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { + /// NOTICE: bias might be materialized mask including -inf values, need + /// consideration. alibi does not have this problem + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return raw_m == -numeric::infinity() + ? type_convert(0.f) + : raw_m; + } + else + { + return raw_m; + } + }; + + constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + auto row_max = scale_s * get_validated_m(m[i_idx]); +#endif + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) { p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); } else { - p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + if constexpr(kHasLogitsSoftCap) + { + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + } } - } #else - p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); + p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); #endif + }); }); - }); - auto rowsum_p = block_tile_reduce( - p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + auto rowsum_p = block_tile_reduce( + p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) - block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); - // l{j}, Oacc{j} - constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); - sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + // l{j}, Oacc{j} + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 - const auto tmp = [&]() { - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); - } - else - { - if constexpr(kHasLogitsSoftCap) + const auto tmp = [&]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) { return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); } else { - auto row_max = scale_s * get_validated_m(m[i_idx]); - return exp2(scale_s * m_old[i_idx] - row_max); + if constexpr(kHasLogitsSoftCap) + { + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); + } } - } - }(); + }(); #else - const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx])); + const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx])); #endif - l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; - sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - // FIXME: this use different equation from FA v2 paper, - // but produce correc result. - // Is the equation wrong? - o_acc(i_j_idx) *= tmp; + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + // FIXME: this use different equation from FA v2 paper, + // but produce correc result. + // Is the equation wrong? + o_acc(i_j_idx) *= tmp; + }); }); - }); - if constexpr(kHasDropout) - { - auto randval_ptr = - reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); - dropout.template Run( - randval_ptr, - seqlen_k_start + i_total_loops * kN0, - p_compute, - randval_dram_window); - } + if constexpr(kHasDropout) + { + auto randval_ptr = reinterpret_cast(smem_ptr) + + Policy::template GetSmemSizeKV(); + dropout + .template Run( + randval_ptr, + seqlen_k_start + i_total_loops * kN0, + p_compute, + randval_dram_window); + } - const auto p = [&]() { #if CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN // For fp32 to fp16, // impl::cast_tile_pkrtz_fp16_fp32 would cause precision issue, @@ -727,11 +849,18 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync { v_buf = load_tile( v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf - static_for<0, V_KRepeat, 1>{}([&](auto k0) { - v_offsets[k0] = page_idx[kK1 * 2 + i_k1.value * kK1 + - v_coord[VPageIndexDim] + k0.value] * - stride_v; - }); + kv_offset_array_transform, + decltype(v_coord), + VPageIndexDim, + kPageBlockSize, + kLog2PageSize, + (2 + i_k1.value) * kK1, + V_KRepeat, + 1, + kKVMemoryLayout, + false, + kVectorSize>( + page_idx, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); v_dram_window.update_page_idx(v_offsets); } block_sync_lds(); @@ -772,14 +901,23 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync i_total_loops++; if(i_total_loops < num_total_loop) { - page_idx += kN0; + current_seq_k += kN0; // move K tile windows move_tile_window(k_dram_block_window, {kN0, 0}); k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); - static_for<0, NRepeat, 1>{}([&](auto n0) { - k_offsets[n0] = page_idx[k_coord[0] + kN0 / NRepeat * n0.value] * stride_k; - }); + kv_offset_array_transform, + decltype(k_coord), + 0, + kPageBlockSize, + kLog2PageSize, + 0, + NRepeat, + kN0 / NRepeat, + kKVMemoryLayout, + true, + kVectorSize>( + page_idx, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k); k_dram_window.update_page_idx(k_offsets); if constexpr(k1_loops >= 2 && LdsSeq.at(number<0>{}) == LdsSeq.at(number{})) @@ -887,6 +1025,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync const index_t* page_idx, const index_t stride_k, const index_t stride_v, + const index_t page_stride_k, + const index_t page_stride_v, DropoutType& dropout) const { return operator()(q_dram_block_window_tmp, @@ -913,6 +1053,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync page_idx, stride_k, stride_v, + page_stride_k, + page_stride_v, dropout); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index a192e3f7b0..f9dc94bc65 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp" #include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp" namespace ck_tile { @@ -65,6 +66,71 @@ struct BlockFmhaPipelineProblem static constexpr bool kHasSink = Traits::kHasSink; }; +template +struct BlockFmhaBatchPrefillPipelineProblem + : public BlockFmhaPipelineProblem +{ + static constexpr index_t kPageBlockSize = kPageBlockSize_; + static_assert(kPageBlockSize > 0, "kPageBlockSize must be positive"); + static_assert((kPageBlockSize & (kPageBlockSize - 1)) == 0, + "kPageBlockSize must be power of two"); + static constexpr index_t kLog2PageSize = []() constexpr { + index_t shift = 0; + index_t val = kPageBlockSize_; + while(val > 1) + { + val >>= 1; + shift++; + } + return shift; + }(); + + static constexpr index_t kVectorSize = 16 / sizeof(KDataType_); // Dwordx4 + static constexpr auto kKVMemoryLayout = Traits_::kKVMemoryLayout; + static constexpr auto kKVLookupTable = Traits_::kKVLookupTable; + static constexpr bool kIsVectorizedLayout = + kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT; + + static_assert(BlockFmhaShape_::kQKHeaddim % kVectorSize == 0, + "kQKHeaddim must be divisible by kVectorSize"); + static_assert(!kIsVectorizedLayout || kPageBlockSize % kVectorSize == 0, + "kPageBlockSize must be divisible by kVectorSize for vectorized layout"); + static_assert(kIsGroupMode_, "Batch prefill requires group mode"); +}; + template +struct TileFmhaBatchPrefillTraits : public TileFmhaTraits +{ + static constexpr auto kKVMemoryLayout = kKVMemoryLayout_; + static constexpr auto kKVLookupTable = kKVLookupTable_; + static constexpr index_t kPageBlockSize = kPageBlockSize_; + static_assert(kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT || + kKVMemoryLayout == BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT, + "Batch prefill only supports vectorized or linear KV cache layout."); + static_assert(kPageBlockSize > 0 && ((kPageBlockSize & (kPageBlockSize - 1)) == 0), + "kPageBlockSize should be a power of 2 to support efficient page-based KV cache " + "addressing."); +}; + template +struct BlockWeightPreshuffleASmemBRegCReg +{ + using Problem = remove_cvref_t; + using BlockPolicy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + static constexpr auto idxM = I0; + static constexpr auto idxN = I1; + static constexpr auto idxK = I2; + using BlockTile = remove_cvref_t; + using BlockWarps = remove_cvref_t; + using WarpTile = remove_cvref_t; + + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + static constexpr index_t MWarp = config.template at<1>(); + static constexpr index_t NWarp = config.template at<2>(); + + static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); + static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); + static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; + + static constexpr index_t MPerBlockPerIter = MWarp * WarpGemm::kM; + static constexpr index_t KPerBlockPerIter = WarpGemm::kK; + + static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read + + static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload) + ? DsReadPreload + : MIterPerWarp * KIterPerWarp; + + using AWarpTensor = typename WarpGemm::AWarpTensor; + statically_indexed_array preloaded_a_warp_tensor; + + CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() + { + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence<1>>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + return a_block_dstr_encode; + } + + template + CK_TILE_DEVICE auto MakeALoadWindows(SmemBlockWindow& a_block_window) const + { + constexpr auto a_load_dstr = make_static_tile_distribution(MakeABlockDistributionEncode()); + + // create MIterPerWarp × KIterPerWarp window + return generate_tuple( + [&](auto kIter) { + return generate_tuple( + [&](auto mIter) { + return make_tile_window( + get_slice_tile( + a_block_window, + sequence{}, + sequence<(mIter + 1) * MPerBlockPerIter, + (kIter + 1) * KPerBlockPerIter>{}), + a_load_dstr); + }, + number{}); + }, + number{}); + } + + template + CK_TILE_DEVICE void LocalPrefetch(const ALoadWindows& a_load_windows) + { + + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + + load_tile(preloaded_a_warp_tensor(loadIter), + a_load_windows[number{}][number{}]); + }); + } + + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() + { + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ALoadWindows& a_load_windows, + BFlatBlockTensor& b_block_tensor, + const BFlatDistribution&) + { + constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1; + + using CWarpDstr = typename WarpGemm::CWarpDstr; + using CWarpTensor = typename WarpGemm::CWarpTensor; + + using BWarpTensor = typename WarpGemm::BWarpTensor; + + constexpr auto b_block_y_lengths = + to_sequence(BFlatDistribution{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto b_block_y_index_zeros = + uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + BWarpTensor b_warp_tensor; + CWarpTensor c_warp_tensor; + + b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, + typename sequence_split::right_type{}), + merge_sequences( + sequence<1, 1>{}, + typename sequence_split::right_type{})); + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WarpGemm{}( + c_warp_tensor, preloaded_a_warp_tensor(number{}), b_warp_tensor); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + + __builtin_amdgcn_sched_barrier(0x7F6); + }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + + load_tile(preloaded_a_warp_tensor(number{}), + a_load_windows[number{}][number{}]); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } + }); + }); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index 95114e8496..3f028ead2b 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -303,24 +303,15 @@ struct GroupedGemmKernel CDataType* c_ptr = static_cast(kargs.e_ptr); // allocate LDS - __shared__ char smem_ptr_0[GetSmemSize()]; + __shared__ char smem_ptr[GetSmemSize()]; // TO DO: // Can we simplify this branching logic? if constexpr(GemmPipeline::DoubleSmemBuffer == true) { - __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; - RunGemmWithPipelineSelection2LDS(a_ptr, - b_ptr, - c_ptr, - kargs.ds_ptr, - smem_ptr_0, - smem_ptr_1, - kargs, - splitk_batch_offset, - i_m, - i_n); + RunGemmWithPipelineSelection2LDS( + a_ptr, b_ptr, c_ptr, kargs.ds_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); } else // SingleSmemBuffer { @@ -331,7 +322,7 @@ struct GroupedGemmKernel b_ptr, kargs.ds_ptr, c_ptr, - smem_ptr_0, + smem_ptr, kargs, splitk_batch_offset, i_m, @@ -343,7 +334,7 @@ struct GroupedGemmKernel {b_ptr}, kargs.ds_ptr, c_ptr, - smem_ptr_0, + smem_ptr, kargs, splitk_batch_offset, i_m, @@ -361,6 +352,7 @@ struct GroupedGemmKernel * * @param a_ptr input A pointer * @param b_ptr input B pointer + * @param ds_ptr input Ds pointer * @param c_ptr output C pointer * @param smem_ptr_0 The start memory pointer of the shared memory block. * @param kargs GEMM kernel arguments @@ -381,49 +373,52 @@ struct GroupedGemmKernel const index_t block_idx_m, const index_t block_idx_n) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - Base::template MakeGemmTensorViews( - {a_ptr}, {b_ptr}, ds_ptr, c_ptr, kargs, splitk_batch_offset.splitted_k); + // Create block windows using specialized methods + const auto& a_block_window = + Base::MakeABlockWindows({a_ptr}, kargs, splitk_batch_offset.splitted_k, block_idx_m) + .at(Base::I0); + const auto& b_block_window = + Base::MakeBBlockWindows({b_ptr}, kargs, splitk_batch_offset.splitted_k, block_idx_n) + .at(Base::I0); + const auto& d_block_window = + Base::MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n); - const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = - Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - const auto& a_block_window = gemm_tile_windows.at(Base::I0); - const auto& b_block_window = gemm_tile_windows.at(Base::I1); - const auto& d_block_window = gemm_tile_windows.at(Base::I2); - - // Get hot-loop and tail configuration const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); - const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); - const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); - // Run GEMM pipeline + // Run GEMM cooperatively by whole workgroup. const auto& c_block_tile = GemmPipeline{}.template operator()( - a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0); + a_block_window, b_block_window, num_loop, smem_ptr_0); + // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(Base::I3); - EpiloguePipeline{}.template - operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); + if(kargs.k_batch == 1) + { + auto c_block_window = Base::template MakeCBlockWindows( + c_ptr, kargs, block_idx_m, block_idx_n); + + EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + else + { + auto c_block_window = + Base::template MakeCBlockWindows( + c_ptr, kargs, block_idx_m, block_idx_n); + + EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } } /** * @brief Runs single GEMM problem cooperatively by whole workgroup. * - * @note The GEMM pipeline is selected in-kernel based on the number of K-loops - * and the tail-number. This is needed for the persistent tile-loop when - * we didn't have access to the K dimension on the host. + * @note RunGEMM2LDS with two shared memory buffers using the ping pong buffer mechanism. * * @param a_ptr input A pointer * @param b_ptr input B pointer * @param c_ptr output C pointer - * @param smem_ptr_0 The start memory pointer of the shared memory block. - * @param smem_ptr_1 The second start memory pointer of the shared memory block. + * @param smem_ptr The start memory pointer of the shared memory block. * @param kargs GEMM kernel arguments - * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k - * batch. + * @param splitk_batch_offset Utility structure used to calculate k batch. * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. * @@ -433,61 +428,45 @@ struct GroupedGemmKernel const BDataType* b_ptr, CDataType* c_ptr, const std::array& ds_ptr, - void* __restrict__ smem_ptr_0, - void* __restrict__ smem_ptr_1, + void* __restrict__ smem_ptr, const UniversalGemmKernelArgs<1, 1, NumDTensor_>& kargs, const typename Base::SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - Base::template MakeGemmTensorViews( - {a_ptr}, {b_ptr}, ds_ptr, c_ptr, kargs, splitk_batch_offset.splitted_k); + // Create block windows using specialized methods + const auto& a_block_window = + Base::MakeABlockWindows({a_ptr}, kargs, splitk_batch_offset.splitted_k, block_idx_m) + .at(Base::I0); + const auto& b_block_window = + Base::MakeBBlockWindows({b_ptr}, kargs, splitk_batch_offset.splitted_k, block_idx_n) + .at(Base::I0); + const auto& d_block_window = + Base::MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n); - const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = - Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - const auto& a_block_window = gemm_tile_windows.at(Base::I0); - const auto& b_block_window = gemm_tile_windows.at(Base::I1); - const auto& d_block_window = gemm_tile_windows.at(Base::I2); - - // Get hot-loop and tail configuration const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); - const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); - // Run GEMM pipeline with compile-time branching - const auto& c_block_tile = [&]() { - if constexpr(GemmPipeline::Preshuffle) - { - // Preshuffle version - without has_hot_loop parameter - return GemmPipeline{}.template operator()(a_block_window[Base::I0], - b_block_window[Base::I0], - num_loop, - tail_num, - smem_ptr_0, - smem_ptr_1); - } - else - { - // Regular version - with has_hot_loop parameter - const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); - return GemmPipeline{}.template operator()(a_block_window[Base::I0], - b_block_window[Base::I0], - num_loop, - has_hot_loop, - tail_num, - smem_ptr_0, - smem_ptr_1); - } - }(); + // Run GEMM cooperatively by whole workgroup. + const auto& c_block_tile = + GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr); // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(Base::I3); - EpiloguePipeline{}.template - operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); + if(kargs.k_batch == 1) + { + auto c_block_window = Base::template MakeCBlockWindows( + c_ptr, kargs, block_idx_m, block_idx_n); + + EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr); + } + else + { + auto c_block_window = + Base::template MakeCBlockWindows( + c_ptr, kargs, block_idx_m, block_idx_n); + + EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr); + } } CK_TILE_DEVICE index_t FindGroupId(const GemmTransKernelArg* gemm_desc_ptr, diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp index d1fd32dc1b..47e59c4704 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp @@ -222,19 +222,13 @@ struct StreamKKernel const index_t block_idx_n, const index_t k_size) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - UniversalGemmKernel::template MakeGemmTensorViews( - as_ptr, bs_ptr, ds_ptr, c_ptr, kargs, k_size); - - const auto& gemm_pad_views = UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = - UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - - // Run GEMM cooperatively by whole workgroup. - const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0); - const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1); - const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2); + // Create block windows using specialized methods + const auto& as_block_window = + UniversalGemmKernel::MakeABlockWindows(as_ptr, kargs, k_size, block_idx_m); + const auto& bs_block_window = + UniversalGemmKernel::MakeBBlockWindows(bs_ptr, kargs, k_size, block_idx_n); + const auto& ds_block_window = + UniversalGemmKernel::MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n); // Since num_loop can vary per WG and per iteration of the Stream-K while loop, we compute // has_hot_loop and tail_num here. This is a similar pattern used by grouped GEMM. In this @@ -243,6 +237,7 @@ struct StreamKKernel const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); + // Run GEMM cooperatively by whole workgroup. const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0], bs_block_window[UniversalGemmKernel::I0], num_loop, @@ -253,7 +248,9 @@ struct StreamKKernel if(UseDefaultScheduler || (get_warp_id() == 0)) { // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3); + auto c_block_window = + UniversalGemmKernel::template MakeCBlockWindows( + c_ptr, kargs, block_idx_m, block_idx_n); EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0); } @@ -525,21 +522,13 @@ struct StreamKKernel const BDataType* b_ptr = static_cast(kargs.bs_ptr[0]) + i_k_b; CDataType* c_ptr = static_cast(kargs.e_ptr); - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - UniversalGemmKernel::template MakeGemmTensorViews< - EpiloguePipeline::MemoryOperation>( - {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, k_size); - - const auto& gemm_pad_views = - UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = - UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, i_m, i_n); - - // Run GEMM cooperatively by whole workgroup. - const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0); - const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1); - const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2); + // Create block windows using specialized methods + const auto& as_block_window = + UniversalGemmKernel::MakeABlockWindows({a_ptr}, kargs, k_size, i_m); + const auto& bs_block_window = + UniversalGemmKernel::MakeBBlockWindows({b_ptr}, kargs, k_size, i_n); + const auto& ds_block_window = + UniversalGemmKernel::MakeDBlockWindows({/*ds_ptr*/}, kargs, i_m, i_n); // Since num_loop can vary per WG and per iteration of the Stream-K while loop, // we compute has_hot_loop and tail_num here. This is a similar pattern used by @@ -548,6 +537,7 @@ struct StreamKKernel const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop_sk); const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop_sk); + // Run GEMM cooperatively by whole workgroup. const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0], bs_block_window[UniversalGemmKernel::I0], num_loop_sk, @@ -594,7 +584,8 @@ struct StreamKKernel } } - auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3); + auto c_block_window = UniversalGemmKernel::template MakeCBlockWindows< + TilePartitioner::MemoryOperation>(c_ptr, kargs, i_m, i_n); EpiloguePipeline{}( c_block_window, accum_block_tile, ds_block_window, smem_ptr_0); } @@ -617,7 +608,8 @@ struct StreamKKernel // tensor. if(tile_started && !partner_in_tile) { - auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3); + auto c_block_window = UniversalGemmKernel::template MakeCBlockWindows< + TilePartitioner::MemoryOperation>(c_ptr, kargs, i_m, i_n); EpiloguePipeline{}( c_block_window, accum_block_tile, ds_block_window, smem_ptr_0); break; diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp index a6022e8b8e..0b0f6c18ef 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp @@ -27,6 +27,9 @@ struct StreamKTilePartitionerBase static constexpr index_t NPerBlock = BlockGemmShapeType::kN; static constexpr index_t KPerBlock = BlockGemmShapeType::kK; static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategyType; + static constexpr auto MemoryOperation = (ReductionStrategy == StreamKReductionStrategy::Atomic) + ? memory_operation_enum::atomic_add + : memory_operation_enum::set; StreamKTilePartitionerBase(index_t m, index_t n, index_t k, index_t grid); diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index 77952c9afd..628f5f7dc8 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -254,6 +254,8 @@ struct UniversalGemmKernel static_assert(DsLayout::size() == DsDataType::size(), "The size of DsLayout and DsDataType should be the same"); + static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!"); + using KernelArgs = UniversalGemmKernelArgs; @@ -421,7 +423,7 @@ struct UniversalGemmKernel const auto vectorSizeA = is_wave32() ? GemmPipeline::template GetVectorSizeA() : GemmPipeline::template GetVectorSizeA(); - bool AsTesnorIsValid = {true}; + bool AsTensorIsValid = {true}; static_for<0, NumATensor, 1>{}([&](auto index) { using AiLayout = remove_cvref_t>; if constexpr(std::is_same_v) @@ -435,15 +437,27 @@ struct UniversalGemmKernel "Can't support K that is not a multiple of k_batch * KPerBlock " "without padding!"); } - AsTesnorIsValid = false; + AsTensorIsValid = false; } if(kargs.K % vectorSizeA != 0) { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + const auto remainder = kargs.K % vectorSizeA; + constexpr ck_tile::index_t APackedSize = + ck_tile::numeric_traits::PackedSize; + const auto remainder_in_bytes = remainder * sizeof(ADataType) / APackedSize; + // oob can support to dword level + if(remainder_in_bytes % 4 == 0) { - CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!"); + AsTensorIsValid = true; + } + else + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!"); + } + AsTensorIsValid = false; } - AsTesnorIsValid = false; } } else @@ -455,20 +469,33 @@ struct UniversalGemmKernel CK_TILE_ERROR( "Can't support M that is not a multiple of MPerBlock without padding!"); } - AsTesnorIsValid = false; + AsTensorIsValid = false; } if(kargs.M % vectorSizeA != 0) { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + const auto remainder = kargs.M % vectorSizeA; + constexpr ck_tile::index_t APackedSize = + ck_tile::numeric_traits::PackedSize; + const auto remainder_in_bytes = remainder * sizeof(ADataType) / APackedSize; + // oob can support to dword level + if(remainder_in_bytes % 4 == 0) { - CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!"); + + AsTensorIsValid = true; + } + else + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!"); + } + AsTensorIsValid = false; } - AsTesnorIsValid = false; } } }); - bool BsTesnorIsValid = {true}; + bool BsTensorIsValid = {true}; const auto vectorSizeB = is_wave32() ? GemmPipeline::template GetVectorSizeB() : GemmPipeline::template GetVectorSizeB(); static_for<0, NumBTensor, 1>{}([&](auto index) { @@ -482,47 +509,72 @@ struct UniversalGemmKernel CK_TILE_ERROR( "Can't support N that is not a multiple of NPerBlock without padding!"); } - BsTesnorIsValid = false; + BsTensorIsValid = false; } if(kargs.N % vectorSizeB != 0) { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + const auto remainder = kargs.N % vectorSizeB; + constexpr ck_tile::index_t BPackedSize = + ck_tile::numeric_traits::PackedSize; + const auto remainder_in_bytes = remainder * sizeof(BDataType) / BPackedSize; + // oob can support to dword level + if(remainder_in_bytes % 4 == 0) { - CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!"); + BsTensorIsValid = true; + } + else + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!"); + } + BsTensorIsValid = false; } - BsTesnorIsValid = false; } - } - else - { - if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 && - GemmPipeline::kPadK == false) + else { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 && + GemmPipeline::kPadK == false) { - CK_TILE_ERROR( - "Can't support K that is not a multiple of k_batch * KPerBlock " - "without padding!"); + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support K that is not a multiple of k_batch * KPerBlock " + "without padding!"); + } + BsTensorIsValid = false; } - BsTesnorIsValid = false; - } - if(kargs.K % vectorSizeB != 0) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + if(kargs.K % vectorSizeB != 0) { - CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!"); + const auto remainder = kargs.K % vectorSizeB; + constexpr ck_tile::index_t BPackedSize = + ck_tile::numeric_traits::PackedSize; + const auto remainder_in_bytes = remainder * sizeof(BDataType) / BPackedSize; + // oob can support to dword level + if(remainder_in_bytes % 4 == 0) + { + BsTensorIsValid = true; + } + else + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "K is not a multiple of vector load size for B tensor!"); + } + BsTensorIsValid = false; + } } - BsTesnorIsValid = false; } } }); - bool DTesnorIsValid = {true}; + bool DTensorIsValid = {true}; static_for<0, NumDTensor, 1>{}([&](auto index) { using DiLayout = remove_cvref_t>; if(std::is_same_v == false) { - DTesnorIsValid = false; + DTensorIsValid = false; } if constexpr(std::is_same_v) { @@ -533,7 +585,7 @@ struct UniversalGemmKernel CK_TILE_ERROR("Can't support N for tensor D that is not a multiple of " "NPerBlock without padding!"); } - DTesnorIsValid = false; + DTensorIsValid = false; } if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0) { @@ -541,7 +593,7 @@ struct UniversalGemmKernel { CK_TILE_ERROR("N is not a multiple of vector load size for D tensor!"); } - DTesnorIsValid = false; + DTensorIsValid = false; } } else @@ -553,7 +605,7 @@ struct UniversalGemmKernel CK_TILE_ERROR("Can't support M for tensor D that is not a multiple of " "MPerBlock without padding!"); } - DTesnorIsValid = false; + DTensorIsValid = false; } if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0) { @@ -561,7 +613,7 @@ struct UniversalGemmKernel { CK_TILE_ERROR("M is not a multiple of vector load size for D tensor!"); } - DTesnorIsValid = false; + DTensorIsValid = false; } } }); @@ -606,20 +658,16 @@ struct UniversalGemmKernel return false; } } - return AsTesnorIsValid && BsTesnorIsValid && DTesnorIsValid; + return AsTensorIsValid && BsTensorIsValid && DTensorIsValid; } - template CK_TILE_DEVICE static auto - MakeGemmTensorViews(const std::array& as_ptr, - const std::array& bs_ptr, - const std::array& ds_ptr, - EDataType* e_ptr, - const KernelArgs& kargs, - const index_t k_size) + MakeABlockWindows(const std::array& as_ptr, + const KernelArgs& kargs, + const index_t k_size, + const index_t i_m) { - static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!"); - + // Step 1: Create tensor views for A tensors (from MakeGemmTensorViews) const auto& as_tensor_view = generate_tuple( [&](auto i) { using AiLayout = remove_cvref_t>; @@ -645,6 +693,58 @@ struct UniversalGemmKernel }, number{}); + // Step 2: Create padded views (from MakeGemmPadViews) + const auto& as_pad_view = generate_tuple( + [&](auto i) { + using AiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return pad_tensor_view(as_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(as_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + }, + number{}); + + // Step 3: Create tile windows (from MakeGemmTileWindows) + const auto& as_block_window = generate_tuple( + [&](auto i) { + using AiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_tile_window(as_pad_view[i], + make_tuple(number{}, + number{}), + {i_m, 0}); + } + else + { + return make_tile_window(as_pad_view[i], + make_tuple(number{}, + number{}), + {0, i_m}); + } + }, + number{}); + + return as_block_window; + } + + CK_TILE_DEVICE static auto + MakeBBlockWindows(const std::array& bs_ptr, + const KernelArgs& kargs, + const index_t k_size, + const index_t i_n) + { + // Step 1: Create tensor views for B tensors (from MakeGemmTensorViews) const auto& bs_tensor_view = generate_tuple( [&](auto i) { using BiLayout = remove_cvref_t>; @@ -733,96 +833,20 @@ struct UniversalGemmKernel }, number{}); - const auto& ds_tensor_view = generate_tuple( - [&](auto i) { - using DiLayout = remove_cvref_t>; - using DDataType_ = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - static_cast(ds_ptr[i]), - make_tuple(kargs.M, kargs.N), - make_tuple(kargs.stride_Ds[i], 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - static_cast(ds_ptr[i]), - make_tuple(kargs.N, kargs.M), - make_tuple(kargs.stride_Ds[i], 1), - number{}, - number<1>{}); - } - }, - number{}); - - // TODO: enable vector write for C in ColMajor - const auto& e_tensor_view = [&]() { - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - e_ptr, - make_tuple(kargs.M, kargs.N), // arguments not matching with flatmm. - make_tuple(kargs.stride_E, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - e_ptr, - make_tuple(kargs.M, kargs.N), - make_tuple(1, kargs.stride_E), - number<1>{}, - number<1>{}); - } - }(); - - return make_tuple(as_tensor_view, bs_tensor_view, ds_tensor_view, e_tensor_view); - } - - template - CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) - { - const auto& as_pad_view = generate_tuple( - [&](auto i) { - const auto& a_tensor_view = views.at(I0); - using AiLayout = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return pad_tensor_view(a_tensor_view[i], - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(a_tensor_view[i], - make_tuple(number{}, - number{}), - sequence{}); - } - }, - number{}); - - const auto& b_flat_pad_view = views.at(I1); - + // Step 2: Create padded views (from MakeGemmPadViews) const auto& bs_pad_view = generate_tuple( [&](auto i) { - const auto& b_tensor_view = views.at(I1); - using BiLayout = remove_cvref_t>; + using BiLayout = remove_cvref_t>; if constexpr(std::is_same_v) { - return pad_tensor_view(b_tensor_view[i], + return pad_tensor_view(bs_tensor_view[i], make_tuple(number{}, number{}), sequence{}); } else { - return pad_tensor_view(b_tensor_view[i], + return pad_tensor_view(bs_tensor_view[i], make_tuple(number{}, number{}), sequence{}); @@ -830,86 +854,7 @@ struct UniversalGemmKernel }, number{}); - const auto& ds_pad_view = generate_tuple( - [&](auto i) { - const auto& d_tensor_view = views.at(I2); - using DiLayout = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return pad_tensor_view(d_tensor_view[i], - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(d_tensor_view[i], - make_tuple(number{}, - number{}), - sequence{}); - } - }, - number{}); - - // TODO vector write in for C in ColMajor - const auto& e_pad_view = [&]() { - const auto& e_tensor_view = views.at(I3); - if constexpr(std::is_same_v) - { - return pad_tensor_view(e_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(e_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - }(); - - if constexpr(GemmPipeline::Preshuffle) - { - // For flatmm, we need to use the flat B tensor view - return make_tuple(as_pad_view, b_flat_pad_view, ds_pad_view, e_pad_view); - } - else - { - return make_tuple(as_pad_view, bs_pad_view, ds_pad_view, e_pad_view); - } - } - - template - CK_TILE_DEVICE static auto - MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) - { - const auto& as_pad_view = views.at(I0); - const auto& bs_pad_view = views.at(I1); - const auto& ds_pad_view = views.at(I2); - const auto& e_pad_view = views.at(I3); - - const auto& as_block_window = generate_tuple( - [&](auto i) { - using AiLayout = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return make_tile_window(as_pad_view[i], - make_tuple(number{}, - number{}), - {i_m, 0}); - } - else - { - return make_tile_window(as_pad_view[i], - make_tuple(number{}, - number{}), - {0, i_m}); - } - }, - number{}); - + // Step 3: Create tile windows (from MakeGemmTileWindows) const auto& bs_block_window = generate_tuple( [&](auto i) { using BiLayout = remove_cvref_t>; @@ -942,7 +887,63 @@ struct UniversalGemmKernel }, number{}); - const auto ds_block_window = generate_tuple( + return bs_block_window; + } + + CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array& ds_ptr, + const KernelArgs& kargs, + const index_t i_m, + const index_t i_n) + { + // Step 1: Create tensor views for D tensors (from MakeGemmTensorViews) + const auto& ds_tensor_view = generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + using DDataType_ = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + static_cast(ds_ptr[i]), + make_tuple(kargs.M, kargs.N), + make_tuple(kargs.stride_Ds[i], 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + static_cast(ds_ptr[i]), + make_tuple(kargs.N, kargs.M), + make_tuple(kargs.stride_Ds[i], 1), + number{}, + number<1>{}); + } + }, + number{}); + + // Step 2: Create padded views (from MakeGemmPadViews) + const auto& ds_pad_view = generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return pad_tensor_view(ds_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(ds_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + }, + number{}); + + // Step 3: Create tile windows (from MakeGemmTileWindows) + const auto& ds_block_window = generate_tuple( [&](auto i) { using DiLayout = remove_cvref_t>; if constexpr(std::is_same_v) @@ -962,12 +963,62 @@ struct UniversalGemmKernel }, number{}); + return ds_block_window; + } + + template + CK_TILE_DEVICE static auto MakeCBlockWindows(EDataType* e_ptr, + const KernelArgs& kargs, + const index_t i_m, + const index_t i_n) + { + // Step 1: Create tensor view for E/C tensor (from MakeGemmTensorViews) + const auto& e_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + e_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(kargs.stride_E, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + e_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(1, kargs.stride_E), + number<1>{}, + number<1>{}); + } + }(); + + // Step 2: Create padded view (from MakeGemmPadViews) + const auto& e_pad_view = [&]() { + if constexpr(std::is_same_v) + { + return pad_tensor_view(e_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(e_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + // Step 3: Create tile window (from MakeGemmTileWindows) auto e_block_window = make_tile_window( e_pad_view, make_tuple(number{}, number{}), {i_m, i_n}); - return make_tuple(as_block_window, bs_block_window, ds_block_window, e_block_window); + return e_block_window; } /** @@ -977,7 +1028,7 @@ struct UniversalGemmKernel * @param bs_ptr input Bs pointer * @param ds_ptr input Ds pointer * @param e_ptr output E pointer - * @param smem_ptr_0 The start memory pointer of the shared memory block. + * @param smem_ptr The start memory pointer of the shared memory block. * @param kargs GEMM kernel arguments * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch. * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. @@ -989,96 +1040,90 @@ struct UniversalGemmKernel const std::array& bs_ptr, const std::array& ds_ptr, EDataType* e_ptr, - void* smem_ptr_0, + void* smem_ptr, const KernelArgs& kargs, const SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews( - as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k); - - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + // Create block windows using specialized methods + const auto& as_block_window = + MakeABlockWindows(as_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); + const auto& bs_block_window = + MakeBBlockWindows(bs_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); + const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n); const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); // Run GEMM cooperatively by whole workgroup. - const auto& as_block_window = gemm_tile_windows.at(I0); - const auto& bs_block_window = gemm_tile_windows.at(I1); - const auto& ds_block_window = gemm_tile_windows.at(I2); - const auto& c_block_tile = GemmPipeline{}.template operator()( - as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr_0); + as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr); - if(UseDefaultScheduler || (get_warp_id() == 0)) + const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch); + // Run Epilogue Pipeline + if(k_batch == 1) { - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I3); - - EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + auto c_block_window = MakeCBlockWindows( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr); + } + else + { + auto c_block_window = MakeCBlockWindows( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr); } } - /** - * @brief Runs single GEMM problem cooperatively by whole workgroup. - * - * @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism. - * - * @param as_ptr input As pointer - * @param bs_ptr input Bs pointer - * @param ds_ptr input Ds pointer - * @param e_ptr output E pointer - * @param smem_ptr_0 The starting pointer of 1st shared memory block. - * @param smem_ptr_1 The starting pointer of 2nd shared memory block. - * @param kargs GEMM kernel arguments - * @param splitk_batch_offset Utility structure used to calculate k batch. - * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. - * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. - * - */ - CK_TILE_DEVICE static void RunGemm2LDS(const std::array& as_ptr, - const std::array& bs_ptr, - const std::array& ds_ptr, - EDataType* e_ptr, - void* __restrict__ smem_ptr_0, - void* __restrict__ smem_ptr_1, - const KernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset, - const index_t block_idx_m, - const index_t block_idx_n) + CK_TILE_DEVICE static auto + GetTileCoordinates(const KernelArgs& kargs) -> tuple { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews( - as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k); + index_t iM, iN; - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + // Regular launch: use 1D block indexing + const auto blockId = amd_wave_read_first_lane(blockIdx.x); + const auto [tile_m, tile_n] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId); + iM = tile_m; + iN = tile_n; - const index_t num_loop = - amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); + const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock); + const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock); - // Run GEMM cooperatively by whole workgroup. - const auto& as_block_window = gemm_tile_windows.at(I0); - const auto& bs_block_window = gemm_tile_windows.at(I1); - const auto& ds_block_window = gemm_tile_windows.at(I2); + return make_tuple(i_m, i_n); + } - const auto& c_block_tile = GemmPipeline{}.template operator()(as_block_window, - AElementWise{}, - bs_block_window, - BElementWise{}, - num_loop, - smem_ptr_0, - smem_ptr_1); + // Helper functions + CK_TILE_DEVICE static auto GetBlockId() -> index_t + { + // For 1D regular launch + return amd_wave_read_first_lane(get_block_id()); + } - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I3); + CK_TILE_DEVICE static auto GetGridSize() -> index_t + { + // For 1D regular launch + return amd_wave_read_first_lane(get_grid_size()); + } - EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + // Helper to get total number of tiles, handling both dim3 and index_t return types + template + CK_TILE_HOST_DEVICE static auto GetNumTiles(Args&&... args) -> index_t + { + auto grid_size = TilePartitioner::GridSize(std::forward(args)...); + + using GridSizeType = decltype(grid_size); + + if constexpr(std::is_same_v) + { + // GridSize returns dim3: compute total tiles as x * y * z + return amd_wave_read_first_lane(grid_size.x * grid_size.y * grid_size.z); + } + else + { + // GridSize returns scalar (index_t): use directly + return amd_wave_read_first_lane(grid_size); + } } // Non-persistent kernel entry point @@ -1114,45 +1159,12 @@ struct UniversalGemmKernel } // allocate LDS - __shared__ char smem_ptr_0[GetSmemSize()]; + __shared__ char smem_ptr[GetSmemSize()]; - if constexpr(GemmPipeline::DoubleSmemBuffer == true) - { - __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; - if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && - EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) - { - RunGemm2LDS(as_ptr, - bs_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr_0, - smem_ptr_1, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - } - else - { - if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && - EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) - { - constexpr auto scheduler_type = (GemmPipeline::NumWaveGroups == 1); - RunGemm(as_ptr, - bs_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr_0, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - } + constexpr auto scheduler_type = + GemmPipeline::DoubleSmemBuffer || (GemmPipeline::NumWaveGroups == 1); + RunGemm( + as_ptr, bs_ptr, kargs.ds_ptr, e_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); } // Persistent kernel entry point @@ -1199,46 +1211,19 @@ struct UniversalGemmKernel } // allocate LDS - __shared__ char smem_ptr_0[GetSmemSize()]; + __shared__ char smem_ptr[GetSmemSize()]; // Run the GEMM - if constexpr(GemmPipeline::DoubleSmemBuffer == true) - { - __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; - if constexpr(!(EpiloguePipeline::MemoryOperation == - memory_operation_enum::atomic_add && - EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) - { - RunGemm2LDS(as_ptr, - bs_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr_0, - smem_ptr_1, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - } - else - { - if constexpr(!(EpiloguePipeline::MemoryOperation == - memory_operation_enum::atomic_add && - EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) - { - RunGemm(as_ptr, - bs_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr_0, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - } + + RunGemm(as_ptr, + bs_ptr, + kargs.ds_ptr, + e_ptr, + smem_ptr, + kargs, + splitk_batch_offset, + i_m, + i_n); + // Advance to the next work item block_id += grid_size; if(block_id >= num_work) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index 343e37ed66..4973d9c941 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -64,12 +64,17 @@ struct GemmPipelineAgBgCrImplBase CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } - template + template CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile, SrcTileWindow& dram_tile_window, const DramTileWindowStep& dram_tile_window_step) const { - load_tile(dst_block_tile, dram_tile_window); + load_int4_tile(dst_block_tile, dram_tile_window); move_tile_window(dram_tile_window, dram_tile_window_step); } @@ -217,22 +222,17 @@ struct GemmPipelineAgBgCrImplBase return std::move(a_copy_dram_window); } - template - CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const ALdsTensorView& a_lds_block_view, - const ALdsLoadTileDistr&, - const array& offset = {0, 0}) const + template + CK_TILE_DEVICE constexpr auto MakeALdsWindows(const ALdsTensorView& a_lds_block_view, + const ALdsLoadTileDistr&) const { - // A DRAM tile window for load - auto a_copy_dram_window = CopyADramWindow(a_dram_block_window_tmp, offset); - - // A LDS tile window for store auto a_lds_shape = []() { if constexpr(is_a_load_tr) return make_tuple(number{}, number{}); else return make_tuple(number{}, number{}); }(); + auto a_copy_lds_window = make_tile_window(a_lds_block_view, a_lds_shape, {0, 0}); auto a_lds_load_tile_distr = []() { @@ -244,32 +244,73 @@ struct GemmPipelineAgBgCrImplBase else return ALdsLoadTileDistr{}; }(); + auto a_lds_gemm_window = make_tile_window(a_lds_block_view, a_lds_shape, {0, 0}, a_lds_load_tile_distr); + return make_tuple(std::move(a_copy_lds_window), std::move(a_lds_gemm_window)); + } + + template < + typename ADramBlockWindowTmp, + typename ALdsTensorView, + typename ALdsLoadTileDistr, + typename std::enable_if_t::value, bool>* = nullptr> + CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const ALdsTensorView& a_lds_block_view, + const ALdsLoadTileDistr& a_lds_load_tile_distr, + const array& offset = {0, 0}) const + { + // A DRAM tile window for load + auto a_copy_dram_window = CopyADramWindow(a_dram_block_window_tmp, offset); + + // Create LDS windows + auto [a_copy_lds_window, a_lds_gemm_window] = + MakeALdsWindows(a_lds_block_view, a_lds_load_tile_distr); + return make_tuple(std::move(a_copy_dram_window), std::move(a_copy_lds_window), std::move(a_lds_gemm_window)); } - template - CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp, - const BLdsTensorView& b_lds_block_view, - const BLdsLoadTileDistr&, + // Unified GetAWindows that supports 1, 2, or 3 LDS buffers + template ::value, bool>* = + nullptr> + CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const ALdsTensorViewsTuple& a_lds_block_views_tuple, + const ALdsLoadTileDistr& a_lds_load_tile_distr, const array& offset = {0, 0}) const { // A DRAM tile window for load - auto b_copy_dram_window = CopyBDramWindow(b_dram_block_window_tmp, offset); + auto a_copy_dram_window = CopyADramWindow(a_dram_block_window_tmp, offset); - // TODO: Do we really need those two tile windows??? - // They're exactly same... - // B LDS tile window for store + // Create LDS windows for each buffer + constexpr index_t num_buffers = ALdsTensorViewsTuple::size(); + auto a_lds_windows = generate_tuple( + [&](auto i) { + return MakeALdsWindows(a_lds_block_views_tuple[i], a_lds_load_tile_distr); + }, + number{}); + + // Return: (dram_window, lds_windows_tuple) + // lds_windows_tuple[i] = (copy_lds_window_i, lds_gemm_window_i) + return make_tuple(std::move(a_copy_dram_window), std::move(a_lds_windows)); + } + + template + CK_TILE_DEVICE constexpr auto MakeBLdsWindows(const BLdsTensorView& b_lds_block_view, + const BLdsLoadTileDistr&) const + { auto b_lds_shape = []() { if constexpr(is_b_load_tr) return make_tuple(number{}, number{}); else return make_tuple(number{}, number{}); }(); + auto b_copy_lds_window = make_tile_window(b_lds_block_view, b_lds_shape, {0, 0}); using BLdsDataType = @@ -286,13 +327,61 @@ struct GemmPipelineAgBgCrImplBase else return BLdsLoadTileDistr{}; }(); + auto b_lds_gemm_window = make_tile_window(b_lds_block_view, b_lds_shape, {0, 0}, b_lds_load_tile_distr); + return make_tuple(std::move(b_copy_lds_window), std::move(b_lds_gemm_window)); + } + + template < + typename BDramBlockWindowTmp, + typename BLdsTensorView, + typename BLdsLoadTileDistr, + typename std::enable_if_t::value, bool>* = nullptr> + CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BLdsTensorView& b_lds_block_view, + const BLdsLoadTileDistr& b_lds_load_tile_distr, + const array& offset = {0, 0}) const + { + // A DRAM tile window for load + auto b_copy_dram_window = CopyBDramWindow(b_dram_block_window_tmp, offset); + + // Create LDS windows + auto [b_copy_lds_window, b_lds_gemm_window] = + MakeBLdsWindows(b_lds_block_view, b_lds_load_tile_distr); + return make_tuple(std::move(b_copy_dram_window), std::move(b_copy_lds_window), std::move(b_lds_gemm_window)); } + + // Unified GetBWindows that supports 1, 2, or 3 LDS buffers + template ::value, bool>* = + nullptr> + CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BLdsTensorViewsTuple& b_lds_block_views_tuple, + const BLdsLoadTileDistr& b_lds_load_tile_distr, + const array& offset = {0, 0}) const + { + // B DRAM tile window for load + auto b_copy_dram_window = CopyBDramWindow(b_dram_block_window_tmp, offset); + + // Create LDS windows for each buffer + constexpr index_t num_buffers = BLdsTensorViewsTuple::size(); + auto b_lds_windows = generate_tuple( + [&](auto i) { + return MakeBLdsWindows(b_lds_block_views_tuple[i], b_lds_load_tile_distr); + }, + number{}); + + // Return: (dram_window, lds_windows_tuple) + // lds_windows_tuple[i] = (copy_lds_window_i, lds_gemm_window_i) + return make_tuple(std::move(b_copy_dram_window), std::move(b_lds_windows)); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp index 0b2cdde05e..8acfea4580 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp @@ -158,6 +158,8 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}; @@ -172,7 +174,8 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync(); + constexpr index_t smem_size = Policy::template GetSmemSize(); + return 2 * smem_size; } CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() @@ -240,8 +243,7 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync); @@ -303,8 +305,10 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}); // this pipeline has a pair of LDS buffers per logical tile - auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0); - auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1); + constexpr index_t smem_size = Policy::template GetSmemSize(); + auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem); + auto&& [a_lds_block1, b_lds_block1] = + Base::GetABLdsTensorViews(static_cast(p_smem) + smem_size); // set up LDS tile shapes constexpr auto a_lds_shape = []() { @@ -534,21 +538,18 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}.template operator()( a_dram_block_window_tmp, a_element_func, b_dram_block_window_tmp, b_element_func, num_loop, - p_smem_0, - p_smem_1); + p_smem); }; return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); @@ -559,8 +560,7 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync static constexpr auto is_a_load_tr_v = bool_constant{}; static constexpr auto is_b_load_tr_v = bool_constant{}; + static_assert(DoubleSmemBuffer == true, "pipeline requires double smem buffer"); + [[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName() { // clang-format off @@ -191,7 +193,8 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { - return Policy::template GetSmemSize(); + constexpr index_t smem_size = Policy::template GetSmemSize(); + return 2 * smem_size; } CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() @@ -281,8 +284,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, - void* __restrict__ p_smem_0, - void* __restrict__ p_smem_1) const + void* __restrict__ p_smem) const { using ADramBlockWindowTmp = remove_cvref_t{}, AsDramBlockWindowTmp>>; @@ -324,8 +326,10 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 // global read 0 ////////////// LDS desc, window & register ///////////////// - auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0); - auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1); + constexpr index_t smem_size = Policy::template GetSmemSize(); + auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem); + auto&& [a_lds_block1, b_lds_block1] = + Base::GetABLdsTensorViews(static_cast(p_smem) + smem_size); constexpr auto a_lds_shape = []() { if constexpr(is_a_load_tr_v()) @@ -680,8 +684,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, - void* p_smem_0, - void* p_smem_1) const + void* p_smem) const { const bool has_hot_loop = Base::BlockHasHotloop(num_loop); const auto tail_number = Base::GetBlockLoopTailNum(num_loop); @@ -693,8 +696,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 b_dram_block_window_tmp, b_element_func, num_loop, - p_smem_0, - p_smem_1); + p_smem); }; return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); @@ -708,8 +710,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, const BsDramBlockWindowTmp& b_dram_block_window_tmp, const index_t num_loop, - void* __restrict__ p_smem_0, - void* __restrict__ p_smem_1) const + void* __restrict__ p_smem) const { const bool has_hot_loop = Base::BlockHasHotloop(num_loop); const auto tail_number = Base::GetBlockLoopTailNum(num_loop); @@ -721,8 +722,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 b_dram_block_window_tmp, [](auto& e, const BDataType& b) { e = b; }, num_loop, - p_smem_0, - p_smem_1); + p_smem); }; return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); @@ -738,8 +738,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 index_t num_loop, bool has_hot_loop, TailNumber tail_number, - void* __restrict__ p_smem_0, - void* __restrict__ p_smem_1) const + void* __restrict__ p_smem) const { const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { constexpr bool hot_loop = hot_loop_.value; @@ -751,8 +750,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 b_dram_block_window_tmp, PassThrough, num_loop, - p_smem_0, - p_smem_1); + p_smem); }; return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } @@ -769,16 +767,14 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 const BDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, - void* p_smem_0, - void* p_smem_1) const + void* p_smem) const { return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), a_element_func, ck_tile::make_tuple(b_dram_block_window_tmp), b_element_func, num_loop, - p_smem_0, - p_smem_1); + p_smem); } template CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, const BDramBlockWindowTmp& b_dram_block_window_tmp, const index_t num_loop, - void* __restrict__ p_smem_0, - void* __restrict__ p_smem_1) const + void* __restrict__ p_smem) const { return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), ck_tile::make_tuple(b_dram_block_window_tmp), num_loop, - p_smem_0, - p_smem_1); + p_smem); } template index_t num_loop, bool has_hot_loop, TailNumber tail_number, - void* __restrict__ p_smem_0, - void* __restrict__ p_smem_1) const + void* __restrict__ p_smem) const { return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), ck_tile::make_tuple(b_dram_block_window_tmp), num_loop, has_hot_loop, tail_number, - p_smem_0, - p_smem_1); + p_smem); } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index a45d41189b..6199142d98 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -246,9 +246,11 @@ struct UniversalGemmBasePolicy } else // A is in RowMajor { - constexpr auto DataTypeSize = sizeof(ADataType); + constexpr auto DataTypeSize = sizeof(ADataType); + constexpr uint64_t MinLdsLayer = 1ULL; constexpr auto MLdsLayer = - max(1UL, get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize); + max(MinLdsLayer, + get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize); constexpr index_t NBanks = get_n_lds_banks(); static_assert(NBanks == 32 || NBanks == 64, "Unexpected LDS bank count"); @@ -442,11 +444,13 @@ struct UniversalGemmBasePolicy } else // B is Column Major { - constexpr index_t KPack = GetSmemPackB(); - constexpr auto BK0 = number{}; - constexpr auto DataTypeSize = sizeof(BDataType); + constexpr index_t KPack = GetSmemPackB(); + constexpr auto BK0 = number{}; + constexpr auto DataTypeSize = sizeof(BDataType); + constexpr uint64_t MinLdsLayer = 1ULL; constexpr auto NLdsLayer = - max(1UL, get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize); + max(MinLdsLayer, + get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize); constexpr index_t NBanks = get_n_lds_banks(); static_assert(NBanks == 32 || NBanks == 64, "Unexpected LDS bank count"); @@ -841,10 +845,10 @@ struct UniversalGemmBasePolicy template CK_TILE_DEVICE static constexpr index_t GetSmemSizeA() { - constexpr index_t smem_size_a = - integer_least_multiple(sizeof(typename Problem::ADataType) * - Problem::BlockGemmShape::kM * Problem::BlockGemmShape::kK, - 16); + using ADataType = remove_cvref_t; + constexpr auto a_lds_block_desc = Derived::template MakeALdsBlockDescriptor(); + constexpr index_t smem_size_a = integer_least_multiple( + a_lds_block_desc.get_element_space_size() * sizeof(ADataType), 16); return smem_size_a; } @@ -855,8 +859,9 @@ struct UniversalGemmBasePolicy std::conditional_t, typename Problem::ADataType, typename Problem::BDataType>; - constexpr index_t smem_size_b = integer_least_multiple( - sizeof(BDataType) * Problem::BlockGemmShape::kN * Problem::BlockGemmShape::kK, 16); + constexpr auto b_lds_block_desc = Derived::template MakeBLdsBlockDescriptor(); + constexpr index_t smem_size_b = integer_least_multiple( + b_lds_block_desc.get_element_space_size() * sizeof(BDataType), 16); return smem_size_b; } diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp index 47607a40f5..5b00eb244b 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp @@ -53,11 +53,11 @@ struct TileGemmUniversalTraits static constexpr int _VectorSize = VectorSize_; static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_; - using AsLayout = AsLayout_; - using BsLayout = BsLayout_; - using CLayout = CLayout_; + using AsLayout = AsLayout_; + using BsLayout = BsLayout_; + using CLayout = CLayout_; + static constexpr bool TransposeC = TransposeC_; - static constexpr bool TransposeC = TransposeC_; static constexpr bool UseStructuredSparsity = UseStructuredSparsity_; static constexpr bool UsePersistentKernel = UsePersistentKernel_; static constexpr index_t NumWaveGroups = NumWaveGroups_; diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp index 019a828ec0..e90c6a27d7 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" namespace ck_tile { @@ -201,6 +202,12 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy { using TileShape = typename Problem::BlockGemmShape; + constexpr index_t kNPerBlock = TileShape::kN; + constexpr index_t kKPerBlock = TileShape::kK; + constexpr index_t NIterPerWarp = + kNPerBlock / TileShape::BlockWarps::at(I1) / TileShape::WarpTile::at(I1); + constexpr index_t KIterPerWarp = kKPerBlock / TileShape::WarpTile::at(I2); + constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t WaveSize = get_warp_size(); constexpr index_t WaveNum = BlockSize / WaveSize; @@ -213,13 +220,13 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy #endif constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim constexpr index_t KWavePerBlk = 1; - constexpr index_t KRepeat = 1; + constexpr index_t KRepeat = KIterPerWarp; static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong"); constexpr index_t NBPerLoad = 1; constexpr index_t NThdPerWave = 1; constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp - constexpr index_t NRepeat = 1; + constexpr index_t NRepeat = NIterPerWarp; constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp; return make_static_tile_distribution( @@ -232,8 +239,8 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy tuple, sequence<0, 1, 2>>, // which direction tuple, sequence<1, 2, 2>>, // which index // - sequence<1, 1, 2, 2>, - sequence<0, 3, 0, 3>>{}); + sequence<1, 2, 1, 2>, + sequence<0, 0, 3, 3>>{}); } template @@ -307,7 +314,7 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy typename Problem::CDataType, BlockWarps, WarpGemm>; - return BlockWeightPreshuffleASmemBSmemCRegV1{}; + return BlockWeightPreshuffleASmemBRegCReg{}; } /** * @brief Get the vector store size for C tensor. @@ -325,7 +332,7 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC() { using BlockGemm = remove_cvref_t())>; - using WG_ = typename BlockGemm::WG; + using WG_ = typename BlockGemm::WarpGemm; constexpr bool TransposeC = Problem::TransposeC; using CLayout = typename Problem::CLayout; diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp index f64901755b..c9499106de 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp @@ -32,19 +32,34 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV2 template CK_TILE_HOST_DEVICE static auto - TailHandler(const RunFunction& run_func, bool, TailNumber tail_number) + TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number) { - if(tail_number == TailNumber::Odd) + if(has_hot_loop) { - return run_func(bool_constant{}, - integral_constant{}); + if(tail_number == TailNumber::Odd) + { + return run_func(bool_constant{}, + integral_constant{}); + } + else // Even tail number + { + return run_func(bool_constant{}, + integral_constant{}); + } } - else // Even tail number + else { - return run_func(bool_constant{}, - integral_constant{}); + if(tail_number == TailNumber::Odd) + { + return run_func(bool_constant{}, + integral_constant{}); + } + else // Even tail number + { + return run_func(bool_constant{}, + integral_constant{}); + } } - return run_func(bool_constant{}, integral_constant{}); } }; @@ -52,7 +67,8 @@ template { - using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV2; + using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV2; + using PipelineImplBase = GemmPipelineAgBgCrImplBase; using AsDataType = remove_cvref_t; using BsDataType = remove_cvref_t; @@ -75,11 +91,6 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 using BlockWeightPreshuffle = remove_cvref_t())>; - static constexpr auto config = - BlockWeightPreshuffle::BlockPolicy::template GetWarpGemmMWarpNWarp(); - - using WG = remove_cvref_t())>; - static constexpr index_t DsWritePreIssue = 3; // default 2, ds write at MIter - 2 static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read @@ -95,6 +106,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t KPerBlock = BlockGemmShape::kK; + static constexpr index_t kflatKPerBlock = BlockGemmShape::flatKPerBlock; + static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp; static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp; @@ -131,12 +144,16 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 using BlockWarps = remove_cvref_t; using WarpTile = remove_cvref_t; - static constexpr index_t MWarp = config.template at<1>(); - static constexpr index_t NWarp = config.template at<2>(); + static constexpr index_t MWarp = BlockWarps::at(I0); + static constexpr index_t NWarp = BlockWarps::at(I1); - static constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); - static constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN); - static constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; + static constexpr index_t WarpTileM = WarpTile::at(I0); + static constexpr index_t WarpTileN = WarpTile::at(I1); + static constexpr index_t WarpTileK = WarpTile::at(I2); + + static constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpTileM); + static constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpTileN); + static constexpr index_t KIterPerWarp = kKPerBlock / WarpTileK; static constexpr index_t KFlatPerBlockPerIter = flatKPerWarp; static constexpr index_t NFlatPerBlockPerIter = flatNPerWarp; @@ -154,20 +171,20 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 #else static constexpr index_t mfma_per_wg = 1; #endif - static constexpr index_t dsread_per_wg = - max(index_t(WG::kM * WG::kK * sizeof(ADataType) / WaveSize / Problem::VectorLoadSize), 1); + static constexpr index_t dsread_per_wg = max( + index_t(WarpTileM * WarpTileK * sizeof(ADataType) / WaveSize / Problem::VectorLoadSize), 1); #if defined(__HIP_DEVICE_COMPILE__) - static_assert((WG::kM * WG::kK * sizeof(ADataType) * MIterPerWarp / WaveSize) % + static_assert((WarpTileM * WarpTileK * sizeof(ADataType) * MIterPerWarp / WaveSize) % Problem::VectorLoadSize == 0); #endif - static constexpr index_t dsread_num_perK = - WG::kM * WG::kK * sizeof(ADataType) * MIterPerWarp / WaveSize / Problem::VectorLoadSize; + static constexpr index_t dsread_num_perK = WarpTileM * WarpTileK * sizeof(ADataType) * + MIterPerWarp / WaveSize / Problem::VectorLoadSize; static constexpr index_t dswrite_num_perK = dsread_num_perK / (MWarp * NWarp); static constexpr index_t dswrite_rep = (dswrite_num_perK + MIterPerWarp - 1) / MIterPerWarp; static constexpr index_t Aload_num_perK = dswrite_num_perK; static constexpr index_t Aload_rep = dswrite_rep; - static constexpr index_t Bload_num_perK = kNPerBlock * WG::kK / NWarp / K1 / WaveSize; + static constexpr index_t Bload_num_perK = kNPerBlock * WarpTileK / NWarp / K1 / WaveSize; static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2; static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter; @@ -187,7 +204,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 // clang-format off return concat('_', "pipeline_AGmemBGmemCRegV2", concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize), - concat('x', WG::kM, WG::kN, WG::kK), + concat('x', WarpTileM, WarpTileN, WarpTileK), concat('x', GetVectorSizeA(), GetVectorSizeB()), concat('x', kPadM, kPadN, kPadK)); @@ -195,14 +212,16 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 } static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; - static constexpr index_t Preshuffle = Problem::Preshuffle; + + static constexpr index_t Preshuffle = Problem::Preshuffle; using Base::UsePersistentKernel; CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { - return PipelinePolicy::template GetSmemSize(); + constexpr index_t smem_size = PipelinePolicy::template GetSmemSize(); + return DoubleSmemBuffer ? 2 * smem_size : smem_size; } // dsread_perM: how many LDS reads want to issue in this M-iter @@ -515,515 +534,184 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 // __builtin_amdgcn_sched_barrier(0); } - template ::value && - !is_detected::value, - bool>* = nullptr, - index_t UnaryOpSize_ = 8> - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const AElementFunction& a_element_func, - const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, - index_t num_loop, - void* p_smem_ping, - void* p_smem_pong) const + struct PipelineImpl : public PipelineImplBase { - static_assert( - std::is_same_v>, - "wrong!"); + using Base = PipelineImplBase; - static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}], - "wrong!"); - static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], - "wrong!"); - - constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1; - const index_t iMWarp = get_warp_id() / NWarp; - - using CWarpDstr = typename WG::CWarpDstr; - using CWarpTensor = typename WG::CWarpTensor; - - constexpr auto c_warp_y_lengths = - to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - - __builtin_amdgcn_sched_barrier(0); - - // A tile in LDS - ADataType* p_a_lds_ping = static_cast(p_smem_ping); - ADataType* p_a_lds_pong = static_cast(p_smem_pong); - - constexpr auto a_lds_block_desc = - PipelinePolicy::template MakeALdsBlockDescriptor(); - - auto a_lds_block_ping = - make_tensor_view(p_a_lds_ping, a_lds_block_desc); - auto a_lds_block_pong = - make_tensor_view(p_a_lds_pong, a_lds_block_desc); - - // A DRAM tile window for load - auto a_copy_dram_window = - make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_dram_block_window_tmp.get_window_origin(), - PipelinePolicy::template MakeADramTileDistribution()); - - auto a_copy_lds_window_ping = - make_tile_window(a_lds_block_ping, - make_tuple(number{}, number{}), - {0, 0}, - PipelinePolicy::template MakeADramTileDistribution()); - - auto a_copy_lds_window_pong = - make_tile_window(a_lds_block_pong, - make_tuple(number{}, number{}), - {0, 0}, - PipelinePolicy::template MakeADramTileDistribution()); - - // ping-pong window for A LDS - auto a_warp_window_ping_tmp = - make_tile_window(a_lds_block_ping, - make_tuple(number{}, number{}), - {iMWarp * WG::kM, 0}, - make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); - - auto a_warp_window_pong_tmp = - make_tile_window(a_lds_block_pong, - make_tuple(number{}, number{}), - {iMWarp * WG::kM, 0}, - make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); - - statically_indexed_array< - statically_indexed_array, - MIterPerWarp> - a_warp_windows_ping; - - statically_indexed_array< - statically_indexed_array, - MIterPerWarp> - a_warp_windows_pong; - - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; - - move_tile_window(a_warp_windows_ping(mIter)(kIter), - {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); - }); - }); - - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp; - - move_tile_window(a_warp_windows_pong(mIter)(kIter), - {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); - }); - }); - - // Block GEMM - auto block_weight_preshuffle = BlockWeightPreshuffle(); - // Acc register tile - auto c_block_tile = block_weight_preshuffle.MakeCBlockTile(); - - // B flat DRAM window for load - auto b_flat_distribution = - PipelinePolicy::template MakeBFlatDramTileDistribution(); - auto b_flat_dram_window = // tile_window_with_static_distribution - make_tile_window( - b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views - make_tuple(number{}, number{}), - b_flat_dram_block_window_tmp.get_window_origin(), - b_flat_distribution); - - // pingpong buffer for B - using BTypeToUse = - std::conditional_t, ADataType, BDataType>; - using BTileType = decltype(make_static_distributed_tensor(b_flat_distribution)); - - statically_indexed_array< - statically_indexed_array, - NIterPerWarp> - b_flat_dram_windows; - - statically_indexed_array, NIterPerWarp> - b_warp_tensor_ping; - - statically_indexed_array, NIterPerWarp> - b_warp_tensor_pong; - - // Prefetch A0 - auto a_block_tile = load_tile(a_copy_dram_window); - // move A window to next k - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - - // prefetch B - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - - load_int4_tile( - b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); - }); - }); - // move B window to next flat K - move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); - - // Prefill A0 - auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window_ping, a_block_tile_tmp); - - __builtin_amdgcn_sched_barrier(0); - - // Prefetch A1 - a_block_tile = load_tile(a_copy_dram_window); - // move A window to next k - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - - // initialize C - tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); - - block_sync_lds(); - - // preload A00,A10 from lds - statically_indexed_array{})(number<0>{}))), - m_preload> - a_warp_tensor; - - static_for<0, m_preload, 1>{}([&](auto loadIter) { - constexpr auto mIter = loadIter % MIterPerWarp; - constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor(loadIter) = - load_tile(a_warp_windows_ping(number{})(number{})); - }); - __builtin_amdgcn_sched_barrier(0); - - // MAIN LOOP - index_t iCounter = (num_loop - 1) / 2; - while(iCounter > 0) + template ::value && + !is_detected::value, + bool>* = nullptr, + index_t UnaryOpSize_ = 8> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + [[maybe_unused]] const AElementFunction& a_element_func, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const { - // prefetch B(2i+1) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + static_assert( + std::is_same_v>, + "wrong!"); - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}], + "wrong!"); + static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); - load_int4_tile( - b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); - }); - }); + // A tile in LDS + constexpr index_t smem_size = PipelinePolicy::template GetSmemSize(); - // Prefill A(2i+1) - a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window_pong, a_block_tile_tmp); + constexpr auto a_lds_block_desc = + PipelinePolicy::template MakeALdsBlockDescriptor(); - // Prefetch A(2i+2) - a_block_tile = load_tile(a_copy_dram_window); - // move A window to next k - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + auto a_lds_blocks = generate_tuple( + [&](auto i) { + ADataType* p_a_lds = static_cast( + static_cast(static_cast(p_smem) + smem_size * i.value)); + return make_tensor_view(p_a_lds, a_lds_block_desc); + }, + number<2>{}); - // GEMM 2i - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; + constexpr auto a_lds_load_tile_distr = make_static_tile_distribution( + BlockWeightPreshuffle::MakeABlockDistributionEncode()); + auto&& windows_result = + Base::GetAWindows(a_dram_block_window_tmp, a_lds_blocks, a_lds_load_tile_distr); + auto&& a_copy_dram_window = windows_result.template get<0>(); + auto&& a_lds_windows = windows_result.template get<1>(); + auto a_copy_lds_windows = generate_tuple( + [&](auto i) -> decltype(auto) { return a_lds_windows[i].template at<0>(); }, + number<2>{}); + // Block GEMM + auto block_weight_preshuffle = BlockWeightPreshuffle(); + // Acc register tile + auto c_block_tile = block_weight_preshuffle.MakeCBlockTile(); - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + auto a_load_windows = generate_tuple( + [&](auto i) -> decltype(auto) { + return block_weight_preshuffle.MakeALoadWindows(a_copy_lds_windows[i]); + }, + number<2>{}); - // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_ping(nIter)(kIter)); + // B flat DRAM window for load + auto b_flat_distribution = + PipelinePolicy::template MakeBFlatDramTileDistribution(); + auto b_flat_dram_window = // tile_window_with_static_distribution + make_tile_window(b_flat_dram_block_window_tmp + .get_bottom_tensor_view(), // from kernel gemm_pad_views + make_tuple(number{}, + number{}), + b_flat_dram_block_window_tmp.get_window_origin(), + b_flat_distribution); - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BFlatBlockWindowTmp::BottomTensorIndex; + constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, kKPerBlock); + constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, kflatKPerBlock); - __builtin_amdgcn_sched_barrier(0x7F6); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) + using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); + using ABlockTile = + decltype(make_static_distributed_tensor(ABlockTileDistr{})); + + using BTypeToUse = + std::conditional_t, ADataType, BDataType>; + using BBlockTile = + decltype(make_static_distributed_tensor(b_flat_distribution)); + + ABlockTile a_global_tile; + BBlockTile b_global_tile[2]; + + // // Prefetch A0 + Base::GlobalPrefetch(a_global_tile, a_copy_dram_window, a_dram_tile_window_step); + + Base::template GlobalPrefetch( + b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step); + + // Prefill A0 + Base::LocalPrefill(a_copy_lds_windows[I0], a_global_tile); + + // Prefetch A1 + Base::GlobalPrefetch(a_global_tile, a_copy_dram_window, a_dram_tile_window_step); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + block_sync_lds(); + + // preload A00,A10 from lds + block_weight_preshuffle.LocalPrefetch(a_load_windows[I0]); + + __builtin_amdgcn_sched_barrier(0); + // MAIN LOOP + if constexpr(HasHotLoop) + { + index_t i_global_read = amd_wave_read_first_lane(2); + do + { { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_ping(number{})(number{})); - } + Base::template GlobalPrefetch( + b_global_tile[1], b_flat_dram_window, b_dram_tile_window_step); + Base::LocalPrefill(a_copy_lds_windows[I1], a_global_tile); + Base::GlobalPrefetch( + a_global_tile, a_copy_dram_window, a_dram_tile_window_step); + block_weight_preshuffle(c_block_tile, + a_load_windows[I0], + b_global_tile[0], + b_flat_distribution); - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + block_weight_preshuffle.LocalPrefetch(a_load_windows[I1]); + HotLoopScheduler(); + } { - block_sync_lds(); + Base::template GlobalPrefetch( + b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step); + Base::LocalPrefill(a_copy_lds_windows[I0], a_global_tile); + Base::GlobalPrefetch( + a_global_tile, a_copy_dram_window, a_dram_tile_window_step); + block_weight_preshuffle(c_block_tile, + a_load_windows[I1], + b_global_tile[1], + b_flat_distribution); + + block_weight_preshuffle.LocalPrefetch(a_load_windows[I0]); + HotLoopScheduler(); } - }); - }); - // move B window to next flat K - move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + i_global_read += 2; + } while(i_global_read < num_loop); + } - static_for<0, m_preload, 1>{}([&](auto loadIter) { - constexpr auto mIter = loadIter % MIterPerWarp; - constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor(loadIter) = - load_tile(a_warp_windows_pong(number{})(number{})); - }); - HotLoopScheduler(); + // tail + if constexpr(TailNum == TailNumber::Even) + { + { + Base::template GlobalPrefetch( + b_global_tile[1], b_flat_dram_window, b_dram_tile_window_step); + Base::LocalPrefill(a_copy_lds_windows[I1], a_global_tile); + block_weight_preshuffle( + c_block_tile, a_load_windows[I0], b_global_tile[0], b_flat_distribution); + block_sync_lds(); + block_weight_preshuffle.LocalPrefetch(a_load_windows[I1]); + Last2ndHotLoopScheduler(); + } + { + block_weight_preshuffle( + c_block_tile, a_load_windows[I1], b_global_tile[1], b_flat_distribution); + LastHotLoopScheduler(); + } + } + else if constexpr(TailNum == TailNumber::Odd) + { + block_weight_preshuffle( + c_block_tile, a_load_windows[I0], b_global_tile[0], b_flat_distribution); + LastHotLoopScheduler(); + } - // Next K - - // prefetch B(2i+2) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - - load_int4_tile( - b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); - }); - }); - - // Prefill A(2i+2) - a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window_ping, a_block_tile_tmp); - - // Prefetch A(2i+3) - a_block_tile = load_tile(a_copy_dram_window); - // move A window to next k - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - - // GEMM 2i+1 - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_pong(nIter)(kIter)); - - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - - __builtin_amdgcn_sched_barrier(0x7F6); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_pong(number{})(number{})); - } - - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - block_sync_lds(); - } - }); - }); - // move B window to next flat K - move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); - - static_for<0, m_preload, 1>{}([&](auto loadIter) { - constexpr auto mIter = loadIter % MIterPerWarp; - constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor(loadIter) = - load_tile(a_warp_windows_ping(number{})(number{})); - }); - HotLoopScheduler(); - - iCounter--; + return c_block_tile; } - - // tail - if constexpr(TailNum == TailNumber::Even) - { - // __builtin_amdgcn_sched_barrier(0); - // prefetch B(loopK) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - - load_int4_tile( - b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); - }); - }); - - // Prefill A(loopK) - a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window_pong, a_block_tile_tmp); - - // GEMM loopK-1 - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_ping(nIter)(kIter)); - - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - - __builtin_amdgcn_sched_barrier(0x7F6); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_ping(number{})(number{})); - } - - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - block_sync_lds(); - } - }); - }); - // TailHotLoopScheduler(); - - static_for<0, m_preload, 1>{}([&](auto loadIter) { - constexpr auto mIter = loadIter % MIterPerWarp; - constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor(loadIter) = - load_tile(a_warp_windows_pong(number{})(number{})); - }); - - Last2ndHotLoopScheduler(); - - // GEMM loopK - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_pong(nIter)(kIter)); - - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_pong(number{})(number{})); - } - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - block_sync_lds(); - } - }); - }); - LastHotLoopScheduler(); - } - else if constexpr(TailNum == TailNumber::Odd) - { - // GEMM loopK - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor(number{}), - b_warp_tensor_ping(nIter)(kIter)); - - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - - __builtin_amdgcn_sched_barrier(0x7F6); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor(number{}) = - load_tile(a_warp_windows_ping(number{})(number{})); - } - - // barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - block_sync_lds(); - } - }); - }); - LastHotLoopScheduler(); - } - - return c_block_tile; - } + }; // called from universal gemm kernel template (a_dram_block_window_tmp[number<0>{}], - PassThrough, - b_flat_dram_block_window_tmp[number<0>{}], - num_loop, - p_smem_ping, - p_smem_pong); + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp[number<0>{}], + a_element_func, + b_flat_dram_block_window_tmp[number<0>{}], + num_loop, + p_smem); }; - return Base::TailHandler(RunPipeline, true, tail_number); + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } // called from general gemm kernel @@ -1066,23 +751,21 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, index_t num_loop, - void* p_smem_ping, - void* p_smem_pong) const + void* p_smem) const { - const auto tail_number = Base::GetBlockLoopTailNum(num_loop); + const auto has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto tail_number = Base::GetBlockLoopTailNum(num_loop); - const auto RunPipeline = [&](auto bool_val, auto tail_num_) { - (void)bool_val; // Suppress unused parameter warning - constexpr auto tail_num = tail_num_.value; + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { constexpr auto PassThrough = [](const ADataType& a) { return a; }; - return operator()(a_dram_block_window_tmp, - PassThrough, - b_flat_dram_block_window_tmp, - num_loop, - p_smem_ping, - p_smem_pong); + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + PassThrough, + b_flat_dram_block_window_tmp, + num_loop, + p_smem); }; - return Base::TailHandler(RunPipeline, true, tail_number); + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } // called from grouped gemm kernel @@ -1095,21 +778,19 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, index_t num_loop, TailNumber tail_number, - void* __restrict__ p_smem_0, - void* __restrict__ p_smem_1) const + void* __restrict__ p_smem) const { - const auto RunPipeline = [&](auto bool_val, auto tail_num_) { - (void)bool_val; // Suppress unused parameter warning - constexpr auto tail_num = tail_num_.value; + const auto has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { constexpr auto PassThrough = [](const auto& x) { return x; }; - return operator()(a_dram_block_window_tmp, - PassThrough, - b_flat_dram_block_window_tmp, - num_loop, - p_smem_0, - p_smem_1); + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + PassThrough, + b_flat_dram_block_window_tmp, + num_loop, + p_smem); }; - return Base::TailHandler(RunPipeline, true, tail_number); + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } }; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index c0fbf8e5d3..7bcc9107da 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -306,6 +306,16 @@ using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl, 2>>; +using WarpGemmMfma_f32_16x16x64_fp8_fp8_CTransposed = + WarpGemmImpl, + 2>>; + +using WarpGemmMfma_f32_16x16x64_bf8_bf8_CTransposed = + WarpGemmImpl, + 2>>; + template using WarpGemmMfma_f32_16x16x128_f8f6f4 = WarpGemmImpl< WarpGemmAttributeMfma, AttrNumAccess>>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp index ff2ba501fe..ef31d06c9c 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp @@ -68,6 +68,19 @@ struct WarpGemmAttributeWmma { using Impl = remove_cvref_t; + // When kTransC is true and A/B types differ, we need an impl with swapped types + using TransposedImpl = + std::conditional_t, + WarpGemmAttributeWmmaImpl>, + Impl>; + using ADataType = typename Impl::ADataType; using BDataType = typename Impl::BDataType; using CDataType = typename Impl::CDataType; @@ -104,7 +117,7 @@ struct WarpGemmAttributeWmma { if constexpr(kTransC) { - Impl{}(c_vec, b_vec, a_vec, bool_constant{}); + TransposedImpl{}(c_vec, b_vec, a_vec, bool_constant{}); } else { @@ -117,7 +130,7 @@ struct WarpGemmAttributeWmma { if constexpr(kTransC) { - return Impl{}(b_vec, a_vec); + return TransposedImpl{}(b_vec, a_vec); } else { diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp index 0464ffbce4..cf0efbbaae 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp @@ -22,9 +22,10 @@ struct WmmaTraits; template struct WarpGemmAttributeWmmaImpl { - using ADataType = typename Traits::ADataType; - using BDataType = typename Traits::BDataType; - using CDataType = typename Traits::CDataType; + using TraitsType = Traits; + using ADataType = typename Traits::ADataType; + using BDataType = typename Traits::BDataType; + using CDataType = typename Traits::CDataType; using AVecType = typename Traits::AVecType; using BVecType = typename Traits::BVecType; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_16bit_traits.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_16bit_traits.hpp index 992f0a8783..d9d4ec9430 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_16bit_traits.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_16bit_traits.hpp @@ -10,6 +10,8 @@ template <> struct WmmaTraits : WmmaTraitsBase { + using ArchType = gfx11_t; + template CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec) @@ -30,6 +32,8 @@ template <> struct WmmaTraits : WmmaTraitsBase { + using ArchType = gfx11_t; + template CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec) @@ -50,6 +54,8 @@ template <> struct WmmaTraits : WmmaTraitsBase { + using ArchType = gfx12_t; + template CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec) @@ -70,6 +76,8 @@ template <> struct WmmaTraits : WmmaTraitsBase { + using ArchType = gfx12_t; + template CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec) diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp index 34c4dbe551..eace7e3956 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp @@ -10,6 +10,8 @@ template <> struct WmmaTraits : WmmaTraitsBase { + using ArchType = gfx11_t; + template CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec) @@ -35,6 +37,8 @@ template <> struct WmmaTraits : WmmaTraitsBase { + using ArchType = gfx12_t; + template CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec) @@ -60,6 +64,8 @@ template <> struct WmmaTraits : WmmaTraitsBase { + using ArchType = gfx12_t; + template CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec) @@ -80,6 +86,8 @@ template <> struct WmmaTraits : WmmaTraitsBase { + using ArchType = gfx12_t; + template CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec) @@ -100,6 +108,8 @@ template <> struct WmmaTraits : WmmaTraitsBase { + using ArchType = gfx12_t; + template CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec) diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_base_traits.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_base_traits.hpp index 524215ddfa..e00b9d772f 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_base_traits.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_base_traits.hpp @@ -10,6 +10,8 @@ struct WmmaTraitsBase; template struct WmmaTraitsBase { + using ArchType = gfx11_t; + using ADataType = ADType; using BDataType = BDType; using CDataType = CDType; @@ -57,6 +59,8 @@ struct WmmaTraitsBase template struct WmmaTraitsBase { + using ArchType = gfx12_t; + using ADataType = ADType; using BDataType = BDType; using CDataType = CDType; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index 82c6e43834..d6c21e88b5 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -100,6 +100,7 @@ template<> struct Dispatcher { using Ty template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8_CTransposed; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8_CTransposed; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; }; @@ -113,6 +114,7 @@ template<> struct Dispatcher { using Ty template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8_CTransposed; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8; }; +template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8_CTransposed; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; }; // scale mfma based f8f6f4 diff --git a/include/ck_tile/ops/gemm_quant.hpp b/include/ck_tile/ops/gemm_quant.hpp index 1e4aece0d7..696de378aa 100644 --- a/include/ck_tile/ops/gemm_quant.hpp +++ b/include/ck_tile/ops/gemm_quant.hpp @@ -3,6 +3,7 @@ #pragma once #include "ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp" +#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp" #include "ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp" #include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp" #include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp" @@ -24,6 +25,8 @@ #include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp" #include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp" diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp new file mode 100644 index 0000000000..63a5151108 --- /dev/null +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp @@ -0,0 +1,282 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp" + +namespace ck_tile { + +// A is block window on shared memory +// BQ (scale tensor) is block distributed tensor. +// Consecutive QuantGroupSize elements of B are quantized with a separate scale. +// B is block window on block distributed tensor. +// C is block distributed tensor +template +struct BlockGemmWeightPreshuffleABQuantARegBRegCReg +{ + private: + template + struct GemmTraits_ + { + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using ADataType = remove_cvref_t; + using AQDataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using BQDataType = remove_cvref_t; + using BQLayout = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + using AQuantGroupSize = remove_cvref_t; + using BQuantGroupSize = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr auto Scheduler = Problem::Scheduler; + + // Threadblock GEMM tile size + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + static constexpr index_t NQPerBlock = NPerBlock / BQuantGroupSize::kN; + static constexpr index_t KQPerBlock = KPerBlock / BQuantGroupSize::kK; + static constexpr index_t AQPerBlock = KPerBlock / AQuantGroupSize::kK; + + static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + // number of warps along M and N for threadblock's GEMM problem size + static constexpr index_t MWarp = config.template at<1>(); + static constexpr index_t NWarp = config.template at<2>(); + + using I0 = number<0>; + using I1 = number<1>; + + static_assert(MWarp == BlockGemmShape::BlockWarps::at(I0{}), + "Error! WarpGemm's MWarp is not consistent with BlockGemmShape!"); + static_assert(NWarp == BlockGemmShape::BlockWarps::at(I1{}), + "Error! WarpGemm's NWarp is not consistent with BlockGemmShape!"); + static_assert(WarpGemm::kM == BlockGemmShape::WarpTile::at(I0{}), + "Error! WarpGemm's M is not consistent with BlockGemmShape!"); + static_assert(WarpGemm::kN == BlockGemmShape::WarpTile::at(I1{}), + "Error! WarpGemm's N is not consistent with BlockGemmShape!"); + + static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); + static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); + static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; + + static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; + + static constexpr index_t QScalesPerBlockRow = + integer_divide_ceil(KPerBlock, BQuantGroupSize::kK); + static constexpr index_t QScalesPerWarpGemmRow = + integer_divide_ceil(WarpGemm::kK, BQuantGroupSize::kK); + + static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow; + + static_assert(BQuantGroupSize::kK % WarpGemm::kK == 0, + "Error! WarpGemm::kK should be a multiple of QuantGroupSize"); + static_assert(QScalesPerWarpGemmRow == 1, + "Error! QuantGroupSize shouldn't be smaller than WarpGemm::kK"); + static_assert(KIterPerWarp % QScalesPerBlockRow == 0, + "Error! KItersPerWarp should be a multiple of QscalesPerBlockRow"); + + static_assert(KPerBlock / BQuantGroupSize::kK > 0, + "Error! Each row of blockgemm should have a separate scale"); + + static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock, + "Error! Warps should cover all Block tile!"); + static_assert(NIterPerWarp * NWarp * WarpGemm::kN == NPerBlock, + "Error! Warps should cover all Block tile!"); + + // Currently tested combinations (A, B, BQ) + // 1. fp8, fp8, fp32 -> f32 + // 2. bf8, bf8, fp32 -> f32 + // 3. i4, fp8, (fp8/fp32) -> f32 + // 4. i4, bf8, (fp8/fp32) -> f32 + static_assert( + (std::is_same_v || std::is_same_v || + std::is_same_v) && + (std::is_same_v || std::is_same_v || + std::is_same_v) && + (std::is_same_v || std::is_same_v || + std::is_same_v) && + (std::is_same_v || std::is_same_v || + std::is_same_v) && + (std::is_same_v || std::is_same_v) && + std::is_same_v); + + static constexpr index_t InterWaveSchedulingMacClusters = 1; + + static constexpr index_t KPack = WarpGemm::kKPerThread; + static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread; + static constexpr bool TransposeC = Problem::TransposeC; + }; + + public: + using Traits = GemmTraits_; + using Problem = remove_cvref_t; + using BlockPolicy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using BQDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; // TileFlatmmShape + using QuantGroupSize = remove_cvref_t; + + static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!"); + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + static constexpr auto idxM = I0; + static constexpr auto idxN = I1; + static constexpr auto idxK = I2; + using BlockTile = remove_cvref_t; + using BlockWarps = remove_cvref_t; + using WarpTile = remove_cvref_t; + + static constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp(); + + static constexpr auto warp_size = get_warp_size(); + + using WG = remove_cvref_t())>; + + static constexpr index_t MWarp = config.template at<1>(); + static constexpr index_t NWarp = config.template at<2>(); + + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); // 128 / (1 * 16) = 8 + static constexpr index_t NIterPerWarp = + BlockTile::at(idxN) / (WarpTile::at(idxN) * BlockWarps::at(idxN)); // 128 / (4 * 16) = 2 + static constexpr index_t KIterPerWarp = KPerBlock / WG::kK; // 128 / 16 = 8 + static constexpr auto MIter_2nd_last = + (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1; + + static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK; + + static constexpr index_t QScalesPerBlockRow = + integer_divide_ceil(KPerBlock, QuantGroupSize::kK); // 128 / 128 = 1 + static constexpr index_t QScalesPerWarpGemmRow = + integer_divide_ceil(WG::kK, QuantGroupSize::kK); + + static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow; // 8 / 1 = 8 + static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read + + static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload) + ? DsReadPreload + : MIterPerWarp * KIterPerWarp; + + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() + { + return BlockGemmQuantCommon:: + MakeCBlockTile(); + } + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + ABlockTensor& a_warp_tensor, + BFlatBlockTensor& b_warp_tensor, + AQBlockTensor& aq_block_tensor, + BQBlockTensor& bq_block_tensor, + ABlockWindow& a_warp_windows) const + { + using CWarpDstr = typename WG::CWarpDstr; + using AccTensor = typename WG::CWarpTensor; + + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + statically_indexed_array, MIterPerWarp> + c_acc; + + auto zero_accumulators = [&] { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, (WG::kM * WG::kN) / warp_size, 1>{}([&](auto i) { + c_acc(mIter)(nIter).get_thread_buffer()[i] = 0.0f; + }); // make sure WG::CWarpTensor exposes a clear/zero + }); + }); + }; + static_for<0, QScalesPerBlockRow, 1>{}([&](auto kQScale) { + zero_accumulators(); + static_for<0, KIterPerQScale, 1>{}([&](auto kIterInQScale) { + constexpr auto kIter = kQScale * KIterPerQScale + kIterInQScale; + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // warp GEMM + WG{}(c_acc(mIter)(nIter), + a_warp_tensor(number{}), + b_warp_tensor(nIter)(number{})); + }); + __builtin_amdgcn_sched_barrier(0x7F6); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = + load_tile(a_warp_windows(number{})(number{})); + } + // barrier + // Could be deleted + if constexpr((mIter == MIter_2nd_last)) + { + block_sync_lds(); + } + }); + }); + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + AQPickerCommon aq_picker(aq_block_tensor); + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + constexpr auto tbuf_offset = + number{}, + c_warp_y_index_zeros)) / + CBlockTensor::PackedSize>{}; + + index_t reg_offset = [&]() { + if constexpr(QuantGroupSize::kN >= (NWarp * WG::kN)) + { + return (nIter * NWarp * WG::kN) / QuantGroupSize::kN * KPerBlockBQ + + kQScale; + } + else + { + return nIter * KPerBlockBQ + kQScale; + } + }(); + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; + float b_scale_reg_f = + aq_picker.template cvt_scale_to_fp32(scale_reg); + + static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) { + float a_scale_reg_f = aq_picker.template pick(); + auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row]; + const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row]; + c_ref = c_ref + acc_val * b_scale_reg_f * a_scale_reg_f; + }); + }); + }); + }); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index 16a0835b1d..313e449c7b 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -322,6 +322,7 @@ struct BQuantBlockUniversalGemmAsBsCr constexpr index_t reg_offset = nIter; auto pull_from_lane = (__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale; + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; // cross lane ops uint32_t scale_reg_dword; diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index ba67a9ee4d..004fb18e0b 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -280,12 +280,13 @@ struct QuantGemmKernel // Helper: Create Pre-shuffled Quantization Tensor Descriptor // =================================================================== template CK_TILE_DEVICE static auto - MakePreshuffledQuantTensorView(const BQDataType_* bq_ptr, index_t N, index_t QK_B) + MakePreshuffledQuantTensorView(const BQDataType_* bq_ptr, index_t N, index_t QN_B, index_t QK_B) { // Step 1: Calculate base BQ tensor dimensions // ---------------------------------------------------------- @@ -304,8 +305,9 @@ struct QuantGemmKernel // ---------------------------------------------------------- // Pad the X dimension to be a multiple of block_tile_size to ensure // each thread block can process complete tiles without edge cases - const auto block_tile_size = NPerBlock * KPerBlockBQ; - const auto bq_pad0_desc = transform_tensor_descriptor( + const auto block_tile_size = NPerBlockBQ * KPerBlockBQ; + + const auto bq_pad0_desc = transform_tensor_descriptor( bq_desc, make_tuple(make_pass_through_transform(bq_y), make_right_pad_transform(bq_x, get_padding_size(bq_x, block_tile_size))), @@ -318,7 +320,7 @@ struct QuantGemmKernel // This separates the work into tiles that can be processed by // individual warps/waves const auto pad_bq_x = bq_pad0_desc.get_lengths()[I1]; - const auto wave_tile_size = WarpTileN * KPerBlockBQ; + const auto wave_tile_size = ((QN_B <= WarpTileN) ? (WarpTileN / QN_B) : 1) * KPerBlockBQ; const auto wave_tile_count_x = ck_tile::integer_divide_ceil(pad_bq_x, wave_tile_size); const auto bq_unmerge_pad0_desc = transform_tensor_descriptor( @@ -401,6 +403,623 @@ struct QuantGemmKernel index_t splitted_k; }; + CK_TILE_DEVICE static auto MakeABlockWindow(const ADataType* a_ptr, + const QuantGemmKernelArgs& kargs, + const index_t k_size, + const index_t i_m) + { + // Step 1: Create tensor view for A + const auto& a_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + a_ptr, + make_tuple(kargs.M, k_size), + make_tuple(kargs.stride_A, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + a_ptr, + make_tuple(k_size, kargs.M), + make_tuple(kargs.stride_A, 1), + number{}, + number<1>{}); + } + }(); + + // Step 2: Create padded view + const auto& a_pad_view = [&]() { + if constexpr(std::is_same_v) + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + // Step 3: Create tile window + const auto& a_block_window = [&]() { + if constexpr(std::is_same_v) + { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {i_m, 0}); + } + else + { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {0, i_m}); + } + }(); + + return a_block_window; + } + + CK_TILE_DEVICE static auto MakeAQBlockWindow(const AQDataType* aq_ptr, + const QuantGemmKernelArgs& kargs, + const index_t i_m, + const index_t i_n) + { + // Step 1: Create tensor view for AQ + const auto& aq_tensor_view = [&]() { + if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant) + { + static_assert(std::is_same_v); + const auto aq_x = kargs.M * GemmPipeline::KPerBlockAQ; + const auto aq_y = kargs.QK_A / GemmPipeline::KPerBlockAQ; + const auto aq_desc = + make_naive_tensor_descriptor(make_tuple(aq_y, aq_x), + make_tuple(aq_x, 1), + number{}, + number<1>{}); + + const auto block_tile_size = GemmPipeline::MPerBlock * GemmPipeline::KPerBlockAQ; + const auto aq_pad0_desc = transform_tensor_descriptor( + aq_desc, + make_tuple( + make_pass_through_transform(aq_y), + make_right_pad_transform(aq_x, get_padding_size(aq_x, block_tile_size))), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + const auto pad_aq_x = aq_pad0_desc.get_lengths()[I1]; + const auto wave_tile_size = + GemmPipeline::BlockGemmShape::WarpTile::at(I0) * GemmPipeline::KPerBlockAQ; + const auto wave_tile_count_x = + ck_tile::integer_divide_ceil(pad_aq_x, wave_tile_size); + + const auto aq_unmerge_pad0_desc = transform_tensor_descriptor( + aq_pad0_desc, + make_tuple( + make_pass_through_transform(aq_y), + make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{})); + + const auto aq_pad1_desc = transform_tensor_descriptor( + aq_unmerge_pad0_desc, + make_tuple( + make_pass_through_transform(aq_y), + make_pass_through_transform(wave_tile_count_x), + make_right_pad_transform( + wave_tile_size, get_padding_size(wave_tile_size, get_warp_size()))), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + const auto pad_wave_size = + ck_tile::integer_least_multiple(wave_tile_size, get_warp_size()); + const auto aq_merge_pad1_desc = transform_tensor_descriptor( + aq_pad1_desc, + make_tuple(make_merge_transform(make_tuple(aq_y, wave_tile_count_x)), + make_pass_through_transform(pad_wave_size)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return make_tensor_view(aq_ptr, aq_merge_pad1_desc); + } + else if constexpr((kQuantType == QuantType::AQuantGrouped || + kQuantType == QuantType::ABQuantGrouped) && + !PreshuffleQuant) + { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + aq_ptr, + make_tuple(kargs.M, kargs.QK_A), + make_tuple(kargs.stride_AQ, 1), + number{}, + number<1>{}); + } + else // Column major AQ + { + return make_naive_tensor_view( + aq_ptr, + make_tuple(kargs.QK_A, kargs.M), + make_tuple(kargs.stride_AQ, 1), + number{}, + number<1>{}); + } + } + else if constexpr(kQuantType == QuantType::RowColQuant) + { + return make_naive_tensor_view( + aq_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(1, 0), // broadcasting over n + number<1>{}, + number<1>{}); + } + else + { + return nullptr; + } + }(); + + // Step 2: Create tile window (no padding for AQ) + const auto& aq_block_window = [&]() { + if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant) + { + static_assert(std::is_same_v); + using QuantGroupSize = remove_cvref_t; + constexpr auto block_m = TilePartitioner::MPerBlock; + constexpr auto warp_m = GemmPipeline::BlockGemmShape::WarpTile::at(I0); + constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; + constexpr auto tile_window_width = + ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size()); + constexpr auto tile_window_height = block_m / warp_m; + auto block_m_idx = i_m / block_m; + return make_tile_window( + aq_tensor_view, + make_tuple(number{}, number{}), + {block_m_idx * tile_window_height, 0}); + } + else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant) + { + using QuantGroupSize = remove_cvref_t; + constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; + constexpr auto block_m = TilePartitioner::MPerBlock; + if constexpr(std::is_same_v) + { + return make_tile_window(aq_tensor_view, + make_tuple(number{}, number{}), + {i_m, 0}); + } + else // Column major AQ + { + return make_tile_window(aq_tensor_view, + make_tuple(number{}, number{}), + {0, i_m}); + } + } + else if constexpr(kQuantType == QuantType::ABQuantGrouped && !PreshuffleQuant) + { + static_assert(std::is_same_v); + using QuantGroupSize = remove_cvref_t; + constexpr auto block_m = TilePartitioner::MPerBlock; + constexpr auto block_k = TilePartitioner::KPerBlock; + return make_tile_window( + aq_tensor_view, + make_tuple(number{}, number{}), + {i_m, 0}); + } + else if constexpr(kQuantType == QuantType::RowColQuant) + { + return make_tile_window(aq_tensor_view, + make_tuple(number{}, + number{}), + {i_m, i_n}); + } + else + { + return nullptr; + } + }(); + + return aq_block_window; + } + + CK_TILE_DEVICE static auto MakeBBlockWindow(const BDataType* b_ptr, + const QuantGemmKernelArgs& kargs, + const index_t k_size, + const index_t i_n) + { + // Step 1: Create tensor view for B + const auto& b_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + if constexpr(GemmPipeline::BlockGemmShape::PermuteB) + { + constexpr index_t K1 = GemmPipeline::GetSmemPackB(); + const index_t K0 = k_size / K1; + constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); + const auto b_k0_n_k1_desc = + make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), + make_tuple(kargs.N * K1, K1, I1), + number{}, + number<1>{}); + const auto b_n_k_desc = transform_tensor_descriptor( + b_k0_n_k1_desc, + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(kargs.N)), + make_tuple(sequence<0, 2>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return make_tensor_view(b_ptr, b_n_k_desc); + } + else + { + return make_naive_tensor_view( + b_ptr, + make_tuple(k_size, kargs.N), + make_tuple(kargs.stride_B, 1), + number{}, + number<1>{}); + } + } + else + { + if constexpr(GemmPipeline::BlockGemmShape::PermuteB) + { + constexpr index_t K1 = GemmPipeline::GetSmemPackB(); + const index_t K0 = k_size / K1; + constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); + const auto b_k0_n_k1_desc = + make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), + make_tuple(kargs.N * K1, K1, I1), + number{}, + number<1>{}); + const auto b_n_k_desc = transform_tensor_descriptor( + b_k0_n_k1_desc, + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(kargs.N)), + make_tuple(sequence<0, 2>{}, sequence<1>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + return make_tensor_view(b_ptr, b_n_k_desc); + } + else + { + if constexpr(PreshuffleB) + { + index_t kFlatK = + GemmPipeline::flatKPerWarp * + (k_size / GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{})); + index_t kFlatN = kargs.N * kargs.K / kFlatK; + return make_naive_tensor_view( + b_ptr, + make_tuple(kFlatN, kFlatK), + make_tuple(kFlatK, 1), + number{}, + number<1>{}); + } + else + { + if constexpr(std::is_same_v) + return make_naive_tensor_view( + b_ptr, + make_tuple(kargs.N, k_size / 2), + make_tuple(kargs.stride_B, 1), + number{}, + number<1>{}); + else + return make_naive_tensor_view( + b_ptr, + make_tuple(kargs.N, k_size), + make_tuple(kargs.stride_B, 1), + number{}, + number<1>{}); + } + } + } + }(); + + // Step 2: Create padded view (or flat view for PreshuffleB) + const auto& b_pad_view = [&]() { + if constexpr(PreshuffleB) + { + return b_tensor_view; // no padding for preshuffle + } + else if constexpr(std::is_same_v) + { + if constexpr(std::is_same_v) + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + else + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + // Step 3: Create tile window + const auto& b_block_window = [&]() { + if constexpr(PreshuffleB) + { + return make_tile_window( + b_pad_view, + make_tuple(number{}, + number{}), + {static_cast(i_n / GemmPipeline::BlockGemmShape::WarpTile::at(I1)), 0}); + } + else + { + if constexpr(std::is_same_v) + { + if constexpr(std::is_same_v) + return make_tile_window( + b_pad_view, + make_tuple(number{}, + number{}), + {i_n, 0}); + else + return make_tile_window(b_pad_view, + make_tuple(number{}, + number{}), + {i_n, 0}); + } + else + { + return make_tile_window(b_pad_view, + make_tuple(number{}, + number{}), + {0, i_n}); + } + } + }(); + + return b_block_window; + } + + CK_TILE_DEVICE static auto MakeBQBlockWindow(const BQDataType* bq_ptr, + const QuantGemmKernelArgs& kargs, + const index_t i_m, + const index_t i_n) + { + // Step 1: Create tensor view for BQ + const auto& bq_tensor_view = [&]() { + if constexpr(kQuantType == QuantType::RowColQuant) + { + return make_naive_tensor_view( + bq_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(0, 1), // broadcasting over m + number<1>{}, + number<1>{}); + } + else if constexpr(kQuantType == QuantType::BQuantGrouped) + { + if constexpr(PreshuffleQuant) + { + static_assert(std::is_same_v, + "PreshuffleQuant with BQuantGrouped currently only supports " + "ColumnMajor BQ layout"); + using QuantGroupSize = remove_cvref_t; + + return MakePreshuffledQuantTensorView< + GemmPipeline::KPerBlockBQ, + GemmPipeline::NPerBlockBQ, + GemmPipeline::NPerBlock, + TilePartitioner::BlockGemmShape::WarpTile::at(I1), + GemmPipeline::GetVectorSizeBQ()>( + bq_ptr, + ck_tile::integer_divide_ceil(kargs.N, QuantGroupSize::kN), + QuantGroupSize::kN, + kargs.QK_B); + } + else + { + using QuantGroupSize = remove_cvref_t; + + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + bq_ptr, + make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), + integer_divide_ceil(kargs.N, QuantGroupSize::kN)), + make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), 1), + number{}, + number<1>{}); + } + else + { + static_assert(std::is_same_v); + return make_naive_tensor_view( + bq_ptr, + make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), + integer_divide_ceil(kargs.K, QuantGroupSize::kK)), + make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), 1), + number{}, + number<1>{}); + } + } + } + else if constexpr(kQuantType == QuantType::ABQuantGrouped) + { + static_assert(std::is_same_v); + using QuantGroupSize = remove_cvref_t; + return make_naive_tensor_view( + bq_ptr, + make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B), + make_tuple(kargs.stride_BQ, 1), + number{}, + number<1>{}); + } + else + { + return nullptr; + } + }(); + + // Step 2: Create tile window (no padding for BQ) + const auto& bq_block_window = [&]() { + if constexpr(kQuantType == QuantType::RowColQuant) + { + return make_tile_window(bq_tensor_view, + make_tuple(number{}, + number{}), + {i_m, i_n}); + } + else if constexpr(kQuantType == QuantType::BQuantGrouped) + { + using QuantGroupSize = remove_cvref_t; + if constexpr(PreshuffleQuant) + { + static_assert(std::is_same_v); + constexpr auto block_n = + TilePartitioner::NPerBlock / + QuantGroupSize::kN; // Number of N-dimension quantization groups per block + constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at( + I1); // Number of N-dimension elements per warp + constexpr auto warp_per_group = + (QuantGroupSize::kN < + warp_n) // Determine how many warps share the same scale in N-dimension + ? (warp_n / QuantGroupSize::kN) + : (QuantGroupSize::kN / warp_n); + constexpr auto bqk_per_block = + TilePartitioner::KPerBlock / + QuantGroupSize::kK; // Number of K-dimension quantization groups per block + constexpr auto + tile_window_width = // The pre-shuffled layout flattens warp_n × + // bqk_per_block scales per row, Padded up to warp_size + // to ensure coalesced memory access. + ck_tile::integer_least_multiple(warp_n * bqk_per_block, get_warp_size()); + + // Adapts based on fine vs coarse quantization granularity: + // - Fine-grained (QuantGroupSize::kN < warp_n): + // Multiple quant groups per warp → fewer rows needed per block. + // height = block_n / warp_per_group + // + // - Coarse-grained (QuantGroupSize::kN >= warp_n): + // Each row represents one quant group. + // height = block_n + constexpr auto tile_window_height = + (QuantGroupSize::kN < warp_n) ? block_n / warp_per_group : block_n; + auto block_n_idx = + i_n / TilePartitioner::NPerBlock; // Converts the global N-index (i_n) to a + // block index. + + return make_tile_window( + bq_tensor_view, + make_tuple(number{}, number{}), + {block_n_idx * tile_window_height, 0}); + } + else + { + if constexpr(std::is_same_v) + { + return make_tile_window( + bq_tensor_view, + make_tuple(number{}, + number{}), + {0, i_n / QuantGroupSize::kN}); + } + else + { + static_assert(std::is_same_v); + return make_tile_window( + bq_tensor_view, + make_tuple(number{}, + number{}), + {i_n / QuantGroupSize::kN, 0}); + } + } + } + else if constexpr(kQuantType == QuantType::ABQuantGrouped) + { + static_assert(std::is_same_v); + using QuantGroupSize = remove_cvref_t; + return make_tile_window( + bq_tensor_view, + make_tuple(number{}, + number{}), + {i_n / QuantGroupSize::kN, 0}); + } + else + { + return nullptr; + } + }(); + + return bq_block_window; + } + + template + CK_TILE_DEVICE static auto MakeCBlockWindow(CDataType* c_ptr, + const QuantGemmKernelArgs& kargs, + const index_t i_m, + const index_t i_n) + { + // Step 1: Create tensor view for C + const auto& c_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + c_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(kargs.stride_C, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + c_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(1, kargs.stride_C), + number<1>{}, + number<1>{}); + } + }(); + + // Step 2: Create padded view + const auto& c_pad_view = [&]() { + if constexpr(std::is_same_v) + { + return pad_tensor_view(c_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(c_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + // Step 3: Create tile window + auto c_block_window = make_tile_window( + c_pad_view, + make_tuple(number{}, number{}), + {i_m, i_n}); + + return c_block_window; + } + CK_TILE_HOST static bool IsSupportedArgument(const QuantGemmKernelArgs& kargs) { if(kargs.k_batch != 1) @@ -539,596 +1158,6 @@ struct QuantGemmKernel return true; } - template - CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr, - const BDataType* b_ptr, - const AQDataType* aq_ptr, - const BQDataType* bq_ptr, - CDataType* c_ptr, - const QuantGemmKernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset) - { - - static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!"); - const auto& a_tensor_view = [&]() { - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - a_ptr, - make_tuple(kargs.M, splitk_batch_offset.splitted_k), - make_tuple(kargs.stride_A, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - a_ptr, - make_tuple(splitk_batch_offset.splitted_k, kargs.M), - make_tuple(kargs.stride_A, 1), - number{}, - number<1>{}); - } - }(); - - const auto& aq_tensor_view = [&]() { - if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant) - { - static_assert(std::is_same_v); - const auto aq_x = kargs.M * GemmPipeline::KPerBlockAQ; - const auto aq_y = kargs.QK_A / GemmPipeline::KPerBlockAQ; - const auto aq_desc = - make_naive_tensor_descriptor(make_tuple(aq_y, aq_x), - make_tuple(aq_x, 1), - number{}, - number<1>{}); - - const auto block_tile_size = GemmPipeline::MPerBlock * GemmPipeline::KPerBlockAQ; - const auto aq_pad0_desc = transform_tensor_descriptor( - aq_desc, - make_tuple( - make_pass_through_transform(aq_y), - make_right_pad_transform(aq_x, get_padding_size(aq_x, block_tile_size))), - make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - const auto pad_aq_x = aq_pad0_desc.get_lengths()[I1]; - const auto wave_tile_size = - GemmPipeline::BlockGemmShape::WarpTile::at(I0) * GemmPipeline::KPerBlockAQ; - const auto wave_tile_count_x = - ck_tile::integer_divide_ceil(pad_aq_x, wave_tile_size); - - const auto aq_unmerge_pad0_desc = transform_tensor_descriptor( - aq_pad0_desc, - make_tuple( - make_pass_through_transform(aq_y), - make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))), - make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{})); - - const auto aq_pad1_desc = transform_tensor_descriptor( - aq_unmerge_pad0_desc, - make_tuple( - make_pass_through_transform(aq_y), - make_pass_through_transform(wave_tile_count_x), - make_right_pad_transform( - wave_tile_size, get_padding_size(wave_tile_size, get_warp_size()))), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); - - const auto pad_wave_size = - ck_tile::integer_least_multiple(wave_tile_size, get_warp_size()); - const auto aq_merge_pad1_desc = transform_tensor_descriptor( - aq_pad1_desc, - make_tuple(make_merge_transform(make_tuple(aq_y, wave_tile_count_x)), - make_pass_through_transform(pad_wave_size)), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return make_tensor_view(aq_ptr, aq_merge_pad1_desc); - } - else if constexpr((kQuantType == QuantType::AQuantGrouped || - kQuantType == QuantType::ABQuantGrouped) && - !PreshuffleQuant) - { - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - aq_ptr, - make_tuple(kargs.M, kargs.QK_A), - make_tuple(kargs.stride_AQ, 1), - number{}, - number<1>{}); - } - else // Column major AQ - { - return make_naive_tensor_view( - aq_ptr, - make_tuple(kargs.QK_A, kargs.M), // Swapped dimensions - make_tuple(kargs.stride_AQ, 1), // Same stride pattern - number{}, - number<1>{}); - } - } - else if constexpr(kQuantType == QuantType::RowColQuant) - { - return make_naive_tensor_view( - aq_ptr, - make_tuple(kargs.M, kargs.N), - make_tuple(1, 0), // broadcasting over n - number<1>{}, - number<1>{}); - } - else - { - return nullptr; // TODO: use some other "empty" type for this - } - }(); - - const auto& b_tensor_view = [&]() { - if constexpr(std::is_same_v) - { - if constexpr(GemmPipeline::BlockGemmShape::PermuteB) - { - constexpr index_t K1 = GemmPipeline::GetSmemPackB(); - const index_t K0 = splitk_batch_offset.splitted_k / K1; - constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); - const auto b_k0_n_k1_desc = - make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), - make_tuple(kargs.N * K1, K1, I1), - number{}, - number<1>{}); - const auto b_n_k_desc = transform_tensor_descriptor( - b_k0_n_k1_desc, - make_tuple(make_merge_transform(make_tuple(K0, K1)), - make_pass_through_transform(kargs.N)), - make_tuple(sequence<0, 2>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - return make_tensor_view(b_ptr, b_n_k_desc); - } - else - { - return make_naive_tensor_view( - b_ptr, - make_tuple(splitk_batch_offset.splitted_k, kargs.N), - make_tuple(kargs.stride_B, 1), - number{}, - number<1>{}); - } - } - else - { - if constexpr(GemmPipeline::BlockGemmShape::PermuteB) - { - constexpr index_t K1 = GemmPipeline::GetSmemPackB(); - const index_t K0 = splitk_batch_offset.splitted_k / K1; - constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); - const auto b_k0_n_k1_desc = - make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), - make_tuple(kargs.N * K1, K1, I1), - number{}, - number<1>{}); - const auto b_n_k_desc = transform_tensor_descriptor( - b_k0_n_k1_desc, - make_tuple(make_merge_transform(make_tuple(K0, K1)), - make_pass_through_transform(kargs.N)), - make_tuple(sequence<0, 2>{}, sequence<1>{}), - make_tuple(sequence<1>{}, sequence<0>{})); - return make_tensor_view(b_ptr, b_n_k_desc); - } - else - { - if constexpr(PreshuffleB) - { - index_t kFlatK = GemmPipeline::flatKPerWarp * - (splitk_batch_offset.splitted_k / - GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{})); - index_t kFlatN = kargs.N * kargs.K / kFlatK; - return make_naive_tensor_view( - b_ptr, - make_tuple(kFlatN, kFlatK), - make_tuple(kFlatK, 1), - number{}, - number<1>{}); - } - else - { - if constexpr(std::is_same_v) - return make_naive_tensor_view( - b_ptr, - make_tuple(kargs.N, splitk_batch_offset.splitted_k / 2), - make_tuple(kargs.stride_B, 1), - number{}, - number<1>{}); - else - return make_naive_tensor_view( - b_ptr, - make_tuple(kargs.N, splitk_batch_offset.splitted_k), - make_tuple(kargs.stride_B, 1), - number{}, - number<1>{}); - } - } - } - }(); - - const auto& bq_tensor_view = [&]() { - if constexpr(kQuantType == QuantType::RowColQuant) - { - return make_naive_tensor_view( - bq_ptr, - make_tuple(kargs.M, kargs.N), - make_tuple(0, 1), // broadcasting over m - number<1>{}, - number<1>{}); - } - else if constexpr(kQuantType == QuantType::BQuantGrouped) - { - if constexpr(PreshuffleQuant) - { - static_assert(std::is_same_v, - "PreshuffleQuant with BQuantGrouped currently only supports " - "ColumnMajor BQ layout"); - - return MakePreshuffledQuantTensorView< - GemmPipeline::KPerBlockBQ, - GemmPipeline::NPerBlock, - TilePartitioner::BlockGemmShape::WarpTile::at(I1), - GemmPipeline::GetVectorSizeBQ()>(bq_ptr, kargs.N, kargs.QK_B); - } - else - { - using QuantGroupSize = remove_cvref_t; - - if constexpr(std::is_same_v) - { - // For RowMajor BQ: memory layout is [K/QuantGroupK][N/QuantGroupN] - // Dimensions: [K/QuantGroupK, N/QuantGroupN] - // Strides: [N/QuantGroupN, 1] - return make_naive_tensor_view( - bq_ptr, - make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), - integer_divide_ceil(kargs.N, QuantGroupSize::kN)), - make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), 1), - number{}, - number<1>{}); - } - else - { - static_assert(std::is_same_v); - // For ColumnMajor BQ: memory layout is [N/QuantGroupN][K/QuantGroupK] - // Dimensions: [N/QuantGroupN, K/QuantGroupK] - // Strides: [K/QuantGroupK, 1] - return make_naive_tensor_view( - bq_ptr, - make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), - integer_divide_ceil(kargs.K, QuantGroupSize::kK)), - make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), 1), - number{}, - number<1>{}); - } - } - } - else if constexpr(kQuantType == QuantType::ABQuantGrouped) - { - static_assert(std::is_same_v); - using QuantGroupSize = remove_cvref_t; - return make_naive_tensor_view( - bq_ptr, - make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B), - make_tuple(kargs.stride_BQ, 1), - number{}, - number<1>{}); - } - else - { - return nullptr; // TODO: use some other "empty" type for this - } - }(); - - // TODO: enable vector write for C in ColMajor - const auto& c_tensor_view = [&]() { - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - c_ptr, - make_tuple(kargs.M, kargs.N), - make_tuple(kargs.stride_C, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - c_ptr, - make_tuple(kargs.M, kargs.N), - make_tuple(1, kargs.stride_C), - number<1>{}, - number<1>{}); - } - }(); - - return make_tuple( - a_tensor_view, aq_tensor_view, b_tensor_view, bq_tensor_view, c_tensor_view); - } - - template - CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) - { - const auto& a_pad_view = [&]() { - const auto& a_tensor_view = views.at(I0); - if constexpr(std::is_same_v) - { - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - }(); - - // no padding - const auto& aq_pad_view = [&]() { return views.at(I1); }(); - - const auto& b_flat_view = views.at(I2); // not applying any padding to flat B view - - const auto& b_pad_view = [&]() { - const auto& b_tensor_view = views.at(I2); - if constexpr(std::is_same_v) - { - if constexpr(std::is_same_v) - return pad_tensor_view(b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - else - return pad_tensor_view(b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - }(); - - // no padding - const auto& bq_pad_view = [&]() { return views.at(I3); }(); - - // TODO vector write in for C in ColMajor - const auto& c_pad_view = [&]() { - const auto& c_tensor_view = views.at(I4); - if constexpr(std::is_same_v) - { - return pad_tensor_view(c_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(c_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - }(); - if constexpr(PreshuffleB) - { - - return make_tuple(a_pad_view, aq_pad_view, b_flat_view, bq_pad_view, c_pad_view); - } - else - { - return make_tuple(a_pad_view, aq_pad_view, b_pad_view, bq_pad_view, c_pad_view); - } - } - - template - CK_TILE_DEVICE static auto - MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) - { - - const auto& a_pad_view = views.at(I0); - const auto& aq_pad_view = views.at(I1); - const auto& b_pad_view = views.at(I2); - const auto& bq_pad_view = views.at(I3); - const auto& c_pad_view = views.at(I4); - const auto& a_block_window = [&]() { - if constexpr(std::is_same_v) - { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {i_m, 0}); - } - else - { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {0, i_m}); - } - }(); - - const auto& aq_block_window = [&]() { - if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant) - { - static_assert(std::is_same_v); - using QuantGroupSize = remove_cvref_t; - constexpr auto block_m = TilePartitioner::MPerBlock; - constexpr auto warp_m = GemmPipeline::BlockGemmShape::WarpTile::at(I0); - constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; - constexpr auto tile_window_width = - ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size()); - constexpr auto tile_window_height = block_m / warp_m; - auto block_m_idx = i_m / block_m; - return make_tile_window( - aq_pad_view, - make_tuple(number{}, number{}), - {block_m_idx * tile_window_height, 0}); - } - else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant) - { - using QuantGroupSize = remove_cvref_t; - constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; - constexpr auto block_m = TilePartitioner::MPerBlock; - if constexpr(std::is_same_v) - { - return make_tile_window(aq_pad_view, - make_tuple(number{}, number{}), - {i_m, 0}); - } - else // Column major AQ - { - return make_tile_window(aq_pad_view, - make_tuple(number{}, number{}), - {0, i_m}); - } - } - else if constexpr(kQuantType == QuantType::ABQuantGrouped && !PreshuffleQuant) - { - static_assert(std::is_same_v); - using QuantGroupSize = remove_cvref_t; - constexpr auto block_m = TilePartitioner::MPerBlock; - constexpr auto block_k = TilePartitioner::KPerBlock; - return make_tile_window( - aq_pad_view, - make_tuple(number{}, number{}), - {i_m, 0}); - } - else if constexpr(kQuantType == QuantType::RowColQuant) - { - return make_tile_window(aq_pad_view, - make_tuple(number{}, - number{}), - {i_m, i_n}); - } - else - { - return nullptr; // TODO: use some other "empty" type? - } - }(); - - const auto& b_block_window = [&]() { - if constexpr(PreshuffleB) - { - - return make_tile_window( - b_pad_view, - make_tuple(number{}, - number{}), - {static_cast(i_n / GemmPipeline::BlockGemmShape::WarpTile::at(I1)), 0}); - } - else - { - if constexpr(std::is_same_v) - { - if constexpr(std::is_same_v) - return make_tile_window( - b_pad_view, - make_tuple(number{}, - number{}), - {i_n, 0}); - else - return make_tile_window(b_pad_view, - make_tuple(number{}, - number{}), - {i_n, 0}); - } - else - { - return make_tile_window(b_pad_view, - make_tuple(number{}, - number{}), - {0, i_n}); - } - } - }(); - - const auto& bq_block_window = [&]() { - if constexpr(kQuantType == QuantType::RowColQuant) - { - return make_tile_window(bq_pad_view, - make_tuple(number{}, - number{}), - {i_m, i_n}); - } - else if constexpr(kQuantType == QuantType::BQuantGrouped) - { - using QuantGroupSize = remove_cvref_t; - if constexpr(PreshuffleQuant) - { - static_assert(std::is_same_v); - constexpr auto block_n = TilePartitioner::NPerBlock / QuantGroupSize::kN; - constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1); - constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; - constexpr auto tile_window_width = - ck_tile::integer_least_multiple(warp_n * bqk_per_block, get_warp_size()); - constexpr auto tile_window_height = block_n / warp_n; - auto block_n_idx = i_n / block_n; - - return make_tile_window( - bq_pad_view, - make_tuple(number{}, number{}), - {block_n_idx * tile_window_height, 0}); - } - else - { - if constexpr(std::is_same_v) - { - return make_tile_window( - bq_pad_view, - make_tuple(number{}, - number{}), - {0, i_n / QuantGroupSize::kN}); - } - else - { - static_assert(std::is_same_v); - return make_tile_window( - bq_pad_view, - make_tuple(number{}, - number{}), - {i_n / QuantGroupSize::kN, 0}); - } - } - } - else if constexpr(kQuantType == QuantType::ABQuantGrouped) - { - static_assert(std::is_same_v); - using QuantGroupSize = remove_cvref_t; - return make_tile_window( - bq_pad_view, - make_tuple(number{}, - number{}), - {i_n / QuantGroupSize::kN, 0}); - } - else - { - return nullptr; // TODO: use some other "empty" type here - } - }(); - - auto c_block_window = make_tile_window( - c_pad_view, - make_tuple(number{}, number{}), - {i_m, i_n}); - - return make_tuple( - a_block_window, aq_block_window, b_block_window, bq_block_window, c_block_window); - } - /** * @brief Runs single GEMM problem cooperatively by whole workgroup. * @@ -1137,69 +1166,61 @@ struct QuantGemmKernel * @param aq_ptr input AQ pointer * @param bq_ptr input BQ pointer * @param c_ptr output C pointer - * @param smem_ptr_0 The start memory pointer of the shared memory block. + * @param smem_ptr The start memory pointer of the shared memory block. * @param kargs GEMM kernel arguments * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch. * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. * - * @tparam DstInMemOp Destination memory operation (default: set). */ - template CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr, const BDataType* b_ptr, const AQDataType* aq_ptr, const BQDataType* bq_ptr, CDataType* c_ptr, - void* smem_ptr_0, + void* smem_ptr, const QuantGemmKernelArgs& kargs, const SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = MakeGemmTensorViews( - a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset); - - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + // Create block windows using specialized methods + const auto& a_block_window = + MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); + const auto& b_block_window = + MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); + const auto& aq_block_window = MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n); + const auto& bq_block_window = MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n); const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_block_window = gemm_tile_windows.at(I2); - const auto& c_block_tile = [&]() { if constexpr(kQuantType == QuantType::AQuantGrouped) { - const auto& aq_block_window = gemm_tile_windows.at(I1); - index_t m = 0; + index_t m = 0; if constexpr(PreshuffleQuant) { m = kargs.M; } return GemmPipeline{}.template operator()( - a_block_window, b_block_window, aq_block_window, num_loop, smem_ptr_0, m); + a_block_window, b_block_window, aq_block_window, num_loop, smem_ptr, m); } else if constexpr(kQuantType == QuantType::BQuantGrouped) { - const auto& bq_block_window = gemm_tile_windows.at(I3); - index_t n = 0; + index_t n = 0; if constexpr(PreshuffleQuant) { n = kargs.N; } return GemmPipeline{}.template operator()( - a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0, n); + a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr, n); } else if constexpr(kQuantType == QuantType::ABQuantGrouped) { - const auto& aq_block_window = gemm_tile_windows.at(I1); - const auto& bq_block_window = gemm_tile_windows.at(I3); - index_t m = 0; - index_t n = 0; + index_t m = 0; + index_t n = 0; if constexpr(PreshuffleQuant) { m = kargs.M; @@ -1210,7 +1231,7 @@ struct QuantGemmKernel aq_block_window, bq_block_window, num_loop, - smem_ptr_0, + smem_ptr, m, n); } @@ -1218,121 +1239,68 @@ struct QuantGemmKernel kQuantType == QuantType::TensorQuant) { return GemmPipeline{}.template operator()( - a_block_window, b_block_window, num_loop, smem_ptr_0); + a_block_window, b_block_window, num_loop, smem_ptr); } }(); - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I4); + const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch); - if constexpr(kQuantType == QuantType::ABQuantGrouped || - kQuantType == QuantType::AQuantGrouped || - kQuantType == QuantType::BQuantGrouped) + // Run Epilogue Pipeline with k_batch dispatch + if(k_batch == 1) { - EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); - } - else if constexpr(kQuantType == QuantType::RowColQuant) - { - const auto& aq_block_window = gemm_tile_windows.at(I1); - const auto& bq_block_window = gemm_tile_windows.at(I3); - EpiloguePipeline{}(c_block_window, - c_block_tile, - c_block_window, - smem_ptr_0, - aq_block_window, - bq_block_window); - } - else if constexpr(kQuantType == QuantType::TensorQuant) - { - // TODO: why doesn't readfirstlane work here? - // const AccDataType aq_scale = - // __builtin_amdgcn_readfirstlane(type_convert(*aq_ptr)); - // const AccDataType bq_scale = - // __builtin_amdgcn_readfirstlane(type_convert(*bq_ptr)); - const AccDataType aq_scale = type_convert(*aq_ptr); - const AccDataType bq_scale = type_convert(*bq_ptr); - EpiloguePipeline{}( - c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale); - } - } - /** - * @brief Runs single GEMM problem cooperatively by whole workgroup. - * - * @param a_ptr input A pointer - * @param b_ptr input B pointer - * @param aq_ptr input AQ pointer - * @param c_ptr output C pointer - * @param smem_ptr_0 The start memory pointer of the shared memory block. - * @param kargs GEMM kernel arguments - * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch. - * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. - * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. - * - * @tparam DstInMemOp Destination memory operation (default: set). - */ - template - CK_TILE_DEVICE static void RunGemm2LDS(const ADataType* a_ptr, - const BDataType* b_ptr, - const AQDataType* aq_ptr, - const BQDataType* bq_ptr, - CDataType* c_ptr, - void* smem_ptr_0, - void* smem_ptr_1, - const QuantGemmKernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset, - const index_t block_idx_m, - const index_t block_idx_n) - { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = MakeGemmTensorViews( - a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset); + auto c_block_window = MakeCBlockWindow( + c_ptr, kargs, block_idx_m, block_idx_n); - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - - const index_t num_loop = __builtin_amdgcn_readfirstlane( - TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); - - // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_block_window = gemm_tile_windows.at(I2); - - const auto& c_block_tile = [&]() { - if constexpr(kQuantType == QuantType::BQuantGrouped) + if constexpr(kQuantType == QuantType::ABQuantGrouped || + kQuantType == QuantType::AQuantGrouped || + kQuantType == QuantType::BQuantGrouped) { - const auto& bq_block_window = gemm_tile_windows.at(I3); - index_t n = 0; - if constexpr(PreshuffleQuant) - { - n = kargs.N; - } - return GemmPipeline{}.template operator()(a_block_window, - b_block_window, - bq_block_window, - num_loop, - smem_ptr_0, - smem_ptr_1, - n); + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr); } - else + else if constexpr(kQuantType == QuantType::RowColQuant) { - return nullptr; + EpiloguePipeline{}(c_block_window, + c_block_tile, + c_block_window, + smem_ptr, + aq_block_window, + bq_block_window); + } + else if constexpr(kQuantType == QuantType::TensorQuant) + { + const AccDataType aq_scale = type_convert(*aq_ptr); + const AccDataType bq_scale = type_convert(*bq_ptr); + EpiloguePipeline{}( + c_block_window, c_block_tile, c_block_window, smem_ptr, aq_scale, bq_scale); } - }(); - - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I4); - - if constexpr(kQuantType == QuantType::BQuantGrouped) - { - EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); } else { - return; - // throw std::runtime_error("DoubleSmemBuffer Not implemented for AQuantGrouped or - // RowColQuant"); static_assert(kQuantType == QuantType::BQuantGrouped, - // "DoubleSmemBuffer Not implemented"); + auto c_block_window = MakeCBlockWindow( + c_ptr, kargs, block_idx_m, block_idx_n); + + if constexpr(kQuantType == QuantType::ABQuantGrouped || + kQuantType == QuantType::AQuantGrouped || + kQuantType == QuantType::BQuantGrouped) + { + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr); + } + else if constexpr(kQuantType == QuantType::RowColQuant) + { + EpiloguePipeline{}(c_block_window, + c_block_tile, + c_block_window, + smem_ptr, + aq_block_window, + bq_block_window); + } + else if constexpr(kQuantType == QuantType::TensorQuant) + { + const AccDataType aq_scale = type_convert(*aq_ptr); + const AccDataType bq_scale = type_convert(*bq_ptr); + EpiloguePipeline{}( + c_block_window, c_block_tile, c_block_window, smem_ptr, aq_scale, bq_scale); + } } } @@ -1343,45 +1311,21 @@ struct QuantGemmKernel const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock); const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock); const SplitKBatchOffset splitk_batch_offset(kargs); - // options - const ADataType* a_ptr = static_cast(kargs.a_ptr); - const BDataType* b_ptr = static_cast(kargs.b_ptr); + + // Apply splitk offset to input pointers + const ADataType* a_ptr = + static_cast(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset; + const BDataType* b_ptr = + static_cast(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset; const AQDataType* aq_ptr = static_cast(kargs.aq_ptr); const BQDataType* bq_ptr = static_cast(kargs.bq_ptr); CDataType* c_ptr = static_cast(kargs.c_ptr); // allocate LDS - __shared__ char smem_ptr_0[GetSmemSize()]; - assert(kargs.k_batch == 1); - if constexpr(GemmPipeline::DoubleSmemBuffer == true) - { - __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; + __shared__ char smem_ptr[GetSmemSize()]; - RunGemm2LDS(a_ptr, - b_ptr, - aq_ptr, - bq_ptr, - c_ptr, - smem_ptr_0, - smem_ptr_1, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - else - { - RunGemm(a_ptr, - b_ptr, - aq_ptr, - bq_ptr, - c_ptr, - smem_ptr_0, - kargs, - splitk_batch_offset, - i_m, - i_n); - } + RunGemm( + a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); } }; diff --git a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp index 7e246961cb..c9e725f5fd 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp @@ -318,21 +318,18 @@ struct QuantGroupedGemmKernel CDataType* c_ptr = static_cast(kargs.c_ptr); // allocate LDS - __shared__ char smem_ptr_0[GetSmemSize()]; + __shared__ char smem_ptr[GetSmemSize()]; // Only for BQuantGrouped DoubleSmemBuffer is supported if constexpr(GemmPipeline::DoubleSmemBuffer == true && kQuantType == QuantType::BQuantGrouped) { - - __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; RunGemmWithPipelineSelection2LDS(a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, - smem_ptr_0, - smem_ptr_1, + smem_ptr, kargs, splitk_batch_offset, i_m, @@ -348,7 +345,7 @@ struct QuantGroupedGemmKernel aq_ptr, bq_ptr, c_ptr, - smem_ptr_0, + smem_ptr, kargs, splitk_batch_offset, i_m, @@ -361,7 +358,7 @@ struct QuantGroupedGemmKernel aq_ptr, bq_ptr, c_ptr, - smem_ptr_0, + smem_ptr, kargs, splitk_batch_offset, i_m, @@ -374,47 +371,47 @@ struct QuantGroupedGemmKernel CK_TILE_DEVICE static void RunGemmWithPipelineSelection2LDS(const ADataType* a_ptr, const BDataType* b_ptr, - const AQDataType* aq_ptr, + [[maybe_unused]] const AQDataType* aq_ptr, const BQDataType* bq_ptr, CDataType* c_ptr, - void* smem_ptr_0, - void* smem_ptr_1, + void* smem_ptr, const QuantGroupedGemmKernelArgs& kargs, const typename Base::SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n) { static_assert(kQuantType == QuantType::BQuantGrouped, "kQuantType must be BQuantGrouped"); - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - Base::template MakeGemmTensorViews( - a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset); - const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = - Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + // Create block windows using specialized methods + const auto& a_block_window = + Base::MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); + const auto& b_block_window = + Base::MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); + const auto& bq_block_window = + Base::MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n); const index_t num_loop = __builtin_amdgcn_readfirstlane( TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); - // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(Base::I0); - const auto& b_block_window = gemm_tile_windows.at(Base::I2); + // Run GEMM cooperatively by whole workgroup + const auto& c_block_tile = GemmPipeline{}.template operator()( + a_block_window, b_block_window, bq_block_window, num_loop, tail_num, smem_ptr); - const auto& bq_block_window = gemm_tile_windows.at(Base::I3); - const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window, - b_block_window, - bq_block_window, - num_loop, - tail_num, - smem_ptr_0, - smem_ptr_1); - - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(Base::I4); - - EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + // Run Epilogue Pipeline with split_k dispatch + if(kargs.k_batch == 1) + { + auto c_block_window = Base::template MakeCBlockWindow( + c_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr); + } + else + { + auto c_block_window = + Base::template MakeCBlockWindow( + c_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr); + } } /** @@ -429,7 +426,7 @@ struct QuantGroupedGemmKernel * @param aq_ptr input AQ pointer * @param bq_ptr input BQ pointer * @param c_ptr output C pointer - * @param smem_ptr_0 The start memory pointer of the shared memory block. + * @param smem_ptr The start memory pointer of the shared memory block. * @param kargs GEMM kernel arguments * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k * batch. @@ -443,22 +440,21 @@ struct QuantGroupedGemmKernel const AQDataType* aq_ptr, const BQDataType* bq_ptr, CDataType* c_ptr, - void* smem_ptr_0, + void* smem_ptr, const QuantGroupedGemmKernelArgs& kargs, const typename Base::SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - Base::template MakeGemmTensorViews( - a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset); - - const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = - Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - const auto& a_block_window = gemm_tile_windows.at(Base::I0); - const auto& b_block_window = gemm_tile_windows.at(Base::I2); + // Create block windows using specialized methods + const auto& a_block_window = + Base::MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); + const auto& b_block_window = + Base::MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); + const auto& aq_block_window = + Base::MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n); + const auto& bq_block_window = + Base::MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n); // Get hot-loop and tail configuration const index_t num_loop = __builtin_amdgcn_readfirstlane( @@ -466,55 +462,65 @@ struct QuantGroupedGemmKernel const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); - if constexpr(kQuantType == QuantType::AQuantGrouped) - { - const auto& aq_block_window = gemm_tile_windows.at(Base::I1); - // Run GEMM pipeline - const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window, - b_block_window, - aq_block_window, - num_loop, - has_hot_loop, - tail_num, - smem_ptr_0); - - auto& c_block_window = gemm_tile_windows.at(Base::I4); - - // Run Epilogue Pipeline - EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); - } - else if constexpr(kQuantType == QuantType::BQuantGrouped) - { - const auto& bq_block_window = gemm_tile_windows.at(Base::I3); - // Run GEMM pipeline - const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window, - b_block_window, - bq_block_window, - num_loop, - has_hot_loop, - tail_num, - smem_ptr_0); - - auto& c_block_window = gemm_tile_windows.at(Base::I4); - - // Run Epilogue Pipeline - EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); - } - else - { - // Run GEMM pipeline - const auto& c_block_tile = GemmPipeline{}.template operator()( - a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0); - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(Base::I4); - if constexpr(kQuantType == QuantType::RowColQuant) + // Run GEMM cooperatively by whole workgroup + const auto& c_block_tile = [&]() { + if constexpr(kQuantType == QuantType::AQuantGrouped) + { + return GemmPipeline{}.template operator()(a_block_window, + b_block_window, + aq_block_window, + num_loop, + has_hot_loop, + tail_num, + smem_ptr); + } + else if constexpr(kQuantType == QuantType::BQuantGrouped) + { + return GemmPipeline{}.template operator()(a_block_window, + b_block_window, + bq_block_window, + num_loop, + has_hot_loop, + tail_num, + smem_ptr); + } + else if constexpr(kQuantType == QuantType::ABQuantGrouped) + { + return GemmPipeline{}.template operator()(a_block_window, + b_block_window, + aq_block_window, + bq_block_window, + num_loop, + has_hot_loop, + tail_num, + smem_ptr); + } + else if constexpr(kQuantType == QuantType::RowColQuant || + kQuantType == QuantType::TensorQuant) + { + return GemmPipeline{}.template operator()( + a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr); + } + }(); + + // Run Epilogue Pipeline with split_k dispatch + if(kargs.k_batch == 1) + { + auto c_block_window = Base::template MakeCBlockWindow( + c_ptr, kargs, block_idx_m, block_idx_n); + + if constexpr(kQuantType == QuantType::AQuantGrouped || + kQuantType == QuantType::BQuantGrouped || + kQuantType == QuantType::ABQuantGrouped) + { + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr); + } + else if constexpr(kQuantType == QuantType::RowColQuant) { - const auto& aq_block_window = gemm_tile_windows.at(Base::I1); - const auto& bq_block_window = gemm_tile_windows.at(Base::I3); EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, - smem_ptr_0, + smem_ptr, aq_block_window, bq_block_window); } @@ -523,7 +529,36 @@ struct QuantGroupedGemmKernel const AccDataType aq_scale = type_convert(*aq_ptr); const AccDataType bq_scale = type_convert(*bq_ptr); EpiloguePipeline{}( - c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale); + c_block_window, c_block_tile, c_block_window, smem_ptr, aq_scale, bq_scale); + } + } + else + { + auto c_block_window = + Base::template MakeCBlockWindow( + c_ptr, kargs, block_idx_m, block_idx_n); + + if constexpr(kQuantType == QuantType::AQuantGrouped || + kQuantType == QuantType::BQuantGrouped || + kQuantType == QuantType::ABQuantGrouped) + { + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr); + } + else if constexpr(kQuantType == QuantType::RowColQuant) + { + EpiloguePipeline{}(c_block_window, + c_block_tile, + c_block_window, + smem_ptr, + aq_block_window, + bq_block_window); + } + else if constexpr(kQuantType == QuantType::TensorQuant) + { + const AccDataType aq_scale = type_convert(*aq_ptr); + const AccDataType bq_scale = type_convert(*bq_ptr); + EpiloguePipeline{}( + c_block_window, c_block_tile, c_block_window, smem_ptr, aq_scale, bq_scale); } } } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp index 39f0cbdbd3..a4bba6cf76 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp @@ -48,7 +48,6 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC constexpr index_t NPerBlockBQ = NPerBlock / Problem::BQuantGroupSize::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK; - constexpr index_t VecLoadSize = GetVectorSizeBQ(); constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; using WarpTile = typename Problem::BlockGemmShape::WarpTile; @@ -68,7 +67,8 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC BlockSize, NPerBlock / WarpGemm::kN, ck_tile::integer_least_multiple(WarpGemm::kN * KPerBlockBQ, get_warp_size()), - VecLoadSize, + Problem::BQuantGroupSize::kN, + Problem::BQuantGroupSize::kK, BQLayout, PreshuffleQuant>; return TileEncodingPattern::make_2d_static_tile_distribution(); @@ -83,6 +83,7 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC KPerBlockBQ, // Logical K dimension NPerBlockBQ, // Logical N dimension Problem::BQuantGroupSize::kN, + Problem::BQuantGroupSize::kK, BQLayout>; return TileEncodingPattern::make_2d_static_tile_distribution(); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp index b43066cdc5..13d400d5fc 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp @@ -65,8 +65,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(); } static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } @@ -300,9 +302,12 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}), - 0) + (PreshuffleQuant) + ? make_array(((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{})) + ? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN) + : ck_tile::integer_least_multiple(n, NPerBlock) / + BlockGemmShape::WarpTile::at(number<1>{})), + 0) : is_bq_row_major ? make_array(KPerBlockBQ, 0) : make_array(0, KPerBlockBQ); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp index 0ec8942426..34f815ed27 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp @@ -192,6 +192,7 @@ template struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern @@ -208,31 +209,6 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding static_assert(num_warps == MWarps * NWarps * KWarps); static_assert(KWarps == 1); - /// @brief Creates a 2D tile distribution for BQ (B-matrix quantization scales) - /// - /// This function determines the optimal thread distribution pattern for loading and applying - /// quantization scales to the B matrix based on the quantization group size (NPerQ) relative - /// to warp dimensions. - /// - /// Three distinct distribution patterns are handled: - /// - /// 1. Fine-grained quantization (NPerQ < WarpGemm::kN): - /// - Multiple quantization groups exist within a single warp's N-dimension - /// - Each warp processes multiple scales (WarpGemm::kN / NPerQ scales per warp) - /// - Distribution includes explicit replication factor (XR = NPerQ) for scale broadcast - /// - Example: NPerQ=8, WarpGemm::kN=16, NWarps=4 → 2 scales per warp - /// - /// 2. Medium-grained quantization (WarpGemm::kN <= NPerQ <= WarpGemm::kN * NWarps): - /// - Each warp handles exactly one quantization scale - /// - Scales are distributed across warps with replication factor XR = NPerQ / WarpGemm::kN - /// - Example: NPerQ=64, WarpGemm::kN=16, NWarps=4 → 1 scale per warp, XR=4 - /// - /// 3. Coarse-grained quantization (NPerQ > WarpGemm::kN * NWarps): - /// - Quantization group spans multiple warps - /// - All warps share the same scale value - /// - Example: NPerQ=128, WarpGemm::kN=16, NWarps=4 → all warps use same scale - /// - /// @return A static tile distribution encoding for the BQ scale tensor CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution() { // Preshuffle only supported for ColumnMajor currently @@ -241,22 +217,136 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding if constexpr(PreshuffleQuant) { - // ColumnMajor only for preshuffle - constexpr index_t X1 = warp_size; - constexpr index_t X0 = NPerTile / warp_size; - constexpr index_t Y1 = NWarps; - constexpr index_t Y0 = KPerTile / Y1; + // ============================================================================= + // PRE-SHUFFLED BQ SCALE TILE DISTRIBUTION + // ============================================================================= + // For pre-shuffled quantization, the BQ scale tensor has been reorganized + // (pre-shuffled) to optimize memory access patterns during dequantization. + // + // Tile Dimensions: + // - K-axis (Y in encoding): Corresponds to the K-dimension iteration + // - N-axis (X in encoding): Flattened scale index combining N and K groups + // + // The encoding distributes work across threads such that each thread loads + // the correct pre-shuffled scale for its corresponding B-matrix elements. + // ============================================================================= + if constexpr(NPerQ <= WarpGemm::kN) + { + // ========================================================================= + // CASE 1: Fine-grained Quantization (NPerQ <= WarpGemm::kN) + // ========================================================================= + // Multiple quantization scales exist within a single warp's N-dimension. + // Each warp processes multiple scales: WarpGemm::kN / NPerQ scales per warp. + // + // Example: NPerQ=8, WarpGemm::kN=16, KPerQ=128, BlockGemmShape::kK=256 + // → 2 scales per warp in N, 2 K-groups per block + constexpr auto N1 = BlockGemmShape::kK / + KPerQ; // Number of K-dimension quantization groups per block, + // Each K-group of KPerQ elements shares the same scale. + constexpr auto N0 = + WarpGemm::kN / NPerQ; // Number of scales per warp in N-dimension, Since NPerQ + // <= WarpGemm::kN, each warp handles multiple scales. + constexpr auto N2 = 1; // Elements per thread + constexpr auto NR1 = NPerQ; // Elements sharing the same scale in N-dimension + constexpr auto NR0 = + warp_size / + (N0 * N1 * N2 * NR1); // Interleave factor to ensure full warp utilization + constexpr auto K1 = NWarps; // Number of warps distributed along this dimension + constexpr auto K0 = KPerTile / K1; // Iterations per warp to cover the K-tile + constexpr auto KR = 1; // No replication in K-dimension - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<2>>, - tuple, sequence<1>>, - sequence<1, 2>, - sequence<0, 0>>{}); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0, 2, 0, 2, 0>>, + tuple, sequence<1, 0, 2, 1, 3>>, + sequence<1, 2>, + sequence<0, 2>>{}); + } + else if constexpr(NPerQ < WarpGemm::kN * NWarps) + { + // ========================================================================= + // CASE 2: Medium-grained Quantization (WarpGemm::kN < NPerQ < WarpGemm::kN * + // NWarps) + // ========================================================================= + // Each warp handles exactly one quantization scale in N-dimension. + // Some warps share the same scale (KR > 1 creates warp grouping). + // + // Example: NPerQ=32, WarpGemm::kN=16, NWarps=4 + // → KR=2 (2 warps share same scale), K1=2 (2 unique scale groups) + + constexpr auto KR = NPerQ / WarpGemm::kN; // Number of warps sharing the same scale + constexpr auto K1 = NWarps / KR; // Number of distinct warp groups (unique scales) + constexpr auto K0 = KPerTile / K1; // Iterations to cover K-tile per warp group + constexpr auto N1 = BlockGemmShape::kK / KPerQ; // K-dimension quantization groups + constexpr auto N0 = 1; // Scales per warp in N-dim (1 since NPerQ >= WarpGemm::kN) + constexpr auto N2 = 1; // Elements per thread + constexpr auto NR1 = NPerQ; // Scale broadcast factor (full NPerQ) + constexpr auto NR0 = + warp_size / (N0 * N1 * N2 * NR1); // Remaining interleave factor + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0, 2, 0, 2>>, + tuple, sequence<1, 0, 2, 1>>, + sequence<1, 2>, + sequence<0, 2>>{}); + } + else + { + // ========================================================================= + // CASE 3: Coarse-grained Quantization (NPerQ >= WarpGemm::kN * NWarps) + // ========================================================================= + // The quantization group spans ALL warps in N-dimension. + // All warps share the same scale value for their N-tiles. + // + // Example: NPerQ=128, WarpGemm::kN=16, NWarps=4 + // → 128 >= 16*4=64, so all 4 warps use the same scale + constexpr auto N1 = BlockGemmShape::kK / KPerQ; // K-dimension quantization groups + constexpr auto N0 = 1; // Minimal (1) since scale is shared across N + constexpr auto N2 = 1; // Elements per thread + constexpr auto NR1 = 32; // Fixed broadcast size + constexpr auto NR0 = + warp_size / (N0 * N1 * N2 * NR1); // Remaining interleave factor + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0, 2, 0, 2>>, + tuple, sequence<2, 0, 3, 1>>, + sequence<1, 2>, + sequence<0, 2>>{}); + } } else { + /// @brief Creates a 2D tile distribution for BQ (B-matrix quantization scales) + /// + /// This function determines the optimal thread distribution pattern for loading and + /// applying quantization scales to the B matrix based on the quantization group size + /// (NPerQ) relative to warp dimensions. + /// + /// Three distinct distribution patterns are handled: + /// + /// 1. Fine-grained quantization (NPerQ < WarpGemm::kN): + /// - Multiple quantization groups exist within a single warp's N-dimension + /// - Each warp processes multiple scales (WarpGemm::kN / NPerQ scales per warp) + /// - Distribution includes explicit replication factor (XR = NPerQ) for scale + /// broadcast + /// - Example: NPerQ=8, WarpGemm::kN=16, NWarps=4 → 2 scales per warp + /// + /// 2. Medium-grained quantization (WarpGemm::kN <= NPerQ <= WarpGemm::kN * NWarps): + /// - Each warp handles exactly one quantization scale + /// - Scales are distributed across warps with replication factor XR = NPerQ / + /// WarpGemm::kN + /// - Example: NPerQ=64, WarpGemm::kN=16, NWarps=4 → 1 scale per warp, XR=4 + /// + /// 3. Coarse-grained quantization (NPerQ > WarpGemm::kN * NWarps): + /// - Quantization group spans multiple warps + /// - All warps share the same scale value + /// - Example: NPerQ=128, WarpGemm::kN=16, NWarps=4 → all warps use same scale + /// + /// @return A static tile distribution encoding for the BQ scale tensor if constexpr(NPerQ < WarpGemm::kN) { // Case 1: Fine-grained - multiple quantization scales within a single warp diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp new file mode 100755 index 0000000000..80e41cad45 --- /dev/null +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp @@ -0,0 +1,120 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp" +#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp" + +namespace ck_tile { + +struct GemmWPABQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipelineAgBgCrPolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeAQ() + { + using AQDataType = remove_cvref_t; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPerBlockAQ = KPerBlock / Problem::AQuantGroupSize::kK; + + return GetABQGlobalVectorLoadSize(); + } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeAQDramTileDistribution() + { + return GemmAQuantPipelineAgBgCrDefaultPolicy::MakeAQDramTileDistribution(); + } + template + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeBQ() + { + using BQDataType = remove_cvref_t; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t NPerBlockBQ = NPerBlock / Problem::BQuantGroupSize::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK; + + return GetABQGlobalVectorLoadSize(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBQDramTileDistribution() + { + return GemmBQuantPipelineAgBgCrDefaultPolicy::MakeBQDramTileDistribution(); + } + + // as UniversalWeightPreshufflePipelineAgBgCrPolicy's MakeBFlatDramTileDistribution is changed; + // move original UniversalWeightPreshufflePipelineAgBgCrPolicy's implementation to here + // temporarily + template + CK_TILE_DEVICE static constexpr auto MakeBFlatDramTileDistribution() + { + using TileShape = typename Problem::BlockGemmShape; + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t WaveSize = get_warp_size(); + constexpr index_t WaveNum = BlockSize / WaveSize; + constexpr index_t KBPerLoad = GetKBPerLoad(); +#if defined(__gfx11__) + constexpr index_t KRepeatInWave = 2; +#else + constexpr index_t KRepeatInWave = 1; +#endif + constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim + constexpr index_t KWavePerBlk = 1; + constexpr index_t KRepeat = 1; + static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong"); + + constexpr index_t NBPerLoad = 1; + constexpr index_t NThdPerWave = 1; + constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp + constexpr index_t NRepeat = 1; + + constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp; + return make_static_tile_distribution( + tile_distribution_encoding< + sequence, // ? + tuple, // second direction + sequence>, // first direction + // wave in blk, // thd in wave + // // + tuple, sequence<0, 1, 2>>, // which direction + tuple, sequence<1, 2, 2>>, // which index + // + sequence<1, 1, 2, 2>, + sequence<0, 3, 0, 3>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockWeightPreshuffleBQuant() + { + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + + using BTypeToUse = + std::conditional_t, + typename Problem::ADataType, + typename Problem::BDataType>; + + using WarpGemm = WarpGemmDispatcher; + + // TODO : Use a custom block policy for AsBrCr + using BlockGemmPolicy = + BlockWeightPreshuffleASmemBSmemCRegV1CustomPolicy; + return BlockGemmWeightPreshuffleABQuantARegBRegCReg{}; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp new file mode 100644 index 0000000000..0f3951ffcc --- /dev/null +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp @@ -0,0 +1,611 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/host/concat.hpp" + +namespace ck_tile { + +template +struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV2 +{ + using Base = WeightPreshufflePipelineAGmemBGmemCRegV2; + using ADataType = remove_cvref_t; + using AQDataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using BQDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + using AQuantGroupSize = remove_cvref_t; + using BQuantGroupSize = remove_cvref_t; + + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using BQLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using BlockWeightPreshuffle = remove_cvref_t< + decltype(PipelinePolicy::template GetBlockWeightPreshuffleBQuant())>; + + static constexpr auto config = + BlockWeightPreshuffle::BlockPolicy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + using Base::kKPerBlock; + using Base::kMPerBlock; + using Base::kNPerBlock; + + using Base::KIterPerWarp; + using Base::MIterPerWarp; + using Base::NIterPerWarp; + + using Base::BlockSize; + + using Base::kPadK; + using Base::kPadM; + using Base::kPadN; + + using Base::I0; + using Base::I1; + using Base::I2; + + using Base::MWarp; + using Base::NWarp; + + using Base::KPerBlockPerIter; + using Base::MPerBlockPerIter; + + using Base::flatKPerWarp; + using Base::flatNPerWarp; + + using Base::m_preload; + + static constexpr index_t VectorLoadSize = Problem::VectorLoadSize; + static constexpr index_t KPerBlockAQ = + integer_divide_ceil(BlockGemmShape::kK, AQuantGroupSize::kK); + static constexpr index_t KPerBlockBQ = + integer_divide_ceil(BlockGemmShape::kK, BQuantGroupSize::kK); + static constexpr index_t QScalesPerBlockRow = + integer_divide_ceil(kKPerBlock, BQuantGroupSize::kK); + static constexpr index_t GetVectorSizeAQ() + { + return PipelinePolicy::template GetVectorSizeAQ(); + } + static constexpr index_t GetVectorSizeBQ() + { + return PipelinePolicy::template GetVectorSizeBQ(); + } + static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1); + return concat('_', "bquant_pipeline_AgBgCrV2_preshuffleB", + concat('x', kMPerBlock, kNPerBlock, kKPerBlock), + BlockSize, + concat('x', WaveNumM, WaveNumN), + concat('x', Base::GetVectorSizeA(), Base::GetVectorSizeB(), GetVectorSizeAQ(), GetVectorSizeBQ()), + concat('x', kPadM, kPadN, kPadK), AQuantGroupSize::GetName(), BQuantGroupSize::GetName()); + // clang-format on + } + + template + CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler() + { + // Estimated number of VMEM vector loads for A per block: + // total A bytes / (threads per block * vector width) + constexpr index_t Aload_inst = + (kMPerBlock * kKPerBlock * sizeof(ADataType)) / BlockSize / VectorLoadSize; + // Estimated number of VMEM vector loads for B per block: + // total B bytes / (threads per block * vector width) + constexpr index_t Bload_inst = + (kKPerBlock * kNPerBlock * sizeof(BDataType)) / BlockSize / VectorLoadSize; + + // Estimated number of VMEM loads for B's quant data (e.g. scales / zp). + // First ceil-divide by quant group size (how many elements share one scale), + // then by vector width to get an approximate number of vector loads. + constexpr index_t BQload_inst = ck_tile::integer_divide_ceil( + ck_tile::integer_divide_ceil(kKPerBlock * kNPerBlock * sizeof(BQDataType), + BQuantGroupSize::kK * BQuantGroupSize::kK), + VectorLoadSize); + + // ToDo: Hardcoded, need to change in future. How many instruction emit per iteration + constexpr index_t kLdsInstCycle = 8; + // Total VMEM load instructions (A + B + quant data) + constexpr index_t buffer_load_inst = Aload_inst + Bload_inst + BQload_inst; + // Approximate number of LDS reads per block + constexpr index_t ds_read_inst = kMPerBlock / kLdsInstCycle; + // Approximate number of LDS writes per block + // (e.g., writing A from VMEM into LDS once per A load) + constexpr index_t ds_write_inst = Aload_inst; + // Number of MFMA instructions per wave for one block tile: + constexpr index_t mfma_inst = (kMPerBlock / WG::kM) * (kNPerBlock / WG::kN); + // How often (in MFMA units) we should insert DS (LDS) operations. + constexpr index_t ds_rep = mfma_inst / (ds_read_inst + ds_write_inst); + // How often (in MFMA units) we should insert VMEM buffer loads. + // buffer_load_rep ≈ "MFMA per VMEM_READ", clamped so that one buffer_load + // is assumed to cover at most 4 MFMA instructions. + constexpr index_t buffer_load_rep = + min(mfma_inst / buffer_load_inst, 4); // 1 buffer_load cover 4 mfma + + static_for<0, nloop, 1>{}([&](auto) { + static_for<0, mfma_inst, 1>{}([&](auto i_inst) { + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); // MFMA + + // Insert LDS read/write groups periodically based on ds_rep. + // The % pattern staggers READ and WRITE so they don't collapse + // into the same cycle in the model. + if constexpr(ds_rep > 0 && i_inst % ds_rep == 0) + { + __builtin_amdgcn_sched_group_barrier( + LLVMSchedGroupMask::DS_READ, 1, 0); // DS read + } + if constexpr(ds_rep > 0 && i_inst % ds_rep == 1) + { + __builtin_amdgcn_sched_group_barrier( + LLVMSchedGroupMask::DS_WRITE, 1, 0); // DS write + } + + if constexpr(buffer_load_rep > 0 && i_inst % buffer_load_rep == 0) + { + if constexpr(ds_write_inst > 0) + { + __builtin_amdgcn_sched_group_barrier( + LLVMSchedGroupMask::VMEM_READ, 1, 0); // VMEM read + } + } + // Always mark some VALU work in the loop to reflect auxiliary scalar + // or vector ALU instructions that coexist with MFMA (Blockscale calculation). + __builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::VALU, 2, 0); // VALU + }); + }); + __builtin_amdgcn_sched_barrier(0); + } + + static constexpr bool PreshuffleB = Problem::PreshuffleB; + static constexpr auto TailNum = Problem::TailNum; + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + const AQDramBlockWindowTmp& aq_dram_block_window_tmp, + const BQDramBlockWindowTmp& bq_dram_block_window_tmp, + index_t m, + index_t n, + index_t num_loop, + void* p_smem) const + { + (void)m; + (void)n; + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "A/B/BQ Dram block window should have the same data type as appropriate " + "([A|B|BQ]DataType) defined in Problem definition!"); + + constexpr bool is_a_col_major = std::is_same_v; + static_assert(!is_a_col_major, "A must be row major (col major not supported yet)"); + + constexpr bool is_bq_col_major = std::is_same_v; + static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)"); + + constexpr bool is_b_row_major = std::is_same_v; + static_assert(!is_b_row_major, "B must be col major (row major not supported yet)"); + + const index_t iMWarp = get_warp_id() / NWarp; + // Double-Buffering (loop_count=2) for full load/compute overlap. + const index_t loop_count = 2; + + __builtin_amdgcn_sched_barrier(0); + + // A tile in LDS + constexpr index_t smem_size = PipelinePolicy::template GetSmemSize(); + ADataType* p_a_lds_ping = static_cast(p_smem); + ADataType* p_a_lds_pong = + reinterpret_cast(static_cast(p_smem) + smem_size); + + constexpr auto a_lds_block_desc = + PipelinePolicy::template MakeALdsBlockDescriptor(); + + auto a_lds_block_ping = + make_tensor_view(p_a_lds_ping, a_lds_block_desc); + auto a_lds_block_pong = + make_tensor_view(p_a_lds_pong, a_lds_block_desc); + + // A DRAM tile window for load + auto a_copy_dram_window = + make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp.get_window_origin(), + PipelinePolicy::template MakeADramTileDistribution()); + + auto a_copy_lds_window_ping = + make_tile_window(a_lds_block_ping, + make_tuple(number{}, number{}), + {0, 0}, + PipelinePolicy::template MakeADramTileDistribution()); + + auto a_copy_lds_window_pong = + make_tile_window(a_lds_block_pong, + make_tuple(number{}, number{}), + {0, 0}, + PipelinePolicy::template MakeADramTileDistribution()); + + // ping-pong window for A LDS + auto a_warp_window_ping_tmp = + make_tile_window(a_lds_block_ping, + make_tuple(number{}, number{}), + {iMWarp * WG::kM, 0}, + make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + + auto a_warp_window_pong_tmp = + make_tile_window(a_lds_block_pong, + make_tuple(number{}, number{}), + {iMWarp * WG::kM, 0}, + make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows_ping; + + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows_pong; + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; + + move_tile_window(a_warp_windows_ping(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp; + + move_tile_window(a_warp_windows_pong(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + // Block GEMM + auto block_weight_preshuffle = BlockWeightPreshuffle(); + // Acc register tile + auto c_block_tile = block_weight_preshuffle.MakeCBlockTile(); + + // B flat DRAM window for load + auto b_flat_distribution = + PipelinePolicy::template MakeBFlatDramTileDistribution(); + auto b_flat_dram_window = // tile_window_with_static_distribution + make_tile_window( + b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views + make_tuple(number{}, number{}), + b_flat_dram_block_window_tmp.get_window_origin(), + b_flat_distribution); + + using BTypeToUse = + std::conditional_t, ADataType, BDataType>; + using BTileType = decltype(make_static_distributed_tensor(b_flat_distribution)); + + // pingpong buffer for B + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_flat_dram_windows; + + statically_indexed_array, NIterPerWarp> + b_warp_tensor_ping; + + statically_indexed_array, NIterPerWarp> + b_warp_tensor_pong; + + auto aq_copy_dram_window = + make_tile_window(aq_dram_block_window_tmp.get_bottom_tensor_view(), + aq_dram_block_window_tmp.get_window_lengths(), + aq_dram_block_window_tmp.get_window_origin(), + PipelinePolicy::template MakeAQDramTileDistribution()); + // BQ DRAM window for load + auto bq_copy_dram_window = + make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(), + bq_dram_block_window_tmp.get_window_lengths(), + bq_dram_block_window_tmp.get_window_origin(), + PipelinePolicy::template MakeBQDramTileDistribution()); + + // Prefetch A0 + auto a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // prefetch B + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * flatNPerWarp, kIter * flatKPerWarp}); + + load_int4_tile( + b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + }); + }); + // move B window to next flat K + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + + // Strictly not needed given type deduction, but helps with readability + using AQBlockTileDistr = decltype(aq_copy_dram_window.get_tile_distribution()); + using AQBlockTile = + decltype(make_static_distributed_tensor(AQBlockTileDistr{})); + using BQBlockTileDistr = decltype(bq_copy_dram_window.get_tile_distribution()); + using BQBlockTile = + decltype(make_static_distributed_tensor(BQBlockTileDistr{})); + + // Load tile 0 for BQ data directly into registers for block tile + AQBlockTile aq_block_tile, aq_block_tile_2; + BQBlockTile bq_block_tile, bq_block_tile_2; + aq_block_tile = load_tile(aq_copy_dram_window); + bq_block_tile = load_tile(bq_copy_dram_window); + // move BQ to tile 1 + move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ}); + move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ}); + // Prefill A0 + auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_ping, a_block_tile_tmp); + + __builtin_amdgcn_sched_barrier(0); + + // Prefetch A1 + a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + block_sync_lds(); + + // preload A00,A10 from lds + statically_indexed_array{})(number<0>{}))), + m_preload> + a_warp_tensor; + + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor(loadIter) = + load_tile(a_warp_windows_ping(number{})(number{})); + }); + __builtin_amdgcn_sched_barrier(0); + + // MAIN LOOP + index_t iCounter = (num_loop - 1) / loop_count; + + while(iCounter > 0) + { + __builtin_amdgcn_sched_barrier(0); + // Prefill A(2i+1) + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_pong, a_block_tile_tmp); + + // Prefetch A(2i+2) + a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // GEMM 2i + block_weight_preshuffle(c_block_tile, + a_warp_tensor, + b_warp_tensor_ping, + aq_block_tile, + bq_block_tile, + a_warp_windows_ping); + // prefetch B(2i+1) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * flatNPerWarp, kIter * flatKPerWarp}); + load_int4_tile( + b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + }); + }); + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + aq_block_tile_2 = load_tile(aq_copy_dram_window); + move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ}); + bq_block_tile_2 = load_tile(bq_copy_dram_window); + move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ}); + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor(loadIter) = + load_tile(a_warp_windows_pong(number{})(number{})); + }); + + // Next K + + // prefetch B(2i+2) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * flatNPerWarp, kIter * flatKPerWarp}); + load_int4_tile( + b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + }); + }); + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + aq_block_tile = load_tile(aq_copy_dram_window); + move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ}); + bq_block_tile = load_tile(bq_copy_dram_window); + move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ}); + + // Prefill A(2i+2) + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_ping, a_block_tile_tmp); + + // Prefetch A(2i+3) + a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // GEMM 2i+1 + block_weight_preshuffle(c_block_tile, + a_warp_tensor, + b_warp_tensor_pong, + aq_block_tile_2, + bq_block_tile_2, + a_warp_windows_pong); + + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor(loadIter) = + load_tile(a_warp_windows_ping(number{})(number{})); + }); + iCounter--; + HotLoopScheduler(); + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + // prefetch B(loopK) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * flatNPerWarp, kIter * flatKPerWarp}); + + load_int4_tile( + b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + }); + }); + aq_block_tile_2 = load_tile(aq_copy_dram_window); + bq_block_tile_2 = load_tile(bq_copy_dram_window); + + // Prefill A(loopK) + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_pong, a_block_tile_tmp); + + // GEMM loopK-1 + block_weight_preshuffle(c_block_tile, + a_warp_tensor, + b_warp_tensor_ping, + aq_block_tile, + bq_block_tile, + a_warp_windows_ping); + + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor(loadIter) = + load_tile(a_warp_windows_pong(number{})(number{})); + }); + + // GEMM loopK + block_weight_preshuffle(c_block_tile, + a_warp_tensor, + b_warp_tensor_pong, + aq_block_tile_2, + bq_block_tile_2, + a_warp_windows_pong); + HotLoopScheduler(); + } + else if constexpr(TailNum == TailNumber::Odd) + { + // GEMM loopK + block_weight_preshuffle(c_block_tile, + a_warp_tensor, + b_warp_tensor_ping, + aq_block_tile, + bq_block_tile, + a_warp_windows_ping); + Base::LastHotLoopScheduler(); + } + + return c_block_tile; + } + + // Replace lines 485-526 with a single optimized operator: + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + const AQDramBlockWindowTmp& aq_dram_block_window_tmp, + const BQDramBlockWindowTmp& bq_dram_block_window_tmp, + index_t num_loop, + void* p_smem, + index_t m = 0, + index_t n = 0) const // Default value for non-preshuffle case + { + return operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_flat_dram_block_window_tmp, + aq_dram_block_window_tmp, + bq_dram_block_window_tmp, + m, + n, + num_loop, + p_smem); + } + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + const AQDramBlockWindowTmp& aq_dram_block_window_tmp, + const BQDramBlockWindowTmp& bq_dram_block_window_tmp, + index_t num_loop, + TailNumber tail_number, + void* p_smem, + index_t n = 0) const + { + const auto RunPipeline = [&](auto bool_val, auto tail_num_) { + (void)bool_val; // Suppress unused parameter warning + constexpr auto tail_num = tail_num_.value; + return operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_flat_dram_block_window_tmp, + aq_dram_block_window_tmp, + bq_dram_block_window_tmp, + n, // dummy value, won't be used + num_loop, + p_smem); + }; + return Base::TailHandler(RunPipeline, true, tail_number); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp index b155297054..b7dc0bd616 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp @@ -29,6 +29,48 @@ struct GemmWPQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipelin return GemmBQuantPipelineAgBgCrDefaultPolicy::MakeBQDramTileDistribution(); } + // as UniversalWeightPreshufflePipelineAgBgCrPolicy's MakeBFlatDramTileDistribution is changed; + // move original UniversalWeightPreshufflePipelineAgBgCrPolicy's implementation to here + // temporarily + template + CK_TILE_DEVICE static constexpr auto MakeBFlatDramTileDistribution() + { + using TileShape = typename Problem::BlockGemmShape; + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t WaveSize = get_warp_size(); + constexpr index_t WaveNum = BlockSize / WaveSize; + constexpr index_t KBPerLoad = GetKBPerLoad(); +#if defined(__gfx11__) + constexpr index_t KRepeatInWave = 2; +#else + constexpr index_t KRepeatInWave = 1; +#endif + constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim + constexpr index_t KWavePerBlk = 1; + constexpr index_t KRepeat = 1; + static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong"); + + constexpr index_t NBPerLoad = 1; + constexpr index_t NThdPerWave = 1; + constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp + constexpr index_t NRepeat = 1; + + constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp; + return make_static_tile_distribution( + tile_distribution_encoding< + sequence, // ? + tuple, // second direction + sequence>, // first direction + // wave in blk, // thd in wave + // // + tuple, sequence<0, 1, 2>>, // which direction + tuple, sequence<1, 2, 2>>, // which index + // + sequence<1, 1, 2, 2>, + sequence<0, 3, 0, 3>>{}); + } + template CK_TILE_HOST_DEVICE static constexpr auto GetBlockWeightPreshuffleBQuant() { diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp index 18b236c29b..e4de7e4211 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp @@ -71,6 +71,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant; static constexpr index_t VectorLoadSize = Problem::VectorLoadSize; + static constexpr index_t NPerBlockBQ = + integer_divide_ceil(BlockGemmShape::kN, QuantGroupSize::kN); static constexpr index_t KPerBlockBQ = integer_divide_ceil(BlockGemmShape::kK, QuantGroupSize::kK); static constexpr index_t QScalesPerBlockRow = @@ -184,8 +186,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV const BQDramBlockWindowTmp& bq_dram_block_window_tmp, index_t n, index_t num_loop, - void* p_smem_ping, - void* p_smem_pong) const + void* p_smem) const { static_assert( std::is_same_v> && @@ -210,8 +211,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV __builtin_amdgcn_sched_barrier(0); // A tile in LDS - ADataType* p_a_lds_ping = static_cast(p_smem_ping); - ADataType* p_a_lds_pong = static_cast(p_smem_pong); + constexpr index_t smem_size = PipelinePolicy::template GetSmemSize(); + ADataType* p_a_lds_ping = static_cast(p_smem); + ADataType* p_a_lds_pong = + reinterpret_cast(static_cast(p_smem) + smem_size); constexpr auto a_lds_block_desc = PipelinePolicy::template MakeALdsBlockDescriptor(); @@ -351,8 +354,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV if constexpr(PreshuffleQuant) { move_tile_window(bq_copy_dram_window, - {ck_tile::integer_least_multiple(n, kNPerBlock) / - BlockGemmShape::WarpTile::at(number<1>{}), + {((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{})) + ? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN) + : ck_tile::integer_least_multiple(n, kNPerBlock) / + BlockGemmShape::WarpTile::at(number<1>{})), 0}); } else @@ -426,8 +431,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV if constexpr(PreshuffleQuant) { move_tile_window(bq_copy_dram_window, - {ck_tile::integer_least_multiple(n, kNPerBlock) / - BlockGemmShape::WarpTile::at(number<1>{}), + {((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{})) + ? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN) + : ck_tile::integer_least_multiple(n, kNPerBlock) / + BlockGemmShape::WarpTile::at(number<1>{})), 0}); } else @@ -461,8 +468,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV if constexpr(PreshuffleQuant) { move_tile_window(bq_copy_dram_window, - {ck_tile::integer_least_multiple(n, kNPerBlock) / - BlockGemmShape::WarpTile::at(number<1>{}), + {((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{})) + ? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN) + : ck_tile::integer_least_multiple(n, kNPerBlock) / + BlockGemmShape::WarpTile::at(number<1>{})), 0}); } else @@ -561,9 +570,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, const BQDramBlockWindowTmp& bq_dram_block_window_tmp, index_t num_loop, - void* p_smem_ping, - void* p_smem_pong, - index_t n = 0) const // Default value for non-preshuffle case + void* p_smem, + index_t n = 0) const { return operator()( a_dram_block_window_tmp, @@ -572,8 +580,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV bq_dram_block_window_tmp, n, num_loop, - p_smem_ping, - p_smem_pong); + p_smem); } template (a_ptr, kargs.a_grid_descs_m_k[group_id]); + + // Step 2: Create padded view + const auto& a_pad_view = pad_tensor_view( + a_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + + // Step 3: Create tile window + auto a_block_window = make_tile_window( + a_pad_view, + make_tuple(number{}, number{}), + {i_m, i_k}); + + return a_block_window; + } + + CK_TILE_DEVICE static auto + MakeBBlockWindow(const InDataType* b_ptr, + const GroupedConvBwdDataKernelArgsSpecialized& kargs, + const index_t group_id, + const index_t i_n, + const index_t i_k) + { + // Step 1: Create tensor view for B (Weight tensor) + const auto& b_tensor_view = + make_tensor_view(b_ptr, kargs.b_grid_descs_n_k[group_id]); + + // Step 2: Create padded view + const auto& b_pad_view = pad_tensor_view( + b_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + + // Step 3: Create tile window + auto b_block_window = make_tile_window( + b_pad_view, + make_tuple(number{}, number{}), + {i_k, i_n}); + + return b_block_window; + } + + CK_TILE_DEVICE static auto + MakeDBlockWindows(const std::array& ds_ptr, + const GroupedConvBwdDataKernelArgsSpecialized& kargs, + const index_t group_id, + const index_t i_m, + const index_t i_n) + { + // Create D tensor block windows + const auto ds_block_window = generate_tuple( + [&](auto i) { + // Step 1: Create tensor view for D + const auto& d_tensor_view = make_tensor_view( + static_cast(ds_ptr[i]), kargs.c_grid_descs_m_n[group_id]); + + // Step 2: Create padded view + const auto& d_pad_view = + pad_tensor_view(d_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + + // Step 3: Create tile window + return make_tile_window(d_pad_view, + make_tuple(number{}, + number{}), + {i_m, i_n}); + }, + number{}); + + return ds_block_window; + } + + template + CK_TILE_DEVICE static auto + MakeCBlockWindow(WeiDataType* c_ptr, + const GroupedConvBwdDataKernelArgsSpecialized& kargs, + const index_t group_id, + const index_t i_m, + const index_t i_n) + { + // Step 1: Create tensor view for C (Input tensor) + const auto& c_tensor_view = make_tensor_view( + c_ptr, kargs.c_grid_descs_m_n[group_id]); + + // Step 2: Create padded view + const auto& c_pad_view = pad_tensor_view( + c_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + + // Step 3: Create tile window + auto c_block_window = make_tile_window( + c_pad_view, + make_tuple(number{}, number{}), + {i_m, i_n}); + + return c_block_window; + } + CK_TILE_HOST static bool IsSupportedArgument(const GroupedConvBwdDataKernelArgsSpecialized& kargs) { @@ -895,38 +1006,49 @@ struct GroupedConvolutionBackwardDataKernel const index_t block_idx_k, const index_t group_id) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews( - a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id); - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); + // Create block windows using specialized methods + const auto& a_block_window = + MakeABlockWindow(a_ptr, kargs, group_id, block_idx_m, block_idx_k); + const auto& b_block_window = + MakeBBlockWindow(b_ptr, kargs, group_id, block_idx_n, block_idx_k); + const auto& d_block_window = + MakeDBlockWindows(ds_ptr, kargs, group_id, block_idx_m, block_idx_n); const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitted_k)); const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); - auto gemm_tile_windows = - MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k); - // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_block_window = gemm_tile_windows.at(I1); - const auto& d_block_window = gemm_tile_windows.at(I2); - const auto& c_block_tile = GemmPipeline{}.template operator()( a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0); - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I3); + const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch); - EpiloguePipeline{}.template operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); + // Run Epilogue Pipeline with k_batch dispatch + if(k_batch == 1) + { + auto c_block_window = MakeCBlockWindow( + c_ptr, kargs, group_id, block_idx_m, block_idx_n); + + EpiloguePipeline{} + .template operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + else + { + auto c_block_window = MakeCBlockWindow( + c_ptr, kargs, group_id, block_idx_m, block_idx_n); + + EpiloguePipeline{} + .template operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } } /** * @brief Runs single GEMM problem cooperatively by whole workgroup. * - * @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism. + * @note RunGemm2LDS in with two shared memory buffers using the ping pong buffer mechanism. * * @param a_ptr input A pointer * @param b_ptr input B pointer @@ -951,23 +1073,19 @@ struct GroupedConvolutionBackwardDataKernel const index_t block_idx_k, const index_t group_id) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews( - a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id); - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); + // Create block windows using specialized methods + const auto& a_block_window = + MakeABlockWindow(a_ptr, kargs, group_id, block_idx_m, block_idx_k); + const auto& b_block_window = + MakeBBlockWindow(b_ptr, kargs, group_id, block_idx_n, block_idx_k); + const auto& d_block_window = + MakeDBlockWindows(ds_ptr, kargs, group_id, block_idx_m, block_idx_n); const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitted_k)); const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); - auto gemm_tile_windows = - MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k); // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_block_window = gemm_tile_windows.at(I1); - const auto& d_block_window = gemm_tile_windows.at(I2); - const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, @@ -976,11 +1094,27 @@ struct GroupedConvolutionBackwardDataKernel smem_ptr_0, smem_ptr_1); - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I3); + const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch); - EpiloguePipeline{}.template operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); + // Run Epilogue Pipeline with k_batch dispatch + if(k_batch == 1) + { + auto c_block_window = MakeCBlockWindow( + c_ptr, kargs, group_id, block_idx_m, block_idx_n); + + EpiloguePipeline{} + .template operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + else + { + auto c_block_window = MakeCBlockWindow( + c_ptr, kargs, group_id, block_idx_m, block_idx_n); + + EpiloguePipeline{} + .template operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } } CK_TILE_DEVICE index_t FindGroupId(const GroupedConvBwdDataKernelArgsSpecialized& kargs, @@ -1066,8 +1200,7 @@ struct GroupedConvolutionBackwardDataKernel if constexpr(GemmPipeline::DoubleSmemBuffer == true) { __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; - if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && - GroupedConvTraitsType_::VectorSizeC % 2 != 0 && + if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 && is_any_of::value)) { RunGemm2LDS(a_ptr, @@ -1086,8 +1219,7 @@ struct GroupedConvolutionBackwardDataKernel } else { - if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && - GroupedConvTraitsType_::VectorSizeC % 2 != 0 && + if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 && is_any_of::value)) { RunGemm(a_ptr, diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp index 4b7ad72ffc..6bcd05e9ba 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp @@ -518,25 +518,6 @@ struct GroupedConvolutionBackwardWeightKernel return false; } -#if defined(__gfx11__) - if constexpr(EpiloguePipeline::MemoryOperation != ck_tile::memory_operation_enum::set) - { - return false; - } -#endif - - if constexpr(EpiloguePipeline_::MemoryOperation == memory_operation_enum::atomic_add) - { - if(kargs.k_batch == 1) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("Atomic add epilogue only supports k_batch > 1."); - } - return false; - } - } - if constexpr(!std::is_same_v && !std::is_same_v) { @@ -704,29 +685,31 @@ struct GroupedConvolutionBackwardWeightKernel template CK_TILE_DEVICE static auto - MakeGemmTensorViews(const OutDataType* a_ptr, - const InDataType* b_ptr, - const std::array& ds_ptr, - WeiDataType* c_ptr, - const GroupedConvBwdWeightKernelArgsSpecialized& kargs) + MakeCBlockWindow(WeiDataType* c_ptr, + const GroupedConvBwdWeightKernelArgsSpecialized& kargs, + const index_t block_idx_m, + const index_t block_idx_n) { - static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!"); - static_assert(!GemmPipeline::BlockGemmShape::PermuteB, "Not implemented!"); - const auto& a_tensor_view = [&]() { - return make_tensor_view(a_ptr, - kargs.a_grid_desc_k_m); // A: out - }(); + const auto& c_tensor_view = + make_tensor_view(c_ptr, kargs.c_grid_desc_m_n); - const auto& b_tensor_view = [&]() { - return make_tensor_view(b_ptr, - kargs.b_grid_desc_k_n); // B: in - }(); + const auto& c_pad_view = pad_tensor_view( + c_tensor_view, + make_tuple(number{}, number{}), + sequence{}); - const auto& c_tensor_view = [&]() { - return make_tensor_view(c_ptr, - kargs.c_grid_desc_m_n); - }(); + return make_tile_window( + c_pad_view, + make_tuple(number{}, number{}), + {block_idx_m, block_idx_n}); + } + CK_TILE_DEVICE static auto + MakeDBlockWindows(const std::array& ds_ptr, + const GroupedConvBwdWeightKernelArgsSpecialized& kargs, + const index_t block_idx_m, + const index_t block_idx_n) + { const auto& ds_tensor_view = generate_tuple( [&](auto i) { static_assert(std::is_same_v, OutLayout>, @@ -741,30 +724,7 @@ struct GroupedConvolutionBackwardWeightKernel }, number{}); - return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view); - } - - template - CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) - { - const auto& a_pad_view = [&]() { - const auto& a_tensor_view = views.at(I0); - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - }(); - - const auto& b_pad_view = [&]() { - const auto& b_tensor_view = views.at(I1); - return pad_tensor_view(b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - }(); - - const auto& ds_tensor_view = views.at(I2); - const auto& ds_pad_view = generate_tuple( + const auto& ds_pad_view = generate_tuple( [&](auto i) { return pad_tensor_view(ds_tensor_view[i], make_tuple(number{}, @@ -773,67 +733,58 @@ struct GroupedConvolutionBackwardWeightKernel }, number{}); - const auto& c_pad_view = [&]() { - const auto& c_tensor_view = views.at(I3); - return pad_tensor_view(c_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - }(); - - return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view); - } - - /** - * @brief Create views to the data that each workgroup will process. - * - * @param views padded views of A, B, D and C tensors - * @param i_m block m-index - * @param i_n block n-index - * @param i_k block k-index - * - * @return tuple of tile windows for A, B, D and C tensors - */ - template - CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views, - const index_t i_m, - const index_t i_n, - const index_t i_k) - { - const auto& a_pad_view = views.at(I0); - const auto& b_pad_view = views.at(I1); - const auto& ds_pad_view = views.at(I2); - const auto& c_pad_view = views.at(I3); - - const auto& a_block_window = [&]() { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {i_k, i_m}); - }(); - - const auto& b_block_window = [&]() { - return make_tile_window(b_pad_view, - make_tuple(number{}, - number{}), - {i_k, i_n}); - }(); - - const auto ds_block_window = generate_tuple( + return generate_tuple( [&](auto i) { return make_tile_window(ds_pad_view[i], make_tuple(number{}, number{}), - {i_m, i_n}); + {block_idx_m, block_idx_n}); }, number{}); + } - auto c_block_window = make_tile_window( - c_pad_view, - make_tuple(number{}, number{}), - {i_m, i_n}); + CK_TILE_DEVICE static auto + MakeBBlockWindow(const InDataType* b_ptr, + const GroupedConvBwdWeightKernelArgsSpecialized& kargs, + const index_t block_idx_n, + const index_t block_idx_k) + { + static_assert(!GemmPipeline::BlockGemmShape::PermuteB, "Not implemented!"); + const auto& b_tensor_view = + make_tensor_view(b_ptr, kargs.b_grid_desc_k_n); - return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window); + const auto& b_pad_view = + pad_tensor_view(b_tensor_view, + make_tuple(number{} * kargs.k_batch, + number{}), + sequence{}); + + return make_tile_window( + b_pad_view, + make_tuple(number{}, number{}), + {block_idx_k, block_idx_n}); + } + + CK_TILE_DEVICE static auto + MakeABlockWindow(const OutDataType* a_ptr, + const GroupedConvBwdWeightKernelArgsSpecialized& kargs, + const index_t block_idx_m, + const index_t block_idx_k) + { + static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!"); + const auto& a_tensor_view = + make_tensor_view(a_ptr, kargs.a_grid_desc_k_m); + + const auto& a_pad_view = + pad_tensor_view(a_tensor_view, + make_tuple(number{} * kargs.k_batch, + number{}), + sequence{}); + + return make_tile_window( + a_pad_view, + make_tuple(number{}, number{}), + {block_idx_k, block_idx_m}); } /** @@ -859,28 +810,30 @@ struct GroupedConvolutionBackwardWeightKernel const index_t block_idx_n, const index_t block_idx_k) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews( - a_ptr, b_ptr, ds_ptr, c_ptr, kargs); - - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = - MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k); + // Create block windows using helper methods + const auto& a_block_window = MakeABlockWindow(a_ptr, kargs, block_idx_m, block_idx_k); + const auto& b_block_window = MakeBBlockWindow(b_ptr, kargs, block_idx_n, block_idx_k); + const auto& d_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n); // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_block_window = gemm_tile_windows.at(I1); - const auto& d_block_window = gemm_tile_windows.at(I2); - const auto& c_block_tile = GemmPipeline{}.template operator()( a_block_window, b_block_window, num_loop, smem_ptr_0); - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I3); + // Run Epilogue Pipeline with k_batch dispatching + if(kargs.k_batch == 1) + { + auto c_block_window = MakeCBlockWindow( + c_ptr, kargs, block_idx_m, block_idx_n); - EpiloguePipeline{}.template operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); + EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + else + { + auto c_block_window = MakeCBlockWindow( + c_ptr, kargs, block_idx_m, block_idx_n); + + EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } } /** @@ -910,27 +863,33 @@ struct GroupedConvolutionBackwardWeightKernel const index_t block_idx_n, const index_t block_idx_k) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews( - a_ptr, b_ptr, ds_ptr, c_ptr, kargs); - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = - MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k); + // Create block windows using helper methods + const auto& a_block_window = MakeABlockWindow(a_ptr, kargs, block_idx_m, block_idx_k); + const auto& b_block_window = MakeBBlockWindow(b_ptr, kargs, block_idx_n, block_idx_k); + const auto& d_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n); // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_block_window = gemm_tile_windows.at(I1); - const auto& d_block_window = gemm_tile_windows.at(I2); - const auto& c_block_tile = GemmPipeline{}.template operator()( a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1); - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I3); + // Run Epilogue Pipeline with k_batch dispatching + if(kargs.k_batch == 1) + { + auto c_block_window = MakeCBlockWindow( + c_ptr, kargs, block_idx_m, block_idx_n); - EpiloguePipeline{}.template operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); + EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + else + { +#if defined(__gfx11__) + return; +#endif + auto c_block_window = MakeCBlockWindow( + c_ptr, kargs, block_idx_m, block_idx_n); + + EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } } CK_TILE_DEVICE void CallExplicitGemm(GroupedConvBwdWeightKernelArgsSpecialized& kargs) const @@ -960,12 +919,6 @@ struct GroupedConvolutionBackwardWeightKernel CK_TILE_DEVICE void operator()(GroupedConvBwdWeightKernelArgsSpecialized& kargs) const { -#if defined(__gfx11__) - if constexpr(EpiloguePipeline::MemoryOperation != ck_tile::memory_operation_enum::set) - { - return; - } -#endif if constexpr(GroupedConvTraitsType_::ExplicitGemm) { CallExplicitGemm(kargs); @@ -1001,9 +954,7 @@ struct GroupedConvolutionBackwardWeightKernel if constexpr(GemmPipeline::DoubleSmemBuffer == true) { __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; - if constexpr(!(EpiloguePipeline::MemoryOperation == - memory_operation_enum::atomic_add && - GroupedConvTraitsType_::VectorSizeC % 2 != 0 && + if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 && is_any_of::value)) { RunGemm2LDS(a_ptr, @@ -1021,9 +972,7 @@ struct GroupedConvolutionBackwardWeightKernel } else { - if constexpr(!(EpiloguePipeline::MemoryOperation == - memory_operation_enum::atomic_add && - GroupedConvTraitsType_::VectorSizeC % 2 != 0 && + if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 && is_any_of::value)) { RunGemm(a_ptr, diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index 0f143d7ff7..1b81bce34a 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -794,34 +794,53 @@ struct GroupedConvolutionForwardKernel return true; } - template + template CK_TILE_DEVICE static auto - MakeGemmTensorViews(const InDataType* a_ptr, - const WeiDataType* b_ptr, - const std::array& ds_ptr, - OutDataType* c_ptr, - const ADescType& a_desc, - const BDescType& b_desc, - const CDescType& c_desc) + MakeABlockWindow(const InDataType* a_ptr, const ADescType& a_desc, const index_t block_idx_m) { - static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!"); - static_assert(!GemmPipeline::BlockGemmShape::PermuteB, "Not implemented!"); - const auto& a_tensor_view = [&]() { - return make_tensor_view(a_ptr, a_desc); - }(); + // Step 1: Create tensor view + const auto& a_tensor_view = make_tensor_view(a_ptr, a_desc); - const auto& b_tensor_view = [&]() { - return make_tensor_view(b_ptr, b_desc); - }(); + // Step 2: Create padded view + const auto& a_pad_view = pad_tensor_view( + a_tensor_view, + make_tuple(number{}, number{}), + sequence{}); - // TODO: enable vector write for C in ColMajor - const auto& c_tensor_view = [&]() { - return make_tensor_view(c_ptr, c_desc); - }(); + // Step 3: Create tile window + return make_tile_window( + a_pad_view, + make_tuple(number{}, number{}), + {block_idx_m, 0}); + } + template + CK_TILE_DEVICE static auto + MakeBBlockWindow(const WeiDataType* b_ptr, const BDescType& b_desc, const index_t block_idx_n) + { + // Step 1: Create tensor view + const auto& b_tensor_view = make_tensor_view(b_ptr, b_desc); + + // Step 2: Create padded view + const auto& b_pad_view = pad_tensor_view( + b_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + + // Step 3: Create tile window + return make_tile_window( + b_pad_view, + make_tuple(number{}, number{}), + {block_idx_n, 0}); + } + + template + CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array& ds_ptr, + const CDescType& c_desc, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Step 1: Create tensor views const auto& ds_tensor_view = generate_tuple( [&](auto i) { static_assert(std::is_same_v, OutLayout>, @@ -836,30 +855,8 @@ struct GroupedConvolutionForwardKernel }, number{}); - return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view); - } - - template - CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) - { - const auto& a_pad_view = [&]() { - const auto& a_tensor_view = views.at(I0); - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - }(); - - const auto& b_pad_view = [&]() { - const auto& b_tensor_view = views.at(I1); - return pad_tensor_view(b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - }(); - - const auto& ds_tensor_view = views.at(I2); - const auto& ds_pad_view = generate_tuple( + // Step 2: Create padded views + const auto& ds_pad_view = generate_tuple( [&](auto i) { return pad_tensor_view(ds_tensor_view[i], make_tuple(number{}, @@ -868,55 +865,38 @@ struct GroupedConvolutionForwardKernel }, number{}); - const auto& c_pad_view = [&]() { - const auto& c_tensor_view = views.at(I3); - return pad_tensor_view(c_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - }(); - - return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view); - } - - template - CK_TILE_DEVICE static auto - MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) - { - const auto& a_pad_view = views.at(I0); - const auto& b_pad_view = views.at(I1); - const auto& ds_pad_view = views.at(I2); - const auto& c_pad_view = views.at(I3); - - const auto& a_block_window = [&]() { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {i_m, 0}); - }(); - - const auto& b_block_window = [&]() { - return make_tile_window(b_pad_view, - make_tuple(number{}, - number{}), - {i_n, 0}); - }(); - - const auto ds_block_window = generate_tuple( + // Step 3: Create tile windows + return generate_tuple( [&](auto i) { return make_tile_window(ds_pad_view[i], make_tuple(number{}, number{}), - {i_m, i_n}); + {block_idx_m, block_idx_n}); }, number{}); + } - auto c_block_window = make_tile_window( + template + CK_TILE_DEVICE static auto MakeCBlockWindow(OutDataType* c_ptr, + const CDescType& c_desc, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Step 1: Create tensor view + const auto& c_tensor_view = + make_tensor_view(c_ptr, c_desc); + + // Step 2: Create padded view + const auto& c_pad_view = pad_tensor_view( + c_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + + // Step 3: Create tile window + return make_tile_window( c_pad_view, make_tuple(number{}, number{}), - {i_m, i_n}); - - return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window); + {block_idx_m, block_idx_n}); } /** @@ -931,6 +911,7 @@ struct GroupedConvolutionForwardKernel * @param b_desc Weight tensor B descriptor * @param c_desc Output tensor C descriptor * @param gemm_k The GEMM K dimension + * @param k_batch The K batch parameter for split-K * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. * @@ -945,34 +926,41 @@ struct GroupedConvolutionForwardKernel const BDescType& b_desc, const CDescType& c_desc, const index_t gemm_k, + const index_t k_batch, const index_t block_idx_m, const index_t block_idx_n, const CDElementwise& elfunc) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews( - a_ptr, b_ptr, ds_ptr, c_ptr, a_desc, b_desc, c_desc); - - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + // Create block windows using specialized methods + const auto& a_block_window = MakeABlockWindow(a_ptr, a_desc, block_idx_m); + const auto& b_block_window = MakeBBlockWindow(b_ptr, b_desc, block_idx_n); + const auto& ds_block_window = MakeDBlockWindows(ds_ptr, c_desc, block_idx_m, block_idx_n); const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(gemm_k)); // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_block_window = gemm_tile_windows.at(I1); - const auto& d_block_window = gemm_tile_windows.at(I2); - const auto& c_block_tile = GemmPipeline{}.template operator()( a_block_window, b_block_window, num_loop, smem_ptr_0); - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I3); + // Run Epilogue Pipeline with k_batch dispatching + if(k_batch == 1) + { + auto c_block_window = MakeCBlockWindow( + c_ptr, c_desc, block_idx_m, block_idx_n); - EpiloguePipeline{elfunc} - .template operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); + EpiloguePipeline{elfunc} + .template operator()( + c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + } + else + { + auto c_block_window = MakeCBlockWindow( + c_ptr, c_desc, block_idx_m, block_idx_n); + + EpiloguePipeline{elfunc} + .template operator()( + c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + } } /** @@ -990,6 +978,7 @@ struct GroupedConvolutionForwardKernel * @param b_desc Weight tensor B descriptor * @param c_desc Output tensor C descriptor * @param gemm_k The GEMM K dimension + * @param k_batch The K batch parameter for split-K * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. * @@ -1005,33 +994,41 @@ struct GroupedConvolutionForwardKernel const BDescType& b_desc, const CDescType& c_desc, const index_t gemm_k, + const index_t k_batch, const index_t block_idx_m, const index_t block_idx_n, const CDElementwise& elfunc) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews( - a_ptr, b_ptr, ds_ptr, c_ptr, a_desc, b_desc, c_desc); - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + // Create block windows using specialized methods + const auto& a_block_window = MakeABlockWindow(a_ptr, a_desc, block_idx_m); + const auto& b_block_window = MakeBBlockWindow(b_ptr, b_desc, block_idx_n); + const auto& ds_block_window = MakeDBlockWindows(ds_ptr, c_desc, block_idx_m, block_idx_n); const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(gemm_k)); // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_block_window = gemm_tile_windows.at(I1); - const auto& d_block_window = gemm_tile_windows.at(I2); - const auto& c_block_tile = GemmPipeline{}.template operator()( a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1); - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I3); + // Run Epilogue Pipeline with k_batch dispatching + if(k_batch == 1) + { + auto c_block_window = MakeCBlockWindow( + c_ptr, c_desc, block_idx_m, block_idx_n); - EpiloguePipeline{elfunc} - .template operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); + EpiloguePipeline{elfunc} + .template operator()( + c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + } + else + { + auto c_block_window = MakeCBlockWindow( + c_ptr, c_desc, block_idx_m, block_idx_n); + + EpiloguePipeline{elfunc} + .template operator()( + c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + } } CK_TILE_DEVICE void CallExplicitGemm(GroupedConvFwdKernelArgsSpecialized& kargs) const @@ -1185,9 +1182,7 @@ struct GroupedConvolutionForwardKernel if constexpr(GemmPipeline::DoubleSmemBuffer == true) { __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; - if constexpr(!(EpiloguePipeline::MemoryOperation == - memory_operation_enum::atomic_add && - GroupedConvTraitsType_::VectorSizeC % 2 != 0 && + if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 && is_any_of::value)) { RunGemm2LDS(a_ptr, @@ -1200,6 +1195,7 @@ struct GroupedConvolutionForwardKernel b_desc, c_desc, kargs.GemmK, + kargs.k_batch, i_m, i_n, kargs.elfunc); @@ -1207,9 +1203,7 @@ struct GroupedConvolutionForwardKernel } else { - if constexpr(!(EpiloguePipeline::MemoryOperation == - memory_operation_enum::atomic_add && - GroupedConvTraitsType_::VectorSizeC % 2 != 0 && + if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 && is_any_of::value)) { RunGemm(a_ptr, @@ -1221,6 +1215,7 @@ struct GroupedConvolutionForwardKernel b_desc, c_desc, kargs.GemmK, + kargs.k_batch, i_m, i_n, kargs.elfunc); diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index 197c9d6e1d..93cd7fa063 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -9,6 +9,7 @@ add_subdirectory(grouped_gemm) add_subdirectory(grouped_gemm_preshuffle) add_subdirectory(grouped_gemm_multi_d) add_subdirectory(grouped_gemm_quant) +add_subdirectory(grouped_gemm_abquant) add_subdirectory(gemm_multi_d) add_subdirectory(gemm_multi_abd) add_subdirectory(gemm_streamk) diff --git a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp index 77eb416532..37005cccc1 100644 --- a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp +++ b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp @@ -99,62 +99,47 @@ class TestCkTileBatchedGemm : public ::testing::Test scheduler>; using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::BatchedGemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::BatchedGemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); - const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); + const dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; - } - - return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - }; - - if(args.k_batch == 1) + if(!Kernel::IsSupportedArgument(kargs)) { - Run(ck_tile::integral_constant{}); + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); } - else + + if(s.log_level_ > 0) { - Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; } + + ck_tile::ignore = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } public: diff --git a/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp b/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp index 9b90110c07..0572115201 100644 --- a/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp +++ b/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp @@ -120,8 +120,8 @@ using SimpleCShuffleEpilogueProblem = MPerXdl, NPerXdl, KPerXdl, - false, // isCTransposed, - memory_operation_enum::set>; + false // isCTransposed + >; template auto run_cshuffle_epilogue_test(ScaleType scale = ScaleType::None) diff --git a/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc b/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc index 6e7c086e55..5239b2d888 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc +++ b/test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc @@ -31,7 +31,14 @@ TYPED_TEST(TEST_SUITE_NAME, SmallM) if constexpr(std::is_same_v) { - EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + if(M * sizeof(typename TestFixture::ADataType) % 4 == 0) // oob fit dword + { + this->Run(M, N, K); + } + else + { + EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } } else { @@ -84,7 +91,14 @@ TYPED_TEST(TEST_SUITE_NAME, MidLargeM) } else { - EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + if(M * sizeof(typename TestFixture::ADataType) % 4 == 0) // oob fit dword + { + this->Run(M, N, K); + } + else + { + EXPECT_THROW((this->Run(M, N, K)), std::runtime_error); + } } } else @@ -103,18 +117,7 @@ TYPED_TEST(TEST_SUITE_NAME, PaddK) for(int M : Ms) { - if constexpr(std::is_same_v) - { -#if defined(ARCH_GFX12) || defined(ARCH_GFX11) - this->Run(M, N, K); -#else - EXPECT_THROW(this->Run(M, N, K), std::runtime_error); -#endif - } - else - { - this->Run(M, N, K); - } + this->Run(M, N, K); } } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index e949ed45e6..8dc2e88430 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -182,74 +182,58 @@ class TestCkTileGemmPipeline : public ::testing::Test using GemmPipeline = typename GemmPipelineTypeSelector::pipeline; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GemmKernel; + const auto kargs = Kernel::MakeKernelArgs(args); - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); - - dim3 grids; - if constexpr(Persistent) - { - grids = Kernel::MaxOccupancyGridSize(s); - } - else - { - grids = Kernel::GridSize(args.M, args.N, args.k_batch); - } - const dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " - << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " - << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - }; - - if(args.k_batch == 1) + const dim3 blocks = Kernel::BlockSize(); + dim3 grids; + if constexpr(Persistent) { - Run(ck_tile::integral_constant{}); + grids = Kernel::MaxOccupancyGridSize(s); } else { - Run(ck_tile::integral_constant{}); + grids = Kernel::GridSize(args.M, args.N, args.k_batch); } + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y + << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y + << ", " << blocks.z << "}" << std::endl; + } + + ck_tile::ignore = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } public: diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt index f89aea1c17..2dad8be205 100644 --- a/test/ck_tile/gemm_block_scale/CMakeLists.txt +++ b/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -39,6 +39,12 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") ) target_compile_options(test_tile_gemm_quant_abquant_padding PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_tile_gemm_quant_abquant_preshuffle + test_gemm_quant_abquant_preshuffle_2d.cpp + ) + target_compile_options(test_tile_gemm_quant_abquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + # AQuant tests add_gtest_executable(test_tile_gemm_quant_aquant_prefill test_gemm_quant_aquant_prefill.cpp ) diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_2d.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_2d.cpp new file mode 100644 index 0000000000..793c9bd1df --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_2d.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using ABQuantGrouped = + std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// 2d block sizes for BQuant +using GroupSize2D128N = ck_tile::QuantGroupShape>; + +// Type combinations for ABQuant tests +// Tuple format: +// clang-format off +using ABQuantPreshuffleBTypes = ::testing::Types< + // PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ) + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for ABQuant +TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantPreshuffleBTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 7d82958acf..3798cc4443 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -356,8 +356,7 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase>; + transpose_c>>; using Kernel = ck_tile::QuantGemmKernel; - using BaseGemmPipeline = - std::conditional_t, - ck_tile::BaseGemmPipelineAgBgCrCompV3>; + using BaseGemmPipeline = std::conditional_t< + PreshuffleB == true, + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2, + ck_tile::BaseGemmPipelineAgBgCrCompV3>; const ck_tile::index_t K_split = (args.K + Base::K_Tile - 1) / Base::K_Tile * Base::K_Tile; const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); @@ -928,8 +926,8 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase; using GemmPipeline = - std::conditional_t, + std::conditional_t, ck_tile::ABQuantGemmPipelineAgBgCrCompV3>; using GemmEpilogue = ck_tile::CShuffleEpilogue< @@ -949,7 +947,6 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase>; + transpose_c>>; using Kernel = ck_tile::QuantGemmKernel>; + transpose_c>>; using Kernel = ck_tile::QuantGemmKernel; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using DefaultGemmEpilogue = ck_tile::DefaultGemm2DEpilogue< + ck_tile::DefaultGemm2DEpilogueProblem>; - using DefaultGemmEpilogue = ck_tile::DefaultGemm2DEpilogue< - ck_tile::DefaultGemm2DEpilogueProblem>; + using CShuffleGemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using CShuffleGemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using GemmEpilogue = + std::conditional_t; - using GemmEpilogue = std:: - conditional_t; + using Kernel = ck_tile::GemmKernelMultiABD; + auto kargs = Kernel::MakeKernelArgs(args); - using Kernel = ck_tile::GemmKernelMultiABD; - auto kargs = Kernel::MakeKernelArgs(args); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; - } - - return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - }; - - if(args.k_batch == 1) + if(!Kernel::IsSupportedArgument(kargs)) { - std::cout << "Run without SplitK" << std::endl; - Run(ck_tile::integral_constant{}); + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); } - else + + if(s.log_level_ > 0) { - std::cout << "Run using SplitK" << std::endl; - Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; } + + ck_tile::ignore = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } public: diff --git a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp index 8217f5a3d9..6a6806641a 100644 --- a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp +++ b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp @@ -170,88 +170,69 @@ class TestCkTileGemmMultiD : public ::testing::Test using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using DefaultGemmEpilogue = ck_tile::DefaultGemm2DEpilogue< + ck_tile::DefaultGemm2DEpilogueProblem>; - using DefaultGemmEpilogue = ck_tile::DefaultGemm2DEpilogue< - ck_tile::DefaultGemm2DEpilogueProblem>; + using CShuffleGemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using CShuffleGemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using GemmEpilogue = + std::conditional_t; - using GemmEpilogue = std:: - conditional_t; + using Kernel = ck_tile::GemmKernelMultiD; + auto kargs = Kernel::MakeKernelArgs(args); - using Kernel = ck_tile::GemmKernelMultiD; - auto kargs = Kernel::MakeKernelArgs(args); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; - } - - return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - }; - - if(args.k_batch == 1) + if(!Kernel::IsSupportedArgument(kargs)) { - std::cout << "Run without SplitK" << std::endl; - Run(ck_tile::integral_constant{}); + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); } - else + + if(s.log_level_ > 0) { - std::cout << "Run using SplitK" << std::endl; - Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; } + + ck_tile::ignore = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } public: diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp b/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp index 540109a999..237dc24c3b 100644 --- a/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp @@ -105,71 +105,60 @@ class TestCkTileStreamK : public ::testing::Test NumWaveGroup, preshuffle>; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - // We create the GEMM pipeline without specifying has_hot_loop or tail_num. - // This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K - // while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K - // Kernel's RunGemm function. This is a similar pattern used by grouped GEMM. - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - // For initial testing, we will just test with one pipeline. - // More extensive testing is coming later and will test other pipelines. - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + // We create the GEMM pipeline without specifying has_hot_loop or tail_num. + // This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K + // while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K + // Kernel's RunGemm function. This is a similar pattern used by grouped GEMM. + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + // For initial testing, we will just test with one pipeline. + // More extensive testing is coming later and will test other pipelines. + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - AccDataType, - CDataType, - ck_tile::tuple<>, - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - M_Warp, - N_Warp, - M_Warp_Tile, - N_Warp_Tile, - K_Warp_Tile, - UniversalGemmProblem::TransposeC, - memory_operation>>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + M_Warp, + N_Warp, + M_Warp_Tile, + N_Warp_Tile, + K_Warp_Tile, + UniversalGemmProblem::TransposeC>>; - using Kernel = ck_tile::StreamKKernel; + using Kernel = ck_tile::StreamKKernel; - auto kargs = Kernel::MakeKernelArgs(args); - const auto workspace_size = Kernel::GetWorkSpaceSize(kargs); - ck_tile::DeviceMem workspace_data(workspace_size); - workspace_data.SetZero(); - kargs.workspace_ptr = workspace_data.GetDeviceBuffer(); + auto kargs = Kernel::MakeKernelArgs(args); + const auto workspace_size = Kernel::GetWorkSpaceSize(kargs); + ck_tile::DeviceMem workspace_data(workspace_size); + workspace_data.SetZero(); + kargs.workspace_ptr = workspace_data.GetDeviceBuffer(); - if(!Kernel::IsSupportedArgument(kargs)) - { - EXPECT_TRUE(false); - } + if(!Kernel::IsSupportedArgument(kargs)) + { + EXPECT_TRUE(false); + } - dim3 grid_dims = Kernel::GridSize(kargs.tile_partitioner); - dim3 block_dims = Kernel::BlockSize(); + dim3 grid_dims = Kernel::GridSize(kargs.tile_partitioner); + dim3 block_dims = Kernel::BlockSize(); - ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grid_dims, block_dims, 0, kargs)); + ck_tile::ignore = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grid_dims, block_dims, 0, kargs)); - return kargs.tile_partitioner.estimate_num_wgs_per_tile(); - }; - - return Run(ck_tile::integral_constant{}); + return kargs.tile_partitioner.estimate_num_wgs_per_tile(); } public: diff --git a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp index 7c085b5098..875684ce08 100644 --- a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp @@ -180,68 +180,52 @@ class TestCkTileGemmPipeline : public ::testing::Test using GemmPipeline = typename GemmPipelineTypeSelector::pipeline; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); - - dim3 grids; - if constexpr(Persistent) - { - grids = Kernel::MaxOccupancyGridSize(s); - } - else - { - grids = Kernel::GridSize(args.M, args.N, args.k_batch); - } - const dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " - << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " - << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - }; - - if(args.k_batch == 1) + dim3 grids; + if constexpr(Persistent) { - Run(ck_tile::integral_constant{}); + grids = Kernel::MaxOccupancyGridSize(s); } else { - Run(ck_tile::integral_constant{}); + grids = Kernel::GridSize(args.M, args.N, args.k_batch); } + const dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y + << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y + << ", " << blocks.z << "}" << std::endl; + } + + ck_tile::ignore = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } public: diff --git a/test/ck_tile/grouped_conv/test_ck_tile_grouped_conv_bwd_weight.cpp b/test/ck_tile/grouped_conv/test_ck_tile_grouped_conv_bwd_weight.cpp index bdce90e385..237641a000 100644 --- a/test/ck_tile/grouped_conv/test_ck_tile_grouped_conv_bwd_weight.cpp +++ b/test/ck_tile/grouped_conv/test_ck_tile_grouped_conv_bwd_weight.cpp @@ -42,8 +42,7 @@ template + index_t NDimSpatial = 2> struct BuildKernel { using GemmShape = TileGemmShape< @@ -123,7 +122,6 @@ struct BuildKernel ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile, ConvTraits::FixedGemmParams::TransposeC, - MemOp, ConvConfig::NumWaveGroups, ConvTraits::FixedGemmParams::FixedVectorSize, ConvTraits::VectorSizeC>; @@ -212,26 +210,6 @@ TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, InvalidKBatchLessThanOne) EXPECT_FALSE(Kernel::IsSupportedArgument(kargs)); } -TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, AtomicAddRequiresKBatchGreaterThanOne) -{ - using Kernel = typename BuildKernel::type; - - // k_batch = 1 should fail with atomic_add - auto host_args_kbatch_1 = create_2d_host_args(1); - auto kargs_1 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_1); - EXPECT_FALSE(Kernel::IsSupportedArgument(kargs_1)); - - // k_batch = 2 should pass - auto host_args_kbatch_2 = create_2d_host_args(2); - auto kargs_2 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_2); - EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_2)); -} - TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, K0KBatchLimitation) { using Kernel = typename BuildKernel; using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); - EXPECT_TRUE(Kernel::IsSupportedArgument(kargs)); + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + EXPECT_TRUE(Kernel::IsSupportedArgument(kargs)); - const dim3 grids = Kernel::GridSize(gemm_descs); - const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + const dim3 blocks = Kernel::BlockSize(); - ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); + ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() - << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " - << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " - << blocks.z << "}" << std::endl; - } - - return ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - gemm_descs.size())); - }; - - if(gemm_descs[0].k_batch == 1) + if(s.log_level_ > 0) { - std::cout << "Run without SplitK" << std::endl; - Run(ck_tile::integral_constant{}); - } - else - { - std::cout << "Run using SplitK" << std::endl; - Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } + + ck_tile::ignore = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); } template void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, - void* kargs_ptr, - bool splitk) + void* kargs_ptr) { constexpr bool TransposeC = false; constexpr bool DoubleSmemBuffer = false; @@ -212,50 +193,47 @@ class TestCkTileGroupedGemm : public ::testing::Test CLayout, TransposeC>; - const auto Run = [&](const auto memory_operation_) { - constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - constexpr auto memory_operation = memory_operation_.value; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - // We create the GEMM pipeline without specifying hotloop or tailnumber. - // These are automatically run inside the kernel based on the given input data. - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + // We create the GEMM pipeline without specifying hotloop or tailnumber. + // These are automatically run inside the kernel based on the given input data. + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::MaxOccupancyGridSize(s); + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() - << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " - << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " - << blocks.z << "}" << std::endl; - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + ck_tile::ignore = ck_tile::launch_kernel(s, ck_tile::make_kernel( Kernel{}, @@ -264,19 +242,6 @@ class TestCkTileGroupedGemm : public ::testing::Test 0, ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), num_groups)); - }; - - if(splitk) - { - Run(ck_tile::integral_constant{}); - } - else - { - - Run(ck_tile::integral_constant{}); - } } auto calculate_rtol_atol(const ck_tile::index_t K, @@ -422,8 +387,7 @@ class TestCkTileGroupedGemm : public ::testing::Test { // Generate kernel arguments std::vector> kargs; - void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); - const bool splitk = gemm_descs[0].k_batch > 1; + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); for(const auto& arg : gemm_descs) { kargs.emplace_back(ck_tile::UniversalGemmKernelArgs<>{{arg.a_ptr}, @@ -448,10 +412,10 @@ class TestCkTileGroupedGemm : public ::testing::Test stream.stream_id_)); #if CK_TILE_USE_WMMA invoke_grouped_gemm_persistent( - stream, group_count, kargs_ptr, splitk); + stream, group_count, kargs_ptr); #else invoke_grouped_gemm_persistent( - stream, group_count, kargs_ptr, splitk); + stream, group_count, kargs_ptr); #endif } else diff --git a/test/ck_tile/grouped_gemm_abquant/CMakeLists.txt b/test/ck_tile/grouped_gemm_abquant/CMakeLists.txt new file mode 100644 index 0000000000..e735aa8e9a --- /dev/null +++ b/test/ck_tile/grouped_gemm_abquant/CMakeLists.txt @@ -0,0 +1,16 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +set(EXAMPLE_GEMM_COMPILE_OPTIONS) +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() + +if(GPU_TARGETS MATCHES "gfx94|gfx95") + add_gtest_executable(test_ck_tile_grouped_gemm_abquant_1x1x128 test_grouped_gemm_abquant_1x1x128.cpp) + target_compile_options(test_ck_tile_grouped_gemm_abquant_1x1x128 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_ck_tile_grouped_gemm_abquant_1x128x128 test_grouped_gemm_abquant_1x128x128.cpp) + target_compile_options(test_ck_tile_grouped_gemm_abquant_1x128x128 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +endif() + diff --git a/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_1x128x128.cpp b/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_1x128x128.cpp new file mode 100644 index 0000000000..06b0068cb7 --- /dev/null +++ b/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_1x128x128.cpp @@ -0,0 +1,47 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_grouped_gemm_abquant_util.hpp" + +using F16 = ck_tile::half_t; +using F32 = float; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; +using True = ck_tile::bool_constant; +using False = ck_tile::bool_constant; + +// AQuant group size is fixed at 1x1x128 +using AQuantGroupSize = ck_tile::QuantGroupShape>; +// BQuant group size: 1x128x128 +using BQuantGroupSize_1x128x128 = ck_tile::QuantGroupShape>; + +// clang-format off +using KernelTypes_ABQuant_1x128x128 = ::testing::Types< + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, AQuantGroupSize, BQuantGroupSize, Persistent + + // FP8 variants + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, False>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, True>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, False>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, True>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, False>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, True>, + + // BF8 variants + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x128x128, True> + >; +// clang-format on + +TYPED_TEST_SUITE(TestCkTileGroupedGemmABQuant_1x128x128, KernelTypes_ABQuant_1x128x128); + +#define TEST_CLASS_NAME TestCkTileGroupedGemmABQuant_1x128x128 +#include "test_grouped_gemm_abquant_ut_cases.inc" +#undef TEST_CLASS_NAME diff --git a/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_1x1x128.cpp b/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_1x1x128.cpp new file mode 100644 index 0000000000..7704eda169 --- /dev/null +++ b/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_1x1x128.cpp @@ -0,0 +1,47 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_grouped_gemm_abquant_util.hpp" + +using F16 = ck_tile::half_t; +using F32 = float; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; +using True = ck_tile::bool_constant; +using False = ck_tile::bool_constant; + +// AQuant group size is fixed at 1x1x128 +using AQuantGroupSize = ck_tile::QuantGroupShape>; +// BQuant group size: 1x1x128 +using BQuantGroupSize_1x1x128 = ck_tile::QuantGroupShape>; + +// clang-format off +using KernelTypes_ABQuant_1x1x128 = ::testing::Types< + // ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, AQuantGroupSize, BQuantGroupSize, Persistent + + // FP8 variants + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, False>, + std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, True>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, False>, + std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, True>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, False>, + std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, True>, + + // BF8 variants + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, False>, + std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, AQuantGroupSize, BQuantGroupSize_1x1x128, True> + >; +// clang-format on + +TYPED_TEST_SUITE(TestCkTileGroupedGemmABQuant_1x1x128, KernelTypes_ABQuant_1x1x128); + +#define TEST_CLASS_NAME TestCkTileGroupedGemmABQuant_1x1x128 +#include "test_grouped_gemm_abquant_ut_cases.inc" +#undef TEST_CLASS_NAME diff --git a/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_ut_cases.inc b/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_ut_cases.inc new file mode 100644 index 0000000000..48574ab977 --- /dev/null +++ b/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_ut_cases.inc @@ -0,0 +1,87 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +TYPED_TEST(TEST_CLASS_NAME, Basic) +{ + const int group_count = 6; + std::vector Ms; + std::vector Ns; + std::vector Ks; + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + std::vector stride_AQs; + std::vector stride_BQs; + for(int i = 0; i < group_count; i++) + { + Ms.push_back(256 + 256 * i); + Ns.push_back(256 + 512 * i); + Ks.push_back(512 + 128 * i); + + stride_As.push_back(0); + stride_Bs.push_back(0); + stride_Cs.push_back(0); + stride_AQs.push_back(0); + stride_BQs.push_back(0); + } + + this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs, group_count); +} + +// No Hot Loop Test Case, this is to test the correctness of the kernel when there is no hot loop +// Using 256x256x128 to match the test kernel's tile size (M_Tile=128, N_Tile=128, K_Tile=128) +TYPED_TEST(TEST_CLASS_NAME, SmallUniform) +{ + const int group_count = 2; + std::vector Ms; + std::vector Ns; + std::vector Ks; + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + std::vector stride_AQs; + std::vector stride_BQs; + for(int i = 0; i < group_count; i++) + { + Ms.push_back(256); + Ns.push_back(256); + Ks.push_back(256); + + stride_As.push_back(0); + stride_Bs.push_back(0); + stride_Cs.push_back(0); + stride_AQs.push_back(0); + stride_BQs.push_back(0); + } + + this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs, group_count); +} + +TYPED_TEST(TEST_CLASS_NAME, OddTail) +{ + const int group_count = 2; + std::vector Ms; + std::vector Ns; + std::vector Ks; + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + std::vector stride_AQs; + std::vector stride_BQs; + for(int i = 0; i < group_count; i++) + { + Ms.push_back(256); + Ns.push_back(256); + Ks.push_back(128); + + stride_As.push_back(0); + stride_Bs.push_back(0); + stride_Cs.push_back(0); + stride_AQs.push_back(0); + stride_BQs.push_back(0); + } + + this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs, group_count); +} diff --git a/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_util.hpp b/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_util.hpp new file mode 100644 index 0000000000..c7ed6f5472 --- /dev/null +++ b/test/ck_tile/grouped_gemm_abquant/test_grouped_gemm_abquant_util.hpp @@ -0,0 +1,530 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" +#include "ck_tile/ops/gemm_quant.hpp" + +template +class TestCkTileGroupedGemmABQuant : public ::testing::Test +{ + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using CLayout = std::tuple_element_t<2, Tuple>; + using ADataType = std::tuple_element_t<3, Tuple>; + using AQDataType = std::tuple_element_t<4, Tuple>; + using BDataType = std::tuple_element_t<5, Tuple>; + using BQDataType = std::tuple_element_t<6, Tuple>; + using AccDataType = std::tuple_element_t<7, Tuple>; + using CDataType = std::tuple_element_t<8, Tuple>; + using AQuantGroupSize = std::tuple_element_t<9, Tuple>; + using BQuantGroupSize = std::tuple_element_t<10, Tuple>; + static constexpr bool Persistent = std::tuple_element_t<11, Tuple>::value; + + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + using AQLayout = Row; + using BQLayout = Col; + + static constexpr auto QuantMode = ck_tile::QuantType::ABQuantGrouped; + + struct GemmConfig + { + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr int kBlockPerCu = 1; + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(ADataType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); + + static constexpr bool PreshuffleB = false; + static constexpr bool TransposeC = false; + static constexpr bool DoubleSmemBuffer = false; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + + static constexpr bool IsPersistent = Persistent; + }; + + using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs; + + std::size_t get_workspace_size(const std::vector& gemm_descs) + { + return gemm_descs.size() * sizeof(ck_tile::QuantGemmTransKernelArg); + } + + template + static constexpr inline auto is_row_major(Layout layout_) + { + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; + } + + auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) + { + using ComputeType = + std::conditional_t; + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); + } + + template + float invoke_grouped_gemm_abquant(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr) + { + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile:: + TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; + + const ck_tile::index_t k_grain = gemm_descs[0].k_batch * Config::K_Tile; + const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * Config::K_Tile; + + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = Config::Scheduler; + + using QuantGemmProblem = ck_tile::GemmABQuantPipelineProblem; + + using GemmPipeline = ck_tile::ABQuantGemmPipelineAgBgCrCompV3; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + Config::M_Warp, + Config::N_Warp, + Config::M_Warp_Tile, + Config::N_Warp_Tile, + Config::K_Warp_Tile, + QuantGemmProblem::TransposeC>>; + + using Kernel = ck_tile::QuantGroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Kernel arguments not supported!"); + } + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() + << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " + << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " + << blocks.z << "}" << std::endl; + } + + return ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); + }; + + return ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); + } + + template + void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s, + const ck_tile::index_t num_groups, + void* kargs_ptr) + { + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + + using QuantGemmProblem = ck_tile::GemmABQuantPipelineProblem; + + using GemmPipeline = ck_tile::ABQuantGemmPipelineAgBgCrCompV3; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + Config::M_Warp, + Config::N_Warp, + Config::M_Warp_Tile, + Config::N_Warp_Tile, + Config::K_Warp_Tile, + QuantGemmProblem::TransposeC>>; + + using Kernel = ck_tile::QuantGroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); + } + + public: + void Run(const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + std::vector& stride_As, + std::vector& stride_Bs, + std::vector& stride_Cs, + std::vector& stride_AQs, + std::vector& stride_BQs, + const int group_count = 8) + { + ck_tile::index_t AQK, BQK; + + std::vector> a_m_k_tensors; + std::vector> b_k_n_tensors; + std::vector> c_m_n_tensors; + std::vector> aq_tensors; + std::vector> bq_tensors; + + a_m_k_tensors.reserve(group_count); + b_k_n_tensors.reserve(group_count); + c_m_n_tensors.reserve(group_count); + aq_tensors.reserve(group_count); + bq_tensors.reserve(group_count); + + std::vector> a_m_k_dev_buf; + std::vector> b_k_n_dev_buf; + std::vector> c_m_n_dev_buf; + std::vector> aq_dev_buf; + std::vector> bq_dev_buf; + + a_m_k_dev_buf.reserve(group_count); + b_k_n_dev_buf.reserve(group_count); + c_m_n_dev_buf.reserve(group_count); + aq_dev_buf.reserve(group_count); + bq_dev_buf.reserve(group_count); + + std::vector gemm_descs; + gemm_descs.reserve(group_count); + + for(int i = 0; i < group_count; ++i) + { + const ck_tile::index_t M = Ms[i]; + const ck_tile::index_t N = Ns[i]; + const ck_tile::index_t K = Ks[i]; + + AQK = K / AQuantGroupSize::kK; + BQK = K / BQuantGroupSize::kK; + + if(K % AQuantGroupSize::kK != 0) + { + throw std::runtime_error( + "K must be divisible by AQuantGroupSize::kK for ABQuantGrouped mode"); + } + if(K % BQuantGroupSize::kK != 0) + { + throw std::runtime_error( + "K must be divisible by BQuantGroupSize::kK for ABQuantGrouped mode"); + } + + stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(ALayout{})); + stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(BLayout{})); + stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(CLayout{})); + stride_AQs[i] = + ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(AQLayout{})); + stride_BQs[i] = + ck_tile::get_default_stride(BQK, N, stride_BQs[i], is_row_major(BQLayout{})); + + a_m_k_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(ALayout{})))); + b_k_n_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], is_row_major(BLayout{})))); + c_m_n_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], is_row_major(CLayout{})))); + aq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, AQK, stride_AQs[i], is_row_major(AQLayout{})))); + bq_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(BQK, N, stride_BQs[i], is_row_major(BQLayout{})))); + + std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc + << " b_k_n: " << b_k_n_tensors[i].mDesc + << " c_m_n: " << c_m_n_tensors[i].mDesc << " aq: " << aq_tensors[i].mDesc + << " bq: " << bq_tensors[i].mDesc << std::endl; + + ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(aq_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(bq_tensors[i]); + + a_m_k_dev_buf.push_back(std::make_unique( + a_m_k_tensors[i].get_element_space_size_in_bytes())); + b_k_n_dev_buf.push_back(std::make_unique( + b_k_n_tensors[i].get_element_space_size_in_bytes())); + c_m_n_dev_buf.push_back(std::make_unique( + c_m_n_tensors[i].get_element_space_size_in_bytes())); + aq_dev_buf.push_back(std::make_unique( + aq_tensors[i].get_element_space_size_in_bytes())); + bq_dev_buf.push_back(std::make_unique( + bq_tensors[i].get_element_space_size_in_bytes())); + + a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data()); + b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data()); + aq_dev_buf[i]->ToDevice(aq_tensors[i].data()); + bq_dev_buf[i]->ToDevice(bq_tensors[i].data()); + c_m_n_dev_buf[i]->SetZero(); + c_m_n_tensors[i].SetZero(); + + const void* p_a = a_m_k_dev_buf[i]->GetDeviceBuffer(); + const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer(); + void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer(); + const void* p_aq = aq_dev_buf[i]->GetDeviceBuffer(); + const void* p_bq = bq_dev_buf[i]->GetDeviceBuffer(); + + gemm_descs.push_back({p_a, + p_b, + p_c, + p_aq, + p_bq, + 1, // k_batch + M, + N, + K, + AQK, + BQK, + stride_As[i], + stride_Bs[i], + stride_Cs[i], + stride_AQs[i], + stride_BQs[i]}); + } + + ck_tile::DeviceMem gemm_workspace; + gemm_workspace.Realloc(get_workspace_size(gemm_descs)); + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); + + if constexpr(Persistent) + { + std::vector kargs; + for(const auto& arg : gemm_descs) + { + kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr, + arg.b_ptr, + arg.aq_ptr, + arg.bq_ptr, + arg.e_ptr, + arg.M, + arg.N, + arg.K, + arg.QK_A, + arg.QK_B, + arg.stride_A, + arg.stride_B, + arg.stride_E, + arg.stride_AQ, + arg.stride_BQ, + arg.k_batch}); + } + const auto stream = ck_tile::stream_config{nullptr, false, 1}; + ck_tile::hip_check_error( + hipMemcpyWithStream(kargs_ptr, + kargs.data(), + kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg), + hipMemcpyHostToDevice, + stream.stream_id_)); + invoke_grouped_gemm_persistent(stream, group_count, kargs_ptr); + } + else + { + const auto stream = ck_tile::stream_config{nullptr, false, 1}; + invoke_grouped_gemm_abquant(gemm_descs, stream, kargs_ptr); + } + + // Copy results back to host for validation + for(int i = 0; i < group_count; i++) + { + c_m_n_dev_buf[i]->FromDevice(c_m_n_tensors[i].data()); + } + + bool pass{true}; + for(int i = 0; i < group_count; ++i) + { + ck_tile::HostTensor c_m_n_host_ref(ck_tile::host_tensor_descriptor( + Ms[i], Ns[i], stride_Cs[i], is_row_major(CLayout{}))); + c_m_n_host_ref.SetZero(); + + ck_tile::reference_gemm_abquant( + a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], bq_tensors[i], c_m_n_host_ref); + + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol(Ks[i], 1, max_accumulated_value); + pass &= + ck_tile::check_err(c_m_n_tensors[i], + c_m_n_host_ref, + "Error: Incorrect results! in group [" + std::to_string(i) + "]", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + std::cout << "gemm[" << i + << "] Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + } + std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; + + EXPECT_TRUE(pass); + } +}; + +// Aliases for split test files +template +using TestCkTileGroupedGemmABQuant_1x1x128 = TestCkTileGroupedGemmABQuant; + +template +using TestCkTileGroupedGemmABQuant_1x128x128 = TestCkTileGroupedGemmABQuant; diff --git a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp index b065df6f8a..c6e311a65c 100644 --- a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp +++ b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp @@ -96,7 +96,7 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test const ck_tile::stream_config& s, void* kargs_ptr) { - + EXPECT_TRUE(gemm_descs[0].k_batch == 1); using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, ck_tile::sequence, @@ -134,74 +134,56 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test ck_tile::GemmPipelineAgBgCrCompV3, ck_tile::GemmPipelineAgBgCrCompV4>>; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + EXPECT_TRUE(Kernel::IsSupportedArgument(kargs)); + const dim3 grids = Kernel::GridSize(gemm_descs); + const dim3 blocks = Kernel::BlockSize(); - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); - EXPECT_TRUE(Kernel::IsSupportedArgument(kargs)); - const dim3 grids = Kernel::GridSize(gemm_descs); - const dim3 blocks = Kernel::BlockSize(); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() - << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " - << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " - << blocks.z << "}" << std::endl; - } - - ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); - - return ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - gemm_descs.size())); - }; - - if(gemm_descs[0].k_batch == 1) + if(s.log_level_ > 0) { - Run(ck_tile::integral_constant{}); - } - else - { - // EXPECT TO FAIL because splitk is not supported - EXPECT_FALSE(true); + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } + + ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + ck_tile::ignore = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); } void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, - void* kargs_ptr, - bool splitk) + void* kargs_ptr) { using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -218,78 +200,58 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test BLayout, ELayout>; - float ave_time{0}; + // We create the GEMM pipeline without specifying hotloop or tailnumber. + // These are automatically run inside the kernel based on the given input data. + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using GemmPipeline = std::conditional_t< + Config::Pipeline_ == (PipelineType::Memory), + ck_tile::GemmPipelineAgBgCrMem, + std::conditional_t, + ck_tile::GemmPipelineAgBgCrCompV4>>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); - // We create the GEMM pipeline without specifying hotloop or tailnumber. - // These are automatically run inside the kernel based on the given input data. - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - - using GemmPipeline = std::conditional_t< - Config::Pipeline_ == (PipelineType::Memory), - ck_tile::GemmPipelineAgBgCrMem, - std::conditional_t, - ck_tile::GemmPipelineAgBgCrCompV4>>; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::MaxOccupancyGridSize(s); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() - << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " - << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " - << blocks.z << "}" << std::endl; - } - - ave_time = ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - num_groups)); - - return ave_time; - }; - if(!splitk) + if(s.log_level_ > 0) { - Run(ck_tile::integral_constant{}); - } - else - { - Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } + + ck_tile::ignore = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); } public: @@ -445,8 +407,7 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test if constexpr(Config::Persistent_) { std::vector> kargs; - void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); - const bool splitk = gemm_descs[0].k_batch > 1; + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); for(const auto& arg : gemm_descs) { kargs.emplace_back( @@ -471,7 +432,7 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test hipMemcpyHostToDevice, stream.stream_id_)); - invoke_grouped_gemm_persistent(stream, group_count, kargs_ptr, splitk); + invoke_grouped_gemm_persistent(stream, group_count, kargs_ptr); } else { diff --git a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp index a7189e7865..e588ad2cc1 100644 --- a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp +++ b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp @@ -127,59 +127,44 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); - EXPECT_TRUE(Kernel::IsSupportedArgument(kargs)); - const dim3 grids = Kernel::GridSize(gemm_descs); - const dim3 blocks = Kernel::BlockSize(); + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + EXPECT_TRUE(Kernel::IsSupportedArgument(kargs)); + const dim3 grids = Kernel::GridSize(gemm_descs); + const dim3 blocks = Kernel::BlockSize(); - ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); + ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); - return ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - gemm_descs.size())); - }; - - if(gemm_descs[0].k_batch == 1) - { - Run(ck_tile::integral_constant{}); - } - else - { - // EXPECT TO FAIL because splitk is not supported - EXPECT_FALSE(true); - } + ck_tile::ignore = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); } private: @@ -226,59 +211,45 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test ck_tile::GemmPipelineScheduler::Default>; using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); - EXPECT_TRUE(Kernel::IsSupportedArgument(kargs)); - const dim3 grids = Kernel::GridSize(gemm_descs); - const dim3 blocks = Kernel::BlockSize(); - ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + EXPECT_TRUE(Kernel::IsSupportedArgument(kargs)); + const dim3 grids = Kernel::GridSize(gemm_descs); + const dim3 blocks = Kernel::BlockSize(); - return ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - gemm_descs.size())); - }; + ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); - if(gemm_descs[0].k_batch == 1) - { - Run(ck_tile::integral_constant{}); - } - else - { - // EXPECT TO FAIL because splitk is not supported - EXPECT_FALSE(true); - } + ck_tile::ignore = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); } struct BShuffleGemmConfig diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp index b73221ac28..3d52bca9e0 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp @@ -148,10 +148,9 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test float ave_time{0}; const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - constexpr auto memory_operation = ck_tile::memory_operation_enum::set; + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; using QuantGemmProblem = std::conditional_t< UseGroupedQuant, @@ -217,8 +216,7 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test GroupedGemKernelParam::M_Warp_Tile, GroupedGemKernelParam::N_Warp_Tile, GroupedGemKernelParam::K_Warp_Tile, - QuantGemmProblem::TransposeC, - memory_operation>>; + QuantGemmProblem::TransposeC>>; using Kernel = ck_tile::QuantGroupedGemmKernel; - using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + constexpr bool UseGroupedQuant = QuantType == ck_tile::QuantType::AQuantGrouped || + QuantType == ck_tile::QuantType::BQuantGrouped; + using QuantGemmProblem = std::conditional_t< + UseGroupedQuant, + std::conditional_t, + ck_tile::GemmBQuantPipelineProblem>, + ck_tile::GemmRowColTensorQuantPipelineProblem>; - const auto Run = [&](const auto memory_operation_) { - constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - constexpr auto memory_operation = memory_operation_.value; - // We create the GEMM pipeline without specifying hotloop or tailnumber. - // These are automatically run inside the kernel based on the given input data. + using GemmPipeline = std::conditional_t< + UseGroupedQuant, + std::conditional_t< + QuantType == ck_tile::QuantType::AQuantGrouped, + ck_tile::AQuantGemmPipelineAgBgCrCompV3, + std::conditional_t, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>>, + ck_tile::GemmPipelineAgBgCrCompV3>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::QuantGroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); - constexpr bool UseGroupedQuant = QuantType == ck_tile::QuantType::AQuantGrouped || - QuantType == ck_tile::QuantType::BQuantGrouped; - using QuantGemmProblem = std::conditional_t< - UseGroupedQuant, - std::conditional_t, - ck_tile::GemmBQuantPipelineProblem>, - ck_tile::GemmRowColTensorQuantPipelineProblem>; - - using GemmPipeline = std::conditional_t< - UseGroupedQuant, - std::conditional_t< - QuantType == ck_tile::QuantType::AQuantGrouped, - ck_tile::AQuantGemmPipelineAgBgCrCompV3, - std::conditional_t, - ck_tile::BQuantGemmPipelineAgBgCrCompV3>>, - ck_tile::GemmPipelineAgBgCrCompV3>; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::QuantGroupedGemmKernel; - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::MaxOccupancyGridSize(s); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() - << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " - << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " - << blocks.z << "}" << std::endl; - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + ck_tile::ignore = ck_tile::launch_kernel(s, ck_tile::make_kernel( Kernel{}, @@ -388,10 +379,6 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test 0, ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), num_groups)); - }; - - Run(ck_tile::integral_constant{}); } template diff --git a/test/ck_tile/moe_sorting/test_moe_sorting_util.hpp b/test/ck_tile/moe_sorting/test_moe_sorting_util.hpp index 37377755ea..de06669063 100644 --- a/test/ck_tile/moe_sorting/test_moe_sorting_util.hpp +++ b/test/ck_tile/moe_sorting/test_moe_sorting_util.hpp @@ -236,13 +236,13 @@ class TestCkTileMoeSorting : public ::testing::Test if(moe_buf_bytes > 0) { #if MOE_SORTING_FMOE_2D_BUF - printf("moe_buf:%lu(%d,%d), ", + printf("moe_buf:%" PRIu64 "(%d,%d), ", static_cast(moe_buf_bytes), moe_buf_interm_dim, moe_buf_elem_bytes); #else - printf("moe_buf:%lu, ", static_cast(moe_buf_bytes)); + printf("moe_buf:%" PRIu64 ", ", static_cast(moe_buf_bytes)); #endif } diff --git a/test/ck_tile/utility/test_fill.cpp b/test/ck_tile/utility/test_fill.cpp index 3633f8bbff..f67dee9757 100644 --- a/test/ck_tile/utility/test_fill.cpp +++ b/test/ck_tile/utility/test_fill.cpp @@ -26,6 +26,7 @@ using TestTypes = ::testing::Types; TYPED_TEST_SUITE(FillUniformDistributionTest, TestTypes); // Test that multiple runs with the same seed produce identical results +#ifndef _WIN32 TYPED_TEST(FillUniformDistributionTest, ConsistencyWithSameSeed) { using T = TypeParam; @@ -53,6 +54,7 @@ TYPED_TEST(FillUniformDistributionTest, ConsistencyWithSameSeed) << "First and second fill should be identical"; } } +#endif // Test consistency across different data sizes (which affects threading) TYPED_TEST(FillUniformDistributionTest, ConsistencyAcrossSizes) diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index 27ca805c2e..81a9b08b70 100644 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -719,8 +719,8 @@ struct SelectedKernel {{ elif self.kernel_name_prefix in ["gemm_universal", "gemm_preshuffle"]: instance_code += f""" - // Kernel type - using GemmKernel = ck_tile::GemmKernel; + // Kernel type + using GemmKernel = ck_tile::GemmKernel; // Kernel arguments auto kargs = GemmKernel::MakeKernelArgs(args); @@ -802,8 +802,8 @@ struct SelectedKernel {{ ck_tile::tuple<>, // DsLayout CLayout, ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, // kM_ - TilePartitioner::NPerBlock, // kN_ + TileM, // kM_ + TileN, // kN_ WarpPerBlock_M, // MWave_ WarpPerBlock_N, // NWave_ WarpTileM, // MPerXdl_ diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py b/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py index 2225619fad..bea46de067 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py @@ -481,8 +481,6 @@ struct SelectedKernel {{ GemmUniversalTraits>; static float launch(const ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& stream) {{ - const auto Run = [&](const auto memory_operation_) {{ - constexpr auto memory_operation = memory_operation_.value; constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; // kNumWaveGroups_ using GemmEpilogue = ck_tile::CShuffleEpilogue; @@ -558,30 +555,12 @@ struct SelectedKernel {{ workspace_data.SetZero(); }} }}; - - + // Launch kernel - float ave_time = ck_tile::launch_kernel_time_mask( + return ck_tile::launch_kernel_time_mask( stream, reset_data_buffers, ck_tile::make_kernel(GemmKernel{{}}, grids, blocks, 0, kargs)); - return ave_time; - - // ck_tile::index_t num_wgs_per_tile = kargs.tile_partitioner.estimate_num_wgs_per_tile(); - // return std::make_tuple(ave_time, num_wgs_per_tile); - }}; - - - if constexpr(ck_tile::StreamKReductionStrategy::Atomic == reduction_strategy) - {{ - return Run(ck_tile::integral_constant{{}}); - }} - else // We are using ck_tile::StreamKReductionStrategy::Reduction - {{ - return Run(ck_tile::integral_constant{{}}); - }} }} }}; """