mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Merge commit '6d25525adc2344d5b62b12b9ffddee50f89cd0ff' into develop
This commit is contained in:
3
.github/scripts/therock_configure_ci.py
vendored
3
.github/scripts/therock_configure_ci.py
vendored
@@ -1,3 +1,6 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import fnmatch
|
||||
import json
|
||||
import os
|
||||
|
||||
@@ -20,12 +20,12 @@ repos:
|
||||
)$
|
||||
- repo: local
|
||||
hooks:
|
||||
# - id: copyright-year-checker
|
||||
# name: copyright-year-checker
|
||||
# entry: script/check_copyright_year.sh
|
||||
# verbose: false
|
||||
# language: script
|
||||
# types: [c++]
|
||||
- id: copyright-header-checker
|
||||
name: Check copyright headers
|
||||
entry: script/check_copyright_year.sh
|
||||
verbose: false
|
||||
language: script
|
||||
types_or: [c++, python, shell, cmake]
|
||||
- id: remove-exec-bit
|
||||
name: Remove executable bit from non-executable files
|
||||
entry: script/remove_exec_bit.sh
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/algorithm/cluster_descriptor.hpp"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/host/arg_parser.hpp"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/elementwise/binary_elementwise_operation.hpp"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/permute/kernel/generic_permute_kernel.hpp"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/pooling/kernel/pool_kernel.hpp"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/softmax/block/block_softmax_2d.hpp"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/topk/block/block_topk_stream_2d.hpp"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp"
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from datetime import datetime
|
||||
import pathlib
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
@@ -13,8 +12,8 @@ OPS = "ops"
|
||||
OPS_COMMON = "common" # common header will be duplicated into ops/* other module
|
||||
|
||||
IGNORED_DIRS = ["utility", "ref"]
|
||||
HEADER_COMMON = f"""// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-{datetime.now().year}, Advanced Micro Devices, Inc. All rights reserved.\n
|
||||
HEADER_COMMON = """// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@@ -2,18 +2,70 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# This script checks if files have the correct copyright header template.
|
||||
# It supports .hpp, .cpp, .inc, .py, .sh, and .cmake files.
|
||||
#
|
||||
# Usage: ./check_copyright_year.sh <file1> <file2> ...
|
||||
|
||||
current_year=$(date +%Y)
|
||||
exit_code=0
|
||||
|
||||
for file in $@; do
|
||||
if grep -q "Copyright (c)" $file
|
||||
then
|
||||
if ! grep -q "Copyright (c).*$current_year" $file
|
||||
then
|
||||
echo "ERROR: File $file has a copyright notice without the current year ($current_year)."
|
||||
exit_code=1
|
||||
fi
|
||||
# Expected copyright header lines (without comment characters)
|
||||
COPYRIGHT_LINE="Copyright (c) Advanced Micro Devices, Inc., or its affiliates."
|
||||
SPDX_LINE="SPDX-License-Identifier: MIT"
|
||||
|
||||
check_file() {
|
||||
local file=$1
|
||||
local basename="${file##*/}"
|
||||
local ext="${file##*.}"
|
||||
local comment_char
|
||||
|
||||
# Determine comment character based on filename or extension
|
||||
if [[ "$basename" == "CMakeLists.txt" ]]; then
|
||||
comment_char="#"
|
||||
else
|
||||
case "$ext" in
|
||||
cpp|hpp|inc)
|
||||
comment_char="//"
|
||||
;;
|
||||
py|sh|cmake)
|
||||
comment_char="#"
|
||||
;;
|
||||
*)
|
||||
# Skip files with unsupported extensions
|
||||
return 0
|
||||
;;
|
||||
esac
|
||||
fi
|
||||
|
||||
# Build expected header patterns
|
||||
expected_copyright="$comment_char $COPYRIGHT_LINE"
|
||||
expected_spdx="$comment_char $SPDX_LINE"
|
||||
|
||||
# Check if file contains both required lines
|
||||
if ! grep -qF "$expected_copyright" "$file"; then
|
||||
echo "ERROR: File $file is missing the correct copyright header line."
|
||||
echo " Expected: $expected_copyright"
|
||||
return 1
|
||||
fi
|
||||
|
||||
if ! grep -qF "$expected_spdx" "$file"; then
|
||||
echo "ERROR: File $file is missing the correct SPDX license identifier line."
|
||||
echo " Expected: $expected_spdx"
|
||||
return 1
|
||||
fi
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
# Process each file provided as argument
|
||||
for file in "$@"; do
|
||||
# Skip if file doesn't exist or is a directory
|
||||
if [[ ! -f "$file" ]]; then
|
||||
continue
|
||||
fi
|
||||
|
||||
if ! check_file "$file"; then
|
||||
exit_code=1
|
||||
fi
|
||||
done
|
||||
|
||||
|
||||
295
script/update_amd_copyright_headers.py
Normal file
295
script/update_amd_copyright_headers.py
Normal file
@@ -0,0 +1,295 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Purpose:
|
||||
Normalize and enforce AMD two-line copyright + SPDX headers across files.
|
||||
|
||||
Target files:
|
||||
- C/C++-style: .cpp, .hpp, .inc -> uses "//" comment style
|
||||
- Hash-style: .py, .cmake, .sh, and CMakeLists.txt -> uses "#" style
|
||||
|
||||
Header formats inserted (top of file, followed by exactly one blank line):
|
||||
C/C++ :
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
<blank>
|
||||
Hash :
|
||||
<blank>
|
||||
|
||||
Shebang special case (hash-style only):
|
||||
- If line 1 starts with "#!", keep shebang, then a blank line, then the
|
||||
two hash-style header lines, then a blank line.
|
||||
|
||||
Removal rules:
|
||||
- Remove any comment lines (anywhere in file) containing the keywords
|
||||
"copyright" or "spdx" (case-insensitive). Blank lines are preserved.
|
||||
- Remove long-form MIT license block comment when:
|
||||
a) The file starts with the block (absolute top), OR
|
||||
b) The block appears immediately after the AMD header position
|
||||
(i.e., when remainder at insertion point begins with "/*" and
|
||||
the first content line is "* The MIT License (MIT)").
|
||||
|
||||
Blank-line normalization:
|
||||
- Enforce exactly ONE blank line immediately after the AMD header.
|
||||
(Drop only the leading blank lines at the insertion point before
|
||||
re-inserting the header.)
|
||||
- Do not change blank lines between other non-copyright comments.
|
||||
|
||||
Preservation:
|
||||
- Preserve original newline style: CRLF (\r\n) vs LF (\n).
|
||||
- Preserve UTF-8 BOM if present.
|
||||
- Do not modify non-comment code lines.
|
||||
|
||||
Idempotency:
|
||||
- Running this script multiple times does not further modify files.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
AMD_CPP_HEADER_TEXT = [
|
||||
"// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.",
|
||||
"// SPDX-License-Identifier: MIT",
|
||||
]
|
||||
AMD_HASH_HEADER_TEXT = [
|
||||
"# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.",
|
||||
"# SPDX-License-Identifier: MIT",
|
||||
]
|
||||
|
||||
CPP_EXTS = {".cpp", ".hpp", ".inc"}
|
||||
HASH_EXTS = {".py", ".cmake", ".sh"}
|
||||
|
||||
# --- Encoding helpers -------------------------------------------------------
|
||||
|
||||
|
||||
def has_bom(raw: bytes) -> bool:
|
||||
return raw.startswith(b"\xef\xbb\xbf")
|
||||
|
||||
|
||||
def decode_text(raw: bytes) -> str:
|
||||
return raw.decode("utf-8-sig", errors="replace")
|
||||
|
||||
|
||||
def encode_text(text: str, bom: bool) -> bytes:
|
||||
data = text.encode("utf-8")
|
||||
return (b"\xef\xbb\xbf" + data) if bom else data
|
||||
|
||||
|
||||
# --- Newline detection ------------------------------------------------------
|
||||
|
||||
|
||||
def detect_newline_sequence(raw: bytes) -> str:
|
||||
if b"\r\n" in raw:
|
||||
return "\r\n"
|
||||
elif b"\n" in raw:
|
||||
return "\n"
|
||||
else:
|
||||
return "\n"
|
||||
|
||||
|
||||
# --- Utilities --------------------------------------------------------------
|
||||
|
||||
|
||||
def is_comment_line(line: str, style: str) -> bool:
|
||||
stripped = line.lstrip()
|
||||
if style == "cpp":
|
||||
return (
|
||||
stripped.startswith("//")
|
||||
or stripped.startswith("/*")
|
||||
or stripped.startswith("*")
|
||||
or stripped.startswith("*/")
|
||||
)
|
||||
elif style == "hash":
|
||||
return stripped.startswith("#")
|
||||
return False
|
||||
|
||||
|
||||
def has_keywords(line: str) -> bool:
|
||||
lower_line = line.lower()
|
||||
return ("copyright" in lower_line) or ("spdx" in lower_line)
|
||||
|
||||
|
||||
# --- MIT License banner detection ------------------------------
|
||||
MIT_C_FIRST_LINE_RE = re.compile(r"^\s*\*\s*The MIT License \(MIT\)")
|
||||
MIT_HASH_FIRST_LINE_RE = re.compile(r"^\s*#\s*The MIT License \(MIT\)")
|
||||
|
||||
|
||||
def remove_top_mit_block(lines: List[str]) -> Tuple[List[str], bool]:
|
||||
"""
|
||||
Unified MIT banner removal at the top of 'lines'.
|
||||
Supports:
|
||||
- C-style block starting with '/*' and ending with '*/'; removes only if
|
||||
a line within the block matches MIT_C_FIRST_LINE_RE.
|
||||
- Hash-style banner: contiguous top run of lines starting with '#';
|
||||
removes only if any line in that run matches MIT_HASH_FIRST_LINE_RE.
|
||||
Returns (new_lines, removed_flag). Preserves EOLs.
|
||||
"""
|
||||
if not lines:
|
||||
return lines, False
|
||||
|
||||
first = lines[0].lstrip()
|
||||
|
||||
# C-style block
|
||||
if first.startswith("/*"):
|
||||
end_idx, saw_mit = None, False
|
||||
for i, line in enumerate(lines[1:], 1):
|
||||
if not saw_mit and MIT_C_FIRST_LINE_RE.match(line):
|
||||
saw_mit = True
|
||||
s = line.lstrip()
|
||||
if s.startswith("*/") or s.rstrip().endswith("*/"):
|
||||
end_idx = i + 1
|
||||
break
|
||||
if end_idx is not None and saw_mit:
|
||||
return lines[end_idx:], True
|
||||
return lines, False
|
||||
|
||||
# Hash-style contiguous banner
|
||||
if first.startswith("#"):
|
||||
end_idx, saw_mit = 0, False
|
||||
for i, line in enumerate(lines):
|
||||
if line.lstrip().startswith("#"):
|
||||
if not saw_mit and MIT_HASH_FIRST_LINE_RE.match(line):
|
||||
saw_mit = True
|
||||
end_idx = i + 1
|
||||
else:
|
||||
break
|
||||
if saw_mit:
|
||||
return lines[end_idx:], True
|
||||
return lines, False
|
||||
|
||||
return lines, False
|
||||
|
||||
|
||||
# --- Removal + normalization helpers ---------------------------------------
|
||||
|
||||
|
||||
def remove_keyword_comment_lines_globally(lines: List[str], style: str) -> List[str]:
|
||||
"""Remove comment lines containing keywords anywhere in the file.
|
||||
**Do not** remove blank lines; preserve all other lines as-is."""
|
||||
out: List[str] = []
|
||||
for line in lines:
|
||||
if is_comment_line(line, style) and has_keywords(line):
|
||||
continue
|
||||
out.append(line)
|
||||
return out
|
||||
|
||||
|
||||
def drop_leading_blank_lines(lines: List[str]) -> List[str]:
|
||||
"""Drop only the leading blank lines at the start of the given list."""
|
||||
i = 0
|
||||
while i < len(lines) and lines[i].strip() == "":
|
||||
i += 1
|
||||
return lines[i:]
|
||||
|
||||
|
||||
# --- Header builder ---------------------------------------------------------
|
||||
|
||||
|
||||
def build_header_lines(style: str, nl: str) -> List[str]:
|
||||
base = AMD_CPP_HEADER_TEXT if style == "cpp" else AMD_HASH_HEADER_TEXT
|
||||
return [base[0] + nl, base[1] + nl, nl] # header + exactly one blank
|
||||
|
||||
|
||||
# --- Main transforms --------------------------------------------------------
|
||||
|
||||
|
||||
def process_cpp(text: str, nl: str) -> str:
|
||||
lines = text.splitlines(True)
|
||||
|
||||
# Remove MIT block if it is at the *absolute* top
|
||||
lines, _ = remove_top_mit_block(lines)
|
||||
|
||||
# Remove keyworded comment lines globally (blank lines preserved)
|
||||
lines = remove_keyword_comment_lines_globally(lines, style="cpp")
|
||||
|
||||
# Normalize insertion point and remove MIT block if it appears *after header*
|
||||
lines = drop_leading_blank_lines(lines)
|
||||
lines, _ = remove_top_mit_block(lines)
|
||||
|
||||
# Prepend AMD header (guarantee exactly one blank after)
|
||||
return "".join(build_header_lines("cpp", nl) + lines)
|
||||
|
||||
|
||||
def process_hash(text: str, nl: str) -> str:
|
||||
lines = text.splitlines(True)
|
||||
if not lines:
|
||||
return "".join(build_header_lines("hash", nl))
|
||||
|
||||
shebang = lines[0].startswith("#!")
|
||||
|
||||
if shebang:
|
||||
remainder = remove_keyword_comment_lines_globally(lines[1:], style="hash")
|
||||
remainder = drop_leading_blank_lines(remainder)
|
||||
remainder, _ = remove_top_mit_block(remainder) # remove MIT block after header
|
||||
new_top = [lines[0], nl] + build_header_lines("hash", nl)
|
||||
return "".join(new_top + remainder)
|
||||
else:
|
||||
remainder = remove_keyword_comment_lines_globally(lines, style="hash")
|
||||
remainder = drop_leading_blank_lines(remainder)
|
||||
remainder, _ = remove_top_mit_block(remainder) # remove MIT block after header
|
||||
return "".join(build_header_lines("hash", nl) + remainder)
|
||||
|
||||
|
||||
# --- File processing & CLI --------------------------------------------------
|
||||
|
||||
|
||||
def process_file(path: Path) -> bool:
|
||||
name = path.name
|
||||
suffix = path.suffix.lower()
|
||||
if suffix in CPP_EXTS:
|
||||
style = "cpp"
|
||||
elif suffix in HASH_EXTS or name == "CMakeLists.txt":
|
||||
style = "hash"
|
||||
else:
|
||||
return False
|
||||
|
||||
raw = path.read_bytes()
|
||||
bom = has_bom(raw)
|
||||
nl = detect_newline_sequence(raw)
|
||||
text = decode_text(raw)
|
||||
|
||||
updated = process_cpp(text, nl) if style == "cpp" else process_hash(text, nl)
|
||||
if updated != text:
|
||||
path.write_bytes(encode_text(updated, bom))
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def main(argv: List[str]) -> int:
|
||||
if len(argv) < 2:
|
||||
print(__doc__)
|
||||
return 2
|
||||
changed = 0
|
||||
skipped = 0
|
||||
errors: List[str] = []
|
||||
for arg in argv[1:]:
|
||||
p = Path(arg)
|
||||
try:
|
||||
if not p.exists():
|
||||
errors.append(f"Not found: {p}")
|
||||
continue
|
||||
if p.is_dir():
|
||||
errors.append(f"Is a directory (pass specific files): {p}")
|
||||
continue
|
||||
if process_file(p):
|
||||
changed += 1
|
||||
print(f"Updated: {p}")
|
||||
else:
|
||||
skipped += 1
|
||||
print(f"Skipped (no change needed or unsupported type): {p}")
|
||||
except Exception as e:
|
||||
errors.append(f"Error processing {p}: {e}")
|
||||
print(f"\nSummary: {changed} updated, {skipped} skipped, {len(errors)} errors")
|
||||
for msg in errors:
|
||||
print(f" - {msg}")
|
||||
return 0 if not errors else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main(sys.argv))
|
||||
@@ -1,5 +1,5 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c), Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
set(GEMM_STREAMK_DATATYPE "fp8;fp16" CACHE STRING "List of datatypes for GEMM (semicolon-separated)")
|
||||
set(GEMM_STREAMK_LAYOUT "rcr" CACHE STRING "List of layout for GEMM (semicolon-separated)")
|
||||
set(GEMM_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)")
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <functional>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
|
||||
import os
|
||||
import json
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
"""
|
||||
Validation utilities for GEMM kernel generation.
|
||||
|
||||
312
tutorial/ck_tile/01_naive_gemm/TILE_DISTRIBUTION.md
Normal file
312
tutorial/ck_tile/01_naive_gemm/TILE_DISTRIBUTION.md
Normal file
@@ -0,0 +1,312 @@
|
||||
# Tile Distribution: Mapping Threads to Data
|
||||
|
||||
## Overview
|
||||
|
||||
**Tile Distribution** describes how each thread in a thread block maps to elements of a block tile. It defines the hierarchical pattern of data distribution across threads, warps, and thread blocks.
|
||||
|
||||
## The Problem
|
||||
|
||||
Given a block tile of size `MPerBlock × KPerBlock` (e.g., 256×32), we need to determine:
|
||||
- Which threads load which elements.
|
||||
- How the threads are organized into warps.
|
||||
- The number of times each warp repeats its pattern.
|
||||
- The number of elements each thread can load in a single vector instruction.
|
||||
|
||||
---
|
||||
|
||||
## Bottom-Up Construction Approach
|
||||
|
||||
### Step 1: Determine K Dimension Layout
|
||||
|
||||
**Start with the innermost dimension (K) for memory coalescing:**
|
||||
|
||||
```cpp
|
||||
constexpr index_t K1 = 16 / sizeof(ADataType); // Elements per thread (vector load)
|
||||
constexpr index_t K0 = kKPerBlock / K1; // Threads needed in K dimension
|
||||
```
|
||||
|
||||
**Example (with fp16):**
|
||||
- `K1 = 16 / 2 = 8` → Each thread loads 8 fp16 elements in a single vector instruction
|
||||
- `kKPerBlock = 32`
|
||||
- `K0 = 32 / 8 = 4` → We need 4 threads along K to cover the entire K dimension
|
||||
|
||||
**Visual:**
|
||||
```
|
||||
K dimension (32 elements):
|
||||
Thread 0: [0-7] Thread 1: [8-15] Thread 2: [16-23] Thread 3: [24-31]
|
||||
K1=8 K1=8 K1=8 K1=8
|
||||
├──────────────────────────────────────────────────────────────┤
|
||||
K0=4 threads
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Step 2: Determine M Dimension Layout
|
||||
|
||||
**Now partition the M dimension hierarchically:**
|
||||
|
||||
#### Level 1: Threads per Warp in M (M2)
|
||||
|
||||
```cpp
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
```
|
||||
|
||||
- Warp size = 64 threads
|
||||
- K dimension already uses `K0 = 4` threads per row
|
||||
- `M2 = 64 / 4 = 16` → Each warp can have 16 threads in M dimension
|
||||
|
||||
**Visual (Single Warp):**
|
||||
```
|
||||
K dimension (4 threads)
|
||||
┌─────┬─────┬─────┬─────┐
|
||||
0 │ T0 │ T1 │ T2 │ T3 │
|
||||
1 │ T4 │ T5 │ T6 │ T7 │
|
||||
2 │ T8 │ T9 │ T10 │ T11 │
|
||||
M 3 │ T12 │ T13 │ T14 │ T15 │ ← 16 rows
|
||||
...│ ... │ ... │ ... │ ... │ (M2=16)
|
||||
15 │ T60 │ T61 │ T62 │ T63 │
|
||||
└─────┴─────┴─────┴─────┘
|
||||
One Warp = 64 threads
|
||||
```
|
||||
|
||||
#### Level 2: Warps per Block (M1)
|
||||
|
||||
```cpp
|
||||
constexpr index_t M1 = kBlockSize / get_warp_size();
|
||||
```
|
||||
|
||||
- `kBlockSize = 256` threads per block
|
||||
- `M1 = 256 / 64 = 4` → We have 4 warps per block
|
||||
|
||||
**Visual (4 Warps):**
|
||||
```
|
||||
Warp 0 (rows 0-15)
|
||||
Warp 1 (rows 16-31)
|
||||
Warp 2 (rows 32-47)
|
||||
Warp 3 (rows 48-63)
|
||||
↑
|
||||
M1 = 4 warps cover 64 rows total
|
||||
```
|
||||
|
||||
#### Level 3: Repetitions (M0)
|
||||
|
||||
```cpp
|
||||
constexpr index_t M0 = kMPerBlock / (M2 * M1);
|
||||
```
|
||||
|
||||
- `kMPerBlock = 256` rows to cover
|
||||
- `M2 * M1 = 16 * 4 = 64` rows covered by all warps
|
||||
- `M0 = 256 / 64 = 4` → Each warp must repeat its pattern 4 times
|
||||
|
||||
**Visual (Complete Block):**
|
||||
```
|
||||
┌──────────────┐
|
||||
│ Iteration 0 │ ← Warp 0: rows 0-15, Warp 1: rows 16-31, ...
|
||||
│ (rows 0-63) │
|
||||
├──────────────┤
|
||||
│ Iteration 1 │ ← Warp 0: rows 64-79, Warp 1: rows 80-95, ...
|
||||
│ (rows 64-127)│
|
||||
├──────────────┤
|
||||
│ Iteration 2 │ ← Warp 0: rows 128-143, Warp 1: rows 144-159, ...
|
||||
│(rows 128-191)│
|
||||
├──────────────┤
|
||||
│ Iteration 3 │ ← Warp 0: rows 192-207, Warp 1: rows 208-223, ...
|
||||
│(rows 192-255)│
|
||||
└──────────────┘
|
||||
M0 = 4 iterations
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## The Tile Distribution Encoding
|
||||
|
||||
Now we can construct the distribution:
|
||||
|
||||
```cpp
|
||||
tile_distribution_encoding<
|
||||
sequence<1>, // [1] Replication
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>, // [2] Hierarchy
|
||||
tuple<sequence<1>, sequence<1, 2>>, // [3] Parallelism:
|
||||
tuple<sequence<1>, sequence<2, 0>>, // [3] Parallelism
|
||||
sequence<1, 2>, // [4] Yield
|
||||
sequence<0, 1> // [4] Yield
|
||||
>
|
||||
```
|
||||
|
||||
### [1] Replication: `sequence<1>`
|
||||
|
||||
Defines how many times warp patterns are replicated:
|
||||
- `1` = Each warp has a unique pattern (no replication)
|
||||
- `2` = Warp 0 and Warp 1 do the same thing, Warp 2 and Warp 3 do the same thing
|
||||
- `4` = All warps do the same thing
|
||||
|
||||
In our case: `1` means no replication (each warp is independent).
|
||||
|
||||
---
|
||||
|
||||
### [2] Hierarchy: The Multi-Level Structure
|
||||
|
||||
```cpp
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>
|
||||
└───────┬──────────┘ └──────┬────────┘
|
||||
M dimension K dimension
|
||||
```
|
||||
|
||||
**Concrete values:**
|
||||
- M hierarchy: `sequence<4, 4, 16>` = (4 repetitions, 4 warps, 16 threads/warp)
|
||||
- K hierarchy: `sequence<4, 8>` = (4 threads, 8 elements/thread)
|
||||
|
||||
---
|
||||
|
||||
### [3] Parallelism: Addressing the Hierarchy
|
||||
|
||||
**The key insight:** Read the tuples **vertically** to understand indexing!
|
||||
|
||||
```cpp
|
||||
tuple<sequence<1>, sequence<1, 2>>
|
||||
tuple<sequence<1>, sequence<2, 0>>
|
||||
```
|
||||
|
||||
#### Reading Pattern
|
||||
|
||||
**Column 1 (Dimension 0 = M):**
|
||||
```
|
||||
sequence<1> → Address hierarchy index 1,1 → M1 (warps/block in M dimension)
|
||||
sequence<1>
|
||||
```
|
||||
|
||||
**Column 2 (Dimension 1 = K):**
|
||||
```
|
||||
sequence<1, 2>
|
||||
sequence<2, 0>
|
||||
```
|
||||
[1,2] M2=threads/warp in M dimension
|
||||
[2,0] K0=threads/warp in K dimension
|
||||
|
||||
---
|
||||
|
||||
### [4] Yield Sequences: Output Ordering
|
||||
|
||||
```cpp
|
||||
sequence<1, 2>
|
||||
sequence<0, 1>
|
||||
|
||||
[1,0] means M0=repetitions/warp in M dimension
|
||||
[2,1] means K1=elements/thread in K dimension
|
||||
```
|
||||
---
|
||||
|
||||
## Complete Example: Thread 25 in Warp 0
|
||||
|
||||
Let's trace where **Thread 25** in **Warp 0** reads data:
|
||||
|
||||
### Thread Coordinates
|
||||
- Thread ID in warp: 25
|
||||
- Warp ID in block: 0
|
||||
|
||||
### Decompose Thread 25
|
||||
```
|
||||
Thread 25 in a 2D layout (M2=16, K0=4):
|
||||
Row index: 25 / 4 = 6
|
||||
Col index: 25 % 4 = 1
|
||||
```
|
||||
|
||||
### M Position (Row)
|
||||
```
|
||||
M0 iteration: 0 (first iteration)
|
||||
M1 warp: 0 (warp 0)
|
||||
M2 thread: 6 (6th row in warp)
|
||||
→ M position = 0*64 + 0*16 + 6 = 6
|
||||
```
|
||||
|
||||
### K Position (Column)
|
||||
```
|
||||
K0 thread: 1 (column group 1)
|
||||
K1 elements: 8 (will load 8 consecutive elements)
|
||||
→ K position = 1*8 + [0-7] = elements 8-15
|
||||
```
|
||||
|
||||
**Result:** Thread 25 in Warp 0 loads **row 6, columns 8-15** (8 elements).
|
||||
|
||||
---
|
||||
|
||||
## Why This Matters
|
||||
|
||||
### 1. **Memory Coalescing**
|
||||
- Consecutive threads access consecutive memory → efficient global memory access
|
||||
- K dimension uses K1=8 for vectorized loads
|
||||
|
||||
### 2. **Warp Efficiency**
|
||||
- All 64 threads in a warp are utilized
|
||||
- Natural 2D layout: 16 threads (M) × 4 threads (K) = 64 threads
|
||||
|
||||
### 3. **Scalability**
|
||||
- M0 repetitions allow handling larger tiles
|
||||
- Same pattern scales to different sizes
|
||||
|
||||
### 4. **Register Allocation**
|
||||
- Each thread knows exactly how many elements it will hold
|
||||
- Compiler can allocate registers optimally
|
||||
|
||||
---
|
||||
|
||||
## Summary Table
|
||||
|
||||
| Parameter | Value | Meaning |
|
||||
|-----------|-------|---------|
|
||||
| **K1** | 8 | Elements per thread (vector width) |
|
||||
| **K0** | 4 | Threads along K per row |
|
||||
| **M2** | 16 | Threads along M per warp |
|
||||
| **M1** | 4 | Warps per block |
|
||||
| **M0** | 4 | Repetitions of warp pattern |
|
||||
| **Total Threads** | 256 | M0×M1×M2 = 4×4×16 (actually M1×64) |
|
||||
| **Total Elements** | 8192 | 256×32 (MPerBlock × KPerBlock) |
|
||||
| **Elements/Thread** | 32 | M0×K1 = 4×8 |
|
||||
|
||||
---
|
||||
|
||||
## Visualization: Complete Thread Block
|
||||
|
||||
```
|
||||
Block Tile: 256×32
|
||||
|
||||
K dimension (32 elements)
|
||||
├─────────────────────┤
|
||||
0 ┌──────────────────────┐ ┐
|
||||
16 │ Warp 0 │ │
|
||||
32 │ Warp 1 │ │ Iteration 0
|
||||
48 │ Warp 2 │ │ (M0=0)
|
||||
64 │ Warp 3 │ ┘
|
||||
80 ├──────────────────────┤ ┐
|
||||
96 │ Warp 0 │ │
|
||||
112 │ Warp 1 │ │ Iteration 1
|
||||
128 │ Warp 2 │ │ (M0=1)
|
||||
144 │ Warp 3 │ ┘
|
||||
160 ├──────────────────────┤ ┐
|
||||
176 │ Warp 0 │ │
|
||||
192 │ Warp 1 │ │ Iteration 2
|
||||
208 │ Warp 2 │ │ (M0=2)
|
||||
224 │ Warp 3 │ ┘
|
||||
240 ├──────────────────────┤ ┐
|
||||
256 │ Warp 0 │ │
|
||||
│ Warp 1 │ │ Iteration 3
|
||||
│ Warp 2 │ │ (M0=3)
|
||||
│ Warp 3 │ ┘
|
||||
└──────────────────────┘
|
||||
|
||||
Each warp processes 16 rows × 32 cols = 512 elements
|
||||
Each iteration processes 64 rows × 32 cols = 2048 elements
|
||||
Total: 4 iterations × 2048 = 8192 elements ✓
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Key Takeaways
|
||||
|
||||
1. **Bottom-up construction**: Start from vector width (K1), build up through thread/warp/block hierarchy
|
||||
2. **Vertical reading**: The repeat and elements tuples are read column-wise to address hierarchy levels
|
||||
3. **Replication controls redundancy**: How many warps share the same pattern
|
||||
4. **Hierarchy encodes structure**: The multi-level sequence defines the complete mapping
|
||||
|
||||
This design enables CK to achieve maximum GPU performance through optimal thread-to-data mapping!
|
||||
|
||||
@@ -98,12 +98,12 @@ struct PracticeGemmBlockPolicy
|
||||
constexpr index_t M0 = kMPerBlock / (M2 * M1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
tile_distribution_encoding<sequence<1>, // replication
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>, // hierarchy
|
||||
tuple<sequence<1>, sequence<1, 2>>, // parallelism
|
||||
tuple<sequence<1>, sequence<2, 0>>, // paralleism
|
||||
sequence<1, 2>, // yield
|
||||
sequence<0, 1>>{}); // yield
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
|
||||
@@ -24,7 +24,7 @@ struct PracticeGemmHostPipeline
|
||||
template <typename ADRAMTensorView, typename BDRAMTensorView, typename CDRAMTensorView>
|
||||
CK_TILE_DEVICE void operator()(const ADRAMTensorView& a_dram,
|
||||
const BDRAMTensorView& b_dram,
|
||||
CDRAMTensorView& c_dram_ref) const
|
||||
CDRAMTensorView& c_dram) const
|
||||
{
|
||||
|
||||
// Size of the entire problem
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
#include "practice_gemm.hpp"
|
||||
#include "reference_gemm.hpp"
|
||||
|
||||
int main()
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
// TODO: GemmTypeConfig
|
||||
using ADataType = ck_tile::half_t;
|
||||
@@ -14,11 +14,22 @@ int main()
|
||||
using CDataType = float;
|
||||
using AccDataType = float;
|
||||
|
||||
// ArgParser
|
||||
ck_tile::index_t M = 512;
|
||||
ck_tile::index_t N = 256;
|
||||
ck_tile::index_t K = 64;
|
||||
ck_tile::index_t verification = 1;
|
||||
// Setup simple argument parser for M, N, K
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "512", "m dimension")
|
||||
.insert("n", "256", "n dimension")
|
||||
.insert("k", "64", "k dimension")
|
||||
.insert("v", "1", "verification: 0=off, 1=on");
|
||||
|
||||
auto result = arg_parser.parse(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
// Get problem dimensions from command line
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t K = arg_parser.get_int("k");
|
||||
ck_tile::index_t verification = arg_parser.get_int("v");
|
||||
|
||||
ck_tile::index_t stride_a = K;
|
||||
ck_tile::index_t stride_b = K;
|
||||
@@ -61,9 +72,6 @@ int main()
|
||||
ck_tile::DeviceMem c_device(c_host);
|
||||
|
||||
// TODO: BlockTileConfig
|
||||
// constexpr ck_tile::index_t warpSize = 64;
|
||||
constexpr ck_tile::index_t kBlockSize = 256;
|
||||
|
||||
using BlockTile = ck_tile::sequence<256, 128, 32>;
|
||||
using WaveTile = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
@@ -77,11 +85,13 @@ int main()
|
||||
ck_tile::index_t kGridSize = ck_tile::integer_divide_ceil(M, PracticeGemmShape::BlockTile_M) *
|
||||
ck_tile::integer_divide_ceil(N, PracticeGemmShape::BlockTile_N);
|
||||
|
||||
std::cout << "kGridSize: " << kGridSize << std::endl;
|
||||
std::cout << "Total number of thread blocks: " << kGridSize << std::endl;
|
||||
constexpr ck_tile::index_t kBlockPerCU = 1; // 1 block per CU
|
||||
|
||||
std::cout << "kBlockSize: " << kBlockSize << std::endl;
|
||||
std::cout << "kBlockPerCU: " << kBlockPerCU << std::endl;
|
||||
// Block size is now derived from the shape configuration
|
||||
constexpr ck_tile::index_t kBlockSize = PracticeGemmShape::kBlockSize;
|
||||
std::cout << "Number of threads per block: " << kBlockSize << std::endl;
|
||||
std::cout << "Number of blocks per compute unit: " << kBlockPerCU << std::endl;
|
||||
|
||||
using gemm_kernel =
|
||||
ck_tile::PracticeGemmKernel<PracticeGemmHostProblem, PracticeGemmHostPolicy>;
|
||||
|
||||
@@ -24,6 +24,10 @@ struct PracticeGemmShape
|
||||
static constexpr index_t WaveTile_N = WaveTile::at(number<1>{});
|
||||
static constexpr index_t WaveTile_K = WaveTile::at(number<2>{});
|
||||
|
||||
// Thread block configuration
|
||||
static constexpr index_t kWarpSize = 64; // AMD GPU warp size (also called wavefront)
|
||||
static constexpr index_t kBlockSize = 256; // Total threads per block (4 warps × 64 threads)
|
||||
|
||||
CK_TILE_HOST static std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
@@ -40,7 +44,8 @@ struct PracticeGemmKernel
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
|
||||
static constexpr index_t kBlockSize = 256;
|
||||
// Derive block size from the shape configuration
|
||||
static constexpr index_t kBlockSize = Problem::Shape::kBlockSize;
|
||||
|
||||
CK_TILE_DEVICE void operator()(const typename Problem::ADataType* p_a,
|
||||
const typename Problem::BDataType* p_b,
|
||||
|
||||
Reference in New Issue
Block a user