From a1037bfc3c3477179d2a205b6ee29e08036b42f6 Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Thu, 11 Dec 2025 07:16:06 +0000 Subject: [PATCH] Merge commit '6d25525adc2344d5b62b12b9ffddee50f89cd0ff' into develop --- .github/scripts/therock_configure_ci.py | 3 + .pre-commit-config.yaml | 12 +- include/ck_tile/core.hpp | 3 +- include/ck_tile/host.hpp | 3 +- include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp | 3 +- include/ck_tile/ops/batched_contraction.hpp | 3 +- include/ck_tile/ops/batched_transpose.hpp | 3 +- include/ck_tile/ops/common.hpp | 3 +- include/ck_tile/ops/elementwise.hpp | 3 +- include/ck_tile/ops/epilogue.hpp | 3 +- include/ck_tile/ops/flatmm.hpp | 3 +- include/ck_tile/ops/fmha.hpp | 3 +- include/ck_tile/ops/fused_moe.hpp | 3 +- include/ck_tile/ops/gemm.hpp | 3 +- include/ck_tile/ops/gemm_quant.hpp | 3 +- include/ck_tile/ops/grouped_convolution.hpp | 3 +- include/ck_tile/ops/image_to_column.hpp | 3 +- include/ck_tile/ops/layernorm2d.hpp | 3 +- include/ck_tile/ops/norm_reduce.hpp | 3 +- include/ck_tile/ops/permute.hpp | 3 +- include/ck_tile/ops/pooling.hpp | 3 +- include/ck_tile/ops/reduce.hpp | 3 +- include/ck_tile/ops/rmsnorm2d.hpp | 3 +- include/ck_tile/ops/smoothquant.hpp | 3 +- include/ck_tile/ops/softmax.hpp | 3 +- include/ck_tile/ops/topk.hpp | 3 +- include/ck_tile/ops/topk_softmax.hpp | 3 +- include/ck_tile/remod.py | 5 +- script/check_copyright_year.sh | 70 +++- script/update_amd_copyright_headers.py | 295 +++++++++++++++++ .../ck_tile/core/arch/mma/test_amdgcn_mma.cpp | 4 +- test/ck_tile/core/arch/test_arch.cpp | 4 +- tile_engine/include/utility/validation.hpp | 2 +- tile_engine/ops/gemm_streamk/CMakeLists.txt | 3 + .../gemm_streamk/gemm_streamk_benchmark.hpp | 2 +- .../gemm_streamk_benchmark_single.cpp | 2 +- .../ops/gemm_streamk/gemm_streamk_common.hpp | 2 +- .../gemm_streamk_instance_builder.py | 3 + .../gemm_streamk/gemm_streamk_profiler.hpp | 2 +- .../gemm_streamk_validation_utils.py | 2 +- .../01_naive_gemm/TILE_DISTRIBUTION.md | 312 ++++++++++++++++++ ...ice_gemm_block_policy_agmem_bgmem_creg.hpp | 12 +- ...ce_gemm_host_pipeline_agmem_bgmem_creg.hpp | 2 +- .../ck_tile/01_naive_gemm/practice_gemm.cpp | 34 +- .../ck_tile/01_naive_gemm/practice_gemm.hpp | 7 +- 45 files changed, 755 insertions(+), 98 deletions(-) create mode 100644 script/update_amd_copyright_headers.py create mode 100644 tutorial/ck_tile/01_naive_gemm/TILE_DISTRIBUTION.md diff --git a/.github/scripts/therock_configure_ci.py b/.github/scripts/therock_configure_ci.py index 860b6bf875..c892941fc6 100644 --- a/.github/scripts/therock_configure_ci.py +++ b/.github/scripts/therock_configure_ci.py @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + import fnmatch import json import os diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 04ebc6b45a..71e7ccdb81 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,12 +20,12 @@ repos: )$ - repo: local hooks: - # - id: copyright-year-checker - # name: copyright-year-checker - # entry: script/check_copyright_year.sh - # verbose: false - # language: script - # types: [c++] + - id: copyright-header-checker + name: Check copyright headers + entry: script/check_copyright_year.sh + verbose: false + language: script + types_or: [c++, python, shell, cmake] - id: remove-exec-bit name: Remove executable bit from non-executable files entry: script/remove_exec_bit.sh diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 5c05e9b6ee..d28d29a0ef 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/core/algorithm/cluster_descriptor.hpp" diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index c769e3e247..b543fd84e9 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/host/arg_parser.hpp" diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp index 6c0972e10a..00234b20cf 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp" diff --git a/include/ck_tile/ops/batched_contraction.hpp b/include/ck_tile/ops/batched_contraction.hpp index 2232ec1261..45fa52e505 100644 --- a/include/ck_tile/ops/batched_contraction.hpp +++ b/include/ck_tile/ops/batched_contraction.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp" diff --git a/include/ck_tile/ops/batched_transpose.hpp b/include/ck_tile/ops/batched_transpose.hpp index 5822d7b91b..b23e45c233 100644 --- a/include/ck_tile/ops/batched_transpose.hpp +++ b/include/ck_tile/ops/batched_transpose.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp" diff --git a/include/ck_tile/ops/common.hpp b/include/ck_tile/ops/common.hpp index eff2d625b3..94243e674f 100644 --- a/include/ck_tile/ops/common.hpp +++ b/include/ck_tile/ops/common.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/common/generic_2d_block_shape.hpp" diff --git a/include/ck_tile/ops/elementwise.hpp b/include/ck_tile/ops/elementwise.hpp index 7f2303932e..5752703ab6 100644 --- a/include/ck_tile/ops/elementwise.hpp +++ b/include/ck_tile/ops/elementwise.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/elementwise/binary_elementwise_operation.hpp" diff --git a/include/ck_tile/ops/epilogue.hpp b/include/ck_tile/ops/epilogue.hpp index ec5a8ef445..555402b53a 100644 --- a/include/ck_tile/ops/epilogue.hpp +++ b/include/ck_tile/ops/epilogue.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" diff --git a/include/ck_tile/ops/flatmm.hpp b/include/ck_tile/ops/flatmm.hpp index 7ef2fd5433..2d3a819e80 100644 --- a/include/ck_tile/ops/flatmm.hpp +++ b/include/ck_tile/ops/flatmm.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp" diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 5b87a821c9..20714397c9 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" diff --git a/include/ck_tile/ops/fused_moe.hpp b/include/ck_tile/ops/fused_moe.hpp index 71721f3408..e6802e82dc 100644 --- a/include/ck_tile/ops/fused_moe.hpp +++ b/include/ck_tile/ops/fused_moe.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp" diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index ec2d2488c8..d518a15b7e 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp" diff --git a/include/ck_tile/ops/gemm_quant.hpp b/include/ck_tile/ops/gemm_quant.hpp index 3e16d937cb..7dc5b40286 100644 --- a/include/ck_tile/ops/gemm_quant.hpp +++ b/include/ck_tile/ops/gemm_quant.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp" diff --git a/include/ck_tile/ops/grouped_convolution.hpp b/include/ck_tile/ops/grouped_convolution.hpp index 23a72d79e9..6743e46613 100644 --- a/include/ck_tile/ops/grouped_convolution.hpp +++ b/include/ck_tile/ops/grouped_convolution.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp" diff --git a/include/ck_tile/ops/image_to_column.hpp b/include/ck_tile/ops/image_to_column.hpp index 2307b05190..1d33ebf39d 100644 --- a/include/ck_tile/ops/image_to_column.hpp +++ b/include/ck_tile/ops/image_to_column.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp" diff --git a/include/ck_tile/ops/layernorm2d.hpp b/include/ck_tile/ops/layernorm2d.hpp index 9ce22137bf..ebb20aebf4 100644 --- a/include/ck_tile/ops/layernorm2d.hpp +++ b/include/ck_tile/ops/layernorm2d.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp" diff --git a/include/ck_tile/ops/norm_reduce.hpp b/include/ck_tile/ops/norm_reduce.hpp index aa074b7f9f..469a98c256 100644 --- a/include/ck_tile/ops/norm_reduce.hpp +++ b/include/ck_tile/ops/norm_reduce.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp" diff --git a/include/ck_tile/ops/permute.hpp b/include/ck_tile/ops/permute.hpp index 46512c57fe..88a3d8a137 100644 --- a/include/ck_tile/ops/permute.hpp +++ b/include/ck_tile/ops/permute.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/permute/kernel/generic_permute_kernel.hpp" diff --git a/include/ck_tile/ops/pooling.hpp b/include/ck_tile/ops/pooling.hpp index 084b498203..3e44122afa 100644 --- a/include/ck_tile/ops/pooling.hpp +++ b/include/ck_tile/ops/pooling.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/pooling/kernel/pool_kernel.hpp" diff --git a/include/ck_tile/ops/reduce.hpp b/include/ck_tile/ops/reduce.hpp index d628e9c945..57f3f3c80a 100644 --- a/include/ck_tile/ops/reduce.hpp +++ b/include/ck_tile/ops/reduce.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/reduce/block/block_reduce.hpp" diff --git a/include/ck_tile/ops/rmsnorm2d.hpp b/include/ck_tile/ops/rmsnorm2d.hpp index 00afcf4aed..ad23a708b7 100644 --- a/include/ck_tile/ops/rmsnorm2d.hpp +++ b/include/ck_tile/ops/rmsnorm2d.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp" diff --git a/include/ck_tile/ops/smoothquant.hpp b/include/ck_tile/ops/smoothquant.hpp index 1aa14c69e1..13372f3289 100644 --- a/include/ck_tile/ops/smoothquant.hpp +++ b/include/ck_tile/ops/smoothquant.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp" diff --git a/include/ck_tile/ops/softmax.hpp b/include/ck_tile/ops/softmax.hpp index d559dc15e2..9cf3e08319 100644 --- a/include/ck_tile/ops/softmax.hpp +++ b/include/ck_tile/ops/softmax.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/softmax/block/block_softmax_2d.hpp" diff --git a/include/ck_tile/ops/topk.hpp b/include/ck_tile/ops/topk.hpp index 040c6b8ddc..090ad0919f 100644 --- a/include/ck_tile/ops/topk.hpp +++ b/include/ck_tile/ops/topk.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/topk/block/block_topk_stream_2d.hpp" diff --git a/include/ck_tile/ops/topk_softmax.hpp b/include/ck_tile/ops/topk_softmax.hpp index d9657a9764..7afce1708b 100644 --- a/include/ck_tile/ops/topk_softmax.hpp +++ b/include/ck_tile/ops/topk_softmax.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp" diff --git a/include/ck_tile/remod.py b/include/ck_tile/remod.py index aeec7bd471..51f3941233 100644 --- a/include/ck_tile/remod.py +++ b/include/ck_tile/remod.py @@ -1,7 +1,6 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -from datetime import datetime import pathlib from pathlib import Path import subprocess @@ -13,8 +12,8 @@ OPS = "ops" OPS_COMMON = "common" # common header will be duplicated into ops/* other module IGNORED_DIRS = ["utility", "ref"] -HEADER_COMMON = f"""// SPDX-License-Identifier: MIT -// Copyright (c) 2018-{datetime.now().year}, Advanced Micro Devices, Inc. All rights reserved.\n +HEADER_COMMON = """// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT """ diff --git a/script/check_copyright_year.sh b/script/check_copyright_year.sh index 1b63c6b711..48c050c76b 100755 --- a/script/check_copyright_year.sh +++ b/script/check_copyright_year.sh @@ -2,18 +2,70 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT +# This script checks if files have the correct copyright header template. +# It supports .hpp, .cpp, .inc, .py, .sh, and .cmake files. +# +# Usage: ./check_copyright_year.sh ... -current_year=$(date +%Y) exit_code=0 -for file in $@; do - if grep -q "Copyright (c)" $file - then - if ! grep -q "Copyright (c).*$current_year" $file - then - echo "ERROR: File $file has a copyright notice without the current year ($current_year)." - exit_code=1 - fi +# Expected copyright header lines (without comment characters) +COPYRIGHT_LINE="Copyright (c) Advanced Micro Devices, Inc., or its affiliates." +SPDX_LINE="SPDX-License-Identifier: MIT" + +check_file() { + local file=$1 + local basename="${file##*/}" + local ext="${file##*.}" + local comment_char + + # Determine comment character based on filename or extension + if [[ "$basename" == "CMakeLists.txt" ]]; then + comment_char="#" + else + case "$ext" in + cpp|hpp|inc) + comment_char="//" + ;; + py|sh|cmake) + comment_char="#" + ;; + *) + # Skip files with unsupported extensions + return 0 + ;; + esac + fi + + # Build expected header patterns + expected_copyright="$comment_char $COPYRIGHT_LINE" + expected_spdx="$comment_char $SPDX_LINE" + + # Check if file contains both required lines + if ! grep -qF "$expected_copyright" "$file"; then + echo "ERROR: File $file is missing the correct copyright header line." + echo " Expected: $expected_copyright" + return 1 + fi + + if ! grep -qF "$expected_spdx" "$file"; then + echo "ERROR: File $file is missing the correct SPDX license identifier line." + echo " Expected: $expected_spdx" + return 1 + fi + + return 0 +} + +# Process each file provided as argument +for file in "$@"; do + # Skip if file doesn't exist or is a directory + if [[ ! -f "$file" ]]; then + continue + fi + + if ! check_file "$file"; then + exit_code=1 fi done diff --git a/script/update_amd_copyright_headers.py b/script/update_amd_copyright_headers.py new file mode 100644 index 0000000000..489b774e97 --- /dev/null +++ b/script/update_amd_copyright_headers.py @@ -0,0 +1,295 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Purpose: + Normalize and enforce AMD two-line copyright + SPDX headers across files. + +Target files: + - C/C++-style: .cpp, .hpp, .inc -> uses "//" comment style + - Hash-style: .py, .cmake, .sh, and CMakeLists.txt -> uses "#" style + +Header formats inserted (top of file, followed by exactly one blank line): + C/C++ : + // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. + // SPDX-License-Identifier: MIT + + Hash : + + +Shebang special case (hash-style only): + - If line 1 starts with "#!", keep shebang, then a blank line, then the + two hash-style header lines, then a blank line. + +Removal rules: + - Remove any comment lines (anywhere in file) containing the keywords + "copyright" or "spdx" (case-insensitive). Blank lines are preserved. + - Remove long-form MIT license block comment when: + a) The file starts with the block (absolute top), OR + b) The block appears immediately after the AMD header position + (i.e., when remainder at insertion point begins with "/*" and + the first content line is "* The MIT License (MIT)"). + +Blank-line normalization: + - Enforce exactly ONE blank line immediately after the AMD header. + (Drop only the leading blank lines at the insertion point before + re-inserting the header.) + - Do not change blank lines between other non-copyright comments. + +Preservation: + - Preserve original newline style: CRLF (\r\n) vs LF (\n). + - Preserve UTF-8 BOM if present. + - Do not modify non-comment code lines. + +Idempotency: + - Running this script multiple times does not further modify files. +""" + +from __future__ import annotations +import re +import sys +from pathlib import Path +from typing import List, Tuple + +AMD_CPP_HEADER_TEXT = [ + "// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.", + "// SPDX-License-Identifier: MIT", +] +AMD_HASH_HEADER_TEXT = [ + "# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.", + "# SPDX-License-Identifier: MIT", +] + +CPP_EXTS = {".cpp", ".hpp", ".inc"} +HASH_EXTS = {".py", ".cmake", ".sh"} + +# --- Encoding helpers ------------------------------------------------------- + + +def has_bom(raw: bytes) -> bool: + return raw.startswith(b"\xef\xbb\xbf") + + +def decode_text(raw: bytes) -> str: + return raw.decode("utf-8-sig", errors="replace") + + +def encode_text(text: str, bom: bool) -> bytes: + data = text.encode("utf-8") + return (b"\xef\xbb\xbf" + data) if bom else data + + +# --- Newline detection ------------------------------------------------------ + + +def detect_newline_sequence(raw: bytes) -> str: + if b"\r\n" in raw: + return "\r\n" + elif b"\n" in raw: + return "\n" + else: + return "\n" + + +# --- Utilities -------------------------------------------------------------- + + +def is_comment_line(line: str, style: str) -> bool: + stripped = line.lstrip() + if style == "cpp": + return ( + stripped.startswith("//") + or stripped.startswith("/*") + or stripped.startswith("*") + or stripped.startswith("*/") + ) + elif style == "hash": + return stripped.startswith("#") + return False + + +def has_keywords(line: str) -> bool: + lower_line = line.lower() + return ("copyright" in lower_line) or ("spdx" in lower_line) + + +# --- MIT License banner detection ------------------------------ +MIT_C_FIRST_LINE_RE = re.compile(r"^\s*\*\s*The MIT License \(MIT\)") +MIT_HASH_FIRST_LINE_RE = re.compile(r"^\s*#\s*The MIT License \(MIT\)") + + +def remove_top_mit_block(lines: List[str]) -> Tuple[List[str], bool]: + """ + Unified MIT banner removal at the top of 'lines'. + Supports: + - C-style block starting with '/*' and ending with '*/'; removes only if + a line within the block matches MIT_C_FIRST_LINE_RE. + - Hash-style banner: contiguous top run of lines starting with '#'; + removes only if any line in that run matches MIT_HASH_FIRST_LINE_RE. + Returns (new_lines, removed_flag). Preserves EOLs. + """ + if not lines: + return lines, False + + first = lines[0].lstrip() + + # C-style block + if first.startswith("/*"): + end_idx, saw_mit = None, False + for i, line in enumerate(lines[1:], 1): + if not saw_mit and MIT_C_FIRST_LINE_RE.match(line): + saw_mit = True + s = line.lstrip() + if s.startswith("*/") or s.rstrip().endswith("*/"): + end_idx = i + 1 + break + if end_idx is not None and saw_mit: + return lines[end_idx:], True + return lines, False + + # Hash-style contiguous banner + if first.startswith("#"): + end_idx, saw_mit = 0, False + for i, line in enumerate(lines): + if line.lstrip().startswith("#"): + if not saw_mit and MIT_HASH_FIRST_LINE_RE.match(line): + saw_mit = True + end_idx = i + 1 + else: + break + if saw_mit: + return lines[end_idx:], True + return lines, False + + return lines, False + + +# --- Removal + normalization helpers --------------------------------------- + + +def remove_keyword_comment_lines_globally(lines: List[str], style: str) -> List[str]: + """Remove comment lines containing keywords anywhere in the file. + **Do not** remove blank lines; preserve all other lines as-is.""" + out: List[str] = [] + for line in lines: + if is_comment_line(line, style) and has_keywords(line): + continue + out.append(line) + return out + + +def drop_leading_blank_lines(lines: List[str]) -> List[str]: + """Drop only the leading blank lines at the start of the given list.""" + i = 0 + while i < len(lines) and lines[i].strip() == "": + i += 1 + return lines[i:] + + +# --- Header builder --------------------------------------------------------- + + +def build_header_lines(style: str, nl: str) -> List[str]: + base = AMD_CPP_HEADER_TEXT if style == "cpp" else AMD_HASH_HEADER_TEXT + return [base[0] + nl, base[1] + nl, nl] # header + exactly one blank + + +# --- Main transforms -------------------------------------------------------- + + +def process_cpp(text: str, nl: str) -> str: + lines = text.splitlines(True) + + # Remove MIT block if it is at the *absolute* top + lines, _ = remove_top_mit_block(lines) + + # Remove keyworded comment lines globally (blank lines preserved) + lines = remove_keyword_comment_lines_globally(lines, style="cpp") + + # Normalize insertion point and remove MIT block if it appears *after header* + lines = drop_leading_blank_lines(lines) + lines, _ = remove_top_mit_block(lines) + + # Prepend AMD header (guarantee exactly one blank after) + return "".join(build_header_lines("cpp", nl) + lines) + + +def process_hash(text: str, nl: str) -> str: + lines = text.splitlines(True) + if not lines: + return "".join(build_header_lines("hash", nl)) + + shebang = lines[0].startswith("#!") + + if shebang: + remainder = remove_keyword_comment_lines_globally(lines[1:], style="hash") + remainder = drop_leading_blank_lines(remainder) + remainder, _ = remove_top_mit_block(remainder) # remove MIT block after header + new_top = [lines[0], nl] + build_header_lines("hash", nl) + return "".join(new_top + remainder) + else: + remainder = remove_keyword_comment_lines_globally(lines, style="hash") + remainder = drop_leading_blank_lines(remainder) + remainder, _ = remove_top_mit_block(remainder) # remove MIT block after header + return "".join(build_header_lines("hash", nl) + remainder) + + +# --- File processing & CLI -------------------------------------------------- + + +def process_file(path: Path) -> bool: + name = path.name + suffix = path.suffix.lower() + if suffix in CPP_EXTS: + style = "cpp" + elif suffix in HASH_EXTS or name == "CMakeLists.txt": + style = "hash" + else: + return False + + raw = path.read_bytes() + bom = has_bom(raw) + nl = detect_newline_sequence(raw) + text = decode_text(raw) + + updated = process_cpp(text, nl) if style == "cpp" else process_hash(text, nl) + if updated != text: + path.write_bytes(encode_text(updated, bom)) + return True + return False + + +def main(argv: List[str]) -> int: + if len(argv) < 2: + print(__doc__) + return 2 + changed = 0 + skipped = 0 + errors: List[str] = [] + for arg in argv[1:]: + p = Path(arg) + try: + if not p.exists(): + errors.append(f"Not found: {p}") + continue + if p.is_dir(): + errors.append(f"Is a directory (pass specific files): {p}") + continue + if process_file(p): + changed += 1 + print(f"Updated: {p}") + else: + skipped += 1 + print(f"Skipped (no change needed or unsupported type): {p}") + except Exception as e: + errors.append(f"Error processing {p}: {e}") + print(f"\nSummary: {changed} updated, {skipped} skipped, {len(errors)} errors") + for msg in errors: + print(f" - {msg}") + return 0 if not errors else 1 + + +if __name__ == "__main__": + raise SystemExit(main(sys.argv)) diff --git a/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp b/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp index 4121e199e2..c7093e3477 100644 --- a/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp +++ b/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp @@ -1,5 +1,5 @@ -// Copyright © Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT #include #include diff --git a/test/ck_tile/core/arch/test_arch.cpp b/test/ck_tile/core/arch/test_arch.cpp index 2d553c1595..f015d3ce0a 100644 --- a/test/ck_tile/core/arch/test_arch.cpp +++ b/test/ck_tile/core/arch/test_arch.cpp @@ -1,5 +1,5 @@ -// Copyright © Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT #include #include "ck_tile/core/arch/arch.hpp" diff --git a/tile_engine/include/utility/validation.hpp b/tile_engine/include/utility/validation.hpp index dc57e6cc6a..f10f37fbaa 100644 --- a/tile_engine/include/utility/validation.hpp +++ b/tile_engine/include/utility/validation.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c), Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/tile_engine/ops/gemm_streamk/CMakeLists.txt b/tile_engine/ops/gemm_streamk/CMakeLists.txt index acfd78edc5..c692a6d247 100644 --- a/tile_engine/ops/gemm_streamk/CMakeLists.txt +++ b/tile_engine/ops/gemm_streamk/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + set(GEMM_STREAMK_DATATYPE "fp8;fp16" CACHE STRING "List of datatypes for GEMM (semicolon-separated)") set(GEMM_STREAMK_LAYOUT "rcr" CACHE STRING "List of layout for GEMM (semicolon-separated)") set(GEMM_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)") diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark.hpp b/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark.hpp index fa8a019be5..45beb0acce 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark.hpp +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark_single.cpp b/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark_single.cpp index 5e88dc486a..9dbba04082 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark_single.cpp +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark_single.cpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_common.hpp b/tile_engine/ops/gemm_streamk/gemm_streamk_common.hpp index 15a3c91964..2708ac2e56 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_common.hpp +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_common.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py b/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py index 6aebc54564..2225619fad 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py @@ -1,4 +1,7 @@ #!/usr/bin/env python +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + import os import json diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_profiler.hpp b/tile_engine/ops/gemm_streamk/gemm_streamk_profiler.hpp index 256e0b9ca4..0541116522 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_profiler.hpp +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_profiler.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_validation_utils.py b/tile_engine/ops/gemm_streamk/gemm_streamk_validation_utils.py index 2288d7752f..bef3cdfe85 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_validation_utils.py +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_validation_utils.py @@ -1,6 +1,6 @@ #!/usr/bin/env python +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. """ Validation utilities for GEMM kernel generation. diff --git a/tutorial/ck_tile/01_naive_gemm/TILE_DISTRIBUTION.md b/tutorial/ck_tile/01_naive_gemm/TILE_DISTRIBUTION.md new file mode 100644 index 0000000000..275d1a1c12 --- /dev/null +++ b/tutorial/ck_tile/01_naive_gemm/TILE_DISTRIBUTION.md @@ -0,0 +1,312 @@ +# Tile Distribution: Mapping Threads to Data + +## Overview + +**Tile Distribution** describes how each thread in a thread block maps to elements of a block tile. It defines the hierarchical pattern of data distribution across threads, warps, and thread blocks. + +## The Problem + +Given a block tile of size `MPerBlock × KPerBlock` (e.g., 256×32), we need to determine: +- Which threads load which elements. +- How the threads are organized into warps. +- The number of times each warp repeats its pattern. +- The number of elements each thread can load in a single vector instruction. + +--- + +## Bottom-Up Construction Approach + +### Step 1: Determine K Dimension Layout + +**Start with the innermost dimension (K) for memory coalescing:** + +```cpp +constexpr index_t K1 = 16 / sizeof(ADataType); // Elements per thread (vector load) +constexpr index_t K0 = kKPerBlock / K1; // Threads needed in K dimension +``` + +**Example (with fp16):** +- `K1 = 16 / 2 = 8` → Each thread loads 8 fp16 elements in a single vector instruction +- `kKPerBlock = 32` +- `K0 = 32 / 8 = 4` → We need 4 threads along K to cover the entire K dimension + +**Visual:** +``` +K dimension (32 elements): +Thread 0: [0-7] Thread 1: [8-15] Thread 2: [16-23] Thread 3: [24-31] + K1=8 K1=8 K1=8 K1=8 +├──────────────────────────────────────────────────────────────┤ + K0=4 threads +``` + +--- + +### Step 2: Determine M Dimension Layout + +**Now partition the M dimension hierarchically:** + +#### Level 1: Threads per Warp in M (M2) + +```cpp +constexpr index_t M2 = get_warp_size() / K0; +``` + +- Warp size = 64 threads +- K dimension already uses `K0 = 4` threads per row +- `M2 = 64 / 4 = 16` → Each warp can have 16 threads in M dimension + +**Visual (Single Warp):** +``` + K dimension (4 threads) + ┌─────┬─────┬─────┬─────┐ + 0 │ T0 │ T1 │ T2 │ T3 │ + 1 │ T4 │ T5 │ T6 │ T7 │ + 2 │ T8 │ T9 │ T10 │ T11 │ +M 3 │ T12 │ T13 │ T14 │ T15 │ ← 16 rows + ...│ ... │ ... │ ... │ ... │ (M2=16) + 15 │ T60 │ T61 │ T62 │ T63 │ + └─────┴─────┴─────┴─────┘ + One Warp = 64 threads +``` + +#### Level 2: Warps per Block (M1) + +```cpp +constexpr index_t M1 = kBlockSize / get_warp_size(); +``` + +- `kBlockSize = 256` threads per block +- `M1 = 256 / 64 = 4` → We have 4 warps per block + +**Visual (4 Warps):** +``` + Warp 0 (rows 0-15) + Warp 1 (rows 16-31) + Warp 2 (rows 32-47) + Warp 3 (rows 48-63) + ↑ + M1 = 4 warps cover 64 rows total +``` + +#### Level 3: Repetitions (M0) + +```cpp +constexpr index_t M0 = kMPerBlock / (M2 * M1); +``` + +- `kMPerBlock = 256` rows to cover +- `M2 * M1 = 16 * 4 = 64` rows covered by all warps +- `M0 = 256 / 64 = 4` → Each warp must repeat its pattern 4 times + +**Visual (Complete Block):** +``` +┌──────────────┐ +│ Iteration 0 │ ← Warp 0: rows 0-15, Warp 1: rows 16-31, ... +│ (rows 0-63) │ +├──────────────┤ +│ Iteration 1 │ ← Warp 0: rows 64-79, Warp 1: rows 80-95, ... +│ (rows 64-127)│ +├──────────────┤ +│ Iteration 2 │ ← Warp 0: rows 128-143, Warp 1: rows 144-159, ... +│(rows 128-191)│ +├──────────────┤ +│ Iteration 3 │ ← Warp 0: rows 192-207, Warp 1: rows 208-223, ... +│(rows 192-255)│ +└──────────────┘ + M0 = 4 iterations +``` + +--- + +## The Tile Distribution Encoding + +Now we can construct the distribution: + +```cpp +tile_distribution_encoding< + sequence<1>, // [1] Replication + tuple, sequence>, // [2] Hierarchy + tuple, sequence<1, 2>>, // [3] Parallelism: + tuple, sequence<2, 0>>, // [3] Parallelism + sequence<1, 2>, // [4] Yield + sequence<0, 1> // [4] Yield +> +``` + +### [1] Replication: `sequence<1>` + +Defines how many times warp patterns are replicated: +- `1` = Each warp has a unique pattern (no replication) +- `2` = Warp 0 and Warp 1 do the same thing, Warp 2 and Warp 3 do the same thing +- `4` = All warps do the same thing + +In our case: `1` means no replication (each warp is independent). + +--- + +### [2] Hierarchy: The Multi-Level Structure + +```cpp +tuple, sequence> + └───────┬──────────┘ └──────┬────────┘ + M dimension K dimension +``` + +**Concrete values:** +- M hierarchy: `sequence<4, 4, 16>` = (4 repetitions, 4 warps, 16 threads/warp) +- K hierarchy: `sequence<4, 8>` = (4 threads, 8 elements/thread) + +--- + +### [3] Parallelism: Addressing the Hierarchy + +**The key insight:** Read the tuples **vertically** to understand indexing! + +```cpp +tuple, sequence<1, 2>> +tuple, sequence<2, 0>> +``` + +#### Reading Pattern + +**Column 1 (Dimension 0 = M):** +``` +sequence<1> → Address hierarchy index 1,1 → M1 (warps/block in M dimension) +sequence<1> +``` + +**Column 2 (Dimension 1 = K):** +``` +sequence<1, 2> +sequence<2, 0> +``` +[1,2] M2=threads/warp in M dimension +[2,0] K0=threads/warp in K dimension + +--- + +### [4] Yield Sequences: Output Ordering + +```cpp +sequence<1, 2> +sequence<0, 1> + +[1,0] means M0=repetitions/warp in M dimension +[2,1] means K1=elements/thread in K dimension +``` +--- + +## Complete Example: Thread 25 in Warp 0 + +Let's trace where **Thread 25** in **Warp 0** reads data: + +### Thread Coordinates +- Thread ID in warp: 25 +- Warp ID in block: 0 + +### Decompose Thread 25 +``` +Thread 25 in a 2D layout (M2=16, K0=4): +Row index: 25 / 4 = 6 +Col index: 25 % 4 = 1 +``` + +### M Position (Row) +``` +M0 iteration: 0 (first iteration) +M1 warp: 0 (warp 0) +M2 thread: 6 (6th row in warp) +→ M position = 0*64 + 0*16 + 6 = 6 +``` + +### K Position (Column) +``` +K0 thread: 1 (column group 1) +K1 elements: 8 (will load 8 consecutive elements) +→ K position = 1*8 + [0-7] = elements 8-15 +``` + +**Result:** Thread 25 in Warp 0 loads **row 6, columns 8-15** (8 elements). + +--- + +## Why This Matters + +### 1. **Memory Coalescing** +- Consecutive threads access consecutive memory → efficient global memory access +- K dimension uses K1=8 for vectorized loads + +### 2. **Warp Efficiency** +- All 64 threads in a warp are utilized +- Natural 2D layout: 16 threads (M) × 4 threads (K) = 64 threads + +### 3. **Scalability** +- M0 repetitions allow handling larger tiles +- Same pattern scales to different sizes + +### 4. **Register Allocation** +- Each thread knows exactly how many elements it will hold +- Compiler can allocate registers optimally + +--- + +## Summary Table + +| Parameter | Value | Meaning | +|-----------|-------|---------| +| **K1** | 8 | Elements per thread (vector width) | +| **K0** | 4 | Threads along K per row | +| **M2** | 16 | Threads along M per warp | +| **M1** | 4 | Warps per block | +| **M0** | 4 | Repetitions of warp pattern | +| **Total Threads** | 256 | M0×M1×M2 = 4×4×16 (actually M1×64) | +| **Total Elements** | 8192 | 256×32 (MPerBlock × KPerBlock) | +| **Elements/Thread** | 32 | M0×K1 = 4×8 | + +--- + +## Visualization: Complete Thread Block + +``` +Block Tile: 256×32 + + K dimension (32 elements) + ├─────────────────────┤ + 0 ┌──────────────────────┐ ┐ + 16 │ Warp 0 │ │ + 32 │ Warp 1 │ │ Iteration 0 + 48 │ Warp 2 │ │ (M0=0) + 64 │ Warp 3 │ ┘ + 80 ├──────────────────────┤ ┐ + 96 │ Warp 0 │ │ + 112 │ Warp 1 │ │ Iteration 1 + 128 │ Warp 2 │ │ (M0=1) + 144 │ Warp 3 │ ┘ + 160 ├──────────────────────┤ ┐ + 176 │ Warp 0 │ │ + 192 │ Warp 1 │ │ Iteration 2 + 208 │ Warp 2 │ │ (M0=2) + 224 │ Warp 3 │ ┘ + 240 ├──────────────────────┤ ┐ + 256 │ Warp 0 │ │ + │ Warp 1 │ │ Iteration 3 + │ Warp 2 │ │ (M0=3) + │ Warp 3 │ ┘ + └──────────────────────┘ + +Each warp processes 16 rows × 32 cols = 512 elements +Each iteration processes 64 rows × 32 cols = 2048 elements +Total: 4 iterations × 2048 = 8192 elements ✓ +``` + +--- + +## Key Takeaways + +1. **Bottom-up construction**: Start from vector width (K1), build up through thread/warp/block hierarchy +2. **Vertical reading**: The repeat and elements tuples are read column-wise to address hierarchy levels +3. **Replication controls redundancy**: How many warps share the same pattern +4. **Hierarchy encodes structure**: The multi-level sequence defines the complete mapping + +This design enables CK to achieve maximum GPU performance through optimal thread-to-data mapping! + diff --git a/tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_policy_agmem_bgmem_creg.hpp b/tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_policy_agmem_bgmem_creg.hpp index 2921bce8bf..a3ed982488 100644 --- a/tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_policy_agmem_bgmem_creg.hpp +++ b/tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_policy_agmem_bgmem_creg.hpp @@ -98,12 +98,12 @@ struct PracticeGemmBlockPolicy constexpr index_t M0 = kMPerBlock / (M2 * M1); return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); + tile_distribution_encoding, // replication + tuple, sequence>, // hierarchy + tuple, sequence<1, 2>>, // parallelism + tuple, sequence<2, 0>>, // paralleism + sequence<1, 2>, // yield + sequence<0, 1>>{}); // yield } template diff --git a/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp b/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp index dd72f08d99..15c1743a86 100644 --- a/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp +++ b/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp @@ -24,7 +24,7 @@ struct PracticeGemmHostPipeline template CK_TILE_DEVICE void operator()(const ADRAMTensorView& a_dram, const BDRAMTensorView& b_dram, - CDRAMTensorView& c_dram_ref) const + CDRAMTensorView& c_dram) const { // Size of the entire problem diff --git a/tutorial/ck_tile/01_naive_gemm/practice_gemm.cpp b/tutorial/ck_tile/01_naive_gemm/practice_gemm.cpp index 4f0bc13dd5..7635c9376b 100644 --- a/tutorial/ck_tile/01_naive_gemm/practice_gemm.cpp +++ b/tutorial/ck_tile/01_naive_gemm/practice_gemm.cpp @@ -6,7 +6,7 @@ #include "practice_gemm.hpp" #include "reference_gemm.hpp" -int main() +int main(int argc, char* argv[]) { // TODO: GemmTypeConfig using ADataType = ck_tile::half_t; @@ -14,11 +14,22 @@ int main() using CDataType = float; using AccDataType = float; - // ArgParser - ck_tile::index_t M = 512; - ck_tile::index_t N = 256; - ck_tile::index_t K = 64; - ck_tile::index_t verification = 1; + // Setup simple argument parser for M, N, K + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "512", "m dimension") + .insert("n", "256", "n dimension") + .insert("k", "64", "k dimension") + .insert("v", "1", "verification: 0=off, 1=on"); + + auto result = arg_parser.parse(argc, argv); + if(!result) + return -1; + + // Get problem dimensions from command line + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t K = arg_parser.get_int("k"); + ck_tile::index_t verification = arg_parser.get_int("v"); ck_tile::index_t stride_a = K; ck_tile::index_t stride_b = K; @@ -61,9 +72,6 @@ int main() ck_tile::DeviceMem c_device(c_host); // TODO: BlockTileConfig - // constexpr ck_tile::index_t warpSize = 64; - constexpr ck_tile::index_t kBlockSize = 256; - using BlockTile = ck_tile::sequence<256, 128, 32>; using WaveTile = ck_tile::sequence<16, 16, 16>; @@ -77,11 +85,13 @@ int main() ck_tile::index_t kGridSize = ck_tile::integer_divide_ceil(M, PracticeGemmShape::BlockTile_M) * ck_tile::integer_divide_ceil(N, PracticeGemmShape::BlockTile_N); - std::cout << "kGridSize: " << kGridSize << std::endl; + std::cout << "Total number of thread blocks: " << kGridSize << std::endl; constexpr ck_tile::index_t kBlockPerCU = 1; // 1 block per CU - std::cout << "kBlockSize: " << kBlockSize << std::endl; - std::cout << "kBlockPerCU: " << kBlockPerCU << std::endl; + // Block size is now derived from the shape configuration + constexpr ck_tile::index_t kBlockSize = PracticeGemmShape::kBlockSize; + std::cout << "Number of threads per block: " << kBlockSize << std::endl; + std::cout << "Number of blocks per compute unit: " << kBlockPerCU << std::endl; using gemm_kernel = ck_tile::PracticeGemmKernel; diff --git a/tutorial/ck_tile/01_naive_gemm/practice_gemm.hpp b/tutorial/ck_tile/01_naive_gemm/practice_gemm.hpp index 850e6ae3b3..91d7fae90c 100644 --- a/tutorial/ck_tile/01_naive_gemm/practice_gemm.hpp +++ b/tutorial/ck_tile/01_naive_gemm/practice_gemm.hpp @@ -24,6 +24,10 @@ struct PracticeGemmShape static constexpr index_t WaveTile_N = WaveTile::at(number<1>{}); static constexpr index_t WaveTile_K = WaveTile::at(number<2>{}); + // Thread block configuration + static constexpr index_t kWarpSize = 64; // AMD GPU warp size (also called wavefront) + static constexpr index_t kBlockSize = 256; // Total threads per block (4 warps × 64 threads) + CK_TILE_HOST static std::string GetName() { // clang-format off @@ -40,7 +44,8 @@ struct PracticeGemmKernel using Problem = remove_cvref_t; using Policy = remove_cvref_t; - static constexpr index_t kBlockSize = 256; + // Derive block size from the shape configuration + static constexpr index_t kBlockSize = Problem::Shape::kBlockSize; CK_TILE_DEVICE void operator()(const typename Problem::ADataType* p_a, const typename Problem::BDataType* p_b,