Merge commit '6d25525adc2344d5b62b12b9ffddee50f89cd0ff' into develop

This commit is contained in:
assistant-librarian[bot]
2025-12-11 07:16:06 +00:00
parent 72cc7dfc77
commit a1037bfc3c
45 changed files with 755 additions and 98 deletions

View File

@@ -1,3 +1,6 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
import fnmatch
import json
import os

View File

@@ -20,12 +20,12 @@ repos:
)$
- repo: local
hooks:
# - id: copyright-year-checker
# name: copyright-year-checker
# entry: script/check_copyright_year.sh
# verbose: false
# language: script
# types: [c++]
- id: copyright-header-checker
name: Check copyright headers
entry: script/check_copyright_year.sh
verbose: false
language: script
types_or: [c++, python, shell, cmake]
- id: remove-exec-bit
name: Remove executable bit from non-executable files
entry: script/remove_exec_bit.sh

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/algorithm/cluster_descriptor.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/host/arg_parser.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/elementwise/binary_elementwise_operation.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/permute/kernel/generic_permute_kernel.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/pooling/kernel/pool_kernel.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/reduce/block/block_reduce.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/softmax/block/block_softmax_2d.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/topk/block/block_topk_stream_2d.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp"

View File

@@ -1,7 +1,6 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
from datetime import datetime
import pathlib
from pathlib import Path
import subprocess
@@ -13,8 +12,8 @@ OPS = "ops"
OPS_COMMON = "common" # common header will be duplicated into ops/* other module
IGNORED_DIRS = ["utility", "ref"]
HEADER_COMMON = f"""// SPDX-License-Identifier: MIT
// Copyright (c) 2018-{datetime.now().year}, Advanced Micro Devices, Inc. All rights reserved.\n
HEADER_COMMON = """// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
"""

View File

@@ -2,18 +2,70 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# This script checks if files have the correct copyright header template.
# It supports .hpp, .cpp, .inc, .py, .sh, and .cmake files.
#
# Usage: ./check_copyright_year.sh <file1> <file2> ...
current_year=$(date +%Y)
exit_code=0
for file in $@; do
if grep -q "Copyright (c)" $file
then
if ! grep -q "Copyright (c).*$current_year" $file
then
echo "ERROR: File $file has a copyright notice without the current year ($current_year)."
exit_code=1
fi
# Expected copyright header lines (without comment characters)
COPYRIGHT_LINE="Copyright (c) Advanced Micro Devices, Inc., or its affiliates."
SPDX_LINE="SPDX-License-Identifier: MIT"
check_file() {
local file=$1
local basename="${file##*/}"
local ext="${file##*.}"
local comment_char
# Determine comment character based on filename or extension
if [[ "$basename" == "CMakeLists.txt" ]]; then
comment_char="#"
else
case "$ext" in
cpp|hpp|inc)
comment_char="//"
;;
py|sh|cmake)
comment_char="#"
;;
*)
# Skip files with unsupported extensions
return 0
;;
esac
fi
# Build expected header patterns
expected_copyright="$comment_char $COPYRIGHT_LINE"
expected_spdx="$comment_char $SPDX_LINE"
# Check if file contains both required lines
if ! grep -qF "$expected_copyright" "$file"; then
echo "ERROR: File $file is missing the correct copyright header line."
echo " Expected: $expected_copyright"
return 1
fi
if ! grep -qF "$expected_spdx" "$file"; then
echo "ERROR: File $file is missing the correct SPDX license identifier line."
echo " Expected: $expected_spdx"
return 1
fi
return 0
}
# Process each file provided as argument
for file in "$@"; do
# Skip if file doesn't exist or is a directory
if [[ ! -f "$file" ]]; then
continue
fi
if ! check_file "$file"; then
exit_code=1
fi
done

View File

@@ -0,0 +1,295 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Purpose:
Normalize and enforce AMD two-line copyright + SPDX headers across files.
Target files:
- C/C++-style: .cpp, .hpp, .inc -> uses "//" comment style
- Hash-style: .py, .cmake, .sh, and CMakeLists.txt -> uses "#" style
Header formats inserted (top of file, followed by exactly one blank line):
C/C++ :
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
<blank>
Hash :
<blank>
Shebang special case (hash-style only):
- If line 1 starts with "#!", keep shebang, then a blank line, then the
two hash-style header lines, then a blank line.
Removal rules:
- Remove any comment lines (anywhere in file) containing the keywords
"copyright" or "spdx" (case-insensitive). Blank lines are preserved.
- Remove long-form MIT license block comment when:
a) The file starts with the block (absolute top), OR
b) The block appears immediately after the AMD header position
(i.e., when remainder at insertion point begins with "/*" and
the first content line is "* The MIT License (MIT)").
Blank-line normalization:
- Enforce exactly ONE blank line immediately after the AMD header.
(Drop only the leading blank lines at the insertion point before
re-inserting the header.)
- Do not change blank lines between other non-copyright comments.
Preservation:
- Preserve original newline style: CRLF (\r\n) vs LF (\n).
- Preserve UTF-8 BOM if present.
- Do not modify non-comment code lines.
Idempotency:
- Running this script multiple times does not further modify files.
"""
from __future__ import annotations
import re
import sys
from pathlib import Path
from typing import List, Tuple
AMD_CPP_HEADER_TEXT = [
"// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.",
"// SPDX-License-Identifier: MIT",
]
AMD_HASH_HEADER_TEXT = [
"# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.",
"# SPDX-License-Identifier: MIT",
]
CPP_EXTS = {".cpp", ".hpp", ".inc"}
HASH_EXTS = {".py", ".cmake", ".sh"}
# --- Encoding helpers -------------------------------------------------------
def has_bom(raw: bytes) -> bool:
return raw.startswith(b"\xef\xbb\xbf")
def decode_text(raw: bytes) -> str:
return raw.decode("utf-8-sig", errors="replace")
def encode_text(text: str, bom: bool) -> bytes:
data = text.encode("utf-8")
return (b"\xef\xbb\xbf" + data) if bom else data
# --- Newline detection ------------------------------------------------------
def detect_newline_sequence(raw: bytes) -> str:
if b"\r\n" in raw:
return "\r\n"
elif b"\n" in raw:
return "\n"
else:
return "\n"
# --- Utilities --------------------------------------------------------------
def is_comment_line(line: str, style: str) -> bool:
stripped = line.lstrip()
if style == "cpp":
return (
stripped.startswith("//")
or stripped.startswith("/*")
or stripped.startswith("*")
or stripped.startswith("*/")
)
elif style == "hash":
return stripped.startswith("#")
return False
def has_keywords(line: str) -> bool:
lower_line = line.lower()
return ("copyright" in lower_line) or ("spdx" in lower_line)
# --- MIT License banner detection ------------------------------
MIT_C_FIRST_LINE_RE = re.compile(r"^\s*\*\s*The MIT License \(MIT\)")
MIT_HASH_FIRST_LINE_RE = re.compile(r"^\s*#\s*The MIT License \(MIT\)")
def remove_top_mit_block(lines: List[str]) -> Tuple[List[str], bool]:
"""
Unified MIT banner removal at the top of 'lines'.
Supports:
- C-style block starting with '/*' and ending with '*/'; removes only if
a line within the block matches MIT_C_FIRST_LINE_RE.
- Hash-style banner: contiguous top run of lines starting with '#';
removes only if any line in that run matches MIT_HASH_FIRST_LINE_RE.
Returns (new_lines, removed_flag). Preserves EOLs.
"""
if not lines:
return lines, False
first = lines[0].lstrip()
# C-style block
if first.startswith("/*"):
end_idx, saw_mit = None, False
for i, line in enumerate(lines[1:], 1):
if not saw_mit and MIT_C_FIRST_LINE_RE.match(line):
saw_mit = True
s = line.lstrip()
if s.startswith("*/") or s.rstrip().endswith("*/"):
end_idx = i + 1
break
if end_idx is not None and saw_mit:
return lines[end_idx:], True
return lines, False
# Hash-style contiguous banner
if first.startswith("#"):
end_idx, saw_mit = 0, False
for i, line in enumerate(lines):
if line.lstrip().startswith("#"):
if not saw_mit and MIT_HASH_FIRST_LINE_RE.match(line):
saw_mit = True
end_idx = i + 1
else:
break
if saw_mit:
return lines[end_idx:], True
return lines, False
return lines, False
# --- Removal + normalization helpers ---------------------------------------
def remove_keyword_comment_lines_globally(lines: List[str], style: str) -> List[str]:
"""Remove comment lines containing keywords anywhere in the file.
**Do not** remove blank lines; preserve all other lines as-is."""
out: List[str] = []
for line in lines:
if is_comment_line(line, style) and has_keywords(line):
continue
out.append(line)
return out
def drop_leading_blank_lines(lines: List[str]) -> List[str]:
"""Drop only the leading blank lines at the start of the given list."""
i = 0
while i < len(lines) and lines[i].strip() == "":
i += 1
return lines[i:]
# --- Header builder ---------------------------------------------------------
def build_header_lines(style: str, nl: str) -> List[str]:
base = AMD_CPP_HEADER_TEXT if style == "cpp" else AMD_HASH_HEADER_TEXT
return [base[0] + nl, base[1] + nl, nl] # header + exactly one blank
# --- Main transforms --------------------------------------------------------
def process_cpp(text: str, nl: str) -> str:
lines = text.splitlines(True)
# Remove MIT block if it is at the *absolute* top
lines, _ = remove_top_mit_block(lines)
# Remove keyworded comment lines globally (blank lines preserved)
lines = remove_keyword_comment_lines_globally(lines, style="cpp")
# Normalize insertion point and remove MIT block if it appears *after header*
lines = drop_leading_blank_lines(lines)
lines, _ = remove_top_mit_block(lines)
# Prepend AMD header (guarantee exactly one blank after)
return "".join(build_header_lines("cpp", nl) + lines)
def process_hash(text: str, nl: str) -> str:
lines = text.splitlines(True)
if not lines:
return "".join(build_header_lines("hash", nl))
shebang = lines[0].startswith("#!")
if shebang:
remainder = remove_keyword_comment_lines_globally(lines[1:], style="hash")
remainder = drop_leading_blank_lines(remainder)
remainder, _ = remove_top_mit_block(remainder) # remove MIT block after header
new_top = [lines[0], nl] + build_header_lines("hash", nl)
return "".join(new_top + remainder)
else:
remainder = remove_keyword_comment_lines_globally(lines, style="hash")
remainder = drop_leading_blank_lines(remainder)
remainder, _ = remove_top_mit_block(remainder) # remove MIT block after header
return "".join(build_header_lines("hash", nl) + remainder)
# --- File processing & CLI --------------------------------------------------
def process_file(path: Path) -> bool:
name = path.name
suffix = path.suffix.lower()
if suffix in CPP_EXTS:
style = "cpp"
elif suffix in HASH_EXTS or name == "CMakeLists.txt":
style = "hash"
else:
return False
raw = path.read_bytes()
bom = has_bom(raw)
nl = detect_newline_sequence(raw)
text = decode_text(raw)
updated = process_cpp(text, nl) if style == "cpp" else process_hash(text, nl)
if updated != text:
path.write_bytes(encode_text(updated, bom))
return True
return False
def main(argv: List[str]) -> int:
if len(argv) < 2:
print(__doc__)
return 2
changed = 0
skipped = 0
errors: List[str] = []
for arg in argv[1:]:
p = Path(arg)
try:
if not p.exists():
errors.append(f"Not found: {p}")
continue
if p.is_dir():
errors.append(f"Is a directory (pass specific files): {p}")
continue
if process_file(p):
changed += 1
print(f"Updated: {p}")
else:
skipped += 1
print(f"Skipped (no change needed or unsupported type): {p}")
except Exception as e:
errors.append(f"Error processing {p}: {e}")
print(f"\nSummary: {changed} updated, {skipped} skipped, {len(errors)} errors")
for msg in errors:
print(f" - {msg}")
return 0 if not errors else 1
if __name__ == "__main__":
raise SystemExit(main(sys.argv))

View File

@@ -1,5 +1,5 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <hip/hip_runtime.h>
#include <gtest/gtest.h>

View File

@@ -1,5 +1,5 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <gtest/gtest.h>
#include "ck_tile/core/arch/arch.hpp"

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c), Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,3 +1,6 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
set(GEMM_STREAMK_DATATYPE "fp8;fp16" CACHE STRING "List of datatypes for GEMM (semicolon-separated)")
set(GEMM_STREAMK_LAYOUT "rcr" CACHE STRING "List of layout for GEMM (semicolon-separated)")
set(GEMM_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)")

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <functional>

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,4 +1,7 @@
#!/usr/bin/env python
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
import os
import json

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,6 +1,6 @@
#!/usr/bin/env python
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
"""
Validation utilities for GEMM kernel generation.

View File

@@ -0,0 +1,312 @@
# Tile Distribution: Mapping Threads to Data
## Overview
**Tile Distribution** describes how each thread in a thread block maps to elements of a block tile. It defines the hierarchical pattern of data distribution across threads, warps, and thread blocks.
## The Problem
Given a block tile of size `MPerBlock × KPerBlock` (e.g., 256×32), we need to determine:
- Which threads load which elements.
- How the threads are organized into warps.
- The number of times each warp repeats its pattern.
- The number of elements each thread can load in a single vector instruction.
---
## Bottom-Up Construction Approach
### Step 1: Determine K Dimension Layout
**Start with the innermost dimension (K) for memory coalescing:**
```cpp
constexpr index_t K1 = 16 / sizeof(ADataType); // Elements per thread (vector load)
constexpr index_t K0 = kKPerBlock / K1; // Threads needed in K dimension
```
**Example (with fp16):**
- `K1 = 16 / 2 = 8` → Each thread loads 8 fp16 elements in a single vector instruction
- `kKPerBlock = 32`
- `K0 = 32 / 8 = 4` → We need 4 threads along K to cover the entire K dimension
**Visual:**
```
K dimension (32 elements):
Thread 0: [0-7] Thread 1: [8-15] Thread 2: [16-23] Thread 3: [24-31]
K1=8 K1=8 K1=8 K1=8
├──────────────────────────────────────────────────────────────┤
K0=4 threads
```
---
### Step 2: Determine M Dimension Layout
**Now partition the M dimension hierarchically:**
#### Level 1: Threads per Warp in M (M2)
```cpp
constexpr index_t M2 = get_warp_size() / K0;
```
- Warp size = 64 threads
- K dimension already uses `K0 = 4` threads per row
- `M2 = 64 / 4 = 16` → Each warp can have 16 threads in M dimension
**Visual (Single Warp):**
```
K dimension (4 threads)
┌─────┬─────┬─────┬─────┐
0 │ T0 │ T1 │ T2 │ T3 │
1 │ T4 │ T5 │ T6 │ T7 │
2 │ T8 │ T9 │ T10 │ T11 │
M 3 │ T12 │ T13 │ T14 │ T15 │ ← 16 rows
...│ ... │ ... │ ... │ ... │ (M2=16)
15 │ T60 │ T61 │ T62 │ T63 │
└─────┴─────┴─────┴─────┘
One Warp = 64 threads
```
#### Level 2: Warps per Block (M1)
```cpp
constexpr index_t M1 = kBlockSize / get_warp_size();
```
- `kBlockSize = 256` threads per block
- `M1 = 256 / 64 = 4` → We have 4 warps per block
**Visual (4 Warps):**
```
Warp 0 (rows 0-15)
Warp 1 (rows 16-31)
Warp 2 (rows 32-47)
Warp 3 (rows 48-63)
M1 = 4 warps cover 64 rows total
```
#### Level 3: Repetitions (M0)
```cpp
constexpr index_t M0 = kMPerBlock / (M2 * M1);
```
- `kMPerBlock = 256` rows to cover
- `M2 * M1 = 16 * 4 = 64` rows covered by all warps
- `M0 = 256 / 64 = 4` → Each warp must repeat its pattern 4 times
**Visual (Complete Block):**
```
┌──────────────┐
│ Iteration 0 │ ← Warp 0: rows 0-15, Warp 1: rows 16-31, ...
│ (rows 0-63) │
├──────────────┤
│ Iteration 1 │ ← Warp 0: rows 64-79, Warp 1: rows 80-95, ...
│ (rows 64-127)│
├──────────────┤
│ Iteration 2 │ ← Warp 0: rows 128-143, Warp 1: rows 144-159, ...
│(rows 128-191)│
├──────────────┤
│ Iteration 3 │ ← Warp 0: rows 192-207, Warp 1: rows 208-223, ...
│(rows 192-255)│
└──────────────┘
M0 = 4 iterations
```
---
## The Tile Distribution Encoding
Now we can construct the distribution:
```cpp
tile_distribution_encoding<
sequence<1>, // [1] Replication
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>, // [2] Hierarchy
tuple<sequence<1>, sequence<1, 2>>, // [3] Parallelism:
tuple<sequence<1>, sequence<2, 0>>, // [3] Parallelism
sequence<1, 2>, // [4] Yield
sequence<0, 1> // [4] Yield
>
```
### [1] Replication: `sequence<1>`
Defines how many times warp patterns are replicated:
- `1` = Each warp has a unique pattern (no replication)
- `2` = Warp 0 and Warp 1 do the same thing, Warp 2 and Warp 3 do the same thing
- `4` = All warps do the same thing
In our case: `1` means no replication (each warp is independent).
---
### [2] Hierarchy: The Multi-Level Structure
```cpp
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>
M dimension K dimension
```
**Concrete values:**
- M hierarchy: `sequence<4, 4, 16>` = (4 repetitions, 4 warps, 16 threads/warp)
- K hierarchy: `sequence<4, 8>` = (4 threads, 8 elements/thread)
---
### [3] Parallelism: Addressing the Hierarchy
**The key insight:** Read the tuples **vertically** to understand indexing!
```cpp
tuple<sequence<1>, sequence<1, 2>>
tuple<sequence<1>, sequence<2, 0>>
```
#### Reading Pattern
**Column 1 (Dimension 0 = M):**
```
sequence<1> → Address hierarchy index 1,1 → M1 (warps/block in M dimension)
sequence<1>
```
**Column 2 (Dimension 1 = K):**
```
sequence<1, 2>
sequence<2, 0>
```
[1,2] M2=threads/warp in M dimension
[2,0] K0=threads/warp in K dimension
---
### [4] Yield Sequences: Output Ordering
```cpp
sequence<1, 2>
sequence<0, 1>
[1,0] means M0=repetitions/warp in M dimension
[2,1] means K1=elements/thread in K dimension
```
---
## Complete Example: Thread 25 in Warp 0
Let's trace where **Thread 25** in **Warp 0** reads data:
### Thread Coordinates
- Thread ID in warp: 25
- Warp ID in block: 0
### Decompose Thread 25
```
Thread 25 in a 2D layout (M2=16, K0=4):
Row index: 25 / 4 = 6
Col index: 25 % 4 = 1
```
### M Position (Row)
```
M0 iteration: 0 (first iteration)
M1 warp: 0 (warp 0)
M2 thread: 6 (6th row in warp)
→ M position = 0*64 + 0*16 + 6 = 6
```
### K Position (Column)
```
K0 thread: 1 (column group 1)
K1 elements: 8 (will load 8 consecutive elements)
→ K position = 1*8 + [0-7] = elements 8-15
```
**Result:** Thread 25 in Warp 0 loads **row 6, columns 8-15** (8 elements).
---
## Why This Matters
### 1. **Memory Coalescing**
- Consecutive threads access consecutive memory → efficient global memory access
- K dimension uses K1=8 for vectorized loads
### 2. **Warp Efficiency**
- All 64 threads in a warp are utilized
- Natural 2D layout: 16 threads (M) × 4 threads (K) = 64 threads
### 3. **Scalability**
- M0 repetitions allow handling larger tiles
- Same pattern scales to different sizes
### 4. **Register Allocation**
- Each thread knows exactly how many elements it will hold
- Compiler can allocate registers optimally
---
## Summary Table
| Parameter | Value | Meaning |
|-----------|-------|---------|
| **K1** | 8 | Elements per thread (vector width) |
| **K0** | 4 | Threads along K per row |
| **M2** | 16 | Threads along M per warp |
| **M1** | 4 | Warps per block |
| **M0** | 4 | Repetitions of warp pattern |
| **Total Threads** | 256 | M0×M1×M2 = 4×4×16 (actually M1×64) |
| **Total Elements** | 8192 | 256×32 (MPerBlock × KPerBlock) |
| **Elements/Thread** | 32 | M0×K1 = 4×8 |
---
## Visualization: Complete Thread Block
```
Block Tile: 256×32
K dimension (32 elements)
├─────────────────────┤
0 ┌──────────────────────┐ ┐
16 │ Warp 0 │ │
32 │ Warp 1 │ │ Iteration 0
48 │ Warp 2 │ │ (M0=0)
64 │ Warp 3 │ ┘
80 ├──────────────────────┤ ┐
96 │ Warp 0 │ │
112 │ Warp 1 │ │ Iteration 1
128 │ Warp 2 │ │ (M0=1)
144 │ Warp 3 │ ┘
160 ├──────────────────────┤ ┐
176 │ Warp 0 │ │
192 │ Warp 1 │ │ Iteration 2
208 │ Warp 2 │ │ (M0=2)
224 │ Warp 3 │ ┘
240 ├──────────────────────┤ ┐
256 │ Warp 0 │ │
│ Warp 1 │ │ Iteration 3
│ Warp 2 │ │ (M0=3)
│ Warp 3 │ ┘
└──────────────────────┘
Each warp processes 16 rows × 32 cols = 512 elements
Each iteration processes 64 rows × 32 cols = 2048 elements
Total: 4 iterations × 2048 = 8192 elements ✓
```
---
## Key Takeaways
1. **Bottom-up construction**: Start from vector width (K1), build up through thread/warp/block hierarchy
2. **Vertical reading**: The repeat and elements tuples are read column-wise to address hierarchy levels
3. **Replication controls redundancy**: How many warps share the same pattern
4. **Hierarchy encodes structure**: The multi-level sequence defines the complete mapping
This design enables CK to achieve maximum GPU performance through optimal thread-to-data mapping!

View File

@@ -98,12 +98,12 @@ struct PracticeGemmBlockPolicy
constexpr index_t M0 = kMPerBlock / (M2 * M1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
tile_distribution_encoding<sequence<1>, // replication
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>, // hierarchy
tuple<sequence<1>, sequence<1, 2>>, // parallelism
tuple<sequence<1>, sequence<2, 0>>, // paralleism
sequence<1, 2>, // yield
sequence<0, 1>>{}); // yield
}
template <typename Problem>

View File

@@ -24,7 +24,7 @@ struct PracticeGemmHostPipeline
template <typename ADRAMTensorView, typename BDRAMTensorView, typename CDRAMTensorView>
CK_TILE_DEVICE void operator()(const ADRAMTensorView& a_dram,
const BDRAMTensorView& b_dram,
CDRAMTensorView& c_dram_ref) const
CDRAMTensorView& c_dram) const
{
// Size of the entire problem

View File

@@ -6,7 +6,7 @@
#include "practice_gemm.hpp"
#include "reference_gemm.hpp"
int main()
int main(int argc, char* argv[])
{
// TODO: GemmTypeConfig
using ADataType = ck_tile::half_t;
@@ -14,11 +14,22 @@ int main()
using CDataType = float;
using AccDataType = float;
// ArgParser
ck_tile::index_t M = 512;
ck_tile::index_t N = 256;
ck_tile::index_t K = 64;
ck_tile::index_t verification = 1;
// Setup simple argument parser for M, N, K
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "512", "m dimension")
.insert("n", "256", "n dimension")
.insert("k", "64", "k dimension")
.insert("v", "1", "verification: 0=off, 1=on");
auto result = arg_parser.parse(argc, argv);
if(!result)
return -1;
// Get problem dimensions from command line
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
ck_tile::index_t verification = arg_parser.get_int("v");
ck_tile::index_t stride_a = K;
ck_tile::index_t stride_b = K;
@@ -61,9 +72,6 @@ int main()
ck_tile::DeviceMem c_device(c_host);
// TODO: BlockTileConfig
// constexpr ck_tile::index_t warpSize = 64;
constexpr ck_tile::index_t kBlockSize = 256;
using BlockTile = ck_tile::sequence<256, 128, 32>;
using WaveTile = ck_tile::sequence<16, 16, 16>;
@@ -77,11 +85,13 @@ int main()
ck_tile::index_t kGridSize = ck_tile::integer_divide_ceil(M, PracticeGemmShape::BlockTile_M) *
ck_tile::integer_divide_ceil(N, PracticeGemmShape::BlockTile_N);
std::cout << "kGridSize: " << kGridSize << std::endl;
std::cout << "Total number of thread blocks: " << kGridSize << std::endl;
constexpr ck_tile::index_t kBlockPerCU = 1; // 1 block per CU
std::cout << "kBlockSize: " << kBlockSize << std::endl;
std::cout << "kBlockPerCU: " << kBlockPerCU << std::endl;
// Block size is now derived from the shape configuration
constexpr ck_tile::index_t kBlockSize = PracticeGemmShape::kBlockSize;
std::cout << "Number of threads per block: " << kBlockSize << std::endl;
std::cout << "Number of blocks per compute unit: " << kBlockPerCU << std::endl;
using gemm_kernel =
ck_tile::PracticeGemmKernel<PracticeGemmHostProblem, PracticeGemmHostPolicy>;

View File

@@ -24,6 +24,10 @@ struct PracticeGemmShape
static constexpr index_t WaveTile_N = WaveTile::at(number<1>{});
static constexpr index_t WaveTile_K = WaveTile::at(number<2>{});
// Thread block configuration
static constexpr index_t kWarpSize = 64; // AMD GPU warp size (also called wavefront)
static constexpr index_t kBlockSize = 256; // Total threads per block (4 warps × 64 threads)
CK_TILE_HOST static std::string GetName()
{
// clang-format off
@@ -40,7 +44,8 @@ struct PracticeGemmKernel
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
static constexpr index_t kBlockSize = 256;
// Derive block size from the shape configuration
static constexpr index_t kBlockSize = Problem::Shape::kBlockSize;
CK_TILE_DEVICE void operator()(const typename Problem::ADataType* p_a,
const typename Problem::BDataType* p_b,