diff --git a/CHANGELOG.md b/CHANGELOG.md index 8fd3c36255..b07e322fe1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ### Added * Added support for bf16 data type to grouped_gemm and grouped_gemm_preshuffle. +* Added Col-Col-Row-Col layout support for aquant mode in blockscale GEMM. * Added support for mixed precision fp8 x bf8 universal GEMM and weight preshuffle GEMM * Added a compute async pipeline in the CK TILE universal GEMM on gfx950 * Added support for B Tensor type pk_int4_t in the CK TILE weight preshuffle GEMM. diff --git a/CMakeLists.txt b/CMakeLists.txt index 45db703b82..9d0c4d79f9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + cmake_minimum_required(VERSION 3.14) if(POLICY CMP0140) # policies CMP0140 not known to CMake until 3.25 @@ -39,10 +42,12 @@ option(ENABLE_CLANG_CPP_CHECKS "Enables clang tidy, cppcheck" ON) option(MIOPEN_REQ_LIBS_ONLY "Build only the MIOpen required libraries" OFF) option(CK_EXPERIMENTAL_BUILDER "Enable experimental builder" OFF) option(BUILD_MHA_LIB "Build the static library for flash attention" OFF) +option(FORCE_DISABLE_XDL "Skip compiling XDL specific instances (even if supported GPUs are included in GPU_TARGETS)" OFF) +option(FORCE_DISABLE_WMMA "Skip compiling WMMA specific instances (even if supported GPUs are included in GPU_TARGETS)" OFF) if(CK_EXPERIMENTAL_BUILDER) add_definitions(-DCK_EXPERIMENTAL_BUILDER) - include_directories(${PROJECT_SOURCE_DIR}/experimental/builder/include) + include_directories(${PROJECT_SOURCE_DIR}/experimental/builder/include) endif() # Usage: for customized Python location cmake -DCK_USE_ALTERNATIVE_PYTHON="/opt/Python-3.8.13/bin/python3.8" @@ -229,12 +234,12 @@ message(STATUS "Building CK for the following targets: ${SUPPORTED_GPU_TARGETS}" # Cache SUPPORTED_GPU_TARGETS for debug set(SUPPORTED_GPU_TARGETS "${SUPPORTED_GPU_TARGETS}" CACHE STRING "List of supported GPU targets") -if (SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") +if (SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx11|gfx12" AND NOT FORCE_DISABLE_XDL) message(STATUS "Enabling XDL instances") add_definitions(-DCK_USE_XDL) set(CK_USE_XDL "ON") endif() -if (SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95") +if ((SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95") AND NOT FORCE_DISABLE_XDL) message(STATUS "Enabling XDL FP8 gemms on native architectures") add_definitions(-DCK_USE_GFX94) set(CK_USE_GFX94 "ON") @@ -247,7 +252,7 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx10") add_definitions(-DCK_GFX1030_SUPPORT) endif() -if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") +if ((SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") AND NOT FORCE_DISABLE_WMMA) message(STATUS "Enabling WMMA instances") add_definitions(-DCK_USE_WMMA) set(CK_USE_WMMA "ON") @@ -257,7 +262,7 @@ endif() # define the macro with the current value (0 or 1) add_definitions(-DCK_TILE_USE_WMMA=${CK_TILE_USE_WMMA}) -if (SUPPORTED_GPU_TARGETS MATCHES "gfx12") +if (SUPPORTED_GPU_TARGETS MATCHES "gfx12" AND NOT FORCE_DISABLE_WMMA) message(STATUS "Enabling WMMA FP8 gemms on native architectures") add_definitions(-DCK_USE_WMMA_FP8) set(CK_USE_WMMA_FP8 "ON") @@ -739,6 +744,13 @@ rocm_install(FILES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck/ ) +if(CK_EXPERIMENTAL_BUILDER) + rocm_install(DIRECTORY + ${PROJECT_SOURCE_DIR}/experimental/builder/include/ck_tile/builder + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck_tile + ) +endif() + set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE") set(CPACK_RPM_PACKAGE_LICENSE "MIT") diff --git a/Jenkinsfile b/Jenkinsfile index f3e690edd7..a2e5b3d20b 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -72,6 +72,129 @@ def sendFailureNotifications() { } } +def generateAndArchiveBuildTraceVisualization() { + try { + def buildTraceFileName = "ck_build_trace.json"; + + // Attempt to download the build trace file to check if it exists + def traceFileExists = false + try { + copyArtifacts( + projectName: env.JOB_NAME, + selector: specific(env.BUILD_NUMBER), + filter: buildTraceFileName + ) + traceFileExists = fileExists(buildTraceFileName) + } catch (Exception e) { + echo "Could not copy artifacts: ${e.getMessage()}" + traceFileExists = false + } + + sh """ + echo "post download:" + ls -la + """ + + if (traceFileExists) { + // Move the build trace file to a temporary location to preserve it during checkout + sh """ + mkdir -p /tmp/jenkins_artifacts + cp ${buildTraceFileName} /tmp/jenkins_artifacts/${buildTraceFileName} + ls -la /tmp/jenkins_artifacts/ + """ + } else { + echo "Build trace archive not found" + return + } + + // Checkout source code to get required files + checkout scm + + // Restore the build trace file after checkout + sh """ + ls -la + cp /tmp/jenkins_artifacts/${buildTraceFileName} ${buildTraceFileName} + ls -la ${buildTraceFileName} + """ + + // Pull image + def image = "ghcr.io/puppeteer/puppeteer:24.30.0" + echo "Pulling image: ${image}" + def retimage = docker.image("${image}") + retimage.pull() + + // Create a temporary workspace + sh """#!/bin/bash + ls -la + mkdir -p workspace + cp ./script/infra_helper/capture_build_trace.js ./workspace + cp ${buildTraceFileName} ./workspace/${buildTraceFileName} + chmod 777 ./workspace + ls -la ./workspace + """ + + // Run container to get snapshot + def dockerOpts = "--cap-add=SYS_ADMIN -v \"\$(pwd)/workspace:/workspace\" -e NODE_PATH=/home/pptruser/node_modules" + // Create unique image name by sanitizing job name + def sanitizedJobName = env.JOB_NAME.replaceAll(/[\/\\:*?"<>| ]/, '_') + def imageName = "perfetto_snapshot_${sanitizedJobName}_build_${env.BUILD_NUMBER}.png" + sh """ + docker run --rm ${dockerOpts} ${image} node /workspace/capture_build_trace.js + mv ./workspace/perfetto_snapshot_build.png ./workspace/${imageName} + """ + + // Archive the snapshot + sh """ + mv ./workspace/${imageName} ${imageName} + """ + archiveArtifacts "${imageName}" + + // Notify the channel + withCredentials([string(credentialsId: 'ck_ci_build_perf_webhook_url', variable: 'WEBHOOK_URL')]) { + sh ''' + # Create build trace filename with build number based on the original filename + BUILD_TRACE_WITH_NUMBER=$(echo "''' + buildTraceFileName + '''" | sed 's/.json/_''' + sanitizedJobName + '''_''' + env.BUILD_NUMBER + '''.json/') + + # Convert image to base64 + echo "Converting image to base64..." + IMAGE_BASE64=$(base64 -w 0 ''' + imageName + ''') + echo "Image base64 length: ${#IMAGE_BASE64}" + + # Convert build trace to base64 + echo "Converting build trace to base64..." + BUILD_TRACE_BASE64=$(base64 -w 0 ''' + buildTraceFileName + ''') + echo "Build trace base64 length: ${#BUILD_TRACE_BASE64}" + + # Create JSON payload with base64 data + echo "Creating JSON payload..." + { + printf '{\n' + printf ' "jobName": "%s",\n' "''' + env.JOB_NAME + '''" + printf ' "buildNumber": "%s",\n' "''' + env.BUILD_NUMBER + '''" + printf ' "jobUrl": "%s",\n' "''' + env.RUN_DISPLAY_URL + '''" + printf ' "imageName": "%s",\n' "''' + imageName + '''" + printf ' "imageData": "%s",\n' "$IMAGE_BASE64" + printf ' "buildTraceName": "%s",\n' "$BUILD_TRACE_WITH_NUMBER" + printf ' "buildTraceData": "%s"\n' "$BUILD_TRACE_BASE64" + printf '}\n' + } > webhook_payload.json + + echo "JSON payload created, size: $(wc -c < webhook_payload.json) bytes" + + curl -X POST "${WEBHOOK_URL}" \ + -H "Content-Type: application/json" \ + -d @webhook_payload.json + + # Clean up temporary file + rm -f webhook_payload.json + ''' + } + } catch (Exception e) { + echo "Throwing error exception while generating build trace visualization" + echo 'Exception occurred: ' + e.toString() + } +} + class Version { int major, minor, patch @Override @@ -1492,11 +1615,13 @@ pipeline { -D GPU_TARGETS="gfx90a" \ -D GEMM_DATATYPE="fp8;fp16" \ -D GEMM_LAYOUT="rcr;rrr;crr;ccr" \ + -D GEMM_STREAMK_DATATYPE="fp8;fp16" \ + -D GEMM_STREAMK_LAYOUT="rcr" \ -D GEMM_MULTI_D_DATATYPE="fp16" \ -D GEMM_MULTI_D_LAYOUT="rcrr;rrrr;crrr;ccrr" \ -D GEMM_PRESHUFFLE_DATATYPE="fp16;fp8;bf16;bf8" \ -D GEMM_PRESHUFFLE_LAYOUT="rcr" .. && \ - ninja -j64 benchmark_gemm_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all && \ + ninja -j64 benchmark_gemm_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all benchmark_gemm_streamk_all && \ python3 ../tile_engine/ops/gemm/gemm_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \ python3 ../tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \ python3 ../tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json """ @@ -1521,11 +1646,13 @@ pipeline { -D GPU_TARGETS="gfx942" \ -D GEMM_DATATYPE="fp8;fp16" \ -D GEMM_LAYOUT="rcr;rrr;crr;ccr" \ + -D GEMM_STREAMK_DATATYPE="fp8;fp16" \ + -D GEMM_STREAMK_LAYOUT="rcr" \ -D GEMM_MULTI_D_DATATYPE="fp16" \ -D GEMM_MULTI_D_LAYOUT="rcrr;rrrr;crrr;ccrr" \ -D GEMM_PRESHUFFLE_DATATYPE="fp16;fp8;bf16;bf8" \ -D GEMM_PRESHUFFLE_LAYOUT="rcr" .. && \ - ninja -j64 benchmark_gemm_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all && \ + ninja -j64 benchmark_gemm_all benchmark_gemm_preshuffle_all benchmark_gemm_multi_d_all benchmark_gemm_streamk_all && \ python3 ../tile_engine/ops/gemm/gemm_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \ python3 ../tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json && \ python3 ../tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py . --problem-sizes "1024,1024,1024" --warmup 5 --repeat 5 --verbose --json results.json """ @@ -1750,6 +1877,15 @@ pipeline { } } post { + always { + node(rocmnode("nogpu")) { + script { + // Simulate capture + generateAndArchiveBuildTraceVisualization() + } + cleanWs() + } + } success { script { // Report the parent stage build ck and run tests status diff --git a/client_example/01_gemm/CMakeLists.txt b/client_example/01_gemm/CMakeLists.txt index 6c4103cda8..de4c6040ca 100644 --- a/client_example/01_gemm/CMakeLists.txt +++ b/client_example/01_gemm/CMakeLists.txt @@ -1,2 +1,5 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_executable(client_gemm gemm.cpp) target_link_libraries(client_gemm PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations) diff --git a/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt b/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt index 4ba86026b2..0b79d4a4e0 100644 --- a/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt +++ b/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if(GPU_TARGETS MATCHES "gfx9") add_custom_target(client_gemm_fastgelu_examples) diff --git a/client_example/03_gemm_layernorm/CMakeLists.txt b/client_example/03_gemm_layernorm/CMakeLists.txt index 8fedc84635..8d980dadbe 100644 --- a/client_example/03_gemm_layernorm/CMakeLists.txt +++ b/client_example/03_gemm_layernorm/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if(GPU_TARGETS MATCHES "gfx9") add_executable(client_gemm_add_add_layernorm_naive gemm_add_add_layernorm_naive.cpp) target_link_libraries(client_gemm_add_add_layernorm_naive PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations) diff --git a/client_example/04_contraction/CMakeLists.txt b/client_example/04_contraction/CMakeLists.txt index 13c0375846..91e1ef4cf0 100644 --- a/client_example/04_contraction/CMakeLists.txt +++ b/client_example/04_contraction/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if(GPU_TARGETS MATCHES "gfx9") add_executable(client_contraction_scale_fp32 contraction_scale_fp32.cpp) target_link_libraries(client_contraction_scale_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) diff --git a/client_example/05_layernorm/CMakeLists.txt b/client_example/05_layernorm/CMakeLists.txt index b7b3c830ed..b742a234a3 100644 --- a/client_example/05_layernorm/CMakeLists.txt +++ b/client_example/05_layernorm/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_executable(client_layernorm2d_bwd_data layernorm2d_bwd_data.cpp) target_link_libraries(client_layernorm2d_bwd_data PRIVATE composable_kernel::device_other_operations) diff --git a/client_example/06_softmax/CMakeLists.txt b/client_example/06_softmax/CMakeLists.txt index 24d30f475e..f77773879d 100644 --- a/client_example/06_softmax/CMakeLists.txt +++ b/client_example/06_softmax/CMakeLists.txt @@ -1,2 +1,5 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_executable(client_softmax4d softmax4d.cpp) target_link_libraries(client_softmax4d PRIVATE composable_kernel::device_other_operations composable_kernel::device_reduction_operations) diff --git a/client_example/07_grouped_convnd_fwd/CMakeLists.txt b/client_example/07_grouped_convnd_fwd/CMakeLists.txt index 2ea31bdf06..8153862e48 100644 --- a/client_example/07_grouped_convnd_fwd/CMakeLists.txt +++ b/client_example/07_grouped_convnd_fwd/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if(GPU_TARGETS MATCHES "gfx9") add_executable(client_grouped_conv2d_fwd grouped_conv2d_fwd.cpp) target_link_libraries(client_grouped_conv2d_fwd PRIVATE composable_kernel::device_conv_operations) diff --git a/client_example/08_fused_attention/CMakeLists.txt b/client_example/08_fused_attention/CMakeLists.txt index 4bcde367dc..47efec7340 100644 --- a/client_example/08_fused_attention/CMakeLists.txt +++ b/client_example/08_fused_attention/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if(GPU_TARGETS MATCHES "gfx9") add_executable(client_fused_attention fused_attention.cpp) target_link_libraries(client_fused_attention PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations) diff --git a/client_example/09_quantization/CMakeLists.txt b/client_example/09_quantization/CMakeLists.txt index d2d3a427e8..7a688243f3 100644 --- a/client_example/09_quantization/CMakeLists.txt +++ b/client_example/09_quantization/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if(GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)) add_executable(client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp) target_link_libraries(client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) diff --git a/client_example/10_grouped_convnd_bwd_data/CMakeLists.txt b/client_example/10_grouped_convnd_bwd_data/CMakeLists.txt index 42a29a1d42..8814bd261e 100644 --- a/client_example/10_grouped_convnd_bwd_data/CMakeLists.txt +++ b/client_example/10_grouped_convnd_bwd_data/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_executable(client_grouped_conv2d_bwd_data grouped_conv2d_bwd_data.cpp) target_link_libraries(client_grouped_conv2d_bwd_data PRIVATE composable_kernel::device_conv_operations) diff --git a/client_example/11_grouped_conv_bwd_weight/CMakeLists.txt b/client_example/11_grouped_conv_bwd_weight/CMakeLists.txt index 60a6dc1021..c38b93f4e4 100644 --- a/client_example/11_grouped_conv_bwd_weight/CMakeLists.txt +++ b/client_example/11_grouped_conv_bwd_weight/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_executable(client_grouped_conv1d_bwd_weight_fp16 grouped_conv1d_bwd_weight_fp16.cpp) add_executable(client_grouped_conv2d_bwd_weight_fp16 grouped_conv2d_bwd_weight_fp16.cpp) add_executable(client_grouped_conv3d_bwd_weight_fp16 grouped_conv3d_bwd_weight_fp16.cpp) diff --git a/client_example/12_elementwise_normalization/CMakeLists.txt b/client_example/12_elementwise_normalization/CMakeLists.txt index 738647de59..4937272715 100644 --- a/client_example/12_elementwise_normalization/CMakeLists.txt +++ b/client_example/12_elementwise_normalization/CMakeLists.txt @@ -1,2 +1,5 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_executable(client_elementwise_layernorm2d elementwise_layernorm2d.cpp) target_link_libraries(client_elementwise_layernorm2d PRIVATE composable_kernel::device_other_operations) diff --git a/client_example/13_batchnorm/CMakeLists.txt b/client_example/13_batchnorm/CMakeLists.txt index 420ea25752..243ea20f2b 100644 --- a/client_example/13_batchnorm/CMakeLists.txt +++ b/client_example/13_batchnorm/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_executable(client_batchnorm_fwd_nhwc batchnorm_fwd_nhwc.cpp) add_executable(client_batchnorm_bwd_nhwc batchnorm_bwd_nhwc.cpp) add_executable(client_batchnorm_infer_nhwc batchnorm_infer_nhwc.cpp) diff --git a/client_example/14_instance_id/CMakeLists.txt b/client_example/14_instance_id/CMakeLists.txt index 6ba0e59f5a..daf3750055 100644 --- a/client_example/14_instance_id/CMakeLists.txt +++ b/client_example/14_instance_id/CMakeLists.txt @@ -1,2 +1,5 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_executable(client_batchnorm_fwd_instance_id batchnorm_fwd_instance_id.cpp) target_link_libraries(client_batchnorm_fwd_instance_id PRIVATE composable_kernel::device_other_operations) diff --git a/client_example/15_convnd_bwd_data/CMakeLists.txt b/client_example/15_convnd_bwd_data/CMakeLists.txt index 8fc62bc2bb..0ad58450be 100644 --- a/client_example/15_convnd_bwd_data/CMakeLists.txt +++ b/client_example/15_convnd_bwd_data/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if(GPU_TARGETS MATCHES "gfx9") add_executable(client_conv3d_bwd_data_fp16 conv3d_bwd_data_fp16.cpp) add_executable(client_conv3d_bwd_data_fp32 conv3d_bwd_data_fp32.cpp) diff --git a/client_example/16_convnd_fwd/CMakeLists.txt b/client_example/16_convnd_fwd/CMakeLists.txt index 8c1372e741..ae9581809d 100644 --- a/client_example/16_convnd_fwd/CMakeLists.txt +++ b/client_example/16_convnd_fwd/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if((DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) add_executable(client_conv3d_fwd_fp16 conv3d_fwd_fp16.cpp) target_link_libraries(client_conv3d_fwd_fp16 PRIVATE composable_kernel::device_conv_operations) diff --git a/client_example/17_grouped_gemm_fastgelu/CMakeLists.txt b/client_example/17_grouped_gemm_fastgelu/CMakeLists.txt index 39bef71814..6f89d33378 100644 --- a/client_example/17_grouped_gemm_fastgelu/CMakeLists.txt +++ b/client_example/17_grouped_gemm_fastgelu/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if(GPU_TARGETS MATCHES "gfx9") add_executable(client_grouped_gemm_fastgelu grouped_gemm_fastgelu.cpp) target_link_libraries(client_grouped_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations) diff --git a/client_example/18_groupnorm/CMakeLists.txt b/client_example/18_groupnorm/CMakeLists.txt index e04c26d8e7..b31419c87a 100644 --- a/client_example/18_groupnorm/CMakeLists.txt +++ b/client_example/18_groupnorm/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_executable(client_groupnorm_bwd_data groupnorm_bwd_data.cpp) target_link_libraries(client_groupnorm_bwd_data PRIVATE composable_kernel::device_other_operations) diff --git a/client_example/19_pool/CMakeLists.txt b/client_example/19_pool/CMakeLists.txt index 861c1a3257..5c239cd7cb 100644 --- a/client_example/19_pool/CMakeLists.txt +++ b/client_example/19_pool/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_executable(client_max_pool2d_fwd max_pool2d_fwd.cpp) target_link_libraries(client_max_pool2d_fwd PRIVATE composable_kernel::device_other_operations) diff --git a/client_example/20_splitk_gemm/CMakeLists.txt b/client_example/20_splitk_gemm/CMakeLists.txt index 383c5d6630..a7f341b3aa 100644 --- a/client_example/20_splitk_gemm/CMakeLists.txt +++ b/client_example/20_splitk_gemm/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94")) add_executable(client_splitK_gemm splitK_gemm_fp16_f8.cpp) target_link_libraries(client_splitK_gemm PRIVATE composable_kernel::device_gemm_operations) diff --git a/client_example/21_grouped_gemm_bias/CMakeLists.txt b/client_example/21_grouped_gemm_bias/CMakeLists.txt index a09921e50a..f640f32d2a 100644 --- a/client_example/21_grouped_gemm_bias/CMakeLists.txt +++ b/client_example/21_grouped_gemm_bias/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if(GPU_TARGETS MATCHES "gfx9") add_executable(client_grouped_gemm_fixed_nk_bias_fp16 grouped_gemm_fixed_nk_bias_fp16.cpp) target_link_libraries(client_grouped_gemm_fixed_nk_bias_fp16 PRIVATE composable_kernel::device_gemm_operations) diff --git a/client_example/22_grouped_gemm/CMakeLists.txt b/client_example/22_grouped_gemm/CMakeLists.txt index 1e1c39681e..fda1e798c4 100644 --- a/client_example/22_grouped_gemm/CMakeLists.txt +++ b/client_example/22_grouped_gemm/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if(GPU_TARGETS MATCHES "gfx9") add_executable(client_grouped_gemm_fixed_nk_fp16 grouped_gemm_fixed_nk_fp16.cpp) target_link_libraries(client_grouped_gemm_fixed_nk_fp16 PRIVATE composable_kernel::device_gemm_operations) diff --git a/client_example/23_elementwise_transpose/CMakeLists.txt b/client_example/23_elementwise_transpose/CMakeLists.txt index 6b2421d881..a14665295b 100644 --- a/client_example/23_elementwise_transpose/CMakeLists.txt +++ b/client_example/23_elementwise_transpose/CMakeLists.txt @@ -1,2 +1,5 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_executable(client_elementwise_transpose3d elementwise_transpose_3d.cpp) target_link_libraries(client_elementwise_transpose3d PRIVATE composable_kernel::device_other_operations) diff --git a/client_example/24_grouped_conv_activation/CMakeLists.txt b/client_example/24_grouped_conv_activation/CMakeLists.txt index 67bbdfec45..31cb082372 100644 --- a/client_example/24_grouped_conv_activation/CMakeLists.txt +++ b/client_example/24_grouped_conv_activation/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if(GPU_TARGETS MATCHES "gfx9") # Fwd scaleadd scaleadd relu add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp32 diff --git a/client_example/25_wrapper/CMakeLists.txt b/client_example/25_wrapper/CMakeLists.txt index b1e9d20bfd..785f9eb275 100644 --- a/client_example/25_wrapper/CMakeLists.txt +++ b/client_example/25_wrapper/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_executable(client_tensor_transform_using_wrapper tensor_transform_using_wrapper.cpp) target_link_libraries(client_tensor_transform_using_wrapper PRIVATE composable_kernel::device_other_operations) add_executable(client_wrapper_img2col wrapper_img2col.cpp) diff --git a/client_example/26_reduce/CMakeLists.txt b/client_example/26_reduce/CMakeLists.txt index a944af5e54..b1f818f77f 100644 --- a/client_example/26_reduce/CMakeLists.txt +++ b/client_example/26_reduce/CMakeLists.txt @@ -1,2 +1,5 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_executable(client_reduce_nhwc_c reduce_nhwc_c.cpp) target_link_libraries(client_reduce_nhwc_c PRIVATE composable_kernel::device_reduction_operations) diff --git a/client_example/27_im2col_col2im/CMakeLists.txt b/client_example/27_im2col_col2im/CMakeLists.txt index d938d8cfb3..faead24f06 100644 --- a/client_example/27_im2col_col2im/CMakeLists.txt +++ b/client_example/27_im2col_col2im/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_executable(client_image_to_column image_to_column.cpp) target_link_libraries(client_image_to_column PRIVATE composable_kernel::device_other_operations) diff --git a/client_example/28_gemm_mx/CMakeLists.txt b/client_example/28_gemm_mx/CMakeLists.txt index 558986bf5a..0e692fecdd 100644 --- a/client_example/28_gemm_mx/CMakeLists.txt +++ b/client_example/28_gemm_mx/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if(GPU_TARGETS MATCHES "gfx950") add_executable(client_gemm_mx_fp8 gemm_mx_fp8.cpp) target_link_libraries(client_gemm_mx_fp8 PRIVATE composable_kernel::device_gemm_operations) diff --git a/client_example/29_gemm_add_multiply/CMakeLists.txt b/client_example/29_gemm_add_multiply/CMakeLists.txt index a683f78571..e4617ce608 100644 --- a/client_example/29_gemm_add_multiply/CMakeLists.txt +++ b/client_example/29_gemm_add_multiply/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if(GPU_TARGETS MATCHES "gfx9") add_executable(client_gemm_add_multiply gemm_add_multiply.cpp) target_link_libraries(client_gemm_add_multiply PRIVATE composable_kernel::device_gemm_operations) diff --git a/client_example/30_gemm_bf16Aint8B/CMakeLists.txt b/client_example/30_gemm_bf16Aint8B/CMakeLists.txt index 5cfcb68e10..55c48e238b 100644 --- a/client_example/30_gemm_bf16Aint8B/CMakeLists.txt +++ b/client_example/30_gemm_bf16Aint8B/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if(GPU_TARGETS MATCHES "gfx9" AND ((DTYPES MATCHES "int8" AND DTYPES MATCHES "bf16") OR NOT DEFINED DTYPES)) add_executable(client_gemm_bias_fastgelu_bf16_i8_bf16 gemm_bias_fastgelu_xdl_bf16_i8.cpp) target_link_libraries(client_gemm_bias_fastgelu_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations) diff --git a/client_example/31_grouped_gemm_bf16Aint8B/CMakeLists.txt b/client_example/31_grouped_gemm_bf16Aint8B/CMakeLists.txt index c3483ef5db..9e9751ed73 100644 --- a/client_example/31_grouped_gemm_bf16Aint8B/CMakeLists.txt +++ b/client_example/31_grouped_gemm_bf16Aint8B/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if(GPU_TARGETS MATCHES "gfx9" AND ((DTYPES MATCHES "int8" AND DTYPES MATCHES "bf16") OR NOT DEFINED DTYPES)) add_executable(client_grouped_gemm_bias_fastgelu_bf16_i8_bf16 grouped_gemm_bias_fastgelu_xdl_bf16_i8.cpp) target_link_libraries(client_grouped_gemm_bias_fastgelu_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations) diff --git a/client_example/CMakeLists.txt b/client_example/CMakeLists.txt index 21f6e652b8..2ed338d08a 100644 --- a/client_example/CMakeLists.txt +++ b/client_example/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + cmake_minimum_required(VERSION 3.15) project(ck_app) add_compile_options(-std=c++20) diff --git a/cmake/ShardInstantiation.cmake b/cmake/ShardInstantiation.cmake index 48ad21d3e9..b370bb080f 100644 --- a/cmake/ShardInstantiation.cmake +++ b/cmake/ShardInstantiation.cmake @@ -35,7 +35,7 @@ function(generate_sharded_instantiations) set(GENERATED_SOURCE_FILES "") set(EXTERN_TEMPLATE_STATEMENTS "") set(CALL_STATEMENTS "") - message(STATUS "Generating sharded instantiations for target: ${GEN_SHARDED_INSTANCES_NAME}") + message(DEBUG "Generating sharded instantiations for target: ${GEN_SHARDED_INSTANCES_NAME}") set(INSTANCES "${GEN_SHARDED_INSTANCES_NAME}") diff --git a/codegen/CMakeLists.txt b/codegen/CMakeLists.txt index 80429a781b..22d8e58d10 100644 --- a/codegen/CMakeLists.txt +++ b/codegen/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + cmake_minimum_required(VERSION 3.16) project(composable_kernel_host) diff --git a/codegen/test/CMakeLists.txt b/codegen/test/CMakeLists.txt index 48fde531da..ad9743ff83 100644 --- a/codegen/test/CMakeLists.txt +++ b/codegen/test/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + list(APPEND CMAKE_PREFIX_PATH /opt/rocm) add_subdirectory(rtc) file(GLOB TEST_SRCS CONFIGURE_DEPENDS *.cpp) diff --git a/codegen/test/rtc/CMakeLists.txt b/codegen/test/rtc/CMakeLists.txt index b8a60cd633..68b43d0dd9 100644 --- a/codegen/test/rtc/CMakeLists.txt +++ b/codegen/test/rtc/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + find_package(hip) file(GLOB RTC_SOURCES CONFIGURE_DEPENDS src/*.cpp) add_library(ck_rtc ${RTC_SOURCES}) diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index a9ae0b2a6a..2d65368d4f 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_custom_target(example_gemm_dl) add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp) diff --git a/example/02_gemm_bilinear/CMakeLists.txt b/example/02_gemm_bilinear/CMakeLists.txt index 2c20b96eee..c35afda779 100644 --- a/example/02_gemm_bilinear/CMakeLists.txt +++ b/example/02_gemm_bilinear/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_example_executable(example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp) add_example_executable(example_gemm_bilinear_wmma_int8 gemm_bilinear_wmma_int8.cpp) add_example_executable(example_gemm_bilinear_xdl_fp16 gemm_bilinear_xdl_fp16.cpp) diff --git a/example/03_gemm_bias_relu/CMakeLists.txt b/example/03_gemm_bias_relu/CMakeLists.txt index 35c54abac0..73f1aca2fe 100644 --- a/example/03_gemm_bias_relu/CMakeLists.txt +++ b/example/03_gemm_bias_relu/CMakeLists.txt @@ -1 +1,4 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_example_executable(example_gemm_bias_relu_xdl_fp16 gemm_bias_relu_xdl_fp16.cpp) diff --git a/example/04_gemm_add_add_fastgelu/CMakeLists.txt b/example/04_gemm_add_add_fastgelu/CMakeLists.txt index 562936418b..c8234bd3b3 100644 --- a/example/04_gemm_add_add_fastgelu/CMakeLists.txt +++ b/example/04_gemm_add_add_fastgelu/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_custom_target(example_gemm_add_add_fastgelu_xdl) add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp) add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16) diff --git a/example/09_convnd_fwd/CMakeLists.txt b/example/09_convnd_fwd/CMakeLists.txt index 791d81e264..930819327c 100644 --- a/example/09_convnd_fwd/CMakeLists.txt +++ b/example/09_convnd_fwd/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp) add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp) diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt b/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt index ef8bea1850..2039146cfa 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_custom_target(example_convnd_fwd_reduce_xdl) add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp) add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8) diff --git a/example/12_reduce/CMakeLists.txt b/example/12_reduce/CMakeLists.txt index 03381a449f..a9047813db 100644 --- a/example/12_reduce/CMakeLists.txt +++ b/example/12_reduce/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_example_executable(example_reduce_blockwise reduce_blockwise.cpp) add_example_executable(example_reduce_threadwise_multi_d reduce_threadwise_multi_d.cpp) add_example_executable(example_reduce_multiblock_atomic_add reduce_multiblock_atomic_add.cpp) diff --git a/example/13_pool2d_fwd/CMakeLists.txt b/example/13_pool2d_fwd/CMakeLists.txt index e2a923cded..b601805a19 100644 --- a/example/13_pool2d_fwd/CMakeLists.txt +++ b/example/13_pool2d_fwd/CMakeLists.txt @@ -1,2 +1,5 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_example_executable(example_pool2d_fwd_fp16 pool2d_fwd_fp16.cpp) add_example_executable(example_pool2d_fwd_fp32 pool2d_fwd_fp32.cpp) diff --git a/example/14_gemm_quantization/CMakeLists.txt b/example/14_gemm_quantization/CMakeLists.txt index b058e7b0fa..19a7a034c5 100644 --- a/example/14_gemm_quantization/CMakeLists.txt +++ b/example/14_gemm_quantization/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_example_executable(example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp) add_example_executable(example_gemm_wmma_quantization_int8 gemm_wmma_quantization_int8.cpp) add_example_executable(example_gemm_xdl_bias_relu_quantization_int8 gemm_xdl_bias_relu_quantization_int8.cpp) diff --git a/example/15_grouped_gemm/CMakeLists.txt b/example/15_grouped_gemm/CMakeLists.txt index 20d9bab7e1..ce41c3310f 100644 --- a/example/15_grouped_gemm/CMakeLists.txt +++ b/example/15_grouped_gemm/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_custom_target(example_grouped_gemm_xdl) add_example_executable(example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32) @@ -34,6 +37,13 @@ if(USE_BITINT_EXTENSION_INT4) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4) endif() +add_custom_target(example_grouped_gemm_wmma) +add_example_executable(example_grouped_gemm_wmma_splitk_fp16 grouped_gemm_wmma_splitk_fp16.cpp) +add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_splitk_fp16) + +add_example_executable(example_grouped_gemm_wmma_splitk_bf16 grouped_gemm_wmma_splitk_bf16.cpp) +add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_splitk_bf16) + list(APPEND gpu_list_tf32 gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) diff --git a/example/15_grouped_gemm/grouped_gemm_wmma_splitk_bf16.cpp b/example/15_grouped_gemm/grouped_gemm_wmma_splitk_bf16.cpp new file mode 100644 index 0000000000..e4da397c23 --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_wmma_splitk_bf16.cpp @@ -0,0 +1,72 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/utility/ignore.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +using ::ck::DeviceMem; +using ::ck::hip_check_error; +using ::ck::HostTensorDescriptor; +using ::ck::Tensor; + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = BF16; +using BDataType = BF16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_CShuffleV3 + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>; + +// clang-format on + +#define EXAMPLE_USE_SPLITK +#include "run_grouped_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } diff --git a/example/15_grouped_gemm/grouped_gemm_wmma_splitk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_wmma_splitk_fp16.cpp new file mode 100644 index 0000000000..d5b2205892 --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_wmma_splitk_fp16.cpp @@ -0,0 +1,71 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/utility/ignore.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +using ::ck::DeviceMem; +using ::ck::hip_check_error; +using ::ck::HostTensorDescriptor; +using ::ck::Tensor; + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_CShuffleV3 + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>; + +// clang-format on + +#define EXAMPLE_USE_SPLITK +#include "run_grouped_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } diff --git a/example/15_grouped_gemm/run_grouped_gemm_example.inc b/example/15_grouped_gemm/run_grouped_gemm_example.inc index 0e64fbb7c6..764b533455 100644 --- a/example/15_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/15_grouped_gemm/run_grouped_gemm_example.inc @@ -19,6 +19,10 @@ struct ProblemSize final std::vector stride_Cs; ck::index_t group_count; + +#if defined(EXAMPLE_USE_SPLITK) + ck::index_t k_batch; +#endif }; struct ExecutionConfig final @@ -177,6 +181,10 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co auto argument = gemm.MakeArgument( p_a, p_b, p_Ds, p_c, gemm_descs, a_element_op, b_element_op, c_element_op); +#if defined(EXAMPLE_USE_SPLITK) + gemm.SetKBatchSize(&argument, problem_size.k_batch); +#endif + std::size_t workspace_size = gemm.GetWorkSpaceSize(&argument); std::size_t kargs_size = gemm.GetDeviceKernelArgSize(&argument); std::size_t hargs_size = gemm.GetHostKernelArgSize(&argument); @@ -285,12 +293,15 @@ bool run_grouped_gemm_example(int argc, char* argv[]) ExecutionConfig config; problem_size.group_count = 16; +#if defined(EXAMPLE_USE_SPLITK) + problem_size.k_batch = 1; +#endif if(argc == 1) { // use default cases } - else if(argc == 4 || argc == 6) + else if(argc == 4 || argc == 6 || argc == 7) { config.do_verification = std::stoi(argv[1]); config.init_method = std::stoi(argv[2]); @@ -300,6 +311,13 @@ bool run_grouped_gemm_example(int argc, char* argv[]) config.async_hargs = std::stoi(argv[4]); problem_size.group_count = std::stoi(argv[5]); } + +#if defined(EXAMPLE_USE_SPLITK) + if(argc == 7) + { + problem_size.k_batch = std::stoi(argv[6]); + } +#endif } else { @@ -307,7 +325,10 @@ bool run_grouped_gemm_example(int argc, char* argv[]) printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n"); printf("arg4: async hargs (0=n0, 1=yes)\n"); - printf("arg5: group count (default=16)"); + printf("arg5: group count (default=16)\n"); +#if defined(EXAMPLE_USE_SPLITK) + printf("arg6: k-batch count (default=1)\n"); +#endif exit(1); } diff --git a/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt b/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt index 1e12c16f30..7a685afb53 100644 --- a/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt +++ b/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_custom_target(example_gemm_reduce_xdl) add_custom_target(example_gemm_reduce_xdl_max) add_custom_target(example_gemm_reduce_xdl_mean_meansquare) diff --git a/example/17_convnd_bwd_data/CMakeLists.txt b/example/17_convnd_bwd_data/CMakeLists.txt index 39f9fb8ec0..8f48bd0d29 100644 --- a/example/17_convnd_bwd_data/CMakeLists.txt +++ b/example/17_convnd_bwd_data/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_example_executable(example_convnd_bwd_data_xdl_fp16 convnd_bwd_data_xdl_fp16.cpp) if(result EQUAL 0) target_link_libraries(example_convnd_bwd_data_xdl_fp16 PRIVATE utility) diff --git a/example/18_batched_gemm_reduce/CMakeLists.txt b/example/18_batched_gemm_reduce/CMakeLists.txt index 1d1f255187..5d3fafe116 100644 --- a/example/18_batched_gemm_reduce/CMakeLists.txt +++ b/example/18_batched_gemm_reduce/CMakeLists.txt @@ -1 +1,4 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_example_executable(example_batched_gemm_reduce_xdl_fp16 batched_gemm_reduce_xdl_fp16.cpp) diff --git a/example/19_binary_elementwise/CMakeLists.txt b/example/19_binary_elementwise/CMakeLists.txt index 39646e0ab5..792de59d15 100644 --- a/example/19_binary_elementwise/CMakeLists.txt +++ b/example/19_binary_elementwise/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_example_executable(example_broadcast_add_2d_amn_bn broadcast_add_2d_amn_bn.cpp) add_example_executable(example_broadcast_add_3d_am_bmnk broadcast_add_3d_am_bmnk.cpp) add_example_executable(example_elementwise_add_1d elementwise_add_1d.cpp) diff --git a/example/20_grouped_conv_bwd_weight/CMakeLists.txt b/example/20_grouped_conv_bwd_weight/CMakeLists.txt index 6fbaee7dba..2e381b09d3 100644 --- a/example/20_grouped_conv_bwd_weight/CMakeLists.txt +++ b/example/20_grouped_conv_bwd_weight/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_custom_target(example_grouped_conv_bwd_weight) add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp) add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16) diff --git a/example/21_gemm_layernorm/CMakeLists.txt b/example/21_gemm_layernorm/CMakeLists.txt index 2eb7052e1e..4c5a335e12 100644 --- a/example/21_gemm_layernorm/CMakeLists.txt +++ b/example/21_gemm_layernorm/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_welford_fp16 gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp) add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_naive_fp16 gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp) add_example_executable(example_gemm_layernorm_xdl_naive_fp16 gemm_layernorm_xdl_naive_fp16.cpp) diff --git a/example/22_cgemm/CMakeLists.txt b/example/22_cgemm/CMakeLists.txt index 44585b11d0..47ef20da30 100644 --- a/example/22_cgemm/CMakeLists.txt +++ b/example/22_cgemm/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_custom_target(example_cgemm_xdl) add_example_executable(example_cgemm_xdl_bf16 cgemm_xdl_bf16.cpp) diff --git a/example/23_softmax/CMakeLists.txt b/example/23_softmax/CMakeLists.txt index dafe65521a..73fb589a58 100644 --- a/example/23_softmax/CMakeLists.txt +++ b/example/23_softmax/CMakeLists.txt @@ -1 +1,4 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_example_executable(example_softmax_blockwise softmax_blockwise.cpp) \ No newline at end of file diff --git a/example/24_batched_gemm/CMakeLists.txt b/example/24_batched_gemm/CMakeLists.txt index d515720944..b43e84fa30 100644 --- a/example/24_batched_gemm/CMakeLists.txt +++ b/example/24_batched_gemm/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_custom_target(example_batched_gemm_xdl) add_example_executable(example_batched_gemm_xdl_fp32 batched_gemm_xdl_fp32.cpp) diff --git a/example/25_gemm_bias_e_permute/CMakeLists.txt b/example/25_gemm_bias_e_permute/CMakeLists.txt index cbc3c007bc..1879066181 100644 --- a/example/25_gemm_bias_e_permute/CMakeLists.txt +++ b/example/25_gemm_bias_e_permute/CMakeLists.txt @@ -1,2 +1,5 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_example_executable(example_gemm_bias_e_permute_g1m3n2k1_xdl_fp16 gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp) add_example_executable(example_gemm_bias_e_permute_g1m2n3k1_xdl_fp16 gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp) diff --git a/example/26_contraction/CMakeLists.txt b/example/26_contraction/CMakeLists.txt index f3d30cea2a..4a41bc5e65 100644 --- a/example/26_contraction/CMakeLists.txt +++ b/example/26_contraction/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_custom_target(example_contraction) add_custom_target(example_contraction_scale) add_custom_target(example_contraction_bilinear) diff --git a/example/27_layernorm2d_fwd/CMakeLists.txt b/example/27_layernorm2d_fwd/CMakeLists.txt index 639bd9c400..94cc8f3c65 100644 --- a/example/27_layernorm2d_fwd/CMakeLists.txt +++ b/example/27_layernorm2d_fwd/CMakeLists.txt @@ -1,2 +1,5 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_example_executable(example_layernorm2d_fwd_fp16 layernorm2d_fwd_fp16.cpp) add_example_executable(example_layernorm2d_fwd_splitk_fp16 layernorm2d_fwd_splitk_fp16.cpp) diff --git a/example/28_grouped_gemm_bias_e_permute/CMakeLists.txt b/example/28_grouped_gemm_bias_e_permute/CMakeLists.txt index 44ab16894c..abe6ca2bf9 100644 --- a/example/28_grouped_gemm_bias_e_permute/CMakeLists.txt +++ b/example/28_grouped_gemm_bias_e_permute/CMakeLists.txt @@ -1 +1,4 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_example_executable(example_grouped_gemm_bias_e_permute_xdl_fp16 grouped_gemm_bias_e_permute_xdl_fp16.cpp) diff --git a/example/29_batched_gemm_bias_e_permute/CMakeLists.txt b/example/29_batched_gemm_bias_e_permute/CMakeLists.txt index ac54aebdc2..d5d5521370 100644 --- a/example/29_batched_gemm_bias_e_permute/CMakeLists.txt +++ b/example/29_batched_gemm_bias_e_permute/CMakeLists.txt @@ -1,2 +1,5 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp) add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp) diff --git a/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt b/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt index 7acb1a1907..4732acac76 100644 --- a/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt +++ b/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_custom_target(example_grouped_conv_fwd_multiple_d) add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp) add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16) diff --git a/example/31_batched_gemm_gemm/CMakeLists.txt b/example/31_batched_gemm_gemm/CMakeLists.txt index 811b133b44..43ddb9577b 100644 --- a/example/31_batched_gemm_gemm/CMakeLists.txt +++ b/example/31_batched_gemm_gemm/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp) add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp) add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp) diff --git a/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt b/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt index 519f539106..7efa169a7d 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt +++ b/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp) add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp) add_example_executable(example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp) diff --git a/example/33_multiple_reduce/CMakeLists.txt b/example/33_multiple_reduce/CMakeLists.txt index bc8c3eb04e..889782abc7 100644 --- a/example/33_multiple_reduce/CMakeLists.txt +++ b/example/33_multiple_reduce/CMakeLists.txt @@ -1,2 +1,5 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_example_executable(example_dual_reduce_multiblock dual_reduce_multiblock.cpp) add_example_executable(example_dual_reduce_threadwise dual_reduce_threadwise.cpp) diff --git a/example/34_batchnorm/CMakeLists.txt b/example/34_batchnorm/CMakeLists.txt index 60824c5f4d..a49a3babe9 100644 --- a/example/34_batchnorm/CMakeLists.txt +++ b/example/34_batchnorm/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_example_executable(example_batchnorm_forward_training batchnorm_forward_training_nhwc.cpp) add_example_executable(example_batchnorm_forward_training_obsolete batchnorm_forward_training_nhwc_obsolete.cpp) add_example_executable(example_batchnorm_forward_inferring batchnorm_forward_inferring_nhwc.cpp) diff --git a/example/35_splitK_gemm/CMakeLists.txt b/example/35_splitK_gemm/CMakeLists.txt index e0476bfaad..b12393c38e 100644 --- a/example/35_splitK_gemm/CMakeLists.txt +++ b/example/35_splitK_gemm/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_custom_target(example_splitK_gemm_xdl) add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp) add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp32) diff --git a/example/36_sparse_embedding/CMakeLists.txt b/example/36_sparse_embedding/CMakeLists.txt index 9cbcf5540e..e4dc1a21f0 100644 --- a/example/36_sparse_embedding/CMakeLists.txt +++ b/example/36_sparse_embedding/CMakeLists.txt @@ -1 +1,4 @@ -add_example_executable(example_sparse_embedding3_forward_layernorm sparse_embedding3_forward_layernorm.cpp) +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +add_example_executable(example_sparse_embedding3_forward_layernorm sparse_embedding3_forward_layernorm.cpp) diff --git a/example/37_batched_gemm_add_add_relu_gemm_add/CMakeLists.txt b/example/37_batched_gemm_add_add_relu_gemm_add/CMakeLists.txt index a9be3a7108..7b6eb01413 100644 --- a/example/37_batched_gemm_add_add_relu_gemm_add/CMakeLists.txt +++ b/example/37_batched_gemm_add_add_relu_gemm_add/CMakeLists.txt @@ -1 +1,4 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_example_executable(example_batched_gemm_add_add_relu_gemm_add_xdl_fp16 batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp) diff --git a/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt b/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt index ce951f6353..b58bd7cb3a 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt +++ b/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_custom_target(example_grouped_conv_bwd_data) add_example_executable(example_grouped_conv_bwd_data_xdl_fp16 grouped_conv_bwd_data_xdl_fp16.cpp) diff --git a/example/39_permute/CMakeLists.txt b/example/39_permute/CMakeLists.txt index 8b850c89a9..16ff5b4e52 100644 --- a/example/39_permute/CMakeLists.txt +++ b/example/39_permute/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_custom_target(example_permute) add_example_executable(example_permute_1xHxW_fp16 permute_1xHxW_fp16.cpp) diff --git a/example/40_conv2d_fwd_quantization/CMakeLists.txt b/example/40_conv2d_fwd_quantization/CMakeLists.txt index 991c1e464b..c3b05bcdd4 100644 --- a/example/40_conv2d_fwd_quantization/CMakeLists.txt +++ b/example/40_conv2d_fwd_quantization/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_example_executable(example_conv2d_fwd_xdl_perlayer_quantization_int8 conv2d_fwd_xdl_perlayer_quantization_int8.cpp) add_example_executable(example_conv2d_fwd_xdl_perchannel_quantization_int8 conv2d_fwd_xdl_perchannel_quantization_int8.cpp) add_example_executable(example_conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8 conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp) diff --git a/example/41_grouped_conv_conv_fwd/CMakeLists.txt b/example/41_grouped_conv_conv_fwd/CMakeLists.txt index c5c5a84b67..e0fd5a1de0 100644 --- a/example/41_grouped_conv_conv_fwd/CMakeLists.txt +++ b/example/41_grouped_conv_conv_fwd/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp) add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp) add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp) diff --git a/example/42_groupnorm_fwd/CMakeLists.txt b/example/42_groupnorm_fwd/CMakeLists.txt index 7d08baccd0..d3466d4933 100644 --- a/example/42_groupnorm_fwd/CMakeLists.txt +++ b/example/42_groupnorm_fwd/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_example_executable(example_groupnorm_fwd_sigmoid_mul_fp16 groupnorm_fwd_sigmoid_mul_fp16.cpp) add_example_executable(example_groupnorm_fwd_splitk_fp16 groupnorm_fwd_splitk_fp16.cpp) add_example_executable(example_groupnorm_fwd_swish_fp16 groupnorm_fwd_swish_fp16.cpp) diff --git a/example/43_splitk_gemm_bias_e_permute/CMakeLists.txt b/example/43_splitk_gemm_bias_e_permute/CMakeLists.txt index c29f18f162..e67c333eeb 100644 --- a/example/43_splitk_gemm_bias_e_permute/CMakeLists.txt +++ b/example/43_splitk_gemm_bias_e_permute/CMakeLists.txt @@ -1,2 +1,5 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp16 splitk_gemm_bias_e_permute_xdl_fp16.cpp) add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp32 splitk_gemm_bias_e_permute_xdl_fp32.cpp) diff --git a/example/44_elementwise_permute/CMakeLists.txt b/example/44_elementwise_permute/CMakeLists.txt index 867493465d..8e7cf753a8 100644 --- a/example/44_elementwise_permute/CMakeLists.txt +++ b/example/44_elementwise_permute/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_example_executable(example_elementwise_permute_4D_fp16 elementwise_permute_4D_fp16.cpp) add_example_executable(example_elementwise_permute_4D_fp32_row elementwise_permute_4D_fp32_row.cpp) add_example_executable(example_elementwise_permute_4D_fp16_row elementwise_permute_4D_fp16_row.cpp) diff --git a/example/45_elementwise_normalization/CMakeLists.txt b/example/45_elementwise_normalization/CMakeLists.txt index 8f5b9d4d87..1931728ade 100644 --- a/example/45_elementwise_normalization/CMakeLists.txt +++ b/example/45_elementwise_normalization/CMakeLists.txt @@ -1 +1,4 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_example_executable(example_elementwise_layernorm_blockwise elementwise_layernorm_blockwise.cpp) diff --git a/example/46_gemm_add_multiply/CMakeLists.txt b/example/46_gemm_add_multiply/CMakeLists.txt index bfe057e8da..2575b3c6ef 100644 --- a/example/46_gemm_add_multiply/CMakeLists.txt +++ b/example/46_gemm_add_multiply/CMakeLists.txt @@ -1,2 +1,5 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_example_executable(example_gemm_add_multiply_dl_fp16 gemm_add_multiply_dl_fp16.cpp) add_example_executable(example_gemm_add_multiply_xdl_fp16 gemm_add_multiply_xdl_fp16.cpp) diff --git a/example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt b/example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt index df1956ca62..0991c1895c 100644 --- a/example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt +++ b/example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt @@ -1 +1,4 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_example_executable(example_gemm_bias_softmax_gemm_permute gemm_bias_softmax_gemm_permute_xdl.cpp) diff --git a/example/48_pool3d_fwd/CMakeLists.txt b/example/48_pool3d_fwd/CMakeLists.txt index 492cb4d37e..f9677065bd 100644 --- a/example/48_pool3d_fwd/CMakeLists.txt +++ b/example/48_pool3d_fwd/CMakeLists.txt @@ -1 +1,4 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_example_executable(example_pool3d_fwd_fp16 pool3d_fwd_fp16.cpp) diff --git a/example/49_maxpool2d_bwd/CMakeLists.txt b/example/49_maxpool2d_bwd/CMakeLists.txt index b29cf9ccbc..7e6ee1de72 100644 --- a/example/49_maxpool2d_bwd/CMakeLists.txt +++ b/example/49_maxpool2d_bwd/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_example_executable(example_maxpool2d_bwd_bf16 maxpool2d_bwd_bf16.cpp) add_example_executable(example_maxpool2d_bwd_fp16 maxpool2d_bwd_fp16.cpp) add_example_executable(example_maxpool2d_bwd_fp32 maxpool2d_bwd_fp32.cpp) diff --git a/example/50_put_element/CMakeLists.txt b/example/50_put_element/CMakeLists.txt index 1b0020ebcf..65a9a7e14a 100644 --- a/example/50_put_element/CMakeLists.txt +++ b/example/50_put_element/CMakeLists.txt @@ -1 +1,4 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_example_executable(example_put_element_fp16 put_element_fp16.cpp) diff --git a/example/51_avgpool3d_bwd/CMakeLists.txt b/example/51_avgpool3d_bwd/CMakeLists.txt index fef0c66835..b621b54ec9 100644 --- a/example/51_avgpool3d_bwd/CMakeLists.txt +++ b/example/51_avgpool3d_bwd/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_example_executable(example_avgpool3d_bwd_bf16 avgpool3d_bwd_bf16.cpp) add_example_executable(example_avgpool3d_bwd_fp16 avgpool3d_bwd_fp16.cpp) add_example_executable(example_avgpool3d_bwd_fp32 avgpool3d_bwd_fp32.cpp) diff --git a/example/52_im2col_col2im/CMakeLists.txt b/example/52_im2col_col2im/CMakeLists.txt index 63ee1d4312..86f4e17386 100644 --- a/example/52_im2col_col2im/CMakeLists.txt +++ b/example/52_im2col_col2im/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_custom_target(example_im2col_col2im) add_example_executable(example_image_to_column_f32 image_to_column_f32.cpp) diff --git a/example/53_layernorm2d_bwd/CMakeLists.txt b/example/53_layernorm2d_bwd/CMakeLists.txt index a58b1109f7..52f11555c1 100644 --- a/example/53_layernorm2d_bwd/CMakeLists.txt +++ b/example/53_layernorm2d_bwd/CMakeLists.txt @@ -1 +1,4 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_example_executable(example_layernorm2d_bwd_fp32 layernorm2d_bwd_fp32.cpp) diff --git a/example/54_groupnorm_bwd/CMakeLists.txt b/example/54_groupnorm_bwd/CMakeLists.txt index 2cb103499c..34d931862e 100644 --- a/example/54_groupnorm_bwd/CMakeLists.txt +++ b/example/54_groupnorm_bwd/CMakeLists.txt @@ -1 +1,4 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_example_executable(example_groupnorm_bwd_fp32 groupnorm_bwd_fp32.cpp) diff --git a/example/59_grouped_gemm_multi_ABD/CMakeLists.txt b/example/59_grouped_gemm_multi_ABD/CMakeLists.txt index e49056a948..4155e0a344 100644 --- a/example/59_grouped_gemm_multi_ABD/CMakeLists.txt +++ b/example/59_grouped_gemm_multi_ABD/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_custom_target(example_grouped_gemm_xdl_multi_abd) add_example_executable(example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16 grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp) diff --git a/example/60_gemm_multi_ABD/CMakeLists.txt b/example/60_gemm_multi_ABD/CMakeLists.txt index ffc6cec61d..c1eab1dd7d 100644 --- a/example/60_gemm_multi_ABD/CMakeLists.txt +++ b/example/60_gemm_multi_ABD/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_example_executable(example_gemm_multi_ABD_wmma_fp16 gemm_multi_ABD_wmma_fp16.cpp) add_example_executable(example_gemm_multi_ABD_wmma_bias_fastgelu_bf16_i8 gemm_multi_ABD_wmma_bias_fastgelu_bf16_i8.cpp) add_example_executable(example_gemm_multi_ABD_wmma_multiply_bias_fastgelu_bf16_i8 gemm_multi_ABD_wmma_multiply_bias_fastgelu_bf16_i8.cpp) diff --git a/example/61_contraction_multi_ABD/CMakeLists.txt b/example/61_contraction_multi_ABD/CMakeLists.txt index 1b8bd4cad2..620a40ba59 100644 --- a/example/61_contraction_multi_ABD/CMakeLists.txt +++ b/example/61_contraction_multi_ABD/CMakeLists.txt @@ -1 +1,4 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_example_executable(example_contraction_multi_ABD_xdl_fp16 contraction_multi_ABD_xdl_fp16.cpp) diff --git a/example/62_convnd_activ/CMakeLists.txt b/example/62_convnd_activ/CMakeLists.txt index 79fafed4eb..ecdbf63d2c 100644 --- a/example/62_convnd_activ/CMakeLists.txt +++ b/example/62_convnd_activ/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_subdirectory(binary) add_subdirectory(convinvscale) add_subdirectory(convscale) diff --git a/example/62_convnd_activ/binary/CMakeLists.txt b/example/62_convnd_activ/binary/CMakeLists.txt index f23f908883..29cd190b7d 100644 --- a/example/62_convnd_activ/binary/CMakeLists.txt +++ b/example/62_convnd_activ/binary/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_custom_target(example_convnd_activ_binary_xdl) # Bilinear residual add_example_executable(example_convnd_fwd_xdl_bilinear_residual_fp16 convnd_fwd_xdl_bilinear_residual_fp16.cpp) diff --git a/example/62_convnd_activ/convinvscale/CMakeLists.txt b/example/62_convnd_activ/convinvscale/CMakeLists.txt index c737bc00ec..9748f50e51 100644 --- a/example/62_convnd_activ/convinvscale/CMakeLists.txt +++ b/example/62_convnd_activ/convinvscale/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if (NOT GPU_TARGETS MATCHES "gfx11") add_custom_target(example_convnd_activ_xdl_convinvscale) add_example_executable(example_convnd_fwd_xdl_convinvscale_fp8 convnd_fwd_xdl_convinvscale_fp8.cpp) diff --git a/example/62_convnd_activ/convscale/CMakeLists.txt b/example/62_convnd_activ/convscale/CMakeLists.txt index 8746a5ad54..705160e01d 100644 --- a/example/62_convnd_activ/convscale/CMakeLists.txt +++ b/example/62_convnd_activ/convscale/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if (NOT GPU_TARGETS MATCHES "gfx11") add_custom_target(example_convnd_activ_xdl_convscale) add_example_executable(example_convnd_fwd_xdl_convscale_fp8 convnd_fwd_xdl_convscale_fp8.cpp) diff --git a/example/62_convnd_activ/convscale_add/CMakeLists.txt b/example/62_convnd_activ/convscale_add/CMakeLists.txt index 5dac630298..e8f1488eb7 100644 --- a/example/62_convnd_activ/convscale_add/CMakeLists.txt +++ b/example/62_convnd_activ/convscale_add/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if (NOT GPU_TARGETS MATCHES "gfx11") add_custom_target(example_convnd_activ_xdl_convscale_add) add_example_executable(example_convnd_fwd_xdl_convscale_add_fp8 convnd_fwd_xdl_convscale_add_fp8.cpp) diff --git a/example/62_convnd_activ/convscale_reduce/CMakeLists.txt b/example/62_convnd_activ/convscale_reduce/CMakeLists.txt index c1c64671b4..0cbf17b2ec 100644 --- a/example/62_convnd_activ/convscale_reduce/CMakeLists.txt +++ b/example/62_convnd_activ/convscale_reduce/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if (NOT GPU_TARGETS MATCHES "gfx11") add_custom_target(example_convnd_activ_xdl_convscale_reduce) add_example_executable(example_convnd_fwd_xdl_convscale_relu_amax_fp8 convnd_fwd_xdl_convscale_relu_amax_fp8.cpp) diff --git a/example/62_convnd_activ/convscale_relu/CMakeLists.txt b/example/62_convnd_activ/convscale_relu/CMakeLists.txt index 024b79e2af..307a4102a6 100644 --- a/example/62_convnd_activ/convscale_relu/CMakeLists.txt +++ b/example/62_convnd_activ/convscale_relu/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if (NOT GPU_TARGETS MATCHES "gfx11") add_custom_target(example_convnd_activ_xdl_convscale_relu) add_example_executable(example_convnd_fwd_xdl_convscale_relu_fp8 convnd_fwd_xdl_convscale_relu_fp8.cpp) diff --git a/example/62_convnd_activ/dynamic_unary/CMakeLists.txt b/example/62_convnd_activ/dynamic_unary/CMakeLists.txt index 359b444dd0..9efc48f905 100644 --- a/example/62_convnd_activ/dynamic_unary/CMakeLists.txt +++ b/example/62_convnd_activ/dynamic_unary/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_custom_target(example_convnd_activ_dynamic_unary_xdl) # Sigmoid add_example_executable(example_convnd_fwd_xdl_dynamic_sigmoid_fp16 convnd_fwd_xdl_dynamic_sigmoid_fp16.cpp) diff --git a/example/62_convnd_activ/multi_AB/CMakeLists.txt b/example/62_convnd_activ/multi_AB/CMakeLists.txt index 80a3a8f196..fc637cdbc6 100644 --- a/example/62_convnd_activ/multi_AB/CMakeLists.txt +++ b/example/62_convnd_activ/multi_AB/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_custom_target(example_convnd_activ_multi_ab_xdl) # ScaleAdd on A and B add_example_executable(example_conv_fwd_xdl_scaleadd_ab_fp16 conv_fwd_xdl_scaleadd_ab_fp16.cpp) diff --git a/example/62_convnd_activ/unary/CMakeLists.txt b/example/62_convnd_activ/unary/CMakeLists.txt index 2b54b1f590..13d185c35a 100644 --- a/example/62_convnd_activ/unary/CMakeLists.txt +++ b/example/62_convnd_activ/unary/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_custom_target(example_convnd_activ_unary_xdl) # Sigmoid add_example_executable(example_convnd_fwd_xdl_sigmoid_fp16 convnd_fwd_xdl_sigmoid_fp16.cpp) diff --git a/example/63_layernorm4d_fwd/CMakeLists.txt b/example/63_layernorm4d_fwd/CMakeLists.txt index 3f8c679ab8..542f63f070 100644 --- a/example/63_layernorm4d_fwd/CMakeLists.txt +++ b/example/63_layernorm4d_fwd/CMakeLists.txt @@ -1,2 +1,5 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_example_executable(example_layernorm4d_fwd_fp16 layernorm4d_fwd_fp16.cpp) add_example_executable(example_layernorm4d_fwd_splitk_fp16 layernorm4d_fwd_splitk_fp16.cpp) diff --git a/example/64_fpAintB_gemm/CMakeLists.txt b/example/64_fpAintB_gemm/CMakeLists.txt index b3c77b3bd7..314838e7f2 100644 --- a/example/64_fpAintB_gemm/CMakeLists.txt +++ b/example/64_fpAintB_gemm/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_custom_target(example_fpAintB_gemm_wmma) add_example_executable(example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp) add_example_dependencies(example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma) diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt index 74930d2b21..6198fd0e22 100644 --- a/example/65_gemm_multiply_multiply/CMakeLists.txt +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_multiply_xdl_fp8.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle.cpp) @@ -16,7 +19,7 @@ add_example_executable(example_moe_gemm2_xdl_fp8 moe_gemm2_xdl_fp8.cpp) add_example_executable(example_moe_gemm2_xdl_fp8_blockscale moe_gemm2_xdl_fp8_blockscale.cpp) add_example_executable(example_moe_gemm1_xdl_fp8_blockscale moe_gemm1_xdl_fp8_blockscale.cpp) -list(APPEND gpu_list gfx942 gfx950 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx11-generic gfx12-generic) +list(APPEND gpu_list gfx942 gfx950 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1153 gfx1200 gfx1201 gfx11-generic gfx12-generic) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) diff --git a/example/66_complex_contraction_bilinear/CMakeLists.txt b/example/66_complex_contraction_bilinear/CMakeLists.txt index c417caf8e7..df66cf2112 100644 --- a/example/66_complex_contraction_bilinear/CMakeLists.txt +++ b/example/66_complex_contraction_bilinear/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_example_executable(example_complex_contraction_bilinear_xdl_fp32 complex_contraction_bilinear_xdl_fp32.cpp) add_example_executable(example_complex_contraction_bilinear_xdl_fp64 complex_contraction_bilinear_xdl_fp64.cpp) diff --git a/example/67_gemm_microscaling/CMakeLists.txt b/example/67_gemm_microscaling/CMakeLists.txt index 6ee43aac62..62e86d7682 100644 --- a/example/67_gemm_microscaling/CMakeLists.txt +++ b/example/67_gemm_microscaling/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_custom_target(example_gemm_mx) add_example_executable(example_gemm_mx_fp8 gemm_mx_fp8.cpp) diff --git a/example/68_gemm_add/CMakeLists.txt b/example/68_gemm_add/CMakeLists.txt index af091d32e4..9e2fd3a7cb 100644 --- a/example/68_gemm_add/CMakeLists.txt +++ b/example/68_gemm_add/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_custom_target(example_gemm_add_xdl) add_example_executable(example_gemm_add_xdl_fp16 gemm_add_xdl_fp16.cpp) diff --git a/example/69_gemm_add_relu/CMakeLists.txt b/example/69_gemm_add_relu/CMakeLists.txt index 9ab3ef5a45..6ba7b2b9ca 100644 --- a/example/69_gemm_add_relu/CMakeLists.txt +++ b/example/69_gemm_add_relu/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_custom_target(example_gemm_add_relu_xdl) add_example_executable(example_gemm_add_relu_xdl_fp16 gemm_add_relu_xdl_fp16.cpp) diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 940e7bc5e6..aed19c083a 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + include_directories(BEFORE ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/library/include @@ -113,7 +116,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) elseif(source_name_list MATCHES "_wmma") list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950) elseif(source_name_list MATCHES "_mx") #only build mx example for gfx950 - list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1153 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) elseif(source_name_list MATCHES "_pk_i4") #only build these examples for gfx942 gfx950 and rdna3/4 message(DEBUG "trimming targets for ${FILE_NAME}") list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx10-3-generic) diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index ce914b92af..0c8102a70b 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + set(INST_TARGETS ${SUPPORTED_GPU_TARGETS}) # Currently only gfx9 and gfx12 archs are supported by FMHA list(FILTER INST_TARGETS INCLUDE REGEX "gfx9|gfx12") @@ -109,6 +112,7 @@ set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances") set(FMHA_BWD_INSTANCES "tile_fmha_bwd_instances") message(DEBUG "adding instances ${FMHA_FWD_INSTANCES}") +# to save build time, exclude the target from "all" target of "01_fmha" directory and its ancestors add_library(${FMHA_FWD_INSTANCES} OBJECT EXCLUDE_FROM_ALL) target_include_directories(${FMHA_FWD_INSTANCES} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_sources(${FMHA_FWD_INSTANCES} PRIVATE ${FMHA_FWD_GEN_BLOBS}) diff --git a/example/ck_tile/02_layernorm2d/CMakeLists.txt b/example/ck_tile/02_layernorm2d/CMakeLists.txt index 07714f0fe2..0f2ce847e8 100644 --- a/example/ck_tile/02_layernorm2d/CMakeLists.txt +++ b/example/ck_tile/02_layernorm2d/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + set(LAYERNORM2D_FWD_KNOWN_APIS "fwd;bwd") set(LAYERNORM2D_FWD_ENABLE_APIS "fwd" CACHE STRING "semicolon-separated list of APIs to generate (${LAYERNORM2D_FWD_KNOWN_APIS}) & link, or \"all\".") @@ -26,7 +29,7 @@ add_custom_command( set(EXAMPLE_LAYERNORM2D_FWD "tile_example_layernorm2d_fwd") message(DEBUG "adding example ${EXAMPLE_LAYERNORM2D_FWD}") -add_executable(${EXAMPLE_LAYERNORM2D_FWD} EXCLUDE_FROM_ALL layernorm2d_fwd.cpp) +add_executable(${EXAMPLE_LAYERNORM2D_FWD} layernorm2d_fwd.cpp) target_include_directories(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${LAYERNORM2D_FWD_GEN_BLOBS}) diff --git a/example/ck_tile/03_gemm/CMakeLists.txt b/example/ck_tile/03_gemm/CMakeLists.txt index d2112a67bf..40547d0719 100644 --- a/example/ck_tile/03_gemm/CMakeLists.txt +++ b/example/ck_tile/03_gemm/CMakeLists.txt @@ -1,20 +1,25 @@ -add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) -add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp) -add_executable(tile_example_gemm_weight_preshuffle EXCLUDE_FROM_ALL gemm_weight_preshuffle.cpp) -add_executable(tile_example_gemm_reduce EXCLUDE_FROM_ALL gemm_splitk_two_stage_reduce.cpp) -add_executable(tile_example_gemm_splitk_two_stage EXCLUDE_FROM_ALL gemm_splitk_two_stage.cpp) -set(EXAMPLE_GEMM_COMPILE_OPTIONS) -set(EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS) -if(CK_USE_OCP_FP8) - list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a") + add_executable(tile_example_gemm_basic gemm_basic.cpp) + add_executable(tile_example_gemm_universal universal_gemm.cpp) + add_executable(tile_example_gemm_weight_preshuffle gemm_weight_preshuffle.cpp) + add_executable(tile_example_gemm_reduce gemm_splitk_two_stage_reduce.cpp) + add_executable(tile_example_gemm_splitk_two_stage gemm_splitk_two_stage.cpp) + set(EXAMPLE_GEMM_COMPILE_OPTIONS) + set(EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS) + if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) + endif() + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) + list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -Wno-unused-local-typedef) + list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -Wno-gnu-line-marker) + list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS --save-temps) + list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm -enable-noalias-to-md-conversion=0") + target_compile_options(tile_example_gemm_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(tile_example_gemm_universal PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(tile_example_gemm_weight_preshuffle PRIVATE ${EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS}) + target_compile_options(tile_example_gemm_reduce PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(tile_example_gemm_splitk_two_stage PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() -list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) -list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -Wno-unused-local-typedef) -list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -Wno-gnu-line-marker) -list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS --save-temps) -list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm -enable-noalias-to-md-conversion=0") -target_compile_options(tile_example_gemm_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) -target_compile_options(tile_example_gemm_universal PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) -target_compile_options(tile_example_gemm_weight_preshuffle PRIVATE ${EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS}) -target_compile_options(tile_example_gemm_reduce PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) -target_compile_options(tile_example_gemm_splitk_two_stage PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp index 492f94bae7..74edddb6c9 100644 --- a/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp +++ b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp @@ -10,6 +10,7 @@ #include #include "ck_tile/host.hpp" +#include "ck_tile/ops/common/utils.hpp" #include "ck_tile/ops/reduce.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" #include "gemm_utils.hpp" @@ -589,9 +590,10 @@ float invoke_gemm_splitk_two_stage(ck_tile::DeviceMem& a_m_k_dev_buf, << " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C << " kbatch=" << kbatch << " WorkspaceSize=" << workspace_size << " bytes" << " A_Layout=" << ALayout::name << " B_Layout =" << BLayout::name - << " C_Layout=" << CLayout::name << " A_Type=" << DataTypeTraits::name - << " B_Type=" << DataTypeTraits::name - << " C_Type=" << DataTypeTraits::name + << " C_Layout=" << CLayout::name + << " A_Type=" << ck_tile::DataTypeTraits::name + << " B_Type=" << ck_tile::DataTypeTraits::name + << " C_Type=" << ck_tile::DataTypeTraits::name << " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off") << " Persistent=" << (persistent ? "on" : "off") << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" << std::endl; @@ -683,7 +685,7 @@ int run_gemm_example_with_layouts_two_stage(ck_tile::ArgParser& arg_parser, if constexpr(preshuffle) { - ck_tile::HostTensor b_shuffle_host = shuffle_b(b_k_n); + ck_tile::HostTensor b_shuffle_host = ck_tile::shuffle_b(b_k_n); // shuffled buffer B for device implementation b_k_n_dev_buf.ToDevice(b_shuffle_host.data()); } diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index bdc37e5a94..b25aec101b 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -401,63 +401,6 @@ struct GemmTypeConfig using CDataType = int32_t; }; -template -struct DataTypeTraits; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp64"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "int32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "pk_int4_t"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "int8"; -}; - template struct PipelineTypeTraits; diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index c38ce7ce83..30cb3d3476 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/host/permute_pk_int4.hpp" #include "ck_tile/host/tensor_shuffle_utils.hpp" +#include "ck_tile/ops/common/utils.hpp" template static constexpr inline auto is_row_major(Layout layout_) @@ -284,12 +285,12 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser, if constexpr(GemmConfig::TiledMMAPermuteN) { std::cout << "Run with PermuteN" << std::endl; - return shuffle_b_permuteN(b_k_n); + return ck_tile::shuffle_b_permuteN(b_k_n); } else { std::cout << "Run without PermuteN" << std::endl; - return shuffle_b(b_k_n); + return ck_tile::shuffle_b(b_k_n); } }(); // shuffled buffer B for device implementation @@ -372,9 +373,10 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser, std::cout << "Run Gemm kernel with M=" << M << " N=" << N << " K=" << K << " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C << " A_Layout=" << ALayout::name << " B_Layout =" << BLayout::name - << " C_Layout=" << CLayout::name << " A_Type=" << DataTypeTraits::name - << " B_Type=" << DataTypeTraits::name - << " C_Type=" << DataTypeTraits::name + << " C_Layout=" << CLayout::name + << " A_Type=" << ck_tile::DataTypeTraits::name + << " B_Type=" << ck_tile::DataTypeTraits::name + << " C_Type=" << ck_tile::DataTypeTraits::name << " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off") << " Persistent=" << (persistent ? "on" : "off") << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; @@ -442,18 +444,18 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser, BDataType, CDataType, GemmConfig, - DataTypeTraits>(arg_parser.get_str("jsonfile"), - M, - N, - K, - stride_A, - stride_B, - stride_C, - persistent, - pass, - ave_time, - tflops, - gb_per_sec); + ck_tile::DataTypeTraits>(arg_parser.get_str("jsonfile"), + M, + N, + K, + stride_A, + stride_B, + stride_C, + persistent, + pass, + ave_time, + tflops, + gb_per_sec); } return pass; diff --git a/example/ck_tile/04_img2col/CMakeLists.txt b/example/ck_tile/04_img2col/CMakeLists.txt index 3864c9ed9d..74576755fb 100644 --- a/example/ck_tile/04_img2col/CMakeLists.txt +++ b/example/ck_tile/04_img2col/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # not using add_example_executable() to add this target, since we don't want this to have # to be included in "make all/install/check" -add_executable(tile_example_img2col EXCLUDE_FROM_ALL image_to_column.cpp) +add_executable(tile_example_img2col image_to_column.cpp) diff --git a/example/ck_tile/05_reduce/CMakeLists.txt b/example/ck_tile/05_reduce/CMakeLists.txt index 2f48bb85a5..715ed35394 100644 --- a/example/ck_tile/05_reduce/CMakeLists.txt +++ b/example/ck_tile/05_reduce/CMakeLists.txt @@ -1,9 +1,12 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + set(EXAMPLE_REDUCE "tile_example_reduce") # not using add_example_executable() to add this target, since we don't want this to have # to be included in "make all/install/check" message(DEBUG "adding example ${EXAMPLE_REDUCE}") -add_executable(${EXAMPLE_REDUCE} EXCLUDE_FROM_ALL reduce.cpp) +add_executable(${EXAMPLE_REDUCE} reduce.cpp) target_include_directories(${EXAMPLE_REDUCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) set(EXAMPLE_REDUCE_COMPILE_OPTIONS) @@ -16,4 +19,4 @@ target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${EXAMPLE_REDUCE_COMPILE_OPTION # by cmake will print too many files, execvp: /bin/sh: Argument list too long # however, this property may affect global # TODO: consider codegen a makefile by us -set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) \ No newline at end of file +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) diff --git a/example/ck_tile/05_reduce/reduce.cpp b/example/ck_tile/05_reduce/reduce.cpp index f1509bfeef..677065c78d 100644 --- a/example/ck_tile/05_reduce/reduce.cpp +++ b/example/ck_tile/05_reduce/reduce.cpp @@ -6,21 +6,6 @@ #include "ck_tile/utility/json_dump.hpp" #include -template -struct DataTypeTraits; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf16"; -}; - auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; @@ -145,7 +130,7 @@ bool run(const ck_tile::ArgParser& arg_parser) if(arg_parser.get_int("json") == 1) { - dump_reduce_json_results( + dump_reduce_json_results( arg_parser.get_str("jsonfile"), N, C, H, W, pass, ave_time, 0, gb_per_sec); } diff --git a/example/ck_tile/06_permute/CMakeLists.txt b/example/ck_tile/06_permute/CMakeLists.txt index 327fceb685..fbd7415c65 100644 --- a/example/ck_tile/06_permute/CMakeLists.txt +++ b/example/ck_tile/06_permute/CMakeLists.txt @@ -1,6 +1,9 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # not using add_example_executable() to add this target, since we don't want this to have # to be included in "make all/install/check" -add_executable(tile_example_permute EXCLUDE_FROM_ALL permute.cpp) +add_executable(tile_example_permute permute.cpp) if(NOT DEFINED PERMUTE_USE_ALTERNATIVE_IMPL) # set(PERMUTE_USE_ALTERNATIVE_IMPL false) diff --git a/example/ck_tile/09_topk_softmax/CMakeLists.txt b/example/ck_tile/09_topk_softmax/CMakeLists.txt index b43b989792..cce2c53ba4 100644 --- a/example/ck_tile/09_topk_softmax/CMakeLists.txt +++ b/example/ck_tile/09_topk_softmax/CMakeLists.txt @@ -1,4 +1,7 @@ -add_executable(tile_example_topk_softmax EXCLUDE_FROM_ALL topk_softmax.cpp topk_softmax_api.cpp) +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +add_executable(tile_example_topk_softmax topk_softmax.cpp topk_softmax_api.cpp) target_include_directories(tile_example_topk_softmax PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/) set(EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS) diff --git a/example/ck_tile/10_rmsnorm2d/CMakeLists.txt b/example/ck_tile/10_rmsnorm2d/CMakeLists.txt index 878f668f91..af64f2c742 100644 --- a/example/ck_tile/10_rmsnorm2d/CMakeLists.txt +++ b/example/ck_tile/10_rmsnorm2d/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + set(RMSNORM2D_FWD_KNOWN_APIS "fwd;bwd") set(RMSNORM2D_FWD_ENABLE_APIS "fwd" CACHE STRING "semicolon-separated list of APIs to generate (${RMSNORM2D_FWD_KNOWN_APIS}) & link, or \"all\".") @@ -26,7 +29,7 @@ add_custom_command( set(TILE_RMSNORM2D_FWD "tile_rmsnorm2d_fwd") message(DEBUG "adding ${TILE_RMSNORM2D_FWD}") -add_executable(${TILE_RMSNORM2D_FWD} EXCLUDE_FROM_ALL rmsnorm2d_fwd.cpp) +add_executable(${TILE_RMSNORM2D_FWD} rmsnorm2d_fwd.cpp) target_include_directories(${TILE_RMSNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_sources(${TILE_RMSNORM2D_FWD} PRIVATE ${RMSNORM2D_FWD_GEN_BLOBS}) @@ -38,7 +41,7 @@ list(APPEND TILE_RMSNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno target_compile_options(${TILE_RMSNORM2D_FWD} PRIVATE ${TILE_RMSNORM2D_FWD_COMPILE_OPTIONS}) set(EXAMPLE_RMSNORM2D_FWD "tile_example_rmsnorm2d_fwd") -add_executable(${EXAMPLE_RMSNORM2D_FWD} EXCLUDE_FROM_ALL example_rmsnorm2d_fwd.cpp) +add_executable(${EXAMPLE_RMSNORM2D_FWD} example_rmsnorm2d_fwd.cpp) target_compile_options(${EXAMPLE_RMSNORM2D_FWD} PRIVATE ${TILE_RMSNORM2D_FWD_COMPILE_OPTIONS}) # TODO: we have to turn off this global prop, otherwise the progress bar generated diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/CMakeLists.txt b/example/ck_tile/11_add_rmsnorm2d_rdquant/CMakeLists.txt index 7d56dd1fe3..cb8f547b9e 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/CMakeLists.txt +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/CMakeLists.txt @@ -1,9 +1,12 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + set(TILE_ADD_RMSNORM2D_RDQUANT_FWD "tile_add_rmsnorm2d_rdquant_fwd") # not using add_example_executable() to add this target, since we don't want this to have # to be included in "make all/install/check" message(DEBUG "adding ${TILE_ADD_RMSNORM2D_RDQUANT_FWD}") file(GLOB INSTANCE_SRCS instances/*.cpp) -add_executable(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} EXCLUDE_FROM_ALL add_rmsnorm2d_rdquant_fwd.cpp) +add_executable(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} add_rmsnorm2d_rdquant_fwd.cpp) target_include_directories(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_sources(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${INSTANCE_SRCS}) @@ -15,7 +18,7 @@ list(APPEND TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS -Wno-undefined-func-t target_compile_options(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS}) set(EXAMPLE_ADD_RMSNORM2D_RDQUANT_FWD "tile_example_add_rmsnorm2d_rdquant_fwd") -add_executable(${EXAMPLE_ADD_RMSNORM2D_RDQUANT_FWD} EXCLUDE_FROM_ALL example_add_rmsnorm2d_rdquant_fwd.cpp) +add_executable(${EXAMPLE_ADD_RMSNORM2D_RDQUANT_FWD} example_add_rmsnorm2d_rdquant_fwd.cpp) target_compile_options(${EXAMPLE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS}) # TODO: we have to turn off this global prop, otherwise the progress bar generated diff --git a/example/ck_tile/12_smoothquant/CMakeLists.txt b/example/ck_tile/12_smoothquant/CMakeLists.txt index 52f10b8d51..c52c947913 100644 --- a/example/ck_tile/12_smoothquant/CMakeLists.txt +++ b/example/ck_tile/12_smoothquant/CMakeLists.txt @@ -1,8 +1,11 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + function (add_smoothquant_example TARGET_NAME MAIN_SRC) message(DEBUG "adding ${TARGET_NAME}") # not using add_example_executable() to add target, since we don't want this to have # to be included in "make all/install/check" - add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL ${MAIN_SRC}) + add_executable(${TARGET_NAME} ${MAIN_SRC}) target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) foreach(source IN LISTS ARGN) diff --git a/example/ck_tile/13_moe_sorting/CMakeLists.txt b/example/ck_tile/13_moe_sorting/CMakeLists.txt index 09f3e4ac4e..ee8c6c9996 100644 --- a/example/ck_tile/13_moe_sorting/CMakeLists.txt +++ b/example/ck_tile/13_moe_sorting/CMakeLists.txt @@ -1,4 +1,7 @@ -add_executable(tile_example_moe_sorting EXCLUDE_FROM_ALL moe_sorting.cpp moe_sorting_api.cpp) +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +add_executable(tile_example_moe_sorting moe_sorting.cpp moe_sorting_api.cpp) target_include_directories(tile_example_moe_sorting PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/) set(EXAMPLE_MOE_SORTING_COMPILE_OPTIONS) diff --git a/example/ck_tile/14_moe_smoothquant/CMakeLists.txt b/example/ck_tile/14_moe_smoothquant/CMakeLists.txt index 6b848bda2a..38f94d17c6 100644 --- a/example/ck_tile/14_moe_smoothquant/CMakeLists.txt +++ b/example/ck_tile/14_moe_smoothquant/CMakeLists.txt @@ -1,8 +1,11 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + function (add_moe_smoothquant_example TARGET_NAME MAIN_SRC) message(DEBUG "adding ${TARGET_NAME}") # not using add_example_executable() to add target, since we don't want this to have # to be included in "make all/install/check" - add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL ${MAIN_SRC}) + add_executable(${TARGET_NAME} ${MAIN_SRC}) target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) foreach(source IN LISTS ARGN) @@ -22,4 +25,3 @@ endfunction(add_moe_smoothquant_example TARGET_NAME MAIN_SRC) file(GLOB INSTANCE_SRCS instances/*.cpp) add_moe_smoothquant_example(tile_example_moe_smoothquant moe_smoothquant.cpp ${INSTANCE_SRCS}) - diff --git a/example/ck_tile/15_fused_moe/CMakeLists.txt b/example/ck_tile/15_fused_moe/CMakeLists.txt index 78ec754528..a1159f5699 100644 --- a/example/ck_tile/15_fused_moe/CMakeLists.txt +++ b/example/ck_tile/15_fused_moe/CMakeLists.txt @@ -1,19 +1,22 @@ -set(TILE_EXAPMLE_FUSED_MOE "tile_example_fused_moe") -# not using add_example_executable() to add this target, since we don't want this to have -# to be included in "make all/install/check" -message(DEBUG "adding ${TILE_EXAPMLE_FUSED_MOE}") -file(GLOB INSTANCE_SRCS instances/*.cpp) -add_executable(${TILE_EXAPMLE_FUSED_MOE} EXCLUDE_FROM_ALL main.cpp) -target_include_directories(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -target_sources(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${INSTANCE_SRCS}) +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT -set(TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS) +if(GPU_TARGETS MATCHES "gfx94|gfx95") + set(TILE_EXAMPLE_FUSED_MOE "tile_example_fused_moe") + message(DEBUG "adding ${TILE_EXAMPLE_FUSED_MOE}") + file(GLOB INSTANCE_SRCS instances/*.cpp) + add_executable(${TILE_EXAMPLE_FUSED_MOE} main.cpp) + target_include_directories(${TILE_EXAMPLE_FUSED_MOE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) + target_sources(${TILE_EXAMPLE_FUSED_MOE} PRIVATE ${INSTANCE_SRCS}) -# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations -list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) -list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -DCK_TILE_BUFFER_LOAD_AGPR=1) # TODO: enable load to a -list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=4) # rta -# list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -mllvm -greedy-reverse-local-assignment=1) -# list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) + set(TILE_EXAMPLE_FUSED_MOE_COMPILE_OPTIONS) -target_compile_options(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS}) + # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations + list(APPEND TILE_EXAMPLE_FUSED_MOE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) + list(APPEND TILE_EXAMPLE_FUSED_MOE_COMPILE_OPTIONS -DCK_TILE_BUFFER_LOAD_AGPR=1) # TODO: enable load to a + list(APPEND TILE_EXAMPLE_FUSED_MOE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=4) # rta + # list(APPEND TILE_EXAMPLE_FUSED_MOE_COMPILE_OPTIONS -mllvm -greedy-reverse-local-assignment=1) + # list(APPEND TILE_EXAMPLE_FUSED_MOE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) + + target_compile_options(${TILE_EXAMPLE_FUSED_MOE} PRIVATE ${TILE_EXAMPLE_FUSED_MOE_COMPILE_OPTIONS}) +endif() diff --git a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp index 04ad882200..d80fed7e8c 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp @@ -402,22 +402,6 @@ float fused_moesorting_mp(fused_moesorting_trait t, using ms_index_t = ck_tile::index_t; using ms_weight_type = float; - auto maybe_clear_workspace = [=](const ck_tile::stream_config& s_) { - if(t.clear_workspace_inside_api) - { - if(is_local_token) - { - auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(true, 1024, 1); - k(s_); - } - else - { - auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(false, 1024, 1); - k(s_); - } - } - }; - if(a.tokens < 2048) { if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) > diff --git a/example/ck_tile/16_batched_gemm/CMakeLists.txt b/example/ck_tile/16_batched_gemm/CMakeLists.txt index 78e78c6b04..3b1cc18298 100644 --- a/example/ck_tile/16_batched_gemm/CMakeLists.txt +++ b/example/ck_tile/16_batched_gemm/CMakeLists.txt @@ -1 +1,4 @@ -add_executable(tile_example_batched_gemm EXCLUDE_FROM_ALL batched_gemm.cpp) +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +add_executable(tile_example_batched_gemm batched_gemm.cpp) diff --git a/example/ck_tile/17_grouped_gemm/CMakeLists.txt b/example/ck_tile/17_grouped_gemm/CMakeLists.txt index bbfb2df006..bf52e0c3f4 100644 --- a/example/ck_tile/17_grouped_gemm/CMakeLists.txt +++ b/example/ck_tile/17_grouped_gemm/CMakeLists.txt @@ -1,12 +1,17 @@ -add_executable(tile_example_grouped_gemm EXCLUDE_FROM_ALL grouped_gemm.cpp) -add_executable(tile_example_quant_grouped_gemm EXCLUDE_FROM_ALL quant_grouped_gemm.cpp) -add_executable(tile_example_grouped_gemm_preshuffle EXCLUDE_FROM_ALL grouped_gemm_preshuffle.cpp) -add_executable(tile_example_grouped_gemm_multi_d EXCLUDE_FROM_ALL grouped_gemm_multi_d.cpp) -set(EXAMPLE_GEMM_COMPILE_OPTIONS) -if(CK_USE_OCP_FP8) - list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +if(GPU_TARGETS MATCHES "gfx94|gfx95") + add_executable(tile_example_grouped_gemm grouped_gemm.cpp) + add_executable(tile_example_quant_grouped_gemm quant_grouped_gemm.cpp) + add_executable(tile_example_grouped_gemm_preshuffle grouped_gemm_preshuffle.cpp) + add_executable(tile_example_grouped_gemm_multi_d grouped_gemm_multi_d.cpp) + set(EXAMPLE_GEMM_COMPILE_OPTIONS) + if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) + endif() + target_compile_options(tile_example_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(tile_example_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(tile_example_grouped_gemm_multi_d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(tile_example_quant_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() -target_compile_options(tile_example_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) -target_compile_options(tile_example_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) -target_compile_options(tile_example_grouped_gemm_multi_d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) -target_compile_options(tile_example_quant_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) \ No newline at end of file diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc index a620964eaf..390a54644b 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc @@ -233,7 +233,8 @@ int run_grouped_gemm_example_with_layouts(int argc, // Perform preshuffle for B tensor if constexpr(GemmConfig::Preshuffle) { - ck_tile::HostTensor b_shuffle_host = shuffle_b(b_k_n_tensors[i]); + ck_tile::HostTensor b_shuffle_host = + ck_tile::shuffle_b(b_k_n_tensors[i]); b_k_n_dev_buf.push_back(std::make_unique(b_shuffle_host)); } else diff --git a/example/ck_tile/18_flatmm/CMakeLists.txt b/example/ck_tile/18_flatmm/CMakeLists.txt index 43789750d0..0fd819d552 100644 --- a/example/ck_tile/18_flatmm/CMakeLists.txt +++ b/example/ck_tile/18_flatmm/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + set(SUPPORTED_GPUS gfx908 gfx90a gfx942 gfx950) set(has_supported_gpu FALSE) @@ -9,18 +12,6 @@ foreach(gpu IN LISTS GPU_TARGETS) endforeach() if(has_supported_gpu) - add_executable(tile_example_flatmm_basic EXCLUDE_FROM_ALL flatmm_basic.cpp) - add_executable(tile_example_mixed_prec_flatmm EXCLUDE_FROM_ALL mixed_prec/mixed_prec_flatmm.cpp) - add_executable(tile_example_moe_flatmm EXCLUDE_FROM_ALL moe_flatmm.cpp) - add_executable(tile_example_a16w4_moe_flatmm EXCLUDE_FROM_ALL mixed_prec/a16w4_moe_flatmm.cpp) - add_executable(tile_example_grouped_flatmm EXCLUDE_FROM_ALL grouped_flatmm.cpp) - - include(mxgemm/mx_flatmm_instance.cmake) - mx_flatmm_instance_generate(EXAMPLE_MX_FLATMM_FILES) - message(STATUS "Generated MX FlatMM kernel files: ${EXAMPLE_MX_FLATMM_FILES}") - add_executable(tile_example_mx_flatmm EXCLUDE_FROM_ALL mxgemm/mx_flatmm.cpp ${EXAMPLE_MX_FLATMM_FILES}) - target_include_directories(tile_example_mx_flatmm PRIVATE mxgemm) - # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations # ... because they are auto-generated set(EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-undefined-func-template) @@ -30,11 +21,28 @@ if(has_supported_gpu) list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() + add_executable(tile_example_flatmm_basic flatmm_basic.cpp) target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) - target_compile_options(tile_example_mixed_prec_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) - target_compile_options(tile_example_moe_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) - target_compile_options(tile_example_a16w4_moe_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) - target_compile_options(tile_example_grouped_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) - target_compile_options(tile_example_mx_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) # TODO: 950 only -endif() + add_executable(tile_example_moe_flatmm moe_flatmm.cpp) + target_compile_options(tile_example_moe_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) + + add_executable(tile_example_grouped_flatmm grouped_flatmm.cpp) + target_compile_options(tile_example_grouped_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) + + if (GPU_TARGETS MATCHES "gfx95") + add_executable(tile_example_mixed_prec_flatmm mixed_prec/mixed_prec_flatmm.cpp) + target_compile_options(tile_example_mixed_prec_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) + + add_executable(tile_example_a16w4_moe_flatmm mixed_prec/a16w4_moe_flatmm.cpp) + target_compile_options(tile_example_a16w4_moe_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) + + include(mxgemm/mx_flatmm_instance.cmake) + mx_flatmm_instance_generate(EXAMPLE_MX_FLATMM_FILES) + message(STATUS "Generated MX FlatMM kernel files: ${EXAMPLE_MX_FLATMM_FILES}") + + add_executable(tile_example_mx_flatmm mxgemm/mx_flatmm.cpp ${EXAMPLE_MX_FLATMM_FILES}) + target_include_directories(tile_example_mx_flatmm PRIVATE mxgemm) + target_compile_options(tile_example_mx_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) + endif() +endif() diff --git a/example/ck_tile/18_flatmm/flatmm_basic.hpp b/example/ck_tile/18_flatmm/flatmm_basic.hpp index 47211bdbbc..ae1fa22bb0 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.hpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.hpp @@ -136,38 +136,6 @@ struct GemmBasicTypeConfig using CDataType = ck_tile::half_t; }; -template -struct DataTypeTraits; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf8"; -}; -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp64"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp16"; -}; - template struct is_8bit_type : std::bool_constant || std::is_same_v> diff --git a/example/ck_tile/18_flatmm/moe_flatmm.hpp b/example/ck_tile/18_flatmm/moe_flatmm.hpp index b464aaa73a..47d969fadb 100644 --- a/example/ck_tile/18_flatmm/moe_flatmm.hpp +++ b/example/ck_tile/18_flatmm/moe_flatmm.hpp @@ -134,38 +134,6 @@ struct GemmBasicTypeConfig using CDataType = ck_tile::half_t; }; -template -struct DataTypeTraits; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf8"; -}; -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp64"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp16"; -}; - template struct is_8bit_type : std::bool_constant || std::is_same_v> diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp index 8d3fd146bc..0134465347 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp @@ -158,7 +158,7 @@ auto create_args(int argc, char* argv[]) .insert("stride_c", "0", "Tensor C stride") .insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") .insert( - "mx_prec", "fp4xfp4", "data type for activation and weight, support: fp6xfp6, fp8xfp8") + "mx_prec", "fp4xfp4", "data type for activation and weight, support: fp4xfp4, fp8xfp8") .insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp index 9c12509d59..f177ef04ca 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp @@ -75,7 +75,7 @@ float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, HasHotLoop, TailNum>; - using MXFlatmmPipeline = ck_tile::MXF4FlatmmPipelineAGmemBGmemCRegV1; + using MXFlatmmPipeline = ck_tile::MXFlatmmPipelineAGmemBGmemCRegV1; using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner b_shuffle_host = - shuffle_b(b_k_n_tensor); + ck_tile::shuffle_b(b_k_n_tensor); std::unique_ptr a_m_k_dev_buf( std::make_unique(a_m_k_tensor.get_element_space_size_in_bytes())); @@ -431,7 +431,7 @@ int run_masked_grouped_flatmm_example_with_layouts( assert(N % N_Warp_Tile == 0 && "N must be divisible by N_Warp_Tile for contiguous grouped gemm"); ck_tile::HostTensor b_shuffle_host = - shuffle_b(b_k_n_tensor); + ck_tile::shuffle_b(b_k_n_tensor); std::unique_ptr a_m_k_dev_buf( std::make_unique(a_m_k_tensor.get_element_space_size_in_bytes())); diff --git a/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc b/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc index f5259ea87b..c58ddc2584 100644 --- a/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc @@ -302,10 +302,6 @@ int run_moe_gemm_example_with_layouts(int argc, static_cast(per_token_scale_dev_buf.GetDeviceBuffer()), static_cast(per_channel_scale_dev_buf.GetDeviceBuffer())); - const float max_accumulated_value = - *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); - const auto rtol_atol = calculate_rtol_atol( - K, 1 /*kbatch*/, max_accumulated_value); c_m_n_ref_buf->FromDevice(c_m_n_host_ref.data()); const float rtol = std::is_same_v && IsInputGemm ? 1e-3 : 1e-2; diff --git a/example/ck_tile/19_gemm_multi_d/CMakeLists.txt b/example/ck_tile/19_gemm_multi_d/CMakeLists.txt index 4ecfec7ccf..16b167fa39 100644 --- a/example/ck_tile/19_gemm_multi_d/CMakeLists.txt +++ b/example/ck_tile/19_gemm_multi_d/CMakeLists.txt @@ -1,4 +1,7 @@ -add_executable(tile_example_gemm_multi_d_fp16 EXCLUDE_FROM_ALL gemm_multi_d_fp16.cpp) +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +add_executable(tile_example_gemm_multi_d_fp16 gemm_multi_d_fp16.cpp) set(EXAMPLE_GEMM_COMPILE_OPTIONS) if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) diff --git a/example/ck_tile/20_grouped_convolution/CMakeLists.txt b/example/ck_tile/20_grouped_convolution/CMakeLists.txt index ed2a2a0dd6..7fcca37bd9 100644 --- a/example/ck_tile/20_grouped_convolution/CMakeLists.txt +++ b/example/ck_tile/20_grouped_convolution/CMakeLists.txt @@ -1,20 +1,25 @@ -set(EXAMPLE_CONV_COMPILE_OPTIONS) -list(APPEND EXAMPLE_CONV_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT -add_executable(tile_example_grouped_conv_fwd EXCLUDE_FROM_ALL grouped_convolution_forward.cpp) -target_compile_options(tile_example_grouped_conv_fwd PRIVATE ${EXAMPLE_CONV_COMPILE_OPTIONS}) +if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a") + set(EXAMPLE_CONV_COMPILE_OPTIONS) + list(APPEND EXAMPLE_CONV_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) -add_executable(tile_example_grouped_conv_fwd_large_tensor EXCLUDE_FROM_ALL grouped_convolution_forward_large_tensor.cpp) -target_compile_options(tile_example_grouped_conv_fwd_large_tensor PRIVATE ${EXAMPLE_CONV_COMPILE_OPTIONS}) + add_executable(tile_example_grouped_conv_fwd grouped_convolution_forward.cpp) + target_compile_options(tile_example_grouped_conv_fwd PRIVATE ${EXAMPLE_CONV_COMPILE_OPTIONS}) -add_executable(tile_example_grouped_conv_fwd_bias_clamp EXCLUDE_FROM_ALL grouped_convolution_forward_bias_clamp.cpp) -target_compile_options(tile_example_grouped_conv_fwd_bias_clamp PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_executable(tile_example_grouped_conv_fwd_large_tensor grouped_convolution_forward_large_tensor.cpp) + target_compile_options(tile_example_grouped_conv_fwd_large_tensor PRIVATE ${EXAMPLE_CONV_COMPILE_OPTIONS}) -add_executable(tile_example_grouped_conv_bwd_weight EXCLUDE_FROM_ALL grouped_convolution_backward_weight.cpp) -target_compile_options(tile_example_grouped_conv_bwd_weight PRIVATE ${EXAMPLE_CONV_COMPILE_OPTIONS}) + add_executable(tile_example_grouped_conv_fwd_bias_clamp grouped_convolution_forward_bias_clamp.cpp) + target_compile_options(tile_example_grouped_conv_fwd_bias_clamp PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) -add_executable(tile_example_grouped_conv_bwd_weight_two_stage EXCLUDE_FROM_ALL grouped_convolution_backward_weight_two_stage.cpp) -target_compile_options(tile_example_grouped_conv_bwd_weight_two_stage PRIVATE ${EXAMPLE_CONV_COMPILE_OPTIONS}) + add_executable(tile_example_grouped_conv_bwd_weight grouped_convolution_backward_weight.cpp) + target_compile_options(tile_example_grouped_conv_bwd_weight PRIVATE ${EXAMPLE_CONV_COMPILE_OPTIONS}) -add_executable(tile_example_grouped_conv_bwd_data EXCLUDE_FROM_ALL grouped_convolution_backward_data.cpp) -target_compile_options(tile_example_grouped_conv_bwd_data PRIVATE ${EXAMPLE_CONV_COMPILE_OPTIONS}) + add_executable(tile_example_grouped_conv_bwd_weight_two_stage grouped_convolution_backward_weight_two_stage.cpp) + target_compile_options(tile_example_grouped_conv_bwd_weight_two_stage PRIVATE ${EXAMPLE_CONV_COMPILE_OPTIONS}) + + add_executable(tile_example_grouped_conv_bwd_data grouped_convolution_backward_data.cpp) + target_compile_options(tile_example_grouped_conv_bwd_data PRIVATE ${EXAMPLE_CONV_COMPILE_OPTIONS}) +endif() diff --git a/example/ck_tile/20_grouped_convolution/conv_configs.hpp b/example/ck_tile/20_grouped_convolution/conv_configs.hpp index 238e3810f0..620b505820 100644 --- a/example/ck_tile/20_grouped_convolution/conv_configs.hpp +++ b/example/ck_tile/20_grouped_convolution/conv_configs.hpp @@ -254,27 +254,6 @@ struct ConvTypeConfig using OutDataType = ck_tile::bf16_t; }; -template -struct DataTypeTraits; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf16"; -}; - template struct PipelineTypeTraits; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp index af9820df2d..d26aaa98e3 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp @@ -50,9 +50,17 @@ int run_grouped_conv_fwd_example(int argc, char* argv[]) int main(int argc, char* argv[]) { + try + { #if CK_TILE_USE_WMMA - return !run_grouped_conv_fwd_example(argc, argv); + return !run_grouped_conv_fwd_example(argc, argv); #else - return !run_grouped_conv_fwd_example(argc, argv); + return !run_grouped_conv_fwd_example(argc, argv); #endif + } + catch(const std::runtime_error& e) + { + std::cerr << "Runtime error: " << e.what() << '\n'; + return EXIT_FAILURE; + } } diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp index f168d36cac..d154d8710b 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp @@ -101,7 +101,6 @@ struct GroupedConvolutionForwardInvoker const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - float ave_time{0}; using TransformType = ck_tile::TransformConvFwdToGemm(const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_, + const auto enable_split_image_) { constexpr bool has_hot_loop_v = has_hot_loop_.value; constexpr auto tail_number_v = tail_number_.value; constexpr auto scheduler = ConvConfig::Scheduler; constexpr auto memory_operation = memory_operation_.value; + constexpr bool EnableSplitImage = enable_split_image_.value; using GroupedConvTraitsType = std::conditional_t(Kernel{}, grids, blocks, 0, kargs)); - - return ave_time; }; // ===================================================================== @@ -369,28 +368,33 @@ struct GroupedConvolutionForwardInvoker // ===================================================================== if(use_split_image) { - // Use split-image kernel (Kernel) const auto RunSplitImage = [&](const auto has_hot_loop_, const auto tail_number_) { if(args.k_batch == 1) - Run.template operator()(has_hot_loop_, tail_number_, MemoryOpSet{}); + return Run( + has_hot_loop_, tail_number_, MemoryOpSet{}, ck_tile::bool_constant{}); else - Run.template operator()(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{}); + return Run(has_hot_loop_, + tail_number_, + MemoryOpAtomicAdd{}, + ck_tile::bool_constant{}); }; - BaseGemmPipeline::TailHandler(RunSplitImage, has_hot_loop, tail_num); + return BaseGemmPipeline::TailHandler(RunSplitImage, has_hot_loop, tail_num); } else { - // Use regular kernel (Kernel) const auto RunRegular = [&](const auto has_hot_loop_, const auto tail_number_) { if(args.k_batch == 1) - Run.template operator()(has_hot_loop_, tail_number_, MemoryOpSet{}); + return Run(has_hot_loop_, + tail_number_, + MemoryOpSet{}, + ck_tile::bool_constant{}); else - Run.template operator()( - has_hot_loop_, tail_number_, MemoryOpAtomicAdd{}); + return Run(has_hot_loop_, + tail_number_, + MemoryOpAtomicAdd{}, + ck_tile::bool_constant{}); }; - BaseGemmPipeline::TailHandler(RunRegular, has_hot_loop, tail_num); + return BaseGemmPipeline::TailHandler(RunRegular, has_hot_loop, tail_num); } - - return ave_time; } }; diff --git a/example/ck_tile/21_elementwise/CMakeLists.txt b/example/ck_tile/21_elementwise/CMakeLists.txt index dc5242f4a1..6874b5ecb6 100644 --- a/example/ck_tile/21_elementwise/CMakeLists.txt +++ b/example/ck_tile/21_elementwise/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # Elementwise example targets 2D inputs set(TARGET_NAME_2D_INPUT tile_example_elementwise) add_executable(${TARGET_NAME_2D_INPUT} elementwise_example.cpp) diff --git a/example/ck_tile/22_gemm_multi_abd/CMakeLists.txt b/example/ck_tile/22_gemm_multi_abd/CMakeLists.txt index f382e0cf45..54c6507950 100644 --- a/example/ck_tile/22_gemm_multi_abd/CMakeLists.txt +++ b/example/ck_tile/22_gemm_multi_abd/CMakeLists.txt @@ -1 +1,4 @@ -add_executable(tile_example_gemm_multi_abd_fp16 EXCLUDE_FROM_ALL gemm_multi_abd_fp16.cpp) +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +add_executable(tile_example_gemm_multi_abd_fp16 gemm_multi_abd_fp16.cpp) diff --git a/example/ck_tile/35_batched_transpose/CMakeLists.txt b/example/ck_tile/35_batched_transpose/CMakeLists.txt index a08fcebb74..ea751e10ce 100644 --- a/example/ck_tile/35_batched_transpose/CMakeLists.txt +++ b/example/ck_tile/35_batched_transpose/CMakeLists.txt @@ -1,9 +1,13 @@ -set(TARGET_NAME tile_example_batched_transpose) -add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL batched_transpose_example.cpp batched_transpose_api.cpp) -target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/) +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT -# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations -list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) -# list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) -target_compile_options(tile_example_batched_transpose PRIVATE ${EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS}) +if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a") + set(TARGET_NAME tile_example_batched_transpose) + add_executable(${TARGET_NAME} batched_transpose_example.cpp batched_transpose_api.cpp) + target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/) + # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations + list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) + # list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) + target_compile_options(tile_example_batched_transpose PRIVATE ${EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS}) +endif() diff --git a/example/ck_tile/36_pooling/CMakeLists.txt b/example/ck_tile/36_pooling/CMakeLists.txt index 425a8c83ba..08d68a2488 100644 --- a/example/ck_tile/36_pooling/CMakeLists.txt +++ b/example/ck_tile/36_pooling/CMakeLists.txt @@ -1,8 +1,10 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + set(EXAMPLE_POOL_3D "tile_example_pool3d") message(DEBUG "adding example ${EXAMPLE_POOL_3D}") -add_executable(${EXAMPLE_POOL_3D} EXCLUDE_FROM_ALL pool3d.cpp) +add_executable(${EXAMPLE_POOL_3D} pool3d.cpp) target_include_directories(${EXAMPLE_POOL_3D} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_compile_options(${EXAMPLE_POOL_3D} PRIVATE ${EXAMPLE_POOL_COMPILE_OPTIONS}) - diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index 40a4166126..d6b63dc47b 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + set(EXAMPLE_GEMM_COMPILE_OPTIONS) if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) @@ -7,7 +10,7 @@ list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") set(EXE_NAME tile_example_gemm_quant) - add_executable(${EXE_NAME} EXCLUDE_FROM_ALL + add_executable(${EXE_NAME} gemm_quant.cpp gemm_aquant_quantgrouped.cpp gemm_aquant_quantgrouped_preshufflequant.cpp diff --git a/example/ck_tile/38_block_scale_gemm/README.md b/example/ck_tile/38_block_scale_gemm/README.md index b81c5de7ab..3a30c2bad3 100644 --- a/example/ck_tile/38_block_scale_gemm/README.md +++ b/example/ck_tile/38_block_scale_gemm/README.md @@ -74,7 +74,7 @@ User need to select correct mapping of config for each quant mode: |:--------|:-----:|:-----:|-------| | For selecting AQuant | aquant | gemm_aquant_quantgrouped.cpp| GemmConfigQuantDecode | | For selecting AQuant with Preshuffle quant | aquant | gemm_aquant_quantgrouped_preshufflequant.cpp | GemmConfigPreshuffleQuantDecode | -| For selecting BQuant | bquant | gemm_bquant_quantgrouped_.cpp| GemmConfigQuantDecode (or) GemmConfigBQuantPrefill | +| For selecting BQuant | bquant | gemm_bquant_quantgrouped_.cpp| GemmConfigQuantDecode (or) GemmConfigQuantPrefill | | For selecting BQuant with Preshuffle quant | bquant | gemm_bquant_quantgrouped_preshufflequant.cpp| GemmConfigPreshuffleQuantDecode (or) GemmConfigPreshuffleBQuantPrefill | | For selecting PreShuffle B with BQuant | bquant | gemm_bquant_quantgrouped_preshuffleb.cpp| GemmConfigPreshuffleB_BQuant_Decode (or) GemmConfigPreshuffleB_BQuant_Prefill | For selecting PreShuffle B with preshuffle BQuant | bquant | gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp |GemmConfigPreshuffleB_PreshuffleBQuant_Decode (or) GemmConfigPreshuffleB_PreshuffleBQuant_Prefill diff --git a/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp b/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp index 0f75976602..ad1a4e0d10 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp @@ -6,6 +6,10 @@ template using GemmConfig = GemmConfigQuantDecode; +// GemmConfigQuantPrefill is also supported for aquant grouped quantization +// template +// using GemmConfig = GemmConfigQuantPrefill; + void aquant_quantgrouped_instance_factory( std::unordered_map>& lut) { diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp index 2dbae9e42c..61fd65960f 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp @@ -4,7 +4,7 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigBQuantPrefill; +using GemmConfig = GemmConfigQuantPrefill; #define RUN_GEMM_EXAMPLE_PREC_TYPE \ run_gemm_example_prec_type, \ diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp index 40cf88624b..1d471068eb 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp @@ -4,7 +4,7 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigBQuantPrefill; +using GemmConfig = GemmConfigQuantPrefill; #define RUN_GEMM_EXAMPLE_PREC_TYPE \ run_gemm_example_prec_type, \ diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp index 5c21d5aa16..280029033b 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp @@ -4,7 +4,7 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigBQuantPrefill; +using GemmConfig = GemmConfigQuantPrefill; #define RUN_GEMM_EXAMPLE_PREC_TYPE \ run_gemm_example_prec_type, \ diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp index 80b9a2765e..a277c864bb 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp @@ -4,7 +4,7 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigBQuantPrefill; +using GemmConfig = GemmConfigQuantPrefill; #define RUN_GEMM_EXAMPLE_PREC_TYPE \ run_gemm_example_prec_type, \ diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 95b0a73ede..116661c157 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -221,7 +221,7 @@ struct GemmConfigPreshuffleB_PreshuffleBQuant_Prefill }; template -struct GemmConfigBQuantPrefill : public GemmConfigBase +struct GemmConfigQuantPrefill : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; @@ -237,13 +237,13 @@ struct GemmConfigBQuantPrefill : public GemmConfigBase }; template -struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigBQuantPrefill +struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigQuantPrefill { static constexpr bool PreshuffleQuant = true; }; template -struct GemmConfigBQuantPrefill_Wmma : public GemmConfigBQuantPrefill +struct GemmConfigBQuantPrefill_Wmma : public GemmConfigQuantPrefill { static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; @@ -280,60 +280,3 @@ struct GemmQuantTypeConfig using AccDataType = float; using CDataType = CDataType_; }; - -template -struct DataTypeTraits; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp64"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "int32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "pk_int4_t"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "int8"; -}; diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 2162141156..396a54c7c2 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -11,15 +11,19 @@ #include #include "ck_tile/core/config.hpp" +#include "ck_tile/ops/common/utils.hpp" #include "ck_tile/host.hpp" #include "ck_tile/host/permute_pk_int4.hpp" #include "ck_tile/host/tensor_shuffle_utils.hpp" +#include "ck_tile/ops/gemm_quant.hpp" #include "gemm_utils.hpp" template ; @@ -65,12 +69,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str using BaseGemmPipeline = std::conditional_t< GemmConfig::PreshuffleB == true, ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2, - std::conditional_t< - QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::PreshuffleQuant == true, - ck_tile::BaseAQuantGemmPipelineAgBgCrCompV3, - std::conditional_t, - ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3>>>; + ck_tile::BaseGemmPipelineAgBgCrCompV3>; const ck_tile::index_t K_split = (args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile; @@ -129,9 +128,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str ck_tile::GemmPipelineAgBgCrCompV3, std::conditional_t< QuantMode == ck_tile::QuantType::AQuantGrouped, - std::conditional_t, - ck_tile::AQuantGemmPipelineAgBgCrMem>, + ck_tile::AQuantGemmPipelineAgBgCrCompV3, std::conditional_t, ck_tile::BQuantGemmPipelineAgBgCrCompV3>>>; @@ -287,7 +284,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, float ave_time = gemm_calc_quant( arg_parser, Row{}, Row{}, Col{}, Col{}, Row{}); } + + if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped && !GemmConfig::PreshuffleQuant) + { + if(a_layout == "R" && b_layout == "R") + { + return run_gemm_example_with_layouts( + arg_parser, Row{}, Row{}, Row{}, Col{}, Row{}); + } + else if(a_layout == "C" && b_layout == "R") + { + return run_gemm_example_with_layouts( + arg_parser, Col{}, Row{}, Row{}, Col{}, Row{}); + } + else if(a_layout == "C" && b_layout == "C") + { + return run_gemm_example_with_layouts( + arg_parser, Col{}, Col{}, Col{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported memory layout for the input matrices!"); + } + } + else { throw std::runtime_error("Unsupported memory layout for the input matrices!"); diff --git a/example/ck_tile/40_streamk_gemm/CMakeLists.txt b/example/ck_tile/40_streamk_gemm/CMakeLists.txt index 3b285a54b5..56af4ee9d3 100644 --- a/example/ck_tile/40_streamk_gemm/CMakeLists.txt +++ b/example/ck_tile/40_streamk_gemm/CMakeLists.txt @@ -1,5 +1,8 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if(GPU_TARGETS MATCHES "gfx9") - add_executable(tile_example_streamk_gemm_basic EXCLUDE_FROM_ALL streamk_gemm_basic.cpp) + add_executable(tile_example_streamk_gemm_basic streamk_gemm_basic.cpp) set(EXAMPLE_GEMM_COMPILE_OPTIONS) if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) diff --git a/example/ck_tile/40_streamk_gemm/gemm_utils.hpp b/example/ck_tile/40_streamk_gemm/gemm_utils.hpp index 37aeec868a..dad31ec637 100644 --- a/example/ck_tile/40_streamk_gemm/gemm_utils.hpp +++ b/example/ck_tile/40_streamk_gemm/gemm_utils.hpp @@ -54,39 +54,6 @@ struct StreamKGemmTypeConfig using CDataType = CDataType_; }; -template -struct DataTypeTraits; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf8"; -}; - auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; diff --git a/example/ck_tile/40_streamk_gemm/run_gemm_example.inc b/example/ck_tile/40_streamk_gemm/run_gemm_example.inc index 041acff509..d18ac2e68a 100644 --- a/example/ck_tile/40_streamk_gemm/run_gemm_example.inc +++ b/example/ck_tile/40_streamk_gemm/run_gemm_example.inc @@ -2,6 +2,8 @@ // SPDX-License-Identifier: MIT #pragma once +#include "ck_tile/ops/common/utils.hpp" + template static constexpr inline auto is_row_major(Layout) { @@ -79,12 +81,11 @@ invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, K, stride_A, stride_B, - stride_C, - reduction_strategy}; + stride_C}; std::tuple ave_time_and_batch; - if(args.reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic) + if(reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic) { ave_time_and_batch = gemm gemm(const ck_tile::StreamKHostArgs& args, } auto reset_data_buffers = [&]() { - if(ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic) + if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic) { // Clear the output C tensor results after each repetition of the kernel hipGetErrorString(hipMemsetAsync( args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); } - else if(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction) + else if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction) { // Reset sk flags to zero before each repetition of the kernel workspace_data.SetZero(); diff --git a/example/ck_tile/41_batched_contraction/CMakeLists.txt b/example/ck_tile/41_batched_contraction/CMakeLists.txt index 10b2e48cbf..43fa1821d8 100644 --- a/example/ck_tile/41_batched_contraction/CMakeLists.txt +++ b/example/ck_tile/41_batched_contraction/CMakeLists.txt @@ -1,4 +1,7 @@ -add_executable(tile_example_batched_contraction EXCLUDE_FROM_ALL batched_contraction.cpp) +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +add_executable(tile_example_batched_contraction batched_contraction.cpp) set(EXAMPLE_CONTRACTION_COMPILE_OPTIONS) if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_CONTRACTION_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) diff --git a/example/ck_tile/41_batched_contraction/batched_contraction.cpp b/example/ck_tile/41_batched_contraction/batched_contraction.cpp index d83cc70c62..6536894394 100644 --- a/example/ck_tile/41_batched_contraction/batched_contraction.cpp +++ b/example/ck_tile/41_batched_contraction/batched_contraction.cpp @@ -219,9 +219,7 @@ float batched_contraction(const ck_tile::BatchedContractionHostArgs Batch dimensions (default: \"1,2\")\n"; + std::cout << " -m_dims= M (row) dimensions (default: \"4,256\")\n"; + std::cout << " -n_dims= N (column) dimensions (default: \"16,128\")\n"; + std::cout << " -k_dims= K (contract) dims (default: \"64\")\n"; + std::cout << " -num_d= Number of D tensors (default: 2, range: 0-4)\n\n"; + + std::cout << "Custom Stride Arguments (for testing non-contiguous tensors):\n"; + std::cout << " -strides_a= A tensor strides (comma-separated, empty = auto)\n"; + std::cout << " -strides_b= B tensor strides (comma-separated, empty = auto)\n"; + std::cout << " -strides_e= E tensor strides (comma-separated, empty = auto)\n"; + std::cout << " -strides_ds= D tensors strides (semicolon-separated, empty = same as E)\n"; + std::cout << " Example: -strides_a=\"32768,128,1\" -strides_ds=\"512,2,1;1024,4,1\"\n\n"; + + std::cout << "Layout Arguments:\n"; + std::cout + << " -a_layout= A tensor layout (R=Row-major, C=Column-major, default: \"R\")\n"; + std::cout << " -b_layout= B tensor layout (default: \"C\")\n"; + std::cout << " -e_layout= E tensor layout (default: \"R\")\n\n"; + + std::cout << "Examples:\n"; + std::cout << " Single batch (12 batches of 256×128):\n"; + std::cout << " " << program_name + << " -g_dims=\"12\" -m_dims=\"256\" -n_dims=\"128\" -k_dims=\"64\"\n\n"; + + std::cout << " 2D batch grid (2×3=6 batches):\n"; + std::cout << " " << program_name + << " -g_dims=\"2,3\" -m_dims=\"128\" -n_dims=\"128\" -k_dims=\"64\"\n\n"; + + std::cout << " Multi-dimensional (flattened to M=128, N=128, K=128):\n"; + std::cout << " " << program_name + << " -g_dims=\"4\" -m_dims=\"8,16\" -n_dims=\"32,4\" -k_dims=\"16,8\"\n\n"; + + std::cout << "Other Options:\n"; + std::cout << " -v=<0|1> Validation (0=off, 1=on, default: 1)\n"; + std::cout << " -split_k= Split-K value (default: 1)\n"; + std::cout << " -warmup= Warmup iterations (default: 5)\n"; + std::cout << " -repeat= Benchmark iterations (default: 10)\n"; + std::cout << " -log=<0|1> Logging level (default: 1)\n"; + std::cout << " -help Show this help\n\n"; +} + auto create_args(int argc, char* argv[]) { + // Check for --help flag + for(int i = 1; i < argc; ++i) + { + std::string arg = argv[i]; + if(arg == "--help" || arg == "-h" || arg == "-help") + { + print_help(argv[0]); + std::exit(0); + } + } + ck_tile::ArgParser arg_parser; arg_parser.insert("m_dims", "4,256", "M dimensions separated by comma (e.g., '16,32' for 2D M)") .insert("n_dims", "16,128", "N dimensions separated by comma (e.g., '32,32' for 2D N)") .insert("k_dims", "64", "K dimensions separated by comma (e.g., '64,32' for 2D K)") .insert( "g_dims", "1,2", "G dimensions separated by comma (e.g., '4,2' for 2D, '2,3,4' for 3D)") - .insert("stride_a", "0", "Custom A tensor leading dimension stride (0 = auto)") - .insert("stride_b", "0", "Custom B tensor leading dimension stride (0 = auto)") - .insert("stride_e", "0", "Custom E tensor leading dimension stride (0 = auto)") + .insert("num_d", "2", "Number of D (auxiliary input) tensors") + .insert("strides_a", "", "A tensor strides (comma-separated, empty = auto/contiguous)") + .insert("strides_b", "", "B tensor strides (comma-separated, empty = auto/contiguous)") + .insert("strides_e", "", "E tensor strides (comma-separated, empty = auto/contiguous)") + .insert("strides_ds", + "", + "D tensors strides (semicolon-separated for multiple, empty = same as E)") .insert("a_layout", "R", "A tensor data layout - Row by default") .insert("b_layout", "C", "B tensor data layout - Col by default") .insert("e_layout", "R", "E tensor data layout - Row by default") diff --git a/example/ck_tile/41_batched_contraction/run_batched_contraction_example.inc b/example/ck_tile/41_batched_contraction/run_batched_contraction_example.inc index 9ebacdedd3..214b14633d 100644 --- a/example/ck_tile/41_batched_contraction/run_batched_contraction_example.inc +++ b/example/ck_tile/41_batched_contraction/run_batched_contraction_example.inc @@ -45,10 +45,10 @@ float invoke_batched_contraction_kernel( const void* b_full_dims_dev_buf, const std::array& ds_dev_buf, void* e_full_dims_dev_buf, - const std::vector& G_dims, - const std::vector& M_dims, - const std::vector& N_dims, - const std::vector& K_dims, + ck_tile::index_t num_g_dims, + ck_tile::index_t num_m_dims, + ck_tile::index_t num_n_dims, + ck_tile::index_t num_k_dims, const std::vector& A_dims, // [G0,G1,..,M0,M1,..,K0,K1,..] const std::vector& B_dims, // [G0,G1,..,N0,N1,..,K0,K1,..] const std::array, DsDataType::size()>& @@ -79,9 +79,8 @@ float invoke_batched_contraction_kernel( E_strides // E_strides ); - std::cout << "Calling batched_contraction with dimensions: G=" << G_dims.size() - << ", M=" << M_dims.size() << ", N=" << N_dims.size() << ", K=" << K_dims.size() - << std::endl; + std::cout << "Calling batched_contraction with dimensions: G=" << num_g_dims + << ", M=" << num_m_dims << ", N=" << num_n_dims << ", K=" << num_k_dims << std::endl; float ave_time = batched_contraction( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}, - G_dims.size(), // num_g_dims - M_dims.size(), // num_m_dims - N_dims.size(), // num_n_dims - K_dims.size() // num_k_dims - ); + num_g_dims, + num_m_dims, + num_n_dims, + num_k_dims); return ave_time; } -template +// C++17-compatible helper function to create array of HostTensors +namespace { +template +std::array, NumDTensor> +make_ds_host_tensors_impl(const std::array& descs, + std::index_sequence) +{ + return {ck_tile::HostTensor(descs[Is])...}; +} + +template +std::array, NumDTensor> +make_ds_host_tensors(const std::array& descs) +{ + return make_ds_host_tensors_impl(descs, + std::make_index_sequence{}); +} +} // anonymous namespace + +template int run_batched_contraction_example_with_layouts( int argc, char* argv[], @@ -122,8 +143,6 @@ int run_batched_contraction_example_with_layouts( std::vector N_dims = parse_dimensions(arg_parser.get_str("n_dims")); std::vector K_dims = parse_dimensions(arg_parser.get_str("k_dims")); - constexpr ck_tile::index_t NumDTensor = 2; - ck_tile::index_t G_total = calculate_total_elements(G_dims); ck_tile::index_t M_total = calculate_total_elements(M_dims); ck_tile::index_t N_total = calculate_total_elements(N_dims); @@ -148,13 +167,105 @@ int run_batched_contraction_example_with_layouts( return converted; }; - ck_tile::HostTensorDescriptor a_desc(A_dims); - ck_tile::HostTensorDescriptor b_desc(B_dims); - ck_tile::HostTensorDescriptor e_desc(E_dims); - std::array ds_descs; - for(ck_tile::index_t d = 0; d < NumDTensor; ++d) + // Get custom stride arguments + std::string strides_a_str = arg_parser.get_str("strides_a"); + std::string strides_b_str = arg_parser.get_str("strides_b"); + std::string strides_e_str = arg_parser.get_str("strides_e"); + std::string strides_ds_str = arg_parser.get_str("strides_ds"); + + // Create A descriptor with custom or default strides + ck_tile::HostTensorDescriptor a_desc; + if(!strides_a_str.empty()) { - ds_descs[d] = ck_tile::HostTensorDescriptor(Ds_dims[d], e_desc.get_strides()); + std::vector custom_a_strides = parse_dimensions(strides_a_str); + if(custom_a_strides.size() != A_dims.size()) + { + throw std::runtime_error("strides_a size must match A_dims size"); + } + std::vector a_strides_size_t(custom_a_strides.begin(), custom_a_strides.end()); + a_desc = ck_tile::HostTensorDescriptor(A_dims, a_strides_size_t); + std::cout << "Using custom strides for A (non-contiguous)" << std::endl; + } + else + { + a_desc = ck_tile::HostTensorDescriptor(A_dims); + } + + // Create B descriptor with custom or default strides + ck_tile::HostTensorDescriptor b_desc; + if(!strides_b_str.empty()) + { + std::vector custom_b_strides = parse_dimensions(strides_b_str); + if(custom_b_strides.size() != B_dims.size()) + { + throw std::runtime_error("strides_b size must match B_dims size"); + } + std::vector b_strides_size_t(custom_b_strides.begin(), custom_b_strides.end()); + b_desc = ck_tile::HostTensorDescriptor(B_dims, b_strides_size_t); + std::cout << "Using custom strides for B (non-contiguous)" << std::endl; + } + else + { + b_desc = ck_tile::HostTensorDescriptor(B_dims); + } + + // Create E descriptor with custom or default strides + ck_tile::HostTensorDescriptor e_desc; + if(!strides_e_str.empty()) + { + std::vector custom_e_strides = parse_dimensions(strides_e_str); + if(custom_e_strides.size() != E_dims.size()) + { + throw std::runtime_error("strides_e size must match E_dims size"); + } + std::vector e_strides_size_t(custom_e_strides.begin(), custom_e_strides.end()); + e_desc = ck_tile::HostTensorDescriptor(E_dims, e_strides_size_t); + std::cout << "Using custom strides for E (non-contiguous)" << std::endl; + } + else + { + e_desc = ck_tile::HostTensorDescriptor(E_dims); + } + // Create D descriptors with custom or default strides (default = same as E) + std::array ds_descs; + if(!strides_ds_str.empty()) + { + // Parse semicolon-separated stride vectors for multiple D tensors + std::vector> all_ds_strides; + std::stringstream ss(strides_ds_str); + std::string d_stride_str; + + while(std::getline(ss, d_stride_str, ';')) + { + all_ds_strides.push_back(parse_dimensions(d_stride_str)); + } + + if(all_ds_strides.size() != NumDTensor) + { + throw std::runtime_error("Number of D stride vectors must match num_d=" + + std::to_string(NumDTensor)); + } + + std::cout << "Using custom strides for D tensors (non-contiguous)" << std::endl; + for(ck_tile::index_t d = 0; d < NumDTensor; ++d) + { + if(all_ds_strides[d].size() != E_dims.size()) + { + throw std::runtime_error("D tensor " + std::to_string(d) + + " stride size must match E_dims size"); + } + std::vector d_strides_size_t(all_ds_strides[d].begin(), + all_ds_strides[d].end()); + ds_descs[d] = ck_tile::HostTensorDescriptor(Ds_dims[d], d_strides_size_t); + } + } + else + { + // Default: use same strides as E + for(ck_tile::index_t d = 0; d < NumDTensor; ++d) + { + ds_descs[d] = ck_tile::HostTensorDescriptor(Ds_dims[d], e_desc.get_strides()); + } } std::vector A_strides = convert_strides(a_desc.get_strides()); @@ -201,11 +312,8 @@ int run_batched_contraction_example_with_layouts( ck_tile::HostTensor<::BDataType> b_full_dims_host(b_desc); ck_tile::HostTensor<::EDataType> e_full_dims_host(e_desc); - std::vector> ds_full_dims_host; - for(int d = 0; d < NumDTensor; ++d) - { - ds_full_dims_host.emplace_back(ck_tile::HostTensor<::DDataType>(ds_descs[d])); - } + // Construct array of HostTensors - C++17 compatible + auto ds_full_dims_host = make_ds_host_tensors<::DDataType, NumDTensor>(ds_descs); ck_tile::FillUniformDistribution<::ADataType>{-5.f, 5.f, std::nullopt}(a_full_dims_host); ck_tile::FillUniformDistribution<::BDataType>{-5.f, 5.f, std::nullopt}(b_full_dims_host); @@ -260,10 +368,10 @@ int run_batched_contraction_example_with_layouts( b_full_dims_dev_buf.GetDeviceBuffer(), ds_ptr_buf, e_full_dims_dev_buf.GetDeviceBuffer(), - G_dims, - M_dims, - N_dims, - K_dims, + G_dims.size(), + M_dims.size(), + N_dims.size(), + K_dims.size(), A_dims, B_dims, Ds_dims, @@ -316,20 +424,25 @@ int run_batched_contraction_example_with_layouts( auto start_time = std::chrono::high_resolution_clock::now(); - calculate_reference_flat_indexing(a_full_dims_host, - b_full_dims_host, - ds_full_dims_host, - e_full_dims_host_ref, - G_total, - M_total, - N_total, - K_total, - CDEElementWise{}); + ck_tile::compute_reference_batched_contraction(a_full_dims_host, + b_full_dims_host, + ds_full_dims_host, + e_full_dims_host_ref, + G_total, + M_total, + N_total, + K_total, + CDEElementWise{}, + G_dims, + M_dims, + N_dims, + K_dims); auto end_time = std::chrono::high_resolution_clock::now(); auto duration = @@ -387,15 +500,45 @@ int run_batched_contraction_example(int argc, char* argv[]) if(!result) return -1; + // Get NumDTensor to dispatch at runtime + const int num_d = arg_parser.get_int("num_d"); + using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; std::string a_layout = arg_parser.get_str("a_layout"); std::string b_layout = arg_parser.get_str("b_layout"); + // Runtime dispatch based on num_d value if(a_layout == "R" && b_layout == "C") { - return run_batched_contraction_example_with_layouts(argc, argv, Row{}, Col{}, Row{}, Row{}); + // Dispatch to appropriate template instantiation based on runtime num_d + switch(num_d) + { + case 0: + std::cout << "Running with 0 D tensors" << std::endl; + return run_batched_contraction_example_with_layouts( + argc, argv, Row{}, Col{}, Row{}, Row{}); + case 1: + std::cout << "Running with 1 D tensor" << std::endl; + return run_batched_contraction_example_with_layouts( + argc, argv, Row{}, Col{}, Row{}, Row{}); + case 2: + std::cout << "Running with 2 D tensors" << std::endl; + return run_batched_contraction_example_with_layouts( + argc, argv, Row{}, Col{}, Row{}, Row{}); + case 3: + std::cout << "Running with 3 D tensors" << std::endl; + return run_batched_contraction_example_with_layouts( + argc, argv, Row{}, Col{}, Row{}, Row{}); + case 4: + std::cout << "Running with 4 D tensors" << std::endl; + return run_batched_contraction_example_with_layouts( + argc, argv, Row{}, Col{}, Row{}, Row{}); + default: + throw std::runtime_error("num_d must be between 0 and 4, got: " + + std::to_string(num_d)); + } } else { diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index bf11045a48..9646e93b4e 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + include_directories(AFTER ${CMAKE_CURRENT_LIST_DIR} ) diff --git a/experimental/builder/CMakeLists.txt b/experimental/builder/CMakeLists.txt index 103acbad55..95b41da40b 100644 --- a/experimental/builder/CMakeLists.txt +++ b/experimental/builder/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if(BUILD_TESTING) add_subdirectory(test) endif() diff --git a/experimental/builder/README.md b/experimental/builder/README.md index 141a34b9f9..aa7c7d969d 100644 --- a/experimental/builder/README.md +++ b/experimental/builder/README.md @@ -2,7 +2,7 @@ This directory contains the experimental builder feature for composable_kernel. -* Status: In development (October - November 2025) +* Status: In development (October - December 2025) ## Overview @@ -14,6 +14,10 @@ This project is a prototype for a more general builder pattern for all of compos - `include/ck_tile/builder/` Core builder headers and public API. +- `include/ck_tile/builder/reflect` + Reflection mechanism. +- `include/ck_tile/builder/factory` + Compile-time dispatch from builder descriptors to our exisitng specialized convolution kernel implementations. - `test/` Unit tests and example usage of the builder pattern. - `CMakeLists.txt` @@ -28,34 +32,49 @@ cmake -D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_BUILD_TYPE=Release \ - -D GPU_TARGETS="gfx942;gfx950" \ + -D GPU_TARGETS="gfx942" \ -D CK_EXPERIMENTAL_BUILDER=ON \ -D CMAKE_CXX_STANDARD=20 \ -G Ninja \ .. ``` -## Building and testing +## Building and Testing -During development, all CK Builder tests can be built with command +The builder test suite is organized into two main categories: + +### Smoke Tests (Fast Unit Tests) +Quick unit tests that verify the builder's internal logic without compiling GPU kernels. These complete in under 1 second total and are suitable for frequent execution during development. ```sh -ninja test_ckb_all +ninja smoke-builder ``` -To execute all tests, run +### Regression Tests (Integration Tests) +Integration tests that compile actual GPU kernels to verify that the builder generates valid, compilable code. These are more expensive than smoke tests (can take minutes to compile) but cover more fuctionality. +) ```sh -ls bin/test_ckb_* | xargs -n1 sh -c +ninja regression-builder ``` -Some tests involve building old CK convolution factories, which will take a long time. -Hence, one might want to build only single test targets. For example +### Running All Tests +To build and run the complete test suite: + +```sh +ninja check-builder +``` + +### Building Individual Tests +To build and run a specific test: ```sh ninja test_ckb_conv_builder && bin/test_ckb_conv_builder ``` -When adding new tests, please follow the convention where the CMake build target starts with a prefix `test_ckb`. -This allows us to filter out the CK Builder tests from the set full CK repository tests. -Also, the `test_ckb_all` target that builds all CK Builder tests relies on having the `test_ckb` prefix on the CMake build targets. +### Test Organization +- **Smoke tests**: Fast feedback during active development +- **Regression tests**: Thorough validation before submitting changes +- **Factory tests**: Expensive tests that build all MIOpen kernels (included in regression tests) + +When adding new tests, please follow the convention where the CMake build target starts with a prefix `test_ckb`. This allows filtering of CK Builder tests from the full CK repository test suite. diff --git a/experimental/builder/include/ck_tile/builder/CMakeLists.txt b/experimental/builder/include/ck_tile/builder/CMakeLists.txt index 45723c3680..e39003dfcc 100644 --- a/experimental/builder/include/ck_tile/builder/CMakeLists.txt +++ b/experimental/builder/include/ck_tile/builder/CMakeLists.txt @@ -1 +1,4 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + #Empty placeholder until we add library code. diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index 9b82909bbb..ecb1ff933e 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -84,7 +84,7 @@ concept LdsTransferDescriptor = requires(T t) { // LDS). template concept EpilogueDescriptor = requires(T t) { - { t.m_per_wave_per_shuffle } -> std::convertible_to; + { t.m_xdl_per_wave_per_shuffle } -> std::convertible_to; { t.n_per_wave_per_shuffle } -> std::convertible_to; { t.scalar_per_vector } -> std::convertible_to; }; @@ -256,41 +256,4 @@ concept SpecifiesDlEpilogue = requires { { T::transfer.c.epilogue } -> DlEpilogueDescriptor; }; -/******************************************** */ -/* Concepts for the different device ops */ -/******************************************** */ - -template -concept DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && - SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesFwdConcSpecialization && - SpecifiesGemmSpecialization && SpecifiesBlockGemm; - -template -concept DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && - SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesFwdConcSpecialization && - SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && - SpecifiesNumGroupsToMerge && SpecifiesLoopScheduler; - -template -concept DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseWmmaGemm && - SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && - SpecifiesSourceAccessOrder && SpecifiesFwdConcSpecialization && - SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && SpecifiesLoopScheduler; - -template -concept DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK = - ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesFwdConcSpecialization && - SpecifiesGemmSpecialization && SpecifiesDlThreadConfig && SpecifiesDlThreadCluster && - SpecifiesDlBlockTransfer && SpecifiesDlEpilogue; - -template -concept DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor = - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle && - SpecifiesLargeTensorSupport; - } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp index 9f20e9d37f..093916dac3 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_limits.hpp @@ -18,8 +18,8 @@ concept InputVectorTransferLimits = requires { // Limits for output vector transfer. template concept OutputVectorTransferLimits = requires { - requires Value.scalar_per_vector > 0 && Value.m_per_wave_per_shuffle > 0 && - Value.n_per_wave_per_shuffle > 0; + requires Value.scalar_per_vector > 0 && Value.m_xdl_per_wave_per_shuffle > 0 && + Value.n_xdl_per_wave_per_shuffle > 0; }; // Limits for access order. Must be a permutation of {0, 1, 2}. diff --git a/experimental/builder/include/ck_tile/builder/conv_builder.hpp b/experimental/builder/include/ck_tile/builder/conv_builder.hpp index bf63bc83f6..85efc28eb1 100644 --- a/experimental/builder/include/ck_tile/builder/conv_builder.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_builder.hpp @@ -6,7 +6,7 @@ #include #include -#include "ck_tile/builder/conv_factory.hpp" +#include "ck_tile/builder/factory/conv_dispatcher.hpp" #include "ck_tile/builder/versions.hpp" namespace ck_tile::builder { @@ -15,7 +15,7 @@ namespace ck_tile::builder { * @brief Top-level builder for creating convolution kernel instances. * * This struct serves as the main entry point for generating a convolution kernel. - * It uses a factory pattern based on the provided signature, algorithm, and version + * It uses a dispatcher function based on the provided signature, algorithm, and version * to construct the appropriate kernel instance. * * @tparam SIGNATURE The convolution signature, which describes the mathematical functionality of @@ -30,9 +30,8 @@ template ; - // Output: The kernel class. - using Instance = Factory::Instance; + // Output: The kernel class instance created via the dispatcher. + using Instance = decltype(factory::make_conv_instance()); }; } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp deleted file mode 100644 index 6f8e50db15..0000000000 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ /dev/null @@ -1,1050 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -// A factory for instantiating CK convolution kernels. -// -// This file translates a semantic description of a convolution operation -// (`ConvSignatureDescriptor` and `ConvAlgorithmDescriptor`) into specific, -// low-level template arguments required by the underlying CK device-level -// kernel implementations. This abstraction enables more complex build -// time logic and simplifies the kernel specification. -// -// Key Components: -// -// Template Metaprogram: -// - ConvFactory: The main factory, with specializations for different -// convolution directions (currently only forward). -// -// Template Metaprogram Helpers: -// - ConvTensorLayouts: Maps layout enums to CK layout types for different -// spatial dimensions (2D/3D) and directions. -// - ConvTensorTypes: Maps data type enums (FP16, BF16, FP32) to C++ types used by CK. -// - ConvPassThroughOps: Hard-coded pass-through element-wise operations. -// - ConvSpec: Encapsulates convolution and GEMM specialization enums. -// -// `constexpr` Helper Functions: -// - SetThreadBlockInfo: Determines thread block dimensions and tile sizes. -// - SetConvTuningInfo: Sets XDL and AK1/BK1 tuning parameters. -// - SetFwdConvBlockTransfer: Configures A/B tensor block transfer parameters. -// - SetCBlockTransfer: Configures C tensor block transfer parameters. -// - SetBlockGemmPipelineVersion: Maps pipeline version enum to CK types. -// -// The primary entry point is the `ConvFactory` struct, which is currently -// specialized for forward convolutions and produces instances of -// DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3. - -#pragma once - -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp" -// WORKAROUND: Macro namespace collision in upstream CK device operation headers. -// device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp (line 41) and -// device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp (line 51) both define -// GridwiseGemmTemplateParameters macro without #undef, causing redefinition errors. -// Use pragma push/pop to isolate the Large_Tensor header's macro scope. -#pragma push_macro("GridwiseGemmTemplateParameters") -#ifdef GridwiseGemmTemplateParameters -#undef GridwiseGemmTemplateParameters -#endif -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp" -#pragma pop_macro("GridwiseGemmTemplateParameters") -#include "ck_tile/builder/conv_signature_concepts.hpp" -#include "ck_tile/builder/conv_algorithm_concepts.hpp" -#include "ck_tile/builder/conv_algorithm_limits.hpp" -#include "ck_tile/builder/builder_utils.hpp" -#include "ck_tile/builder/types.hpp" -#include "ck_tile/builder/versions.hpp" - -#include "ck_tile/builder/conv_signature_utils.hpp" - -namespace ck_tile::builder::factory_internal { - -// Type mappings from the builder FwdGroupConvLayout enum classes to the CK tensor data types. -template - requires(ConvSpatialDim && ValidConvLayoutForSpatialDim) -struct ConvTensorLayouts -{ - // This will trigger if a specialization for the given layout is not found. - // We should always catch this in an earlier validation check. - using Layout = decltype(LayoutValue); - static_assert(sizeof(Layout) == 0, - "Internal error. Unsupported layout for convolution factory."); -}; - -// 1D Forward Convolution Layout Specializations -template <> -struct ConvTensorLayouts -{ - using ALayout = ck::tensor_layout::convolution::NWGC; - using BLayout = ck::tensor_layout::convolution::GKXC; - using DsLayout = ck::Tuple<>; - using ELayout = ck::tensor_layout::convolution::NWGK; -}; - -template <> -struct ConvTensorLayouts -{ - using ALayout = ck::tensor_layout::convolution::NGCW; - using BLayout = ck::tensor_layout::convolution::GKXC; - using DsLayout = ck::Tuple<>; - using ELayout = ck::tensor_layout::convolution::NGKW; -}; - -template <> -struct ConvTensorLayouts -{ - using ALayout = ck::tensor_layout::convolution::GNWC; - using BLayout = ck::tensor_layout::convolution::GKXC; - using DsLayout = ck::Tuple<>; - using ELayout = ck::tensor_layout::convolution::GNWK; -}; - -template <> -struct ConvTensorLayouts -{ - using ALayout = ck::tensor_layout::convolution::NGCW; - using BLayout = ck::tensor_layout::convolution::GKCX; - using DsLayout = ck::Tuple<>; - using ELayout = ck::tensor_layout::convolution::NGKW; -}; - -template <> -struct ConvTensorLayouts -{ - using ALayout = ck::tensor_layout::convolution::NGCHW; - using BLayout = ck::tensor_layout::convolution::GKYXC; - using DsLayout = ck::Tuple<>; - using ELayout = ck::tensor_layout::convolution::NGKHW; -}; - -template <> -struct ConvTensorLayouts -{ - using ALayout = ck::tensor_layout::convolution::NHWGC; - using BLayout = ck::tensor_layout::convolution::GKYXC; - using DsLayout = ck::Tuple<>; - using ELayout = ck::tensor_layout::convolution::NHWGK; -}; - -template <> -struct ConvTensorLayouts -{ - using ALayout = ck::tensor_layout::convolution::GNHWC; - using BLayout = ck::tensor_layout::convolution::GKYXC; - using DsLayout = ck::Tuple<>; - using ELayout = ck::tensor_layout::convolution::GNHWK; -}; - -template <> -struct ConvTensorLayouts -{ - using ALayout = ck::tensor_layout::convolution::NGCHW; - using BLayout = ck::tensor_layout::convolution::GKCYX; - using DsLayout = ck::Tuple<>; - using ELayout = ck::tensor_layout::convolution::NGKHW; -}; - -template <> -struct ConvTensorLayouts -{ - using ALayout = ck::tensor_layout::convolution::NGCDHW; - using BLayout = ck::tensor_layout::convolution::GKCZYX; - using DsLayout = ck::Tuple<>; - using ELayout = ck::tensor_layout::convolution::NGKDHW; -}; - -template <> -struct ConvTensorLayouts -{ - using ALayout = ck::tensor_layout::convolution::NDHWGC; - using BLayout = ck::tensor_layout::convolution::GKZYXC; - using DsLayout = ck::Tuple<>; - using ELayout = ck::tensor_layout::convolution::NDHWGK; -}; - -template <> -struct ConvTensorLayouts -{ - using ALayout = ck::tensor_layout::convolution::GNDHWC; - using BLayout = ck::tensor_layout::convolution::GKZYXC; - using DsLayout = ck::Tuple<>; - using ELayout = ck::tensor_layout::convolution::GNDHWK; -}; - -template -consteval auto GetTensorLayout() -{ - - if constexpr(SPATIAL_DIM == 1) - { - return factory_internal::ConvTensorLayouts{}; - } - else if constexpr(SPATIAL_DIM == 2) - { - return factory_internal::ConvTensorLayouts{}; - } - else if constexpr(SPATIAL_DIM == 3) - { - return factory_internal::ConvTensorLayouts{}; - } - else - { - static_assert(false, "Unsupported spatial dimension for convolution layout."); - } -} - -// Type mappings from builder convolution data type to CK tensor types. -template -struct ConvTensorTypes -{ - // This will trigger if a specialization for the given DataType is not found. - // We should always catch this in an earlier validation check. - static_assert(sizeof(UnsupportedEnumValue) == 0, - "Internal error. Unsupported data type for convolution factory."); -}; - -template <> -struct ConvTensorTypes -{ - using ADataType = ck::half_t; - using AComputeType = ck::half_t; - using BDataType = ck::half_t; - using BComputeType = ck::half_t; - using CShuffleDataType = ck::half_t; - using DsDataTypes = ck::Tuple<>; - using AccDataType = float; - using EDataType = ck::half_t; -}; - -template <> -struct ConvTensorTypes -{ - using ADataType = ck::bhalf_t; - using AComputeType = ck::bhalf_t; - using BDataType = ck::bhalf_t; - using BComputeType = ck::bhalf_t; - using CShuffleDataType = ck::bhalf_t; - using DsDataTypes = ck::Tuple<>; - using AccDataType = float; - using EDataType = ck::bhalf_t; -}; - -template <> -struct ConvTensorTypes -{ - using ADataType = float; - using AComputeType = float; - using BDataType = float; - using BComputeType = float; - using CShuffleDataType = float; - using DsDataTypes = ck::Tuple<>; - using AccDataType = float; - using EDataType = float; -}; - -template <> -struct ConvTensorTypes -{ - using ADataType = int8_t; - using AComputeType = int8_t; - using BDataType = int8_t; - using BComputeType = int8_t; - using CShuffleDataType = int8_t; - using DsDataTypes = ck::Tuple<>; - using AccDataType = int32_t; - using EDataType = int8_t; -}; - -template <> -struct ConvTensorTypes -{ - using ADataType = ck::f8_t; - using AComputeType = ck::f8_t; - using BDataType = ck::f8_t; - using BComputeType = ck::f8_t; - using CShuffleDataType = ck::f8_t; - using DsDataTypes = ck::Tuple<>; - using AccDataType = float; - using EDataType = ck::f8_t; -}; - -template -struct ElementwiseOps -{ - // This will trigger if a specialization for the given DataType is not found. - // We should always catch this in an earlier validation check. - static_assert(sizeof(UnsupportedEnumValue) == 0, - "Internal error. Unsupported elementwise operation for convolution factory."); -}; - -template <> -struct ElementwiseOps -{ - using AElementwiseOp = ck::tensor_operation::element_wise::PassThrough; - using BElementwiseOp = ck::tensor_operation::element_wise::PassThrough; - using CDEElementwiseOp = ck::tensor_operation::element_wise::PassThrough; -}; - -template <> -struct ElementwiseOps -{ - using AElementwiseOp = ck::tensor_operation::element_wise::PassThrough; - using BElementwiseOp = ck::tensor_operation::element_wise::PassThrough; - using CDEElementwiseOp = ck::tensor_operation::element_wise::Scale; -}; - -// The algorithm specializations for the convolution and GEMM. -template - requires( - std::is_same_v) -struct ConvSpec -{ - CONV_ENUM conv_spec; - ck::tensor_operation::device::GemmSpecialization gemm_spec; -}; - -// Deduction guide for ConvSpec to simplify brace initialization. -template -ConvSpec(CONV_ENUM, GEMM_ENUM) -> ConvSpec; - -struct BlockGemmSpec -{ - ck::BlockGemmPipelineVersion pipeline_version; - ck::BlockGemmPipelineScheduler scheduler; -}; - -template -consteval BlockGemmSpec SetBlockGemm() -{ - constexpr auto& BG = ALGORITHM.block_gemm; - - ck::BlockGemmPipelineScheduler scheduler; - ck::BlockGemmPipelineVersion version; - - switch(BG.scheduler) - { - case PipelineScheduler::INTRAWAVE: scheduler = ck::BlockGemmPipelineScheduler::Intrawave; break; - case PipelineScheduler::INTERWAVE: scheduler = ck::BlockGemmPipelineScheduler::Interwave; break; - case PipelineScheduler::DEFAULT: throw "Block GEMM scheduler must be Intrawave or Interwave."; - default: throw "Unknown PipelineScheduler"; - } - - switch(BG.pipeline_version) - { - case PipelineVersion::V1: version = ck::BlockGemmPipelineVersion::v1; break; - case PipelineVersion::V2: version = ck::BlockGemmPipelineVersion::v2; break; - case PipelineVersion::V3: version = ck::BlockGemmPipelineVersion::v3; break; - case PipelineVersion::V4: version = ck::BlockGemmPipelineVersion::v4; break; - case PipelineVersion::V5: version = ck::BlockGemmPipelineVersion::v5; break; - case PipelineVersion::WEIGHT_ONLY: - throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM."; - default: throw "Unknown PipelineVersion"; - } - - return BlockGemmSpec{.pipeline_version = version, .scheduler = scheduler}; -} - -// Block info for a convolution. -struct MNK -{ - size_t m{}; - size_t n{}; - size_t k{}; -}; -struct ConvBlock -{ - size_t block_size = 0; - MNK per_block = {}; -}; - -template -constexpr ConvBlock SetThreadBlockInfo() -{ - constexpr auto& TB = ALGORITHM.thread_block; - return ConvBlock{.block_size = TB.block_size, - .per_block = {.m = TB.tile_size.m, .n = TB.tile_size.n, .k = TB.tile_size.k}}; -} - -// Block transfer parameters for A or B tensor. -struct BlockTransfer -{ - ck::Array thread_cluster_dims = {0, 0, 0}; // k0, m, k1 - ck::Array thread_cluster_order = {0, 0, 0}; - ck::Array src_access_order = {0, 0, 0}; - size_t src_vector_dim = 0; - size_t src_scalar_per_vector = 0; - size_t lds_dst_scalar_per_vector = 0; - bool is_direct_load = false; - bool lds_padding = false; -}; - -template -constexpr BlockTransfer SetFwdConvBlockTransfer() -{ - constexpr auto& TCL = TRANSFER.block_transfer; - constexpr auto& TCO = TRANSFER.block_transfer_access_order; - constexpr auto& SAO = TRANSFER.src_access_order; - constexpr auto& LDS = TRANSFER.lds_transfer; - - BlockTransfer block_transfer{.thread_cluster_dims = {TCL.k0, TCL.m_n, TCL.k1}, - .thread_cluster_order = {TCO.order[0], TCO.order[1], TCO.order[2]}, - .src_access_order = {SAO.order[0], SAO.order[1], SAO.order[2]}, - .src_vector_dim = LDS.src_vector_dim, - .src_scalar_per_vector = LDS.src_scalar_per_vector, - .lds_dst_scalar_per_vector = LDS.lds_dst_scalar_per_vector, - .is_direct_load = LDS.is_direct_load, - .lds_padding = LDS.lds_padding}; - return block_transfer; -} - -// Block transfer parameters for C tensor. -struct CBlockTransfer -{ - size_t m_per_wave_per_shuffle = 0; - size_t n_per_wave_per_shuffle = 0; - ck::Array thread_cluster_dims = {0, 0, 0, 0}; - size_t scalar_per_vector = 0; -}; - -template -constexpr CBlockTransfer SetCBlockTransfer() -{ - constexpr auto& TCL = ALGORITHM.transfer.c.thread_cluster_dims; - constexpr auto& EPC = ALGORITHM.transfer.c.epilogue; - CBlockTransfer block_transfer{.m_per_wave_per_shuffle = EPC.m_per_wave_per_shuffle, - .n_per_wave_per_shuffle = EPC.n_per_wave_per_shuffle, - .thread_cluster_dims = - { - TCL.m_block, - TCL.m_wave_per_xdl, - TCL.n_block, - TCL.n_wave_per_xdl, - }, - .scalar_per_vector = EPC.scalar_per_vector}; - return block_transfer; -} - -template -consteval ck::LoopScheduler SetLoopScheduler() -{ - constexpr auto loop_scheduler = ALGORITHM.loop_scheduler; - using ck_loop_sched = ck::LoopScheduler; - switch(loop_scheduler) - { - case PipelineScheduler::DEFAULT: return ck_loop_sched::Default; - case PipelineScheduler::INTERWAVE: return ck_loop_sched::Interwave; - case PipelineScheduler::INTRAWAVE: throw "LoopScheduler must be either DEFAULT or INTERWAVE."; - default: throw "Unknown PipelineScheduler"; - } -} - -template -consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion() -{ - constexpr auto pipeline_version = ALGORITHM.gridwise_gemm.pipeline_version; - using ck_pipeline = ck::PipelineVersion; - switch(pipeline_version) - { - case PipelineVersion::V1: return ck_pipeline::v1; - case PipelineVersion::V2: return ck_pipeline::v2; - case PipelineVersion::V3: throw "PipelineVersion::V3 is used only for stream-K."; - case PipelineVersion::V4: return ck_pipeline::v4; - case PipelineVersion::V5: throw "PipelineVersion::V5 cannot be used for gridwise GEMM."; - case PipelineVersion::WEIGHT_ONLY: return ck_pipeline::weight_only; - default: throw "Unknown GridwiseGemmPipelineVersion"; - } -} - -template -consteval ck::tensor_operation::device::GemmSpecialization SetGemmSpecialization() -{ - constexpr auto gemm_spec = ALGORITHM.gemm_specialization; - using ck_gemm_spec = ck::tensor_operation::device::GemmSpecialization; - - switch(gemm_spec) - { - case GemmSpecialization::Default: return ck_gemm_spec::Default; - case GemmSpecialization::MPadding: return ck_gemm_spec::MPadding; - case GemmSpecialization::NPadding: return ck_gemm_spec::NPadding; - case GemmSpecialization::KPadding: return ck_gemm_spec::KPadding; - case GemmSpecialization::MNPadding: return ck_gemm_spec::MNPadding; - case GemmSpecialization::MKPadding: return ck_gemm_spec::MKPadding; - case GemmSpecialization::NKPadding: return ck_gemm_spec::NKPadding; - case GemmSpecialization::MNKPadding: return ck_gemm_spec::MNKPadding; - case GemmSpecialization::OPadding: return ck_gemm_spec::OPadding; - case GemmSpecialization::MOPadding: return ck_gemm_spec::MOPadding; - case GemmSpecialization::NOPadding: return ck_gemm_spec::NOPadding; - case GemmSpecialization::KOPadding: return ck_gemm_spec::KOPadding; - case GemmSpecialization::MNOPadding: return ck_gemm_spec::MNOPadding; - case GemmSpecialization::MKOPadding: return ck_gemm_spec::MKOPadding; - case GemmSpecialization::NKOPadding: return ck_gemm_spec::NKOPadding; - case GemmSpecialization::MNKOPadding: return ck_gemm_spec::MNKOPadding; - default: throw "Unknown GemmSpecialization"; - } -} - -template -consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion() -{ - constexpr auto version = ALGORITHM.pipeline_version; - using ck_pipeline = ck::BlockGemmPipelineVersion; - switch(version) - { - case PipelineVersion::V1: return ck_pipeline::v1; - case PipelineVersion::V2: return ck_pipeline::v2; - case PipelineVersion::V3: return ck_pipeline::v3; - case PipelineVersion::V4: return ck_pipeline::v4; - case PipelineVersion::V5: return ck_pipeline::v5; - default: throw "Unknown block GEMM PipelineVersion"; - } -} - -template -consteval ck::tensor_operation::device::ConvolutionForwardSpecialization SetFwdConvSpecialization() -{ - constexpr auto specialization = ALGORITHM.fwd_specialization; - using ck_conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization; - switch(specialization) - { - case ConvFwdSpecialization::DEFAULT: return ck_conv_spec::Default; - case ConvFwdSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0; - case ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0; - case ConvFwdSpecialization::FILTER_3x3: return ck_conv_spec::Filter3x3; - default: throw "Unknown ConvFwdSpecialization"; - } -} - -} // namespace ck_tile::builder::factory_internal - -namespace ck_tile::builder { - -// Primary template for the convolution factory. -template -struct ConvFactory -{ - // This will trigger if a specialization for the given convolution direction is not found. - // We should always catch this in an earlier validation check. - static_assert(false, "Unsupported device operation."); -}; - -// Factory specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 instance -// of a grouped forward convolution kernel. -template - requires ConvDirectionIsForward && - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 -struct ConvFactory -{ - static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; - using Layouts = decltype(factory_internal::GetTensorLayout()); - using Types = factory_internal::ConvTensorTypes; - using Ops = factory_internal::ElementwiseOps()>; - using AlgorithmType = decltype(ALGORITHM); - - static_assert(ALGORITHM.transfer.a.lds_transfer.is_direct_load == - ALGORITHM.transfer.b.lds_transfer.is_direct_load, - "A and B block transfers must both be direct load or not."); - - static constexpr bool IS_DIRECT_LOAD = ALGORITHM.transfer.a.lds_transfer.is_direct_load; - static constexpr auto FWD_CONV_SPECIALIZATION = - factory_internal::SetFwdConvSpecialization(); - static constexpr auto GEMM_SPECIALIZATION = - factory_internal::SetGemmSpecialization(); - static constexpr factory_internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION, - .gemm_spec = GEMM_SPECIALIZATION}; - - static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; - static constexpr auto A_BLOCK_TRANSFER = - factory_internal::SetFwdConvBlockTransfer(); - static constexpr auto B_BLOCK_TRANSFER = - factory_internal::SetFwdConvBlockTransfer(); - static constexpr auto C_BLOCK_TRANSFER = - factory_internal::SetCBlockTransfer(); - static constexpr auto BLOCK_GEMM = factory_internal::SetBlockGemm(); - - // Check limits for the algorithm parameters. - // TODO: Add more limits checks as needed. - static_assert(InputVectorTransferLimits); - static_assert(InputVectorTransferLimits); - static_assert(OutputVectorTransferLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - - // The forward convolution kernel class instance. - using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< - SPATIAL_DIM, - typename Layouts::ALayout, - typename Layouts::BLayout, - typename Layouts::DsLayout, - typename Layouts::ELayout, - typename Types::ADataType, - typename Types::BDataType, - typename Types::AccDataType, - typename Types::CShuffleDataType, - typename Types::DsDataTypes, - typename Types::EDataType, - typename Ops::AElementwiseOp, - typename Ops::BElementwiseOp, - typename Ops::CDEElementwiseOp, - SPECIALIZATION.conv_spec, - SPECIALIZATION.gemm_spec, - BLOCK.block_size, - BLOCK.per_block.m, - BLOCK.per_block.n, - BLOCK.per_block.k, - GRIDWISE_GEMM.ak1, - GRIDWISE_GEMM.bk1, - GRIDWISE_GEMM.m_per_xdl, - GRIDWISE_GEMM.n_per_xdl, - GRIDWISE_GEMM.m_xdl_per_wave, - GRIDWISE_GEMM.n_xdl_per_wave, - to_sequence_v, - to_sequence_v, - to_sequence_v, - A_BLOCK_TRANSFER.src_vector_dim, - A_BLOCK_TRANSFER.src_scalar_per_vector, - A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, - static_cast(A_BLOCK_TRANSFER.lds_padding), - to_sequence_v, - to_sequence_v, - to_sequence_v, - B_BLOCK_TRANSFER.src_vector_dim, - B_BLOCK_TRANSFER.src_scalar_per_vector, - B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, - static_cast(B_BLOCK_TRANSFER.lds_padding), - C_BLOCK_TRANSFER.m_per_wave_per_shuffle, - C_BLOCK_TRANSFER.n_per_wave_per_shuffle, - to_sequence_v, - C_BLOCK_TRANSFER.scalar_per_vector, - BLOCK_GEMM.scheduler, - BLOCK_GEMM.pipeline_version, - typename Types::AComputeType, - typename Types::BComputeType, - IS_DIRECT_LOAD>; -}; - -// Factory specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle instance -// of a grouped forward convolution kernel. -template - requires ConvDirectionIsForward && - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle> -struct ConvFactory -{ - static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; - using Layouts = decltype(factory_internal::GetTensorLayout()); - using Types = factory_internal::ConvTensorTypes; - using Ops = factory_internal::ElementwiseOps()>; - using AlgorithmType = decltype(ALGORITHM); - - static constexpr auto FWD_CONV_SPECIALIZATION = - factory_internal::SetFwdConvSpecialization(); - static constexpr auto GEMM_SPECIALIZATION = - factory_internal::SetGemmSpecialization(); - static constexpr factory_internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION, - .gemm_spec = GEMM_SPECIALIZATION}; - - static constexpr auto LOOP_SCHEDULER = factory_internal::SetLoopScheduler(); - static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; - static constexpr auto A_BLOCK_TRANSFER = - factory_internal::SetFwdConvBlockTransfer(); - static constexpr auto B_BLOCK_TRANSFER = - factory_internal::SetFwdConvBlockTransfer(); - static constexpr auto C_BLOCK_TRANSFER = - factory_internal::SetCBlockTransfer(); - - // Check limits for the algorithm parameters. - // TODO: Add more limits checks as needed. - static_assert(InputVectorTransferLimits); - static_assert(InputVectorTransferLimits); - static_assert(OutputVectorTransferLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - - // The forward convolution kernel class instance. - using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< - SPATIAL_DIM, - typename Layouts::ALayout, - typename Layouts::BLayout, - typename Layouts::DsLayout, - typename Layouts::ELayout, - typename Types::ADataType, - typename Types::BDataType, - typename Types::AccDataType, - typename Types::CShuffleDataType, - typename Types::DsDataTypes, - typename Types::EDataType, - typename Ops::AElementwiseOp, - typename Ops::BElementwiseOp, - typename Ops::CDEElementwiseOp, - SPECIALIZATION.conv_spec, - SPECIALIZATION.gemm_spec, - ALGORITHM.num_gemm_k_prefetch_stages, - BLOCK.block_size, - BLOCK.per_block.m, - BLOCK.per_block.n, - BLOCK.per_block.k, - GRIDWISE_GEMM.ak1, - GRIDWISE_GEMM.bk1, - GRIDWISE_GEMM.m_per_xdl, - GRIDWISE_GEMM.n_per_xdl, - GRIDWISE_GEMM.m_xdl_per_wave, - GRIDWISE_GEMM.n_xdl_per_wave, - to_sequence_v, - to_sequence_v, - to_sequence_v, - A_BLOCK_TRANSFER.src_vector_dim, - A_BLOCK_TRANSFER.src_scalar_per_vector, - A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, - static_cast(A_BLOCK_TRANSFER.lds_padding), - to_sequence_v, - to_sequence_v, - to_sequence_v, - B_BLOCK_TRANSFER.src_vector_dim, - B_BLOCK_TRANSFER.src_scalar_per_vector, - B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, - static_cast(B_BLOCK_TRANSFER.lds_padding), - C_BLOCK_TRANSFER.m_per_wave_per_shuffle, - C_BLOCK_TRANSFER.n_per_wave_per_shuffle, - to_sequence_v, - C_BLOCK_TRANSFER.scalar_per_vector, - typename Types::AComputeType, - typename Types::BComputeType, - LOOP_SCHEDULER, - ALGORITHM.num_groups_to_merge>; -}; - -// Factory specialization for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle instance -// of a grouped forward convolution kernel. -template - requires ConvDirectionIsForward && - DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle> -struct ConvFactory -{ - static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; - using Layouts = decltype(factory_internal::GetTensorLayout()); - using Types = factory_internal::ConvTensorTypes; - using Ops = factory_internal::ElementwiseOps()>; - using AlgorithmType = decltype(ALGORITHM); - - static constexpr auto FWD_CONV_SPECIALIZATION = - factory_internal::SetFwdConvSpecialization(); - static constexpr auto GEMM_SPECIALIZATION = - factory_internal::SetGemmSpecialization(); - static constexpr factory_internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION, - .gemm_spec = GEMM_SPECIALIZATION}; - - static constexpr auto LOOP_SCHEDULER = factory_internal::SetLoopScheduler(); - static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; - static constexpr auto GRIDWISE_GEMM_PIPELINE_VERSION = - factory_internal::SetGridwiseGemmPipelineVersion(); - static constexpr auto A_BLOCK_TRANSFER = - factory_internal::SetFwdConvBlockTransfer(); - static constexpr auto B_BLOCK_TRANSFER = - factory_internal::SetFwdConvBlockTransfer(); - static constexpr auto C_BLOCK_TRANSFER = - factory_internal::SetCBlockTransfer(); - - // Check limits for the algorithm parameters. - // TODO: Add more limits checks as needed. - static_assert(InputVectorTransferLimits); - static_assert(InputVectorTransferLimits); - static_assert(OutputVectorTransferLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - - // The forward convolution kernel class instance. - using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< - SPATIAL_DIM, - typename Layouts::ALayout, - typename Layouts::BLayout, - typename Layouts::DsLayout, - typename Layouts::ELayout, - typename Types::ADataType, - typename Types::BDataType, - typename Types::AccDataType, - typename Types::CShuffleDataType, - typename Types::DsDataTypes, - typename Types::EDataType, - typename Ops::AElementwiseOp, - typename Ops::BElementwiseOp, - typename Ops::CDEElementwiseOp, - SPECIALIZATION.conv_spec, - SPECIALIZATION.gemm_spec, - ALGORITHM.num_gemm_k_prefetch_stages, - BLOCK.block_size, - BLOCK.per_block.m, - BLOCK.per_block.n, - BLOCK.per_block.k, - GRIDWISE_GEMM.k1, - GRIDWISE_GEMM.m_per_wmma, - GRIDWISE_GEMM.n_per_wmma, - GRIDWISE_GEMM.m_wmma_per_wave, - GRIDWISE_GEMM.n_wmma_per_wave, - to_sequence_v, - to_sequence_v, - to_sequence_v, - A_BLOCK_TRANSFER.src_vector_dim, - A_BLOCK_TRANSFER.src_scalar_per_vector, - A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, - static_cast(A_BLOCK_TRANSFER.lds_padding), - to_sequence_v, - to_sequence_v, - to_sequence_v, - B_BLOCK_TRANSFER.src_vector_dim, - B_BLOCK_TRANSFER.src_scalar_per_vector, - B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, - static_cast(B_BLOCK_TRANSFER.lds_padding), - C_BLOCK_TRANSFER.m_per_wave_per_shuffle, - C_BLOCK_TRANSFER.n_per_wave_per_shuffle, - to_sequence_v, - C_BLOCK_TRANSFER.scalar_per_vector, - LOOP_SCHEDULER, - GRIDWISE_GEMM_PIPELINE_VERSION>; -}; - -// Factory specialization for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK instance -// of a grouped forward convolution kernel using Direct Load (DL) approach. -template - requires ConvDirectionIsForward && DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< - std::remove_const_t> -struct ConvFactory -{ - static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; - using Layouts = decltype(factory_internal::GetTensorLayout()); - using Types = factory_internal::ConvTensorTypes; - using Ops = factory_internal::ElementwiseOps()>; - using AlgorithmType = decltype(ALGORITHM); - - static constexpr auto FWD_CONV_SPECIALIZATION = - factory_internal::SetFwdConvSpecialization(); - static constexpr auto GEMM_SPECIALIZATION = - factory_internal::SetGemmSpecialization(); - - static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo(); - - // DL-specific parameters from algorithm descriptor - static constexpr auto DL_THREAD_CFG = ALGORITHM.thread_config; - static constexpr ck::index_t K0PerBlock = DL_THREAD_CFG.k0_per_block; - static constexpr ck::index_t K1 = DL_THREAD_CFG.k1; - static constexpr ck::index_t M1PerThread = DL_THREAD_CFG.m1_per_thread; - static constexpr ck::index_t N1PerThread = DL_THREAD_CFG.n1_per_thread; - static constexpr ck::index_t KPerThread = DL_THREAD_CFG.k_per_thread; - - // Thread cluster from descriptor - static constexpr auto DL_CLUSTER = ALGORITHM.thread_cluster; - using M1N1ThreadClusterM1Xs = to_sequence_v; - using M1N1ThreadClusterN1Xs = to_sequence_v; - - // A Block Transfer from descriptor - K0_M0_M1_K1 tensor format - static constexpr auto DL_A_TRANSFER = ALGORITHM.transfer.a.block_transfer; - using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 = - to_sequence_v; - using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 = - to_sequence_v; - using ABlockTransferThreadClusterArrangeOrder = - to_sequence_v; - using ABlockTransferSrcAccessOrder = to_sequence_v; - using ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = - to_sequence_v; - using ABlockTransferSrcVectorTensorContiguousDimOrder = - to_sequence_v; - using ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = - to_sequence_v; - - // B Block Transfer from descriptor - K0_N0_N1_K1 tensor format - static constexpr auto DL_B_TRANSFER = ALGORITHM.transfer.b.block_transfer; - using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 = - to_sequence_v; - using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 = - to_sequence_v; - using BBlockTransferThreadClusterArrangeOrder = - to_sequence_v; - using BBlockTransferSrcAccessOrder = to_sequence_v; - using BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = - to_sequence_v; - using BBlockTransferSrcVectorTensorContiguousDimOrder = - to_sequence_v; - using BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = - to_sequence_v; - - // C Thread Transfer from descriptor - static constexpr auto DL_C_TRANSFER = ALGORITHM.transfer.c.epilogue; - using CThreadTransferSrcDstAccessOrder = to_sequence_v; - static constexpr ck::index_t CThreadTransferSrcDstVectorDim = DL_C_TRANSFER.src_dst_vector_dim; - static constexpr ck::index_t CThreadTransferDstScalarPerVector = - DL_C_TRANSFER.dst_scalar_per_vector; - - // The DL forward convolution kernel class instance - using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< - SPATIAL_DIM, - typename Types::ADataType, - typename Types::BDataType, - typename Types::DsDataTypes, - typename Types::EDataType, - typename Types::AccDataType, - typename Layouts::ALayout, - typename Layouts::BLayout, - typename Layouts::DsLayout, - typename Layouts::ELayout, - typename Ops::AElementwiseOp, - typename Ops::BElementwiseOp, - typename Ops::CDEElementwiseOp, - FWD_CONV_SPECIALIZATION, - GEMM_SPECIALIZATION, - BLOCK.block_size, - BLOCK.per_block.m, - BLOCK.per_block.n, - K0PerBlock, - K1, - M1PerThread, - N1PerThread, - KPerThread, - M1N1ThreadClusterM1Xs, - M1N1ThreadClusterN1Xs, - ABlockTransferThreadSliceLengths_K0_M0_M1_K1, - ABlockTransferThreadClusterLengths_K0_M0_M1_K1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, - ABlockTransferSrcVectorTensorContiguousDimOrder, - ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, - BBlockTransferThreadSliceLengths_K0_N0_N1_K1, - BBlockTransferThreadClusterLengths_K0_N0_N1_K1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, - BBlockTransferSrcVectorTensorContiguousDimOrder, - BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector>; -}; - -// Factory specialization for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor instance -// of a grouped forward convolution kernel with large tensor support (N-splitting). -template - requires ConvDirectionIsForward && - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor< - std::remove_const_t> -struct ConvFactory -{ - static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; - using Layouts = decltype(factory_internal::GetTensorLayout()); - using Types = factory_internal::ConvTensorTypes; - using Ops = factory_internal::ElementwiseOps()>; - using AlgorithmType = decltype(ALGORITHM); - - static constexpr auto BASE_ALGORITHM = ALGORITHM.base_algorithm; - - static constexpr auto FWD_CONV_SPECIALIZATION = - factory_internal::SetFwdConvSpecialization(); - static constexpr auto GEMM_SPECIALIZATION = - factory_internal::SetGemmSpecialization(); - static constexpr factory_internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION, - .gemm_spec = GEMM_SPECIALIZATION}; - - static constexpr auto LOOP_SCHEDULER = factory_internal::SetLoopScheduler(); - static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = BASE_ALGORITHM.gridwise_gemm; - static constexpr auto A_BLOCK_TRANSFER = - factory_internal::SetFwdConvBlockTransfer(); - static constexpr auto B_BLOCK_TRANSFER = - factory_internal::SetFwdConvBlockTransfer(); - static constexpr auto C_BLOCK_TRANSFER = - factory_internal::SetCBlockTransfer(); - - // Check limits for the algorithm parameters. - static_assert(InputVectorTransferLimits); - static_assert(InputVectorTransferLimits); - static_assert(OutputVectorTransferLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - static_assert(AccessOrderLimits); - - // The forward convolution kernel class instance with large tensor support. - using Instance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor< - SPATIAL_DIM, - typename Layouts::ALayout, - typename Layouts::BLayout, - typename Layouts::DsLayout, - typename Layouts::ELayout, - typename Types::ADataType, - typename Types::BDataType, - typename Types::AccDataType, - typename Types::CShuffleDataType, - typename Types::DsDataTypes, - typename Types::EDataType, - typename Ops::AElementwiseOp, - typename Ops::BElementwiseOp, - typename Ops::CDEElementwiseOp, - SPECIALIZATION.conv_spec, - SPECIALIZATION.gemm_spec, - BASE_ALGORITHM.num_gemm_k_prefetch_stages, - BLOCK.block_size, - BLOCK.per_block.m, - BLOCK.per_block.n, - BLOCK.per_block.k, - GRIDWISE_GEMM.ak1, - GRIDWISE_GEMM.bk1, - GRIDWISE_GEMM.m_per_xdl, - GRIDWISE_GEMM.n_per_xdl, - GRIDWISE_GEMM.m_xdl_per_wave, - GRIDWISE_GEMM.n_xdl_per_wave, - to_sequence_v, - to_sequence_v, - to_sequence_v, - A_BLOCK_TRANSFER.src_vector_dim, - A_BLOCK_TRANSFER.src_scalar_per_vector, - A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, - static_cast(A_BLOCK_TRANSFER.lds_padding), - to_sequence_v, - to_sequence_v, - to_sequence_v, - B_BLOCK_TRANSFER.src_vector_dim, - B_BLOCK_TRANSFER.src_scalar_per_vector, - B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, - static_cast(B_BLOCK_TRANSFER.lds_padding), - C_BLOCK_TRANSFER.m_per_wave_per_shuffle, - C_BLOCK_TRANSFER.n_per_wave_per_shuffle, - to_sequence_v, - C_BLOCK_TRANSFER.scalar_per_vector, - typename Types::AComputeType, - typename Types::BComputeType, - LOOP_SCHEDULER>; -}; - -} // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/factory/README.md b/experimental/builder/include/ck_tile/builder/factory/README.md new file mode 100644 index 0000000000..d1794349ab --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/README.md @@ -0,0 +1,31 @@ +# Convolution Builder Factory Directory + +This directory implements compile-time dispatch from high-level signature algorithm descriptors to our exisitng specialized convolution kernel implementations. + +See the [main builder documentation](../README.md) for an overview. + +## Design Overview + +The factory system operates in two phases: + +1. **Algorithm Classification**: The function `make_conv_instance` in `conv_dispatcher.hpp` inspects the signature and algorithm descriptors to determine which kernel variant they satisfy (XDL V3, XDL, WMMA, DL, or Large Tensor) + +2. **Factory Instantiation**: Each factory (`conv_fwd_*_factory.hpp`) transforms builder descriptors into CK device operation template parameters and instantiates the corresponding kernel device operation. + +## Key Files + +- **`conv_dispatcher.hpp`**: Entry point with `make_conv_instance()` function. Contains dispatch logic and algorithm classification predicates. **Start here** to understand the overall flow. + +- **`conv_fwd_*_factory.hpp`**: Individual factories for each kernel variant. Each extracts configuration from descriptors, validates parameters, and instantiates the underlying CK device operation. + +- **`helpers/`**: Transformation utilities that map builder types to CK device operation parameters (layouts, data types, elementwise ops, block configurations, etc.) + +## Usage + +```cpp +#include "ck_tile/builder/factory/conv_dispatcher.hpp" + +using Factory = decltype(make_conv_instance()); +``` + +The dispatcher automatically selects the appropriate factory following explicit logic. diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp new file mode 100644 index 0000000000..51945544b2 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -0,0 +1,196 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// Compile-time dispatcher for convolution kernel instantiation. +// +// This header provides a centralized factory dispatch mechanism that routes algorithm +// specifications to appropriate convolution kernel implementations at compile-time. +// +// ## Design Overview +// +// The dispatcher operates in two phases: +// 1. **Algorithm Identification**: Five `consteval` predicate functions (`IsXdlV3Algorithm`, +// `IsXdlAlgorithm`, `IsWmmaAlgorithm`, `IsDlAlgorithm`, `IsLargeTensorAlgorithm`) inspect +// the algorithm descriptor's structure to determine which kernel variant it satisfies. +// Each predicate checks a specific set of concept constraints that define a kernel variant. +// +// 2. **Factory Routing**: The main `make_conv_instance()` function uses `if constexpr` +// to dispatch to the appropriate factory class based on both the convolution direction +// and the identified algorithm type. All routing decisions occur at compile-time, +// ensuring zero runtime overhead. +// +// ## Supported Kernel Variants +// +// - **XDL V3**: Newer XDL-based pipeline using block GEMM structure. Requires fewer parameters +// than standard XDL (e.g., uses `SpecifiesBlockGemm` instead of scheduling/prefetch configs). +// +// - **XDL**: Standard XDL-based kernel using AMD XDLops hardware instructions for matrix +// multiply. Requires full scheduling configuration including prefetch stages and loop scheduler. +// +// - **WMMA**: Wavefront Matrix-Matrix Accumulate variant optimized for WMMA-capable hardware. +// Requires similar configuration to XDL. +// +// - **DL**: Specialized vectorized dot-product kernel optimized for specific data layouts +// (NHWC/KYXC/NHWK). The "DL" label just indicates this does not use XDLops instructions. +// +// - **Large Tensor**: XDL-based kernel with extended tensor support. Wraps a base XDL algorithm +// and adds large tensor capabilities. +// +// ## Current Limitations +// +// Currently only forward convolution is supported. Backward data and backward weight convolution +// directions will fail at compile-time with informative static_assert messages. +// +// ## Usage Example +// +// ``` +// auto kernel = make_conv_instance(); +// ``` + +#pragma once + +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/types.hpp" + +// Include all factory implementations +#include "ck_tile/builder/factory/conv_fwd_v3_factory.hpp" +#include "ck_tile/builder/factory/conv_fwd_xdl_factory.hpp" +#include "ck_tile/builder/factory/conv_fwd_wmma_factory.hpp" +#include "ck_tile/builder/factory/conv_fwd_dl_factory.hpp" +#include "ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp" + +namespace ck_tile::builder::factory { + +// This dispatch logic is rigid and confusing for users. Further, hides most of +// the great error messages from our concepts. +// +// Requirements for a good design: +// 1. Fall through is bad: inputs should get directly to an implementation +// if we are going to have good compiler errors. +// 2. Logic should be easy for library users to understand. +// 3. Logic should be easy to test, maintain, and extend. +// +// We should probably add explicit tags to the algorithm descriptors, at least +// for the initial implemenation. +// +// To avoid changing behavior too much during refactoring, we leave the explicit +// dispatch logic here for now, just changing it from SFINAE to consteval + if constexpr. +// There may be some subtle behavior changes, but build failure messages will be more +// clear. +// +// TODO: Make this dispatch logic much more robust and clear for users. + +// XDL-based kernel with V3 pipeline structure (newer block GEMM pipeline) +template +consteval bool IsXdlV3Algorithm() +{ + return ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && + SpecifiesBlockTransfer && SpecifiesLdsTransfer && + SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && + SpecifiesFwdConcSpecialization && SpecifiesGemmSpecialization && + SpecifiesBlockGemm; +} + +// Standard XDL-based kernel (uses XDLops hardware instructions for matrix multiply) +template +consteval bool IsXdlAlgorithm() +{ + return ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && + SpecifiesBlockTransfer && SpecifiesLdsTransfer && + SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && + SpecifiesFwdConcSpecialization && SpecifiesGemmSpecialization && + SpecifiesNumPrefetchStages && SpecifiesNumGroupsToMerge && + SpecifiesLoopScheduler; +} + +// WMMA-based kernel (uses Wavefront Matrix-Matrix Accumulate instructions) +template +consteval bool IsWmmaAlgorithm() +{ + return ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseWmmaGemm && + SpecifiesBlockTransfer && SpecifiesLdsTransfer && + SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && + SpecifiesFwdConcSpecialization && SpecifiesGemmSpecialization && + SpecifiesNumPrefetchStages && SpecifiesLoopScheduler; +} + +// Specialized DL kernel for specific NHWC/KYXC/NHWK data layouts +template +consteval bool IsDlAlgorithm() +{ + return ConvAlgorithmDescriptor && SpecifiesThreadBlock && + SpecifiesFwdConcSpecialization && SpecifiesGemmSpecialization && + SpecifiesDlThreadConfig && SpecifiesDlThreadCluster && + SpecifiesDlBlockTransfer && SpecifiesDlEpilogue; +} + +// XDL-based kernel with large tensor support +template +consteval bool IsLargeTensorAlgorithm() +{ + return IsXdlAlgorithm() && SpecifiesLargeTensorSupport; +} + +template +constexpr auto make_conv_instance() +{ + if constexpr(ConvDirectionIsForward) + { + using AlgoType = std::remove_const_t; + + if constexpr(IsXdlV3Algorithm()) + { + return typename ConvFwdXdlV3Factory::Instance{}; + } + else if constexpr(IsXdlAlgorithm()) + { + return typename ConvFwdXdlFactory::Instance{}; + } + else if constexpr(IsWmmaAlgorithm()) + { + return typename ConvFwdWmmaFactory::Instance{}; + } + else if constexpr(IsDlAlgorithm()) + { + return typename ConvFwdDlFactory::Instance{}; + } + else if constexpr(IsLargeTensorAlgorithm()) + { + return typename ConvFwdLargeTensorFactory::Instance{}; + } + else + { + static_assert( + false, + "No suitable forward convolution kernel factory found for the provided ALGORITHM. " + "The ALGORITHM must satisfy requirements for one of: XDL V3, XDL, WMMA, DL (NHWC " + "layout), or Large Tensor variant."); + } + } + else if constexpr(ConvDirectionIsBackwardData) + { + static_assert( + false, + "Backward data convolution is not yet supported. " + "Only forward convolution (ConvDirection::FORWARD) is currently implemented."); + } + else if constexpr(ConvDirectionIsBackwardWeight) + { + static_assert( + false, + "Backward weight convolution is not yet supported. " + "Only forward convolution (ConvDirection::FORWARD) is currently implemented."); + } + else + { + static_assert(false, + "Invalid or unsupported convolution direction. " + "The SIGNATURE must specify a valid ConvDirection: FORWARD, BACKWARD_DATA, " + "or BACKWARD_WEIGHT."); + } +} + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp new file mode 100644 index 0000000000..dee918cc1f --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_dl_factory.hpp @@ -0,0 +1,138 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/conv_signature_utils.hpp" +#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK instance +// of a grouped forward convolution kernel. +template + requires ConvDirectionIsForward +struct ConvFwdDlFactory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = decltype(internal::GetTensorLayout()); + using Types = internal::ConvTensorTypes; + using Ops = internal::ElementwiseOps()>; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization(); + static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization(); + + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + + // DL-specific parameters from algorithm descriptor + static constexpr auto DL_THREAD_CFG = ALGORITHM.thread_config; + static constexpr ck::index_t K0PerBlock = DL_THREAD_CFG.k0_per_block; + static constexpr ck::index_t K1 = DL_THREAD_CFG.k1; + static constexpr ck::index_t M1PerThread = DL_THREAD_CFG.m1_per_thread; + static constexpr ck::index_t N1PerThread = DL_THREAD_CFG.n1_per_thread; + static constexpr ck::index_t KPerThread = DL_THREAD_CFG.k_per_thread; + + // Thread cluster from descriptor + static constexpr auto DL_CLUSTER = ALGORITHM.thread_cluster; + using M1N1ThreadClusterM1Xs = to_sequence_v; + using M1N1ThreadClusterN1Xs = to_sequence_v; + + // A Block Transfer from descriptor - K0_M0_M1_K1 tensor format + static constexpr auto DL_A_TRANSFER = ALGORITHM.transfer.a.block_transfer; + using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 = + to_sequence_v; + using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 = + to_sequence_v; + using ABlockTransferThreadClusterArrangeOrder = + to_sequence_v; + using ABlockTransferSrcAccessOrder = to_sequence_v; + using ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = + to_sequence_v; + using ABlockTransferSrcVectorTensorContiguousDimOrder = + to_sequence_v; + using ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = + to_sequence_v; + + // B Block Transfer from descriptor - K0_N0_N1_K1 tensor format + static constexpr auto DL_B_TRANSFER = ALGORITHM.transfer.b.block_transfer; + using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 = + to_sequence_v; + using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 = + to_sequence_v; + using BBlockTransferThreadClusterArrangeOrder = + to_sequence_v; + using BBlockTransferSrcAccessOrder = to_sequence_v; + using BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = + to_sequence_v; + using BBlockTransferSrcVectorTensorContiguousDimOrder = + to_sequence_v; + using BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = + to_sequence_v; + + // C Thread Transfer from descriptor + static constexpr auto DL_C_TRANSFER = ALGORITHM.transfer.c.epilogue; + using CThreadTransferSrcDstAccessOrder = to_sequence_v; + static constexpr ck::index_t CThreadTransferSrcDstVectorDim = DL_C_TRANSFER.src_dst_vector_dim; + static constexpr ck::index_t CThreadTransferDstScalarPerVector = + DL_C_TRANSFER.dst_scalar_per_vector; + + // The DL forward convolution kernel class instance + using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< + SPATIAL_DIM, + typename Types::ADataType, + typename Types::BDataType, + typename Types::DsDataTypes, + typename Types::EDataType, + typename Types::AccDataType, + typename Layouts::ALayout, + typename Layouts::BLayout, + typename Layouts::DsLayout, + typename Layouts::ELayout, + typename Ops::AElementwiseOp, + typename Ops::BElementwiseOp, + typename Ops::CDEElementwiseOp, + FWD_CONV_SPECIALIZATION, + GEMM_SPECIALIZATION, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + K0PerBlock, + K1, + M1PerThread, + N1PerThread, + KPerThread, + M1N1ThreadClusterM1Xs, + M1N1ThreadClusterN1Xs, + ABlockTransferThreadSliceLengths_K0_M0_M1_K1, + ABlockTransferThreadClusterLengths_K0_M0_M1_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, + ABlockTransferSrcVectorTensorContiguousDimOrder, + ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, + BBlockTransferThreadSliceLengths_K0_N0_N1_K1, + BBlockTransferThreadClusterLengths_K0_N0_N1_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, + BBlockTransferSrcVectorTensorContiguousDimOrder, + BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp new file mode 100644 index 0000000000..383ecbf8c9 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp @@ -0,0 +1,117 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/conv_signature_utils.hpp" +#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor instance +// of a grouped forward convolution kernel. +template + requires ConvDirectionIsForward +struct ConvFwdLargeTensorFactory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = decltype(internal::GetTensorLayout()); + using Types = internal::ConvTensorTypes; + using Ops = internal::ElementwiseOps()>; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto BASE_ALGORITHM = ALGORITHM.base_algorithm; + + static constexpr auto FWD_CONV_SPECIALIZATION = + internal::SetFwdConvSpecialization(); + static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization(); + static constexpr internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION, + .gemm_spec = GEMM_SPECIALIZATION}; + + static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler(); + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = BASE_ALGORITHM.gridwise_gemm; + static constexpr auto A_BLOCK_TRANSFER = + internal::SetFwdConvBlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + internal::SetFwdConvBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = + internal::SetCBlockTransfer(); + + // Check limits for the algorithm parameters. + static_assert(InputVectorTransferLimits); + static_assert(InputVectorTransferLimits); + static_assert(OutputVectorTransferLimits); + static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits); + + // The forward convolution kernel class instance with large tensor support. + using Instance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor< + SPATIAL_DIM, + typename Layouts::ALayout, + typename Layouts::BLayout, + typename Layouts::DsLayout, + typename Layouts::ELayout, + typename Types::ADataType, + typename Types::BDataType, + typename Types::AccDataType, + typename Types::CShuffleDataType, + typename Types::DsDataTypes, + typename Types::EDataType, + typename Ops::AElementwiseOp, + typename Ops::BElementwiseOp, + typename Ops::CDEElementwiseOp, + SPECIALIZATION.conv_spec, + SPECIALIZATION.gemm_spec, + BASE_ALGORITHM.num_gemm_k_prefetch_stages, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.ak1, + GRIDWISE_GEMM.bk1, + GRIDWISE_GEMM.m_per_xdl, + GRIDWISE_GEMM.n_per_xdl, + GRIDWISE_GEMM.m_xdl_per_wave, + GRIDWISE_GEMM.n_xdl_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + static_cast(A_BLOCK_TRANSFER.lds_padding), + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + static_cast(B_BLOCK_TRANSFER.lds_padding), + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + typename Types::AComputeType, + typename Types::BComputeType, + LOOP_SCHEDULER>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp new file mode 100644 index 0000000000..90d4abe3e7 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp @@ -0,0 +1,119 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/conv_signature_utils.hpp" +#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 instance +// of a grouped forward convolution kernel. +template + requires ConvDirectionIsForward +struct ConvFwdXdlV3Factory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = decltype(internal::GetTensorLayout()); + using Types = internal::ConvTensorTypes; + using Ops = internal::ElementwiseOps()>; + using AlgorithmType = decltype(ALGORITHM); + + static_assert(ALGORITHM.transfer.a.lds_transfer.is_direct_load == + ALGORITHM.transfer.b.lds_transfer.is_direct_load, + "A and B block transfers must both be direct load or not."); + + static constexpr bool IS_DIRECT_LOAD = ALGORITHM.transfer.a.lds_transfer.is_direct_load; + static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization(); + static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization(); + static constexpr internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION, + .gemm_spec = GEMM_SPECIALIZATION}; + + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto A_BLOCK_TRANSFER = + internal::SetFwdConvBlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + internal::SetFwdConvBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); + static constexpr auto BLOCK_GEMM = internal::SetBlockGemm(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(InputVectorTransferLimits); + static_assert(InputVectorTransferLimits); + static_assert(OutputVectorTransferLimits); + static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits); + + // The forward convolution kernel class instance. + using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + SPATIAL_DIM, + typename Layouts::ALayout, + typename Layouts::BLayout, + typename Layouts::DsLayout, + typename Layouts::ELayout, + typename Types::ADataType, + typename Types::BDataType, + typename Types::AccDataType, + typename Types::CShuffleDataType, + typename Types::DsDataTypes, + typename Types::EDataType, + typename Ops::AElementwiseOp, + typename Ops::BElementwiseOp, + typename Ops::CDEElementwiseOp, + SPECIALIZATION.conv_spec, + SPECIALIZATION.gemm_spec, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.ak1, + GRIDWISE_GEMM.bk1, + GRIDWISE_GEMM.m_per_xdl, + GRIDWISE_GEMM.n_per_xdl, + GRIDWISE_GEMM.m_xdl_per_wave, + GRIDWISE_GEMM.n_xdl_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + static_cast(A_BLOCK_TRANSFER.lds_padding), + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + static_cast(B_BLOCK_TRANSFER.lds_padding), + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + BLOCK_GEMM.scheduler, + BLOCK_GEMM.pipeline_version, + typename Types::AComputeType, + typename Types::BComputeType, + IS_DIRECT_LOAD>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp new file mode 100644 index 0000000000..e35b3f3d46 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp @@ -0,0 +1,113 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/conv_signature_utils.hpp" +#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle instance +// of a grouped forward convolution kernel. +template + requires ConvDirectionIsForward +struct ConvFwdWmmaFactory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = decltype(internal::GetTensorLayout()); + using Types = internal::ConvTensorTypes; + using Ops = internal::ElementwiseOps()>; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization(); + static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization(); + static constexpr internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION, + .gemm_spec = GEMM_SPECIALIZATION}; + + static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler(); + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto GRIDWISE_GEMM_PIPELINE_VERSION = + internal::SetGridwiseGemmPipelineVersion(); + static constexpr auto A_BLOCK_TRANSFER = + internal::SetFwdConvBlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + internal::SetFwdConvBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(InputVectorTransferLimits); + static_assert(InputVectorTransferLimits); + static_assert(OutputVectorTransferLimits); + static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits); + + // The forward convolution kernel class instance. + using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< + SPATIAL_DIM, + typename Layouts::ALayout, + typename Layouts::BLayout, + typename Layouts::DsLayout, + typename Layouts::ELayout, + typename Types::ADataType, + typename Types::BDataType, + typename Types::AccDataType, + typename Types::CShuffleDataType, + typename Types::DsDataTypes, + typename Types::EDataType, + typename Ops::AElementwiseOp, + typename Ops::BElementwiseOp, + typename Ops::CDEElementwiseOp, + SPECIALIZATION.conv_spec, + SPECIALIZATION.gemm_spec, + ALGORITHM.num_gemm_k_prefetch_stages, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.k1, + GRIDWISE_GEMM.m_per_wmma, + GRIDWISE_GEMM.n_per_wmma, + GRIDWISE_GEMM.m_wmma_per_wave, + GRIDWISE_GEMM.n_wmma_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + static_cast(A_BLOCK_TRANSFER.lds_padding), + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + static_cast(B_BLOCK_TRANSFER.lds_padding), + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + LOOP_SCHEDULER, + GRIDWISE_GEMM_PIPELINE_VERSION>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp new file mode 100644 index 0000000000..fc5b32f799 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp @@ -0,0 +1,114 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_algorithm_limits.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/conv_signature_utils.hpp" +#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" +#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" +#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp" +#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp" +#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp" + +namespace ck_tile::builder::factory { + +// Factory for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle instance +// of a grouped forward convolution kernel. +template + requires ConvDirectionIsForward +struct ConvFwdXdlFactory +{ + static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim; + using Layouts = decltype(internal::GetTensorLayout()); + using Types = internal::ConvTensorTypes; + using Ops = internal::ElementwiseOps()>; + using AlgorithmType = decltype(ALGORITHM); + + static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization(); + static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization(); + static constexpr internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION, + .gemm_spec = GEMM_SPECIALIZATION}; + + static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler(); + static constexpr auto BLOCK = internal::SetThreadBlockInfo(); + static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto A_BLOCK_TRANSFER = + internal::SetFwdConvBlockTransfer(); + static constexpr auto B_BLOCK_TRANSFER = + internal::SetFwdConvBlockTransfer(); + static constexpr auto C_BLOCK_TRANSFER = internal::SetCBlockTransfer(); + + // Check limits for the algorithm parameters. + // TODO: Add more limits checks as needed. + static_assert(InputVectorTransferLimits); + static_assert(InputVectorTransferLimits); + static_assert(OutputVectorTransferLimits); + static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits); + static_assert(AccessOrderLimits); + + // The forward convolution kernel class instance. + using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + SPATIAL_DIM, + typename Layouts::ALayout, + typename Layouts::BLayout, + typename Layouts::DsLayout, + typename Layouts::ELayout, + typename Types::ADataType, + typename Types::BDataType, + typename Types::AccDataType, + typename Types::CShuffleDataType, + typename Types::DsDataTypes, + typename Types::EDataType, + typename Ops::AElementwiseOp, + typename Ops::BElementwiseOp, + typename Ops::CDEElementwiseOp, + SPECIALIZATION.conv_spec, + SPECIALIZATION.gemm_spec, + ALGORITHM.num_gemm_k_prefetch_stages, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + GRIDWISE_GEMM.ak1, + GRIDWISE_GEMM.bk1, + GRIDWISE_GEMM.m_per_xdl, + GRIDWISE_GEMM.n_per_xdl, + GRIDWISE_GEMM.m_xdl_per_wave, + GRIDWISE_GEMM.n_xdl_per_wave, + to_sequence_v, + to_sequence_v, + to_sequence_v, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scalar_per_vector, + A_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + static_cast(A_BLOCK_TRANSFER.lds_padding), + to_sequence_v, + to_sequence_v, + to_sequence_v, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scalar_per_vector, + B_BLOCK_TRANSFER.lds_dst_scalar_per_vector, + static_cast(B_BLOCK_TRANSFER.lds_padding), + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + to_sequence_v, + C_BLOCK_TRANSFER.scalar_per_vector, + typename Types::AComputeType, + typename Types::BComputeType, + LOOP_SCHEDULER, + ALGORITHM.num_groups_to_merge>; +}; + +} // namespace ck_tile::builder::factory diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_block_transfer.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/conv_block_transfer.hpp new file mode 100644 index 0000000000..5da1e4eadb --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/conv_block_transfer.hpp @@ -0,0 +1,73 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/utility/array.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" + +namespace ck_tile::builder::factory::internal { + +// Block transfer parameters for A or B tensor. +struct BlockTransfer +{ + ck::Array thread_cluster_dims = {0, 0, 0}; // k0, m, k1 + ck::Array thread_cluster_order = {0, 0, 0}; + ck::Array src_access_order = {0, 0, 0}; + size_t src_vector_dim = 0; + size_t src_scalar_per_vector = 0; + size_t lds_dst_scalar_per_vector = 0; + bool is_direct_load = false; + bool lds_padding = false; +}; + +template +constexpr BlockTransfer SetFwdConvBlockTransfer() +{ + auto& block_xfer = TRANSFER.block_transfer; + auto& block_order = TRANSFER.block_transfer_access_order; + auto& src_order = TRANSFER.src_access_order; + auto& lds_cfg = TRANSFER.lds_transfer; + + return BlockTransfer{ + .thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1}, + .thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2]}, + .src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]}, + .src_vector_dim = lds_cfg.src_vector_dim, + .src_scalar_per_vector = lds_cfg.src_scalar_per_vector, + .lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector, + .is_direct_load = lds_cfg.is_direct_load, + .lds_padding = lds_cfg.lds_padding, + }; +} + +// Block transfer parameters for C tensor. +struct CBlockTransfer +{ + size_t m_xdl_per_wave_per_shuffle = 0; + size_t n_xdl_per_wave_per_shuffle = 0; + ck::Array thread_cluster_dims = {0, 0, 0, 0}; + size_t scalar_per_vector = 0; +}; + +template +constexpr CBlockTransfer SetCBlockTransfer() +{ + auto& thread_cluster_dims = ALGORITHM.transfer.c.thread_cluster_dims; + auto& epilogue_config = ALGORITHM.transfer.c.epilogue; + return CBlockTransfer{ + .m_xdl_per_wave_per_shuffle = epilogue_config.m_xdl_per_wave_per_shuffle, + .n_xdl_per_wave_per_shuffle = epilogue_config.n_per_wave_per_shuffle, + .thread_cluster_dims = + { + thread_cluster_dims.m_block, + thread_cluster_dims.m_wave_per_xdl, + thread_cluster_dims.n_block, + thread_cluster_dims.n_wave_per_xdl, + }, + .scalar_per_vector = epilogue_config.scalar_per_vector, + }; +} + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_elementwise_op.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/conv_elementwise_op.hpp new file mode 100644 index 0000000000..4a13f4e508 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/conv_elementwise_op.hpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck_tile/builder/builder_utils.hpp" +#include "ck_tile/builder/types.hpp" + +namespace ck_tile::builder::factory::internal { + +template +struct ElementwiseOps +{ + // This will trigger if a specialization for the given DataType is not found. + // We should always catch this in an earlier validation check. + static_assert(sizeof(UnsupportedEnumValue) == 0, + "Internal error. Unsupported elementwise operation for convolution factory."); +}; + +template <> +struct ElementwiseOps +{ + using AElementwiseOp = ck::tensor_operation::element_wise::PassThrough; + using BElementwiseOp = ck::tensor_operation::element_wise::PassThrough; + using CDEElementwiseOp = ck::tensor_operation::element_wise::PassThrough; +}; + +template <> +struct ElementwiseOps +{ + using AElementwiseOp = ck::tensor_operation::element_wise::PassThrough; + using BElementwiseOp = ck::tensor_operation::element_wise::PassThrough; + using CDEElementwiseOp = ck::tensor_operation::element_wise::Scale; +}; + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_tensor_layout.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/conv_tensor_layout.hpp new file mode 100644 index 0000000000..b3effa782e --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/conv_tensor_layout.hpp @@ -0,0 +1,146 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/tuple.hpp" +#include "ck_tile/builder/conv_signature_concepts.hpp" + +namespace ck_tile::builder::factory::internal { + +// Type mappings from the builder FwdGroupConvLayout enum classes to the CK tensor data types. +template + requires(ConvSpatialDim && ValidConvLayoutForSpatialDim) +struct ConvTensorLayouts +{ + // This will trigger if a specialization for the given layout is not found. + // We should always catch this in an earlier validation check. + using Layout = decltype(LayoutValue); + static_assert(sizeof(Layout) == 0, + "Internal error. Unsupported layout for convolution factory."); +}; + +// 1D Forward Convolution Layout Specializations +template <> +struct ConvTensorLayouts +{ + using ALayout = ck::tensor_layout::convolution::NWGC; + using BLayout = ck::tensor_layout::convolution::GKXC; + using DsLayout = ck::Tuple<>; + using ELayout = ck::tensor_layout::convolution::NWGK; +}; + +template <> +struct ConvTensorLayouts +{ + using ALayout = ck::tensor_layout::convolution::NGCW; + using BLayout = ck::tensor_layout::convolution::GKXC; + using DsLayout = ck::Tuple<>; + using ELayout = ck::tensor_layout::convolution::NGKW; +}; + +template <> +struct ConvTensorLayouts +{ + using ALayout = ck::tensor_layout::convolution::GNWC; + using BLayout = ck::tensor_layout::convolution::GKXC; + using DsLayout = ck::Tuple<>; + using ELayout = ck::tensor_layout::convolution::GNWK; +}; + +template <> +struct ConvTensorLayouts +{ + using ALayout = ck::tensor_layout::convolution::NGCW; + using BLayout = ck::tensor_layout::convolution::GKCX; + using DsLayout = ck::Tuple<>; + using ELayout = ck::tensor_layout::convolution::NGKW; +}; + +template <> +struct ConvTensorLayouts +{ + using ALayout = ck::tensor_layout::convolution::NGCHW; + using BLayout = ck::tensor_layout::convolution::GKYXC; + using DsLayout = ck::Tuple<>; + using ELayout = ck::tensor_layout::convolution::NGKHW; +}; + +template <> +struct ConvTensorLayouts +{ + using ALayout = ck::tensor_layout::convolution::NHWGC; + using BLayout = ck::tensor_layout::convolution::GKYXC; + using DsLayout = ck::Tuple<>; + using ELayout = ck::tensor_layout::convolution::NHWGK; +}; + +template <> +struct ConvTensorLayouts +{ + using ALayout = ck::tensor_layout::convolution::GNHWC; + using BLayout = ck::tensor_layout::convolution::GKYXC; + using DsLayout = ck::Tuple<>; + using ELayout = ck::tensor_layout::convolution::GNHWK; +}; + +template <> +struct ConvTensorLayouts +{ + using ALayout = ck::tensor_layout::convolution::NGCHW; + using BLayout = ck::tensor_layout::convolution::GKCYX; + using DsLayout = ck::Tuple<>; + using ELayout = ck::tensor_layout::convolution::NGKHW; +}; + +template <> +struct ConvTensorLayouts +{ + using ALayout = ck::tensor_layout::convolution::NGCDHW; + using BLayout = ck::tensor_layout::convolution::GKCZYX; + using DsLayout = ck::Tuple<>; + using ELayout = ck::tensor_layout::convolution::NGKDHW; +}; + +template <> +struct ConvTensorLayouts +{ + using ALayout = ck::tensor_layout::convolution::NDHWGC; + using BLayout = ck::tensor_layout::convolution::GKZYXC; + using DsLayout = ck::Tuple<>; + using ELayout = ck::tensor_layout::convolution::NDHWGK; +}; + +template <> +struct ConvTensorLayouts +{ + using ALayout = ck::tensor_layout::convolution::GNDHWC; + using BLayout = ck::tensor_layout::convolution::GKZYXC; + using DsLayout = ck::Tuple<>; + using ELayout = ck::tensor_layout::convolution::GNDHWK; +}; + +template +consteval auto GetTensorLayout() +{ + + if constexpr(SPATIAL_DIM == 1) + { + return internal::ConvTensorLayouts{}; + } + else if constexpr(SPATIAL_DIM == 2) + { + return internal::ConvTensorLayouts{}; + } + else if constexpr(SPATIAL_DIM == 3) + { + return internal::ConvTensorLayouts{}; + } + else + { + static_assert(false, "Unsupported spatial dimension for convolution layout."); + } +} + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_tensor_type.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/conv_tensor_type.hpp new file mode 100644 index 0000000000..d8a8eb5da0 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/conv_tensor_type.hpp @@ -0,0 +1,87 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck_tile/builder/types.hpp" +#include "ck_tile/builder/builder_utils.hpp" + +namespace ck_tile::builder::factory::internal { + +// Type mappings from builder convolution data type to CK tensor types. +template +struct ConvTensorTypes +{ + // This will trigger if a specialization for the given DataType is not found. + // We should always catch this in an earlier validation check. + static_assert(sizeof(UnsupportedEnumValue) == 0, + "Internal error. Unsupported data type for convolution factory."); +}; + +template <> +struct ConvTensorTypes +{ + using ADataType = ck::half_t; + using AComputeType = ck::half_t; + using BDataType = ck::half_t; + using BComputeType = ck::half_t; + using CShuffleDataType = ck::half_t; + using DsDataTypes = ck::Tuple<>; + using AccDataType = float; + using EDataType = ck::half_t; +}; + +template <> +struct ConvTensorTypes +{ + using ADataType = ck::bhalf_t; + using AComputeType = ck::bhalf_t; + using BDataType = ck::bhalf_t; + using BComputeType = ck::bhalf_t; + using CShuffleDataType = ck::bhalf_t; + using DsDataTypes = ck::Tuple<>; + using AccDataType = float; + using EDataType = ck::bhalf_t; +}; + +template <> +struct ConvTensorTypes +{ + using ADataType = float; + using AComputeType = float; + using BDataType = float; + using BComputeType = float; + using CShuffleDataType = float; + using DsDataTypes = ck::Tuple<>; + using AccDataType = float; + using EDataType = float; +}; + +template <> +struct ConvTensorTypes +{ + using ADataType = int8_t; + using AComputeType = int8_t; + using BDataType = int8_t; + using BComputeType = int8_t; + using CShuffleDataType = int8_t; + using DsDataTypes = ck::Tuple<>; + using AccDataType = int32_t; + using EDataType = int8_t; +}; + +template <> +struct ConvTensorTypes +{ + using ADataType = ck::f8_t; + using AComputeType = ck::f8_t; + using BDataType = ck::f8_t; + using BComputeType = ck::f8_t; + using CShuffleDataType = ck::f8_t; + using DsDataTypes = ck::Tuple<>; + using AccDataType = float; + using EDataType = ck::f8_t; +}; + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_thread_block.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/conv_thread_block.hpp new file mode 100644 index 0000000000..7627165181 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/conv_thread_block.hpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/conv_algorithm_concepts.hpp" + +namespace ck_tile::builder::factory::internal { + +/// @brief Data tile dimensions processed by a workgroup. +/// @details This struct defines the M, N, and K dimensions of the data tile +/// that a single workgroup (thread block) is responsible for processing in the +/// underlying GEMM computation. +struct DataTileInfo +{ + int m; ///< M dimension of the tile processed by the workgroup (MPerBlock). + int n; ///< N dimension of the tile processed by the workgroup (NPerBlock). + int k; ///< K dimension of the tile processed by the workgroup (KPerBlock). +}; + +struct ConvBlock +{ + size_t block_size = 0; + DataTileInfo per_block = {}; +}; + +template +constexpr ConvBlock SetThreadBlockInfo() +{ + constexpr auto& TB = ALGORITHM.thread_block; + return ConvBlock{ + .block_size = TB.block_size, + .per_block = {.m = TB.tile_size.m, .n = TB.tile_size.n, .k = TB.tile_size.k}, + }; +} + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/conv_tuning_params.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/conv_tuning_params.hpp new file mode 100644 index 0000000000..3ec0a94960 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/conv_tuning_params.hpp @@ -0,0 +1,160 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" +#include "ck/utility/loop_scheduler.hpp" +#include "ck_tile/builder/conv_algorithm_concepts.hpp" +#include "ck_tile/builder/types.hpp" + +namespace ck_tile::builder::factory::internal { + +// The algorithm specializations for the convolution and GEMM. +template + requires( + std::is_same_v) +struct ConvSpec +{ + CONV_ENUM conv_spec; + ck::tensor_operation::device::GemmSpecialization gemm_spec; +}; + +// Deduction guide for ConvSpec to simplify brace initialization. +template +ConvSpec(CONV_ENUM, GEMM_ENUM) -> ConvSpec; + +struct BlockGemmSpec +{ + ck::BlockGemmPipelineVersion pipeline_version; + ck::BlockGemmPipelineScheduler scheduler; +}; + +template +consteval BlockGemmSpec SetBlockGemm() +{ + constexpr auto& BG = ALGORITHM.block_gemm; + + ck::BlockGemmPipelineScheduler scheduler; + ck::BlockGemmPipelineVersion version; + + switch(BG.scheduler) + { + case PipelineScheduler::INTRAWAVE: scheduler = ck::BlockGemmPipelineScheduler::Intrawave; break; + case PipelineScheduler::INTERWAVE: scheduler = ck::BlockGemmPipelineScheduler::Interwave; break; + case PipelineScheduler::DEFAULT: throw "Block GEMM scheduler must be Intrawave or Interwave."; + default: throw "Unknown PipelineScheduler"; + } + + switch(BG.pipeline_version) + { + case PipelineVersion::V1: version = ck::BlockGemmPipelineVersion::v1; break; + case PipelineVersion::V2: version = ck::BlockGemmPipelineVersion::v2; break; + case PipelineVersion::V3: version = ck::BlockGemmPipelineVersion::v3; break; + case PipelineVersion::V4: version = ck::BlockGemmPipelineVersion::v4; break; + case PipelineVersion::V5: version = ck::BlockGemmPipelineVersion::v5; break; + case PipelineVersion::WEIGHT_ONLY: + throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM."; + default: throw "Unknown PipelineVersion"; + } + + return BlockGemmSpec{.pipeline_version = version, .scheduler = scheduler}; +} + +template +consteval ck::LoopScheduler SetLoopScheduler() +{ + constexpr auto loop_scheduler = ALGORITHM.loop_scheduler; + using ck_loop_sched = ck::LoopScheduler; + switch(loop_scheduler) + { + case PipelineScheduler::DEFAULT: return ck_loop_sched::Default; + case PipelineScheduler::INTERWAVE: return ck_loop_sched::Interwave; + case PipelineScheduler::INTRAWAVE: throw "LoopScheduler must be either DEFAULT or INTERWAVE."; + default: throw "Unknown PipelineScheduler"; + } +} + +template +consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion() +{ + constexpr auto pipeline_version = ALGORITHM.gridwise_gemm.pipeline_version; + using ck_pipeline = ck::PipelineVersion; + switch(pipeline_version) + { + case PipelineVersion::V1: return ck_pipeline::v1; + case PipelineVersion::V2: return ck_pipeline::v2; + case PipelineVersion::V3: throw "PipelineVersion::V3 is used only for stream-K."; + case PipelineVersion::V4: return ck_pipeline::v4; + case PipelineVersion::V5: throw "PipelineVersion::V5 cannot be used for gridwise GEMM."; + case PipelineVersion::WEIGHT_ONLY: return ck_pipeline::weight_only; + default: throw "Unknown GridwiseGemmPipelineVersion"; + } +} + +template +consteval ck::tensor_operation::device::GemmSpecialization SetGemmSpecialization() +{ + constexpr auto gemm_spec = ALGORITHM.gemm_specialization; + using ck_gemm_spec = ck::tensor_operation::device::GemmSpecialization; + + switch(gemm_spec) + { + case GemmSpecialization::Default: return ck_gemm_spec::Default; + case GemmSpecialization::MPadding: return ck_gemm_spec::MPadding; + case GemmSpecialization::NPadding: return ck_gemm_spec::NPadding; + case GemmSpecialization::KPadding: return ck_gemm_spec::KPadding; + case GemmSpecialization::MNPadding: return ck_gemm_spec::MNPadding; + case GemmSpecialization::MKPadding: return ck_gemm_spec::MKPadding; + case GemmSpecialization::NKPadding: return ck_gemm_spec::NKPadding; + case GemmSpecialization::MNKPadding: return ck_gemm_spec::MNKPadding; + case GemmSpecialization::OPadding: return ck_gemm_spec::OPadding; + case GemmSpecialization::MOPadding: return ck_gemm_spec::MOPadding; + case GemmSpecialization::NOPadding: return ck_gemm_spec::NOPadding; + case GemmSpecialization::KOPadding: return ck_gemm_spec::KOPadding; + case GemmSpecialization::MNOPadding: return ck_gemm_spec::MNOPadding; + case GemmSpecialization::MKOPadding: return ck_gemm_spec::MKOPadding; + case GemmSpecialization::NKOPadding: return ck_gemm_spec::NKOPadding; + case GemmSpecialization::MNKOPadding: return ck_gemm_spec::MNKOPadding; + default: throw "Unknown GemmSpecialization"; + } +} + +template +consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion() +{ + constexpr auto version = ALGORITHM.pipeline_version; + using ck_pipeline = ck::BlockGemmPipelineVersion; + switch(version) + { + case PipelineVersion::V1: return ck_pipeline::v1; + case PipelineVersion::V2: return ck_pipeline::v2; + case PipelineVersion::V3: return ck_pipeline::v3; + case PipelineVersion::V4: return ck_pipeline::v4; + case PipelineVersion::V5: return ck_pipeline::v5; + case PipelineVersion::WEIGHT_ONLY: + throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM pipeline version."; + default: throw "Unknown block GEMM PipelineVersion"; + } +} + +template +consteval ck::tensor_operation::device::ConvolutionForwardSpecialization SetFwdConvSpecialization() +{ + constexpr auto specialization = ALGORITHM.fwd_specialization; + using ck_conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization; + switch(specialization) + { + case ConvFwdSpecialization::DEFAULT: return ck_conv_spec::Default; + case ConvFwdSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0; + case ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0; + case ConvFwdSpecialization::FILTER_3x3: return ck_conv_spec::Filter3x3; + default: throw "Unknown ConvFwdSpecialization"; + } +} + +} // namespace ck_tile::builder::factory::internal diff --git a/experimental/builder/include/ck_tile/builder/reflect/README.md b/experimental/builder/include/ck_tile/builder/reflect/README.md new file mode 100644 index 0000000000..6ef9ebe87a --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/README.md @@ -0,0 +1,38 @@ +# Convolution Reflection Directory + +This directory contains tools for "reflecting" on convolution kernel instances. It allows developers to inspect the compile-time configuration of a kernel and generate detailed, human-readable descriptions. + +See the [main builder documentation](../README.md) for an overview. + +## Design Overview + +The reflection system works by extracting properties from a convolution kernel *type* and formatting them into a string. This is useful for debugging, performance tuning, and generating documentation. + +1. **Trait Extraction**: The `ConvTraits` template (in `conv_traits.hpp`) is specialized for each kernel instance. It extracts low-level details like tile sizes, data layouts, and pipeline versions from the kernel's type definition. + +2. **Description Generation**: The `Describe()` function (in `conv_description.hpp`) uses `ConvTraits` to populate a `ConvDescription` struct. + +3. **Formatting**: The `ConvDescription` struct contains methods like `brief()` and `detailed()` that format the extracted properties into well-structured strings for display. + +## Key Files + +- **`conv_description.hpp`**: The main entry point. Contains the `ConvDescription` struct and the `Describe()` factory function. +- **`conv_traits.hpp`**: Home of the `ConvTraits` template, which is the core of the property extraction mechanism. +- **`tree_formatter.hpp`**: A simple utility for generating the indented, tree-like format used in the `detailed()` description. + +## Usage + +To get a description of a convolution kernel instance, use the `Describe` function and call one of its formatting methods: + +```cpp +#include "ck_tile/builder/reflect/conv_description.hpp" + +// Assume MyConvFwdInstance is a type alias for a specific kernel instance +using MyConvFwdInstance = /* ... some kernel type ... */; + +// Describe the instance +const auto description = ck_tile::reflect::conv::Describe(); + +// Print the detailed description +std::cout << description.detailed() << std::endl; +``` diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp index 08e506b614..375e465721 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp @@ -1,6 +1,22 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT +/** + * @file + * @brief Provides utilities to reflect on convolution kernel instances and generate + * human-readable descriptions of their configuration. + * + * This file contains the necessary components to transform a convolution kernel's + * compile-time properties into a structured, descriptive format. This is primarily + * used for debugging, logging, and generating documentation. + * + * Key components: + * - ck_tile::reflect::conv::ConvDescription: A struct that holds the extracted + * properties and provides methods to format them into strings. + * - ck_tile::reflect::conv::Describe(): A factory function that creates a + * ConvDescription from a given kernel instance type. + */ + #pragma once #include @@ -13,11 +29,13 @@ #include #include -/// @file conv_description.hpp -/// @brief Provides human-readable descriptions of ConvBuilder configurations +/// @brief Provides human-readable descriptions of convolution kernel instances namespace ck_tile::reflect::conv { +/// @brief Signature information for a convolution operation +/// Contains high-level properties that define the convolution's interface, +/// including dimensionality, data layout, data types, and elementwise operations. struct ConvSignatureInfo { int spatial_dim; @@ -30,7 +48,9 @@ struct ConvSignatureInfo builder::ElementwiseOperation output_element_op; }; -// Algorithm information - groups all algorithm-related configuration +/// @brief Algorithm configuration for a convolution kernel +/// Contains low-level implementation details including thread block configuration, +/// tile dimensions, memory access patterns, and pipeline settings. struct GemmAlgorithmInfo { int thread_block_size; @@ -48,13 +68,16 @@ struct GemmAlgorithmInfo builder::GemmPadding padding; }; -// Provides human-readable descriptions of ConvBuilder configurations. +/// @brief Provides human-readable descriptions of convolution kernel instances +/// Generates formatted text descriptions at various levels of detail for +/// understanding and documenting convolution kernel configurations. struct ConvDescription { ConvSignatureInfo signature; GemmAlgorithmInfo algorithm; - // Brief one-line summary + /// @brief Generate a brief one-line summary of the convolution + /// @return A concise description (e.g., "2D Forward convolution") std::string brief() const { std::ostringstream oss; @@ -62,7 +85,8 @@ struct ConvDescription return oss.str(); } - // Detailed hierarchical description + /// @brief Generate a detailed hierarchical description of the convolution + /// @return A multi-line tree-formatted description covering signature and algorithm details std::string detailed() const { TreeFormatter f; @@ -74,7 +98,7 @@ struct ConvDescription f.writeLine(2, "Weights elementwise operation: ", signature.weight_element_op); f.writeLast(2, "Output elementwise operation: ", signature.output_element_op); - f.writeLine(1, "Algorithm"); + f.writeLast(1, "Algorithm"); // Compute Block section f.writeLine(2, "Thread block size: ", algorithm.thread_block_size); f.writeLine(2, @@ -99,7 +123,7 @@ struct ConvDescription algorithm.warp_gemm.n_iter); // Memory Access section - f.writeLine(2, "Memory access:"); + f.writeLast(2, "Memory access:"); f.writeLine(3, "A Tile transfer: "); f.writeLine(4, @@ -195,12 +219,12 @@ struct ConvDescription f.writeLast(4, "Vector access (GMEM write) instruction size: ", algorithm.c_tile_transfer.scalar_per_vector); - f.writeLast(2); - f.writeLast(1); return f.getString(); } - // Educational explanation of optimization choices + /// @brief Generate an educational explanation of optimization choices + /// @return Educational content explaining why certain algorithm choices were made + /// @note Currently unimplemented - reserved for future enhancement std::string explain() const { std::ostringstream oss; @@ -208,7 +232,9 @@ struct ConvDescription return oss.str(); } - // Performance characteristics and use case guidance + /// @brief Generate performance characteristics and use case guidance + /// @return Guidance on when this configuration is optimal and expected performance + /// @note Currently unimplemented - reserved for future enhancement std::string suggest() const { std::ostringstream oss; @@ -217,18 +243,13 @@ struct ConvDescription } }; -// Helper concept to detect if a type has InstanceTraits specialization +/// @brief Helper concept to detect if a type has InstanceTraits specialization template concept HasInstanceTraits = requires { typename InstanceTraits; }; -// Helper concept to detect ConvBuilder types -template -concept IsConvBuilder = requires { - typename T::Factory; - typename T::Instance; -}; - -// Primary factory function: Create ConvDescription from Instance type directly +/// @brief Factory function to create ConvDescription from a convolution instance type +/// @tparam Instance The convolution instance type (must have InstanceTraits specialization) +/// @return A ConvDescription object populated with the instance's configuration details template requires HasInstanceTraits ConvDescription Describe() @@ -255,14 +276,4 @@ ConvDescription Describe() .padding = Traits::gemm_padding}}; } -// Backward compatibility: Create ConvDescription from Builder type -template - requires IsConvBuilder && (!HasInstanceTraits) -ConvDescription Describe() -{ - // Delegate to Instance-based version - using Instance = typename Builder::Instance; - return Describe(); -} - } // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp index 7c1c293a23..a59c6c3045 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp @@ -5,9 +5,9 @@ #include #include -#include #include #include +#include #include #include #include @@ -680,15 +680,14 @@ struct ConvTraits /// @brief Specialization of `ConvTraits` for a `ConvBuilder` type. /// @details This specialization provides backward compatibility for reflecting /// on kernels defined via the `ConvBuilder` interface. It works by first -/// creating the `Instance` via the builder's factory, and then delegating +/// creating the `Instance` via the builder, and then delegating /// all trait extraction to the `ConvTraits` specialization. template struct ConvTraits> { - using Factory = builder::ConvFactory; - using Instance = typename Factory::Instance; + using Instance = typename builder::ConvBuilder::Instance; // Delegate to Instance-based ConvTraits using InstanceConvTraits = ConvTraits; diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 787a5883b8..a32c68e219 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -1,6 +1,48 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +################################################################################ +# CK Builder Test Suite +################################################################################ +# +# This file defines the test suite for the Composable Kernel (CK) Builder, +# which is responsible for generating optimized GPU kernels for convolution +# operations. +# +# TESTING PHILOSOPHY: +# ------------------- +# Tests are organized into two main categories: +# +# 1. SMOKE TESTS (fast, < 1 second total) +# - Unit tests that verify the builder's internal logic +# - Do NOT compile GPU kernels (fast compilation) +# - Run these frequently during development for quick feedback +# - Target: `ninja smoke-builder` +# +# 2. REGRESSION TESTS (slower, may take minutes) +# - Integration tests that compile and verify actual GPU kernels +# - Ensure the builder generates valid, compilable code +# - Include expensive "factory tests" that build all MIOpen kernels +# - Run these before submitting changes +# - Target: `ninja regression-builder` +# +# QUICK START: +# ------------ +# - During development: ninja smoke-builder +# - Before submitting: ninja regression-builder +# - Run everything: ninja check-builder +# - Build specific test: ninja test_ckb_conv_builder && bin/test_ckb_conv_builder +# +################################################################################ + include(gtest) +################################################################################ +# Helper Functions +################################################################################ + # Helper function to create a gtest executable with common properties +# All builder tests share the same compilation settings and dependencies function(add_ck_builder_test test_name) add_executable(${test_name} ${ARGN} testing_utils.cpp) target_compile_features(${test_name} PRIVATE cxx_std_20) @@ -16,17 +58,56 @@ function(add_ck_builder_test test_name) target_link_libraries(${test_name} PRIVATE GTest::gtest_main GTest::gmock) endfunction() -# The test_ckb_conv_builder target has all the unit tests (each test should run < 10 ms) +# Factory tests attempt to build all the kernels needed by MIOpen. +# These are only for regression testing and development; the builds are too +# expensive for regular use in CI. +function(add_ck_factory_test test_name) + add_ck_builder_test(${test_name} ${ARGN}) + target_link_libraries(${test_name} PRIVATE composablekernels::device_conv_operations) +endfunction() + +################################################################################ +# SMOKE TESTS - Fast Unit Tests (No Kernel Compilation) +################################################################################ +# These tests verify the builder's internal logic without compiling GPU kernels. +# They should complete in under 10ms each and are suitable for frequent execution +# during development. add_ck_builder_test(test_ckb_conv_builder + test_bwd_weight_instance_traits.cpp test_conv_builder.cpp test_fwd_instance_traits.cpp - test_bwd_weight_instance_traits.cpp test_bwd_data_instance_traits.cpp - test_instance_traits_util.cpp) + test_instance_traits_util.cpp -add_ck_builder_test(test_ckb_inline_diff test_inline_diff.cpp) + unit_conv_elementwise_op.cpp + unit_conv_tensor_layout.cpp + unit_conv_tensor_type.cpp + unit_conv_thread_block.cpp + unit_conv_tuning_params.cpp) + + # Tests the inline diff utility used for comparing strings in tests assertions + add_ck_builder_test(test_ckb_inline_diff test_inline_diff.cpp) -# Testing the virtual GetInstanceString methods requires kernel compilation. + # Tests convolution trait selection and configuration + add_ck_builder_test(test_ckb_conv_traits + conv/test_conv_traits.cpp) + + # Tests convolution problem description and parameter handling + add_ck_builder_test(test_ckb_conv_description + test_conv_description.cpp) + +################################################################################ +# REGRESSION TESTS - Integration Tests (With Kernel Compilation) +################################################################################ +# These tests compile actual GPU kernels to verify the builder generates valid, +# compilable code. They are more expensive but catch real-world issues. + + +# Verifies that GetInstanceString() methods produce valid kernel code. +# Tests various convolution types: +# - Group convolution (v3, standard, large tensor, WMMA, DL variants) +# - Backward weight group convolution (XDL) +# Requires kernel compilation to validate the generated strings. add_ck_builder_test(test_ckb_get_instance_string test_get_instance_string_fwd_grp_conv_v3.cpp test_get_instance_string_fwd_grp_conv.cpp @@ -35,8 +116,8 @@ add_ck_builder_test(test_ckb_get_instance_string test_get_instance_string_fwd_grp_conv_dl.cpp test_get_instance_string_bwd_weight_grp_conv_xdl.cpp) -# Testing the fwd convolution builder requires kernel compilation. -# To enable parallel compilation, the individual tests are split into separate files. +# Tests the forward convolution builder across multiple data types and dimensions. +# Individual tests are split into separate files to enable parallel compilation. add_ck_builder_test(test_ckb_build_fwd_instances conv/test_ckb_conv_fwd_1d_fp16.cpp conv/test_ckb_conv_fwd_1d_bf16.cpp @@ -52,15 +133,21 @@ add_ck_builder_test(test_ckb_build_fwd_instances conv/test_ckb_conv_fwd_3d_fp32.cpp ) -# Factory tests attempt to build all the kernels need by MIOpen. -# This is only for regression testing and development, the builds are too expensive for regular use in CI. -function(add_ck_factory_test test_name) - add_ck_builder_test(${test_name} ${ARGN}) - target_link_libraries(${test_name} PRIVATE composablekernels::device_conv_operations) -endfunction() -# TODO: add these tests back in once we have CI working across all GPU architectures. +################################################################################ +# FACTORY TESTS - Expensive Regression Tests (Full MIOpen Kernel Set) +################################################################################ +# These tests attempt to build ALL kernels needed by MIOpen for various +# convolution operations. They are extremely expensive (minutes to compile) +# and are intended for deep regression testing and development only. +# NOT suitable for regular CI runs. +# +# Many tests are commented out pending CI support across all GPU architectures. + +# Tests the testing utilities themselves add_ck_factory_test(test_ckb_testing_utils test_testing_utils.cpp) + +# TODO: Re-enable these tests once we have CI working across all GPU architectures. # add_ck_factory_test(test_ckb_factory_grouped_convolution_forward test_ck_factory_grouped_convolution_forward.cpp) # add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_clamp test_ck_factory_grouped_convolution_forward_clamp.cpp) add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_convscale test_ck_factory_grouped_convolution_forward_convscale.cpp) @@ -72,22 +159,30 @@ add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_scaleadd_ab tes add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_scaleadd_scaleadd_relu test_ck_factory_grouped_convolution_forward_scaleadd_scaleadd_relu.cpp) add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_dynamic_op test_ck_factory_grouped_convolution_forward_dynamic_op.cpp) -add_ck_builder_test(test_ckb_conv_traits - conv/test_conv_traits.cpp) +################################################################################ +# CTest Integration - Register Tests and Assign Labels +################################################################################ +# Tests are registered with CTest and labeled for selective execution: +# - BUILDER_SMOKE: Fast unit tests for frequent development cycles +# - BUILDER_REGRESSION: Slower integration tests for pre-submission validation -add_ck_builder_test(test_ckb_conv_description - test_conv_description.cpp) - -# Register tests with CTest and assign labels include(CTest) -# Smoke test: fast-compiling unit test -add_test(NAME test_ckb_conv_builder COMMAND test_ckb_conv_builder) -set_tests_properties(test_ckb_conv_builder PROPERTIES LABELS "BUILDER_SMOKE") - -# Regression tests: all other tests that require kernel compilation -set(CKB_REGRESSION_TESTS +# Register all smoke tests (fast unit tests, no kernel compilation) +set(CKB_SMOKE_TESTS + test_ckb_conv_builder test_ckb_inline_diff + test_ckb_conv_traits + test_ckb_conv_description +) + +foreach(test_target ${CKB_SMOKE_TESTS}) + add_test(NAME ${test_target} COMMAND ${test_target}) + set_tests_properties(${test_target} PROPERTIES LABELS "BUILDER_SMOKE") +endforeach() + +# Register all regression tests (integration tests with kernel compilation) +set(CKB_REGRESSION_TESTS test_ckb_get_instance_string test_ckb_build_fwd_instances test_ckb_testing_utils @@ -95,8 +190,6 @@ set(CKB_REGRESSION_TESTS test_ckb_factory_grouped_convolution_forward_scaleadd_ab test_ckb_factory_grouped_convolution_forward_scaleadd_scaleadd_relu test_ckb_factory_grouped_convolution_forward_dynamic_op - test_ckb_conv_traits - test_ckb_conv_description ) foreach(test_target ${CKB_REGRESSION_TESTS}) @@ -104,18 +197,31 @@ foreach(test_target ${CKB_REGRESSION_TESTS}) set_tests_properties(${test_target} PROPERTIES LABELS "BUILDER_REGRESSION") endforeach() -# Helper target to build all regression tests +################################################################################ +# Custom Build Targets - Convenient Test Execution +################################################################################ +# These targets provide convenient ways to build and run different test suites: +# - smoke-builder: Quick sanity check during development +# - regression-builder: Thorough validation before submitting changes +# - check-builder: Complete test suite execution + +# Helper target to build all smoke tests (without running them) +add_custom_target(build-smoke-builder DEPENDS ${CKB_SMOKE_TESTS}) + +# Helper target to build all regression tests (without running them) add_custom_target(build-regression-builder DEPENDS ${CKB_REGRESSION_TESTS}) -# Target to run only smoke tests (builds only test_ckb_conv_builder) +# Target to run only smoke tests (builds and runs all smoke test executables) +# Use this for quick feedback during active development add_custom_target(smoke-builder COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "BUILDER_SMOKE" - DEPENDS test_ckb_conv_builder + DEPENDS build-smoke-builder USES_TERMINAL COMMENT "Running experimental builder smoke tests..." ) -# Target to run only regression tests (builds all regression test executables) +# Target to run only regression tests (builds and runs all regression test executables) +# Use this before submitting changes to catch integration issues add_custom_target(regression-builder COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "BUILDER_REGRESSION" DEPENDS build-regression-builder @@ -123,15 +229,20 @@ add_custom_target(regression-builder COMMENT "Running experimental builder regression tests..." ) -# Target to run all builder tests (builds all test executables) +# Target to run all builder tests (builds and runs all test executables) +# Use this for comprehensive validation add_custom_target(check-builder COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -R "^test_ckb" - DEPENDS test_ckb_conv_builder build-regression-builder + DEPENDS build-smoke-builder build-regression-builder USES_TERMINAL COMMENT "Running all experimental builder tests..." ) -# Print summary of test organization +################################################################################ +# Build Summary +################################################################################ + +# Print summary of test organization for developer reference message(STATUS "CK Builder test organization:") -message(STATUS " Smoke test: test_ckb_conv_builder") +message(STATUS " Smoke tests: ${CKB_SMOKE_TESTS}") message(STATUS " Regression tests: ${CKB_REGRESSION_TESTS}") diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index 3331bf204f..d89d83357f 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -91,7 +91,7 @@ static_assert(LdsTransferDescriptor); struct Epilogue { - size_t m_per_wave_per_shuffle; + size_t m_xdl_per_wave_per_shuffle; size_t n_per_wave_per_shuffle; size_t scalar_per_vector; }; diff --git a/experimental/builder/test/test_ckb_conv_builder.cpp b/experimental/builder/test/test_ckb_conv_builder.cpp new file mode 100644 index 0000000000..e69de29bb2 diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp index 0d48a91738..20df86efd3 100644 --- a/experimental/builder/test/test_conv_description.cpp +++ b/experimental/builder/test/test_conv_description.cpp @@ -67,11 +67,11 @@ struct DefaultAlgorithm ckb::test::TransferABC transfer{ .a = { - .block_transfer = {.k0 = 4, .m_n = 256, .k1 = 8}, + .block_transfer = {.k0 = 1, .m_n = 128, .k1 = 2}, .lds_transfer = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = true, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 2, + .is_direct_load = false, .lds_padding = false}, .block_transfer_access_order = {.order = {0, 1, 2}}, .src_access_order = {.order = {0, 1, 2}}, @@ -79,11 +79,11 @@ struct DefaultAlgorithm }, .b = { - .block_transfer = {.k0 = 4, .m_n = 256, .k1 = 8}, + .block_transfer = {.k0 = 1, .m_n = 128, .k1 = 2}, .lds_transfer = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = true, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 2, + .is_direct_load = false, .lds_padding = false}, .block_transfer_access_order = {.order = {0, 1, 2}}, .src_access_order = {.order = {0, 1, 2}}, @@ -92,9 +92,9 @@ struct DefaultAlgorithm { .thread_cluster_dims = {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, - .epilogue = {.m_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 8}, + .epilogue = {.m_xdl_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 2}, }, }; @@ -109,16 +109,16 @@ TEST(ConvDescriptionTest, DefaultInstanceHasBriefDescription) { static constexpr const ConvSignature SIGNATURE; static constexpr const DefaultAlgorithm ALGORITHM; - using Builder = ckb::ConvBuilder; - EXPECT_THAT(ckr::Describe().brief(), ckt::StringEqWithDiff("2D Forward convolution")); + using Instance = ckb::ConvBuilder::Instance; + EXPECT_THAT(ckr::Describe().brief(), ckt::StringEqWithDiff("2D Forward convolution")); } TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription) { static constexpr const ConvSignature SIGNATURE; static constexpr const DefaultAlgorithm ALGORITHM; - using Builder = ckb::ConvBuilder; - EXPECT_THAT(ckr::Describe().detailed(), + using Instance = ckb::ConvBuilder::Instance; + EXPECT_THAT(ckr::Describe().detailed(), ckt::StringEqWithDiff( // "2D Forward Convolution Kernel\n" "├─ Signature\n" @@ -127,41 +127,39 @@ TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription) "│ ├─ Input elementwise operation: PASS_THROUGH\n" "│ ├─ Weights elementwise operation: PASS_THROUGH\n" "│ └─ Output elementwise operation: PASS_THROUGH\n" - "├─ Algorithm\n" - "│ ├─ Thread block size: 256\n" - "│ ├─ Data tile size: 256×256×32\n" - "│ ├─ Gemm padding: DEFAULT\n" - "│ ├─ Convolution specialization: DEFAULT\n" - "│ ├─ Pipeline version: V4\n" - "│ ├─ Pipeline scheduler: INTRAWAVE\n" - "│ ├─ Warp Gemm parameters: \n" - "│ │ ├─ subtile size: 16×16\n" - "│ │ └─ Number of warp gemm iterations: 4×4\n" - "│ ├─ Memory access:\n" - "│ │ ├─ A Tile transfer: \n" - "│ │ │ ├─ Tile dimensions: 4×256×8×\n" - "│ │ │ ├─ The innermost K subdimension size: 8\n" - "│ │ │ ├─ Spatial thread distribution over the data tile: 0×1×2\n" - "│ │ │ ├─ The order of accessing data tile axes: 0×1×2\n" - "│ │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" - "│ │ │ ├─ Vector access (GMEM read) instruction size: 8\n" - "│ │ │ ├─ Vector access (LDS write) instruction size: 8\n" - "│ │ │ └─ LDS data layout padding (to prevent bank conflicts): 8\n" - "│ │ ├─ B Tile transfer: \n" - "│ │ │ ├─ Tile dimensions: 4×256×8×\n" - "│ │ │ ├─ The innermost K subdimension size: 8\n" - "│ │ │ ├─ Spatial thread distribution over the data tile: 0×1×2\n" - "│ │ │ ├─ The order of accessing data tile axes: 0×1×2\n" - "│ │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" - "│ │ │ ├─ Vector access (GMEM read) instruction size: 8\n" - "│ │ │ ├─ Vector access (LDS write) instruction size: 8\n" - "│ │ │ └─ LDS data layout padding (to prevent bank conflicts): 8\n" - "│ │ └─ C Tile transfer: \n" - "│ │ ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n" - "│ │ ├─ Spatial thread distribution used to store data: 1×32×1×8\n" - "│ │ └─ Vector access (GMEM write) instruction size: 8\n" - "│ └─ \n" - "└─ ")); + "└─ Algorithm\n" + " ├─ Thread block size: 256\n" + " ├─ Data tile size: 256×256×32\n" + " ├─ Gemm padding: DEFAULT\n" + " ├─ Convolution specialization: DEFAULT\n" + " ├─ Pipeline version: V4\n" + " ├─ Pipeline scheduler: INTRAWAVE\n" + " ├─ Warp Gemm parameters: \n" + " │ ├─ subtile size: 16×16\n" + " │ └─ Number of warp gemm iterations: 4×4\n" + " └─ Memory access:\n" + " ├─ A Tile transfer: \n" + " │ ├─ Tile dimensions: 4×256×8×\n" + " │ ├─ The innermost K subdimension size: 8\n" + " │ ├─ Spatial thread distribution over the data tile: 0×1×2\n" + " │ ├─ The order of accessing data tile axes: 0×1×2\n" + " │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" + " │ ├─ Vector access (GMEM read) instruction size: 2\n" + " │ ├─ Vector access (LDS write) instruction size: 2\n" + " │ └─ LDS data layout padding (to prevent bank conflicts): 2\n" + " ├─ B Tile transfer: \n" + " │ ├─ Tile dimensions: 4×256×8×\n" + " │ ├─ The innermost K subdimension size: 8\n" + " │ ├─ Spatial thread distribution over the data tile: 0×1×2\n" + " │ ├─ The order of accessing data tile axes: 0×1×2\n" + " │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" + " │ ├─ Vector access (GMEM read) instruction size: 2\n" + " │ ├─ Vector access (LDS write) instruction size: 2\n" + " │ └─ LDS data layout padding (to prevent bank conflicts): 2\n" + " └─ C Tile transfer: \n" + " ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n" + " ├─ Spatial thread distribution used to store data: 1×32×1×8\n" + " └─ Vector access (GMEM write) instruction size: 2")); } // NOTE: BackwardDataInstanceHasDetailedDescription test is disabled because ConvFactory diff --git a/experimental/builder/test/unit_conv_elementwise_op.cpp b/experimental/builder/test/unit_conv_elementwise_op.cpp new file mode 100644 index 0000000000..66593bf802 --- /dev/null +++ b/experimental/builder/test/unit_conv_elementwise_op.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include + +#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp" + +namespace { + +using ::ck_tile::builder::factory::internal::ElementwiseOps; +using enum ::ck_tile::builder::ElementwiseOperation; + +TEST(ConvElementwiseOp, AssignsOpsForPassThrough) +{ + using Ops = ElementwiseOps; + + EXPECT_TRUE( + (std::is_same_v)); + EXPECT_TRUE( + (std::is_same_v)); + EXPECT_TRUE( + (std::is_same_v)); +} + +TEST(ConvElementwiseOp, AssignsOpsForScale) +{ + using Ops = ElementwiseOps; + + EXPECT_TRUE( + (std::is_same_v)); + EXPECT_TRUE( + (std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); +} + +} // namespace diff --git a/experimental/builder/test/unit_conv_tensor_layout.cpp b/experimental/builder/test/unit_conv_tensor_layout.cpp new file mode 100644 index 0000000000..6cdcc429dd --- /dev/null +++ b/experimental/builder/test/unit_conv_tensor_layout.cpp @@ -0,0 +1,119 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include + +// Include the helper file we're testing +#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp" + +namespace { + +namespace ckb = ::ck_tile::builder; +using ::ck_tile::builder::factory::internal::ConvTensorLayouts; +using ::ck_tile::builder::factory::internal::GetTensorLayout; +using enum ::ck_tile::builder::ConvDirection; + +TEST(ConvTensorLayout, AssignsLayoutsFor1D_NWGC_GKXC_NWGK) +{ + using TensorLayouts = ConvTensorLayouts; + + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); +} + +TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKXC_NGKW) +{ + using TensorLayouts = ConvTensorLayouts; + + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); +} + +TEST(ConvTensorLayout, AssignsLayoutsFor1D_GNWC_GKXC_GNWK) +{ + using TensorLayouts = ConvTensorLayouts; + + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); +} + +TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKCX_NGKW) +{ + using TensorLayouts = ConvTensorLayouts; + + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); +} + +TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKYXC_NGKHW) +{ + using TensorLayouts = ConvTensorLayouts; + + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); +} + +TEST(ConvTensorLayout, AssignsLayoutsFor2D_NHWGC_GKYXC_NHWGK) +{ + using TensorLayouts = ConvTensorLayouts; + + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); +} + +TEST(ConvTensorLayout, AssignsLayoutsFor2D_GNHWC_GKYXC_GNHWK) +{ + using TensorLayouts = ConvTensorLayouts; + + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); +} + +TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKCYX_NGKHW) +{ + using TensorLayouts = ConvTensorLayouts; + + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); +} + +TEST(ConvTensorLayout, AssignsLayoutsFor3D_NGCDHW_GKCZYX_NGKDHW) +{ + using TensorLayouts = + ConvTensorLayouts; + + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); +} + +TEST(ConvTensorLayout, AssignsLayoutsFor3D_NDHWGC_GKZYXC_NDHWGK) +{ + using TensorLayouts = + ConvTensorLayouts; + + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); +} + +TEST(ConvTensorLayout, AssignsLayoutsFor3D_GNDHWC_GKZYXC_GNDHWK) +{ + using TensorLayouts = + ConvTensorLayouts; + + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); +} + +} // namespace diff --git a/experimental/builder/test/unit_conv_tensor_type.cpp b/experimental/builder/test/unit_conv_tensor_type.cpp new file mode 100644 index 0000000000..5aa82774da --- /dev/null +++ b/experimental/builder/test/unit_conv_tensor_type.cpp @@ -0,0 +1,79 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include + +#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp" + +namespace { + +namespace ckb = ck_tile::builder; +using ck_tile::builder::factory::internal::ConvTensorTypes; + +TEST(ConvTensorType, AssignsTypesForFP16) +{ + using Types = ConvTensorTypes; + + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); +} + +TEST(ConvTensorType, AssignsTypesForBF16) +{ + using Types = ConvTensorTypes; + + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); +} + +TEST(ConvTensorType, AssignsTypesForFP32) +{ + using Types = ConvTensorTypes; + + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); +} + +TEST(ConvTensorType, AssignsTypesForI8) +{ + using Types = ConvTensorTypes; + + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); +} + +TEST(ConvTensorType, AssignsTypesForFP8) +{ + using Types = ConvTensorTypes; + + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); + EXPECT_TRUE((std::is_same_v)); +} + +} // namespace diff --git a/experimental/builder/test/unit_conv_thread_block.cpp b/experimental/builder/test/unit_conv_thread_block.cpp new file mode 100644 index 0000000000..f829708696 --- /dev/null +++ b/experimental/builder/test/unit_conv_thread_block.cpp @@ -0,0 +1,35 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp" + +namespace { + +using ::ck_tile::builder::factory::internal::ConvBlock; +using ::ck_tile::builder::factory::internal::SetThreadBlockInfo; + +TEST(ConvThreadBlock, AssignsThreadBlockAndTileSize) +{ + constexpr struct Algorithm + { + struct ThreadBlock + { + int block_size = 256; + struct TileSize + { + int m = 128; + int n = 128; + int k = 16; + } tile_size; + } thread_block; + } kAlgorithm; + constexpr ConvBlock block_info = SetThreadBlockInfo(); + + EXPECT_EQ(block_info.block_size, 256); + EXPECT_EQ(block_info.per_block.m, 128); + EXPECT_EQ(block_info.per_block.n, 128); + EXPECT_EQ(block_info.per_block.k, 16); +} + +} // namespace diff --git a/experimental/builder/test/unit_conv_tuning_params.cpp b/experimental/builder/test/unit_conv_tuning_params.cpp new file mode 100644 index 0000000000..82117c53d8 --- /dev/null +++ b/experimental/builder/test/unit_conv_tuning_params.cpp @@ -0,0 +1,90 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp" + +namespace { + +namespace ckb = ::ck_tile::builder; +using namespace ck_tile::builder; +using namespace ck_tile::builder::factory::internal; + +TEST(ConvTuningParams, AssignsBlockGemmParams) +{ + constexpr struct Algorithm + { + struct BlockGemm + { + ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V3; + ckb::PipelineScheduler scheduler = ckb::PipelineScheduler::INTRAWAVE; + } block_gemm; + } kAlgorithm; + constexpr auto block_gemm = SetBlockGemm(); + + EXPECT_EQ(block_gemm.pipeline_version, ck::BlockGemmPipelineVersion::v3); + EXPECT_EQ(block_gemm.scheduler, ck::BlockGemmPipelineScheduler::Intrawave); +} + +TEST(ConvTuningParams, AssignsLoopSchedulerParam) +{ + constexpr struct Algorithm + { + ckb::PipelineScheduler loop_scheduler = ckb::PipelineScheduler::INTERWAVE; + } kAlgorithm; + constexpr auto loop_scheduler = SetLoopScheduler(); + + EXPECT_EQ(loop_scheduler, ck::LoopScheduler::Interwave); +} + +TEST(ConvTuningParams, AssignsGridwiseGemmPipelineVersion) +{ + constexpr struct Algorithm + { + struct GridwiseGemm + { + ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V4; + } gridwise_gemm; + } kAlgorithm; + constexpr auto pipeline_version = SetGridwiseGemmPipelineVersion(); + + EXPECT_EQ(pipeline_version, ck::PipelineVersion::v4); +} + +TEST(ConvTuningParams, AssignsGemmSpecialization) +{ + constexpr struct Algorithm + { + ckb::GemmSpecialization gemm_specialization = ckb::GemmSpecialization::MNKPadding; + } kAlgorithm; + constexpr auto gemm_spec = SetGemmSpecialization(); + + EXPECT_EQ(gemm_spec, ck::tensor_operation::device::GemmSpecialization::MNKPadding); +} + +TEST(ConvTuningParams, AssignsBlockGemmPipelineVersion) +{ + constexpr struct Algorithm + { + ckb::PipelineVersion pipeline_version = ckb::PipelineVersion::V2; + } kAlgorithm; + constexpr auto pipeline_version = SetBlockGemmPipelineVersion(); + + EXPECT_EQ(pipeline_version, ck::BlockGemmPipelineVersion::v2); +} + +TEST(ConvTuningParams, AssignsFwdConvSpecialization) +{ + constexpr struct Algorithm + { + ckb::ConvFwdSpecialization fwd_specialization = + ckb::ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0; + } kAlgorithm; + constexpr auto conv_spec = SetFwdConvSpecialization(); + + EXPECT_EQ(conv_spec, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0); +} + +} // namespace diff --git a/experimental/builder/test/utils/ckb_conv_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_test_configs.hpp index 7384603854..5436755608 100644 --- a/experimental/builder/test/utils/ckb_conv_test_configs.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_configs.hpp @@ -66,9 +66,9 @@ constexpr TransferABC FwdTransfer_4x64x1{ { .thread_cluster_dims = {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, - .epilogue = {.m_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 8}, + .epilogue = {.m_xdl_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, }, }; @@ -99,9 +99,9 @@ constexpr TransferABC FwdTransfer_4x64x1_fp8{ { .thread_cluster_dims = {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, - .epilogue = {.m_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 8}, + .epilogue = {.m_xdl_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, }, }; @@ -132,9 +132,9 @@ constexpr TransferABC FwdTransfer_4x16x1{ { .thread_cluster_dims = {.m_block = 1, .m_wave_per_xdl = 16, .n_block = 1, .n_wave_per_xdl = 4}, - .epilogue = {.m_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 8}, + .epilogue = {.m_xdl_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, }, }; @@ -166,9 +166,9 @@ constexpr TransferABC FwdTransfer_4x32x1{ { .thread_cluster_dims = {.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 4}, - .epilogue = {.m_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 8}, + .epilogue = {.m_xdl_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, }, }; diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 9debdc12b2..879fb31ca5 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -69,7 +69,7 @@ #endif #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \ defined(__gfx1103__) || defined(__gfx1150__) || defined(__gfx1151__) || \ - defined(__gfx1152__) || defined(__gfx11_generic__) + defined(__gfx1152__) || defined(__gfx1153__) || defined(__gfx11_generic__) #define __gfx11__ #endif #if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__) diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index 6db5c0a071..8739f65740 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -62,7 +62,7 @@ inline bool is_gfx11_supported() return ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103" || ck::get_device_name() == "gfx1150" || ck::get_device_name() == "gfx1151" || - ck::get_device_name() == "gfx1152"; + ck::get_device_name() == "gfx1152" || ck::get_device_name() == "gfx1153"; } inline bool is_xdl_supported() diff --git a/include/ck/tensor_operation/gpu/device/device_base.hpp b/include/ck/tensor_operation/gpu/device/device_base.hpp index a42f7170aa..ec623db6f7 100644 --- a/include/ck/tensor_operation/gpu/device/device_base.hpp +++ b/include/ck/tensor_operation/gpu/device/device_base.hpp @@ -199,7 +199,7 @@ struct BaseArgument BaseArgument(const BaseArgument&) = default; BaseArgument& operator=(const BaseArgument&) = default; - virtual ~BaseArgument() {} + virtual __host__ __device__ ~BaseArgument() {} void* p_workspace_ = nullptr; }; diff --git a/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 29ccd7289f..5f60d8787d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -446,7 +446,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle using GemmADataType = ck::conditional_t, ADataType>; using GemmBDataType = ck::conditional_t, BDataType>; -#define GridwiseGemmMultiABDTemplateParameters \ +#define CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_TEMPLATE_PARAMETERS \ GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \ @@ -462,7 +462,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CDEBlockTransferScalarPerVector_NPerBlock, LoopSched -#define GridwiseGemmTemplateParameters \ +#define CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_TEMPLATE_PARAMETERS \ GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \ @@ -480,8 +480,10 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle template using GridwiseGemmBase = ck::conditional_t< isMultiA || isMultiB, - GridwiseGemmMultipleABD_xdl_cshuffle, - GridwiseGemmMultipleD_xdl_cshuffle>; + GridwiseGemmMultipleABD_xdl_cshuffle, + GridwiseGemmMultipleD_xdl_cshuffle>; +#undef CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_TEMPLATE_PARAMETERS +#undef CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_TEMPLATE_PARAMETERS using GridwiseGemm64 = GridwiseGemmBase; using GridwiseGemm32 = GridwiseGemmBase; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index b291b20bcd..d33e807828 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -439,7 +439,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 } // GridwiseGemm -#define GridwiseGemmMultiDTemplateParams \ +#define CK_GRIDWISE_GEMM_BWD_DATA_MULTIPLE_D_TEMPLATE_PARAMETERS \ ABDataType, ABDataType, AComputeType, AccDataType, CShuffleDataType, DsDataType, EDataType, \ AElementwiseOp, BElementwiseOp, CDEElementwiseOp, NumGemmKPrefetchStage, BlockSize, \ MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, \ @@ -454,7 +454,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType -#define GridwiseGemmCTransposeTemplateParameters \ +#define CK_GRIDWISE_GEMM_BWD_DATA_CTRANSPOSE_TEMPLATE_PARAMETERS \ ABDataType, ABDataType, AComputeType, AccDataType, CShuffleDataType, DsDataType, EDataType, \ BElementwiseOp, AElementwiseOp, CDEElementwiseOp, NumGemmKPrefetchStage, BlockSize, \ NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, MPerXDL, NXdlPerWave_, MXdlPerWave, \ @@ -470,10 +470,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType template - using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle; + using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle< + CK_GRIDWISE_GEMM_BWD_DATA_MULTIPLE_D_TEMPLATE_PARAMETERS>; template - using GridwiseGemmCTransposeBase = - GridwiseGemmMultipleD_xdl_cshuffle; + using GridwiseGemmCTransposeBase = GridwiseGemmMultipleD_xdl_cshuffle< + CK_GRIDWISE_GEMM_BWD_DATA_CTRANSPOSE_TEMPLATE_PARAMETERS>; +#undef CK_GRIDWISE_GEMM_BWD_DATA_MULTIPLE_D_TEMPLATE_PARAMETERS +#undef CK_GRIDWISE_GEMM_BWD_DATA_CTRANSPOSE_TEMPLATE_PARAMETERS using GridwiseGemm64 = GridwiseGemmBase; using GridwiseGemm32 = GridwiseGemmBase; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 698af8846d..a9b0975050 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -485,7 +485,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle using GemmADataType = std::conditional_t, ADataType>; using GemmBDataType = std::conditional_t, BDataType>; -#define GridwiseGemmMultiABDTemplateParameters \ +#define CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_TEMPLATE_PARAMETERS \ GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \ @@ -502,7 +502,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \ BComputeDataType -#define GridwiseGemmTemplateParameters \ +#define CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_XDL_CSHUFFLE_TEMPLATE_PARAMETERS \ GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \ @@ -518,7 +518,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \ BComputeDataType, DoElementwiseBeforeCShuffle -#define GridwiseGemmCTransposeTemplateParameters \ +#define CK_GRIDWISE_GEMM_FWD_CTRANSPOSE_XDL_CSHUFFLE_TEMPLATE_PARAMETERS \ GemmBDataType, GemmADataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ EDataType, BElementwiseOperation, AElementwiseOperation, CDEElementwiseOperation, \ NumGemmKPrefetchStage, BlockSize, NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, \ @@ -536,14 +536,17 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // Use appropriate gridwise gemm template - using GridwiseGemmMultipleABDBase = - GridwiseGemmMultipleABD_xdl_cshuffle; + using GridwiseGemmMultipleABDBase = GridwiseGemmMultipleABD_xdl_cshuffle< + CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_TEMPLATE_PARAMETERS>; template - using GridwiseGemmMultipleDBase = - GridwiseGemmMultipleD_xdl_cshuffle; + using GridwiseGemmMultipleDBase = GridwiseGemmMultipleD_xdl_cshuffle< + CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_XDL_CSHUFFLE_TEMPLATE_PARAMETERS>; template - using GridwiseGemmMultipleDCTransposeBase = - GridwiseGemmMultipleD_xdl_cshuffle; + using GridwiseGemmMultipleDCTransposeBase = GridwiseGemmMultipleD_xdl_cshuffle< + CK_GRIDWISE_GEMM_FWD_CTRANSPOSE_XDL_CSHUFFLE_TEMPLATE_PARAMETERS>; +#undef CK_GRIDWISE_GEMM_FWD_MULTIPLE_ABD_XDL_CSHUFFLE_TEMPLATE_PARAMETERS +#undef CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_XDL_CSHUFFLE_TEMPLATE_PARAMETERS +#undef CK_GRIDWISE_GEMM_FWD_CTRANSPOSE_XDL_CSHUFFLE_TEMPLATE_PARAMETERS using GridwiseGemm64 = std::conditional_t - using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle< + CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_LARGE_TENSOR_TEMPLATE_PARAMETERS>; +#undef CK_GRIDWISE_GEMM_FWD_MULTIPLE_D_LARGE_TENSOR_TEMPLATE_PARAMETERS + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // desc for blockwise copy using AGridDesc_AK0_M_AK1 = diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp new file mode 100644 index 0000000000..2f0c047167 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp @@ -0,0 +1,827 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/env.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/hip_check_error.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_grouped_gemm_wmma_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count) +{ +#if(defined(__gfx11__) || defined(__gfx12__)) + constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< + typename GridwiseGemm::EpilogueCShuffle>(); + __shared__ char p_shared[LDS_size]; + + const index_t block_id = get_block_1d_id(); + const auto gemm_desc_ptr = + reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); + + // Binary search lookup to find which group this block is part of + index_t left = 0; + index_t right = group_count; + index_t group_id = index_t((left + right) / 2); + while((!(block_id >= gemm_desc_ptr[group_id].block_start_ && + block_id < gemm_desc_ptr[group_id].block_end_)) && + left <= right) + { + if(block_id < gemm_desc_ptr[group_id].block_start_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); + } + + // NOTE: Local copy of the arg struct since SplitKBatchOffset verifies and modifies K index + // and thus needs a non-const reference. It's also not feasible to store this in global + // memory as different threads would be writing different K values to the same arg struct + auto karg = gemm_desc_ptr[group_id].karg_; + +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using c_data_type = remove_cvref_t>; + if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + const auto& block_2_ctile_map = gemm_desc_ptr[group_id].block_2_ctile_map_; + + // Tile index first dimension is the K batch + auto tile_index = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + auto splitk_batch_offset = + typename GridwiseGemm::SplitKBatchOffset(karg, tile_index[Number<0>{}]); + auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + + GridwiseGemm::template Run(static_cast(p_shared), + splitk_batch_offset, + karg, + block_2_ctile_map, + epilogue_args); +#if defined(__gfx11__) + } +#endif +#else + ignore = gemm_descs_const; + ignore = group_count; +#endif // end of if(defined(__gfx11__) || defined(__gfx12__)) +} + +template +struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static_assert(KPerBlock % AK1 == 0); + static constexpr index_t K0PerBlock = KPerBlock / AK1; + + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + ALayout, + BLayout, + DsLayout, + ELayout, + Tuple, + Tuple, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + false, // PermuteA not supported by DeviceBatchedGemm base class. + false>; // PermuteB not supported by DeviceBatchedGemm base class. + + using CGridDesc_M_N = + remove_cvref_t( + 1, 1, 1, 1, 1))>; + using Block2ETileMapKSplit = + BlockToCTileMap_KSplit_M00_N0_M01Adapt; + // Block2CTileMap configuration parameter. + static constexpr index_t B2E_M01 = 8; + using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap; + using KernelArgument = typename GridwiseGemm::Argument; + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + template + struct GemmTransKernelArgBase + { + KernelArgument_ karg_; + GroupedGemmBlock2ETileMap block_2_ctile_map_; + index_t block_start_, block_end_; + + GemmTransKernelArgBase() = default; + GemmTransKernelArgBase(KernelArgument_&& karg, + GroupedGemmBlock2ETileMap&& b2c_map, + index_t block_start, + index_t block_end) + : karg_{karg}, + block_2_ctile_map_{b2c_map}, + block_start_{block_start}, + block_end_{block_end} + { + } + }; + using GemmTransKernelArg = GemmTransKernelArgBase; + + static constexpr index_t DefaultKBatch = 1; + + static constexpr bool CalculateHasMainKBlockLoop(const KernelArgument& karg) + { + index_t k_grain = karg.KBatch * KPerBlock; + index_t K_split = (karg.K + k_grain - 1) / karg.KBatch; + return GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + } + + // Argument + // TODO: Add A/B/CDE element op? + struct Argument : public BaseArgument + { + + Argument(std::vector& p_As, + std::vector& p_Bs, + std::vector& p_Es, + std::vector& gemm_descs) + : Argument(p_As, p_Bs, p_Es, gemm_descs, DefaultKBatch) + { + // TODO: use occupancy api to calculate appropriate batch size. + } + + Argument(std::vector& p_As, + std::vector& p_Bs, + std::vector& p_Es, + std::vector& gemm_descs, + index_t kbatch) + : K_BATCH{kbatch}, gemm_kernel_host_args_{nullptr} + { + grid_size_ = 0; + group_count_ = ck::type_convert(gemm_descs.size()); + + if(!(group_count_ == ck::type_convert(p_As.size()) && + group_count_ == ck::type_convert(p_Bs.size()) && + group_count_ == ck::type_convert(p_Es.size()))) + { + throw std::runtime_error("wrong! group_count_ != p_As/b/c.size"); + } + + gemm_kernel_args_.reserve(group_count_); + + skipped_group_count_ = 0; + + for(std::size_t i = 0; i < gemm_descs.size(); ++i) + { + const index_t M = gemm_descs[i].M_; + const index_t N = gemm_descs[i].N_; + const index_t K = gemm_descs[i].K_; + + if(M == 0) + { + skipped_group_count_++; + continue; + } + + const index_t stride_a = gemm_descs[i].stride_A_; + const index_t stride_b = gemm_descs[i].stride_B_; + const index_t stride_c = gemm_descs[i].stride_C_; + + const index_t m_padded = GridwiseGemm::CalculateMPadded(M); + const index_t n_padded = GridwiseGemm::CalculateNPadded(N); + + const auto c_grid_desc_m_n = + GridwiseGemm::template MakeDEGridDescriptor_M_N( + M, m_padded, N, n_padded, stride_c); + + const auto local_b2c_tile_map = + Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH}; + const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n); + + const index_t block_start = grid_size_; + const index_t block_end = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + // block-to-e-tile map + auto grouped_block_2_ctile_map = + GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); + + auto karg = KernelArgument(std::array{p_As[i]}, + std::array{p_Bs[i]}, + std::array{}, // p_ds_grid_ + type_convert(p_Es[i]), + M, + N, + K, + std::array{stride_a}, + std::array{stride_b}, + std::array{}, // StrideDs_ + stride_c, + K_BATCH, + PassThrough{}, + PassThrough{}, + PassThrough{}, + false); + + gemm_kernel_args_.emplace_back( + std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end); + } + } + + /** + * @brief Recalculate group grid size for all gemms and update B2C maps. + * + * @param[in] kbatch The new splitK parameter value. + */ + void UpdateKBatch(index_t kbatch) + { + K_BATCH = kbatch; + grid_size_ = 0; + + for(std::size_t i = 0; i < gemm_kernel_args_.size(); ++i) + { + auto& karg = gemm_kernel_args_[i].karg_; + + const index_t k_read = GridwiseGemm::CalculateKRead(karg.K, K_BATCH); + const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH); + const index_t ak0_padded = GridwiseGemm::CalculateAK0Padded(karg.K, K_BATCH); + const index_t bk0_padded = GridwiseGemm::CalculateBK0Padded(karg.K, K_BATCH); + + const auto c_grid_desc_m_n = + GridwiseGemm::template MakeDEGridDescriptor_M_N( + karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideE); + + const auto local_b2c_tile_map = + Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH}; + const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n); + + const index_t block_start = grid_size_; + const index_t block_end = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + // block-to-e-tile map + auto grouped_block_2_ctile_map = + GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); + + karg.KRead = k_read; + karg.KPadded = k_padded; + karg.AK0 = ak0_padded; + karg.BK0 = bk0_padded; + karg.KBatch = K_BATCH; + gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map; + gemm_kernel_args_[i].block_start_ = block_start; + gemm_kernel_args_[i].block_end_ = block_end; + } + } + + // private: + index_t K_BATCH; + index_t group_count_; + index_t skipped_group_count_; + + std::vector gemm_kernel_args_; + void* gemm_kernel_host_args_; + index_t grid_size_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, + const StreamConfig& stream_config = StreamConfig{}, + hipStream_t cpy_stream = nullptr, + hipEvent_t cpy_event = nullptr) + { + using GemmTransKernelArg_ = GemmTransKernelArgBase; + static_assert(sizeof(GemmTransKernelArg_) == sizeof(GemmTransKernelArg)); + + bool all_have_kbatch_gt_one = arg.gemm_kernel_args_[0].karg_.KBatch > 1; + bool all_have_main_k0_block_loop = + CalculateHasMainKBlockLoop(arg.gemm_kernel_args_[0].karg_); + + bool not_all_have_main_k0_block_loop_same = false; + bool not_all_have_kbatch_value_same = false; + + for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) + { + const auto& karg = reinterpret_cast( + arg.gemm_kernel_args_[i].karg_); + if(stream_config.log_level_ > 0) + { + karg.Print(); + } + + auto kbatch = karg.KBatch; + + if(!GridwiseGemm::CheckValidity(karg)) + { + std::ostringstream err; + err << "Group id: " << i << " has invalid GridwiseGemm settings!" << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + not_all_have_main_k0_block_loop_same |= + all_have_main_k0_block_loop xor CalculateHasMainKBlockLoop(karg); + not_all_have_kbatch_value_same |= all_have_kbatch_gt_one xor (kbatch > 1); + } + + if(not_all_have_main_k0_block_loop_same) + { + std::ostringstream err; + err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + // throw std::runtime_error(err.str()); + } + + if(not_all_have_kbatch_value_same) + { + std::ostringstream err; + err << "Not all gemms have same kbatch value (=1 or >1)! " << " in " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + // If the user provides copy stream and copy event, we assume that they're also + // responsible for providing allocated host memory (eg. pinned) which + // would be used to copy kernel arguments to the device. + if(cpy_stream && cpy_event) + { + if(arg.gemm_kernel_host_args_ == nullptr) + { + std::ostringstream err; + err << "No memory has been allocated for gemm kernel host args " + << "when providing the copy stream and copy event! In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + hip_check_error(hipMemcpyAsync(arg.p_workspace_, + arg.gemm_kernel_host_args_, + arg.group_count_ * sizeof(GemmTransKernelArg_), + hipMemcpyHostToDevice, + cpy_stream)); + hip_check_error(hipEventRecord(cpy_event, cpy_stream)); + hip_check_error(hipEventSynchronize(cpy_event)); + } + else // In this case CK owns memory allocated on host. + { + + hip_check_error( + hipMemcpyAsync(arg.p_workspace_, + arg.gemm_kernel_args_.data(), + arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg_), + hipMemcpyHostToDevice, + stream_config.stream_id_)); + } + + float ave_time = 0; + + const auto Run = [&](const auto& kernel) { + if(all_have_kbatch_gt_one) + { + for(const auto& trans_arg : arg.gemm_kernel_args_) + { + const auto& karg = trans_arg.karg_; + hip_check_error(hipMemsetAsync(karg.p_e_grid, + 0, + karg.M * karg.N * sizeof(EDataType), + stream_config.stream_id_)); + } + } + + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.p_workspace_), + arg.gemm_kernel_args_.size()); + }; + + // NOTE: If at least one gemm problem has a main k0 block loop, we include it for all + if(all_have_main_k0_block_loop || not_all_have_main_k0_block_loop_same) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(all_have_kbatch_gt_one) + { + const auto kernel = + kernel_grouped_gemm_wmma_splitk; + + Run(kernel); + } + else + { + const auto kernel = + kernel_grouped_gemm_wmma_splitk; + + Run(kernel); + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(all_have_kbatch_gt_one) + { + const auto kernel = + kernel_grouped_gemm_wmma_splitk; + + Run(kernel); + } + else + { + const auto kernel = + kernel_grouped_gemm_wmma_splitk; + + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + return false; + } + if constexpr(std::is_same_v || + std::is_same_v) + { + if(arg.K_BATCH > 1 && ck::is_gfx11_supported()) + { + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + return false; + } + } + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + return false; + } + } + + if((ck::type_convert(arg.gemm_kernel_args_.size()) + + arg.skipped_group_count_) != arg.group_count_) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "The group count is not equal to sum of skipped groups " + "and kernel args size!" + << std::endl; + } + return false; + } + + bool supported = true; + for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) + { + const auto& a = arg.gemm_kernel_args_[i].karg_; + bool group_arg_valid = GridwiseGemm::CheckValidity(a); + + if(not group_arg_valid) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[" << __func__ << "] group id: " << i + << " has invalid GridwiseGemm settings!" << std::endl; + a.Print(); + } + } + supported = supported && group_arg_valid; + } + return supported; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::vector& p_As, + std::vector& p_Bs, + std::vector>&, + std::vector& p_Es, + std::vector gemm_descs, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation) + { + return Argument{p_As, p_Bs, p_Es, gemm_descs}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(std::vector& p_As, + std::vector& p_Bs, + std::vector>&, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation) override + { + return std::make_unique(p_As, p_Bs, p_Es, gemm_descs); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGroupedGemm_WmmaSplitK" + << "<" + << std::string(ALayout::name)[0] << "," + << std::string(BLayout::name)[0] << "," + << std::string(ELayout::name)[0] << "," + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerWmma << ", " + << NPerWmma << ", " + << MRepeat << ", " + << NRepeat << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMRepeatPerShuffle << ", " + << CShuffleNRepeatPerShuffle << ", " + << getGemmSpecializationString(GemmSpec) << ", " + << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", " + << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] + << ">"; + // clang-format on + + return str.str(); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + return p_arg_->gemm_kernel_args_.size() * sizeof(GemmTransKernelArg); + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Wmma_CShuffleV3::Argument structure!"); + } + + size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override + { + return GetWorkSpaceSize(p_arg); + } + + size_t GetHostKernelArgSize(const BaseArgument* p_arg) const { return GetWorkSpaceSize(p_arg); } + + // TODO: deperecation notice. + static void SetKBatchSize(Argument& arg, index_t kbatch) { arg.UpdateKBatch(kbatch); } + + // polymorphic + void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + p_arg_->UpdateKBatch(kbatch); + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Wmma_CShuffleV3::Argument structure!"); + } + + void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override + { + return this->SetWorkSpacePointer(p_arg, p_dev_kernel_args); + } + + //---------------------------------------------------------------------------------------------- + /// @brief Sets the host kernel arguments pointer and copies that data on the host side. + /// This function can be utilised to use pinned memory for the host args and + /// achieve fully async data copy. + /// + /// @param p_arg The pointer to the Argument we're going to update. + /// @param[in] p_host_kernel_args The pointer to the host memory where the kernel + /// arguments will be copied + /// + void SetHostKernelArgsPointer(BaseArgument* p_arg, void* p_host_kernel_args) const + { + Argument* pArg_ = dynamic_cast(p_arg); + if(!pArg_) + { + throw std::runtime_error("Failed to cast argument pointer!"); + } + + pArg_->gemm_kernel_host_args_ = p_host_kernel_args; + std::copy(pArg_->gemm_kernel_args_.begin(), + pArg_->gemm_kernel_args_.end(), + static_cast(pArg_->gemm_kernel_host_args_)); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp index 8eff9d2415..ddc941100d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp b/include/ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp index d43dab2983..365e593dfe 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp index cff386c7a7..c413befd80 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/split_k_arg.hpp b/include/ck/tensor_operation/gpu/device/impl/split_k_arg.hpp index de683f3282..e812a9ce59 100644 --- a/include/ck/tensor_operation/gpu/device/impl/split_k_arg.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/split_k_arg.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp b/include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp index 32179d179e..3a3bacd945 100644 --- a/include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include diff --git a/include/ck/tensor_operation/gpu/device/masking_specialization.hpp b/include/ck/tensor_operation/gpu/device/masking_specialization.hpp index cc500bb9cb..5a797c25bf 100644 --- a/include/ck/tensor_operation/gpu/device/masking_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/masking_specialization.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/matrix_padder.hpp b/include/ck/tensor_operation/gpu/device/matrix_padder.hpp index 0294153147..95e7bd367a 100644 --- a/include/ck/tensor_operation/gpu/device/matrix_padder.hpp +++ b/include/ck/tensor_operation/gpu/device/matrix_padder.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp b/include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp index 5351d4ef24..8a6168419e 100644 --- a/include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp +++ b/include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/tensor_layout.hpp b/include/ck/tensor_operation/gpu/device/tensor_layout.hpp index 79deb81512..7018bbd251 100644 --- a/include/ck/tensor_operation/gpu/device/tensor_layout.hpp +++ b/include/ck/tensor_operation/gpu/device/tensor_layout.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/tensor_specialization.hpp b/include/ck/tensor_operation/gpu/device/tensor_specialization.hpp index 713fc93ebb..78ab03280b 100644 --- a/include/ck/tensor_operation/gpu/device/tensor_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/tensor_specialization.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/welford_helper.hpp b/include/ck/tensor_operation/gpu/device/welford_helper.hpp index d7772d8764..eb210b702a 100644 --- a/include/ck/tensor_operation/gpu/device/welford_helper.hpp +++ b/include/ck/tensor_operation/gpu/device/welford_helper.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index 10a9a4dbae..5da2dbc567 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp index 083327b8f0..cd761abb3c 100644 --- a/include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp index 1bb0b63792..2c17b82608 100644 --- a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/element/quantization_operation.hpp b/include/ck/tensor_operation/gpu/element/quantization_operation.hpp index a5cbfbb2fc..915edee1c6 100644 --- a/include/ck/tensor_operation/gpu/element/quantization_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/quantization_operation.hpp @@ -1,3 +1,6 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + #pragma once #include "ck/utility/data_type.hpp" diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index d0802ff65d..6cd7b3d9f6 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp index 7c9febf4de..fd41067fd7 100644 --- a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp +++ b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp index 4e182ec29d..37cf21fb09 100644 --- a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp +++ b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp index a82a173500..e0d6c98aff 100644 --- a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp +++ b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp index 672be91a79..6e486f0738 100644 --- a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp +++ b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp index 2d5dc90bfb..6d1b454282 100644 --- a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp +++ b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp index b6bc634d74..0b0c418a6e 100644 --- a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp +++ b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp b/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp index c2bd65f134..942d4351b3 100644 --- a/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_welford_wmma.hpp b/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_welford_wmma.hpp index 85d13538cc..c78673b7ac 100644 --- a/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_welford_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_welford_wmma.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp b/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp index ccd999b724..6b769cc331 100644 --- a/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp b/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp index 30f81b7411..b8dd5905aa 100644 --- a/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp +++ b/include/ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma_base.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp index 7a09b84a63..f47a84613e 100644 --- a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp index 69468c25be..87f3d50e10 100644 --- a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp index 2d9197a7f4..949edb35f6 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp index fc4f27e33b..5ad0ef3117 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp index 774df1f993..f72aca8605 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp index 910c926c7e..fb20531133 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp index cc3306e1bd..637ba27cef 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp index 23f16d38e9..12c2c38b7e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp index a36ccd43ca..64a8d5786a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp index 38ebdab65e..f58f67dc6b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp index a15f11a93f..da731ead2f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp index b8f5a545aa..a3c27c9555 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp index 0e8d003071..fd0f77aad7 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp index 502c449ef1..231acc7e4f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp index e0cf12e429..bc2c197847 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp index ed1ffdd857..1a9bbcb603 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp index b6c83af13a..0f0825b4e4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d_scale.hpp index 13e9f7bd5e..d547ab612f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d_scale.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp index 839a68a978..ca0372b521 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp index 072275c089..f2cd6fca5c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp @@ -1,500 +1,500 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck/utility/data_type.hpp" -#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" -#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" -#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -namespace ck { - -// X = Elementwise(input1, input2, input3, ...) -// Y = Normalization(X, beta, gamma) -template -struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk -{ - static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || - (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), - "Invalid thread slice sizes and/or vector sizes configuration, please check!"); - - static_assert((YDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) || - (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), - "Invalid thread slice sizes and/or vector sizes configuration, please check!"); - - static constexpr index_t NumInput = InDataTypePointerTuple::Size(); - - static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0); - - using ThreadClusterLengths_M_K = Sequence; - - using ThreadBufferDimAccessOrder = - typename conditional, Sequence<0, 1>>::type; - - using ThreadClusterArrangeOrder = - typename conditional, Sequence<0, 1>>::type; - - static constexpr auto thread_cluster_desc = - make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); - - using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}))); - using ThreadReduceDstDesc_M = - decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); - - using ThreadwiseWelford = - ThreadwiseWelford; - - using BlockwiseWelford = BlockwiseWelford; - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - - static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; - static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; - static constexpr index_t K_BlockTileStepSize = KThreadClusterSize * XSrcVectorSize; - - static constexpr auto XThreadBufferNumber = Number{}; - static constexpr auto GammaThreadBufferNumber = Number{}; - static constexpr auto BetaThreadBufferNumber = Number{}; - static constexpr auto YThreadBufferNumber = Number{}; - - __device__ static int GetKPerThread(const GridDesc_M_K& x_grid_desc_m_k, - int thread_k_cluster_id) - { - int kPerBlock = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]; - int kPerThread = - kPerBlock < K_BlockTileSize ? 0 : KThreadSliceSize * (kPerBlock / K_BlockTileSize); - int kPerBlockTail = kPerBlock - kPerThread * KThreadClusterSize; - - if(kPerBlockTail > 0) - { - static_for<0, XThreadBufferNumber, 1>{}([&](auto i) { - int thread_max_len = - (thread_k_cluster_id + 1) * XSrcVectorSize + K_BlockTileStepSize * i; - int delta = thread_max_len - kPerBlockTail; - delta = math::clamp(thread_max_len - kPerBlockTail, 0, XSrcVectorSize); - kPerThread += XSrcVectorSize - delta; - }); - } - - return kPerThread; - } - - __device__ static void Run(const InGrid2dDescTuple in_grid_2d_desc_tuple, - const GridDesc_M_K& x_grid_desc_m_k, - const GridDesc_M_K& gamma_grid_desc_m_k, - const GridDesc_M_K& beta_grid_desc_m_k, - const GridDesc_M_K& y_grid_desc_m_k, - index_t num_k_block_tile_iteration, - AccDataType epsilon, - const InDataTypePointerTuple p_in_global_tuple, - XDataType* const __restrict__ p_x_lds_, - const GammaDataType* const __restrict__ p_gamma_global, - const BetaDataType* const __restrict__ p_beta_global, - YDataType* const __restrict__ p_y_global, - const XElementwiseOperation x_elementwise_op, - const YElementwiseOperation y_elementwise_op) - { - if constexpr(SweepOnce) - { - num_k_block_tile_iteration = 1; - } - - const index_t thread_local_id = get_thread_local_1d_id(); - const index_t block_global_id = get_block_1d_id(); - const index_t grid_size = get_grid_size(); - - auto in_global_buf_tuple = generate_tuple( - [&](auto I) { - static_assert(in_grid_2d_desc_tuple[I].GetNumOfDimension() == - 2); // matrix dimension - - return make_dynamic_buffer( - p_in_global_tuple[I], in_grid_2d_desc_tuple[I].GetElementSpaceSize()); - }, - Number{}); - - auto y_global_val_buf = make_dynamic_buffer( - p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); - - auto x_lds_val_buf = make_dynamic_buffer( - p_x_lds_, x_grid_desc_m_k.GetElementSpaceSize() / grid_size); - - auto in_thread_buf_tuple = generate_tuple( - [&](auto) { - return generate_tuple( - [&](auto) { - return StaticBuffer{}; - }, - Number{}); - }, - Number{}); - - auto x_thread_buf = generate_tuple( - [&](auto) { - return StaticBuffer{}; - }, - Number{}); - - auto gamma_thread_buf = generate_tuple( - [&](auto) { - return StaticBuffer{}; - }, - Number{}); - - auto beta_thread_buf = generate_tuple( - [&](auto) { - return StaticBuffer{}; - }, - Number{}); - - auto y_thread_buf = generate_tuple( - [&](auto) { - return StaticBuffer{}; - }, - Number{}); - - StaticBuffer mean_thread_buf; - StaticBuffer var_thread_buf; - - const auto thread_cluster_idx = - thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); - - const auto thread_m_cluster_id = thread_cluster_idx[I0]; - const auto thread_k_cluster_id = thread_cluster_idx[I1]; - - using ThreadBufferLengths_M_K = Sequence; - - constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{})); - - auto in_global_load_tuple = generate_tuple( - [&](auto I) { - using DataTypePointer = remove_cvref_t; - using DataType = remove_cv_t>; - - return ThreadwiseTensorSliceTransfer_v2{ - in_grid_2d_desc_tuple[I], - make_multi_index(block_global_id * M_BlockTileSize + - thread_m_cluster_id * MThreadSliceSize, - thread_k_cluster_id * XSrcVectorSize)}; - }, - Number{}); - - auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( - x_grid_desc_m_k, - make_multi_index(thread_m_cluster_id * MThreadSliceSize, - thread_k_cluster_id * XSrcVectorSize)); - - auto threadwise_gamma_load = - ThreadwiseTensorSliceTransfer_v2( - gamma_grid_desc_m_k, - make_multi_index(block_global_id * M_BlockTileSize + - thread_m_cluster_id * MThreadSliceSize, - thread_k_cluster_id * GammaSrcVectorSize)); - - auto threadwise_beta_load = - ThreadwiseTensorSliceTransfer_v2( - beta_grid_desc_m_k, - make_multi_index(block_global_id * M_BlockTileSize + - thread_m_cluster_id * MThreadSliceSize, - thread_k_cluster_id * BetaSrcVectorSize)); - - using PassThrough = tensor_operation::element_wise::PassThrough; - PassThrough pass_through_op; - auto threadwise_x_store = - ThreadwiseTensorSliceTransfer_v1r3( - x_grid_desc_m_k, - make_multi_index(thread_m_cluster_id * MThreadSliceSize, - thread_k_cluster_id * XSrcVectorSize), - pass_through_op); - - auto threadwise_y_store = - ThreadwiseTensorSliceTransfer_v1r3( - y_grid_desc_m_k, - make_multi_index(block_global_id * M_BlockTileSize + - thread_m_cluster_id * MThreadSliceSize, - thread_k_cluster_id * YDstVectorSize), - y_elementwise_op); - - // Copy x from Cache - // one pass: fwd, second pass: bwd - constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize); - constexpr auto thread_copy_bwd_step_m_k = - make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize); - - const auto gamma_global_val_buf = make_dynamic_buffer( - p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize()); - - const auto beta_global_val_buf = make_dynamic_buffer( - p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize()); - - auto threadwise_welford = ThreadwiseWelford(); - threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k, thread_k_cluster_id); - - static_for<0, MThreadSliceSize, 1>{}([&](auto I) { - mean_thread_buf(I) = type_convert(0.0f); - var_thread_buf(I) = type_convert(0.0f); - }); - - for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) - { - static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) { - static_for<0, NumInput, 1>{}([&](auto I) { // input load loop - in_global_load_tuple(I).Run(in_grid_2d_desc_tuple[I], - in_global_buf_tuple[I], - thread_buffer_desc_m_k, - make_tuple(I0, I0), - in_thread_buf_tuple(iK0)(I)); - - in_global_load_tuple(I).MoveSrcSliceWindow(in_grid_2d_desc_tuple[I], - thread_copy_fwd_step_m_k); - }); - - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { // input add loop - static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { - constexpr auto offset_m_k = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); - - // get reference to in data - const auto in_data_refs = generate_tie( - // return type should be lvalue - [&](auto I) -> const auto& { - return in_thread_buf_tuple(iK0)(I)(Number{}); - }, - Number{}); - - // get reference to dst data - auto out_data_refs = generate_tie( - // return type should be lvalue - [&](auto) -> auto& { return x_thread_buf(iK0)(Number{}); }, - I1); - - unpack2(x_elementwise_op, out_data_refs, in_data_refs); - }); - }); - threadwise_welford.Run(x_thread_buf[iK0], mean_thread_buf, var_thread_buf); - - if constexpr(!SweepOnce) - { - threadwise_x_store.Run(thread_buffer_desc_m_k, - make_tuple(I0, I0), - x_thread_buf(iK0), - x_grid_desc_m_k, - x_lds_val_buf); - threadwise_x_store.MoveDstSliceWindow(x_grid_desc_m_k, - thread_copy_fwd_step_m_k); - } - }); - } - - static_for<0, MThreadSliceSize, 1>{}([&](auto I) { - if constexpr(I > 0) - block_sync_lds(); - - int count = threadwise_welford.cur_count_; - BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); - }); - - auto thread_copy_tail_m_k = - (num_k_block_tile_iteration - 1) * XThreadBufferNumber * thread_copy_fwd_step_m_k; - - if constexpr(!SweepOnce) - threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_tail_m_k); - threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k); - threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_tail_m_k); - threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k); - - for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) - { - if constexpr(!SweepOnce) - { - static_for<0, XThreadBufferNumber, 1>{}([&](auto i) { - threadwise_x_load.Run(x_grid_desc_m_k, - x_lds_val_buf, - thread_buffer_desc_m_k, - make_tuple(I0, I0), - x_thread_buf(i)); - threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); - }); - } - - static_for<0, GammaThreadBufferNumber, 1>{}([&](auto i) { - threadwise_gamma_load.Run(gamma_grid_desc_m_k, - gamma_global_val_buf, - thread_buffer_desc_m_k, - make_tuple(I0, I0), - gamma_thread_buf(i)); - threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, - thread_copy_fwd_step_m_k); - }); - - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon); - static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) { - static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { - constexpr auto offset_m_k = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); - - // normalize - y_thread_buf(iK0)(Number{}) = - (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * - divisor; - - // gamma - y_thread_buf(iK0)(Number{}) = - y_thread_buf(iK0)(Number{}) * - gamma_thread_buf(iK0)(Number{}); - }); - }); - }); - - static_for<0, BetaThreadBufferNumber, 1>{}([&](auto i) { - threadwise_beta_load.Run(beta_grid_desc_m_k, - beta_global_val_buf, - thread_buffer_desc_m_k, - make_tuple(I0, I0), - beta_thread_buf(i)); - threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, - thread_copy_fwd_step_m_k); - }); - - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) { - static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { - constexpr auto offset_m_k = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); - - // beta - y_thread_buf(iK0)(Number{}) = - y_thread_buf(iK0)(Number{}) + - beta_thread_buf(iK0)(Number{}); - }); - }); - }); - - static_for<0, YThreadBufferNumber, 1>{}([&](auto i) { - threadwise_y_store.Run(thread_buffer_desc_m_k, - make_tuple(I0, I0), - y_thread_buf(i), - y_grid_desc_m_k, - y_global_val_buf); - threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_fwd_step_m_k); - }); - - if constexpr(!SweepOnce) - threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k); - threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, - 2 * thread_copy_bwd_step_m_k); - threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, - 2 * thread_copy_bwd_step_m_k); - threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k); - } - } -}; - -} // namespace ck +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +// X = Elementwise(input1, input2, input3, ...) +// Y = Normalization(X, beta, gamma) +template +struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk +{ + static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static_assert((YDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) || + (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr index_t NumInput = InDataTypePointerTuple::Size(); + + static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0); + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using ThreadwiseWelford = + ThreadwiseWelford; + + using BlockwiseWelford = BlockwiseWelford; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + static constexpr index_t K_BlockTileStepSize = KThreadClusterSize * XSrcVectorSize; + + static constexpr auto XThreadBufferNumber = Number{}; + static constexpr auto GammaThreadBufferNumber = Number{}; + static constexpr auto BetaThreadBufferNumber = Number{}; + static constexpr auto YThreadBufferNumber = Number{}; + + __device__ static int GetKPerThread(const GridDesc_M_K& x_grid_desc_m_k, + int thread_k_cluster_id) + { + int kPerBlock = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]; + int kPerThread = + kPerBlock < K_BlockTileSize ? 0 : KThreadSliceSize * (kPerBlock / K_BlockTileSize); + int kPerBlockTail = kPerBlock - kPerThread * KThreadClusterSize; + + if(kPerBlockTail > 0) + { + static_for<0, XThreadBufferNumber, 1>{}([&](auto i) { + int thread_max_len = + (thread_k_cluster_id + 1) * XSrcVectorSize + K_BlockTileStepSize * i; + int delta = thread_max_len - kPerBlockTail; + delta = math::clamp(thread_max_len - kPerBlockTail, 0, XSrcVectorSize); + kPerThread += XSrcVectorSize - delta; + }); + } + + return kPerThread; + } + + __device__ static void Run(const InGrid2dDescTuple in_grid_2d_desc_tuple, + const GridDesc_M_K& x_grid_desc_m_k, + const GridDesc_M_K& gamma_grid_desc_m_k, + const GridDesc_M_K& beta_grid_desc_m_k, + const GridDesc_M_K& y_grid_desc_m_k, + index_t num_k_block_tile_iteration, + AccDataType epsilon, + const InDataTypePointerTuple p_in_global_tuple, + XDataType* const __restrict__ p_x_lds_, + const GammaDataType* const __restrict__ p_gamma_global, + const BetaDataType* const __restrict__ p_beta_global, + YDataType* const __restrict__ p_y_global, + const XElementwiseOperation x_elementwise_op, + const YElementwiseOperation y_elementwise_op) + { + if constexpr(SweepOnce) + { + num_k_block_tile_iteration = 1; + } + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + const index_t grid_size = get_grid_size(); + + auto in_global_buf_tuple = generate_tuple( + [&](auto I) { + static_assert(in_grid_2d_desc_tuple[I].GetNumOfDimension() == + 2); // matrix dimension + + return make_dynamic_buffer( + p_in_global_tuple[I], in_grid_2d_desc_tuple[I].GetElementSpaceSize()); + }, + Number{}); + + auto y_global_val_buf = make_dynamic_buffer( + p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); + + auto x_lds_val_buf = make_dynamic_buffer( + p_x_lds_, x_grid_desc_m_k.GetElementSpaceSize() / grid_size); + + auto in_thread_buf_tuple = generate_tuple( + [&](auto) { + return generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); + }, + Number{}); + + auto x_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); + + auto gamma_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); + + auto beta_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); + + auto y_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); + + StaticBuffer mean_thread_buf; + StaticBuffer var_thread_buf; + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + using ThreadBufferLengths_M_K = Sequence; + + constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + auto in_global_load_tuple = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_cv_t>; + + return ThreadwiseTensorSliceTransfer_v2{ + in_grid_2d_desc_tuple[I], + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * XSrcVectorSize)}; + }, + Number{}); + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index(thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * XSrcVectorSize)); + + auto threadwise_gamma_load = + ThreadwiseTensorSliceTransfer_v2( + gamma_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * GammaSrcVectorSize)); + + auto threadwise_beta_load = + ThreadwiseTensorSliceTransfer_v2( + beta_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * BetaSrcVectorSize)); + + using PassThrough = tensor_operation::element_wise::PassThrough; + PassThrough pass_through_op; + auto threadwise_x_store = + ThreadwiseTensorSliceTransfer_v1r3( + x_grid_desc_m_k, + make_multi_index(thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * XSrcVectorSize), + pass_through_op); + + auto threadwise_y_store = + ThreadwiseTensorSliceTransfer_v1r3( + y_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * YDstVectorSize), + y_elementwise_op); + + // Copy x from Cache + // one pass: fwd, second pass: bwd + constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize); + constexpr auto thread_copy_bwd_step_m_k = + make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize); + + const auto gamma_global_val_buf = make_dynamic_buffer( + p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize()); + + const auto beta_global_val_buf = make_dynamic_buffer( + p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize()); + + auto threadwise_welford = ThreadwiseWelford(); + threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k, thread_k_cluster_id); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + mean_thread_buf(I) = type_convert(0.0f); + var_thread_buf(I) = type_convert(0.0f); + }); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, NumInput, 1>{}([&](auto I) { // input load loop + in_global_load_tuple(I).Run(in_grid_2d_desc_tuple[I], + in_global_buf_tuple[I], + thread_buffer_desc_m_k, + make_tuple(I0, I0), + in_thread_buf_tuple(iK0)(I)); + + in_global_load_tuple(I).MoveSrcSliceWindow(in_grid_2d_desc_tuple[I], + thread_copy_fwd_step_m_k); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { // input add loop + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + + // get reference to in data + const auto in_data_refs = generate_tie( + // return type should be lvalue + [&](auto I) -> const auto& { + return in_thread_buf_tuple(iK0)(I)(Number{}); + }, + Number{}); + + // get reference to dst data + auto out_data_refs = generate_tie( + // return type should be lvalue + [&](auto) -> auto& { return x_thread_buf(iK0)(Number{}); }, + I1); + + unpack2(x_elementwise_op, out_data_refs, in_data_refs); + }); + }); + threadwise_welford.Run(x_thread_buf[iK0], mean_thread_buf, var_thread_buf); + + if constexpr(!SweepOnce) + { + threadwise_x_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf(iK0), + x_grid_desc_m_k, + x_lds_val_buf); + threadwise_x_store.MoveDstSliceWindow(x_grid_desc_m_k, + thread_copy_fwd_step_m_k); + } + }); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + int count = threadwise_welford.cur_count_; + BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); + }); + + auto thread_copy_tail_m_k = + (num_k_block_tile_iteration - 1) * XThreadBufferNumber * thread_copy_fwd_step_m_k; + + if constexpr(!SweepOnce) + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_tail_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_tail_m_k); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + if constexpr(!SweepOnce) + { + static_for<0, XThreadBufferNumber, 1>{}([&](auto i) { + threadwise_x_load.Run(x_grid_desc_m_k, + x_lds_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf(i)); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + }); + } + + static_for<0, GammaThreadBufferNumber, 1>{}([&](auto i) { + threadwise_gamma_load.Run(gamma_grid_desc_m_k, + gamma_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + gamma_thread_buf(i)); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon); + static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + + // normalize + y_thread_buf(iK0)(Number{}) = + (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * + divisor; + + // gamma + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) * + gamma_thread_buf(iK0)(Number{}); + }); + }); + }); + + static_for<0, BetaThreadBufferNumber, 1>{}([&](auto i) { + threadwise_beta_load.Run(beta_grid_desc_m_k, + beta_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + beta_thread_buf(i)); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + + // beta + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) + + beta_thread_buf(iK0)(Number{}); + }); + }); + }); + + static_for<0, YThreadBufferNumber, 1>{}([&](auto i) { + threadwise_y_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + y_thread_buf(i), + y_grid_desc_m_k, + y_global_val_buf); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_fwd_step_m_k); + }); + + if constexpr(!SweepOnce) + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + 2 * thread_copy_bwd_step_m_k); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, + 2 * thread_copy_bwd_step_m_k); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp index c8b154228f..42d973388b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp index 0a4691b509..64f50d13df 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp index 5f6f2768eb..c486b12423 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp index 5e779b2881..a08c9cfa8b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp index ff534b0777..4a4cb153ae 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp index ad28a12e57..c3858b967a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp index 5411137bf4..64f04a64c4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp index 7d68d64ed8..1a29907c25 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index 1d9b7eb978..3523da6c46 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp index 1e72e78349..a9adde1da5 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp index 872e1271e1..6307f649df 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp index f8de0a48e5..8d45b8fd74 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp index 0cdb7ce2ca..1262029f21 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp index 25e1cebdbe..1641d71c50 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v3.hpp index ced62241cd..d318d91303 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v3.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp index 08d986d0da..b6ff93ca53 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp index d3479fe0b4..6ce04e858b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp index 8f7aac0171..12d07ca23a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp index 9d34f9e2a4..5af9d97c9f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp index de5a424198..60c4bb5fda 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp index 65f74de3cf..647fa107e3 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp index 5d8bbca79d..6629be2511 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -470,9 +470,9 @@ struct GridwiseGemm_wmma_cshuffle_v3 DsGridPointer p_ds_grid; EDataType* p_e_grid; - const AElementwiseOperation a_element_op; - const BElementwiseOperation b_element_op; - const CDEElementwiseOperation cde_element_op; + AElementwiseOperation a_element_op; + BElementwiseOperation b_element_op; + CDEElementwiseOperation cde_element_op; // TODO: it can be used with SplitK+reduction but currently only used with SplitK+atomicAdd bool is_reduce; @@ -555,13 +555,17 @@ struct GridwiseGemm_wmma_cshuffle_v3 template + typename Block2CTileMap, + typename EpilogueArgument, + int BlockMapMBlockIndex = 0, + int BlockMapNBlockIndex = 1> __device__ static void Run(AsGridPointer& p_as_grid, BsGridPointer& p_bs_grid, DsGridPointer& p_ds_grid, EDataType* p_e_grid, void* p_shared, const Problem& problem, + const Block2CTileMap& block_2_ctile_map, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, @@ -582,9 +586,6 @@ struct GridwiseGemm_wmma_cshuffle_v3 MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( e_grid_desc_m_n, problem.MBlock, problem.NBlock); - // divide block work by [M, N] - const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; - const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); @@ -596,8 +597,10 @@ struct GridwiseGemm_wmma_cshuffle_v3 return; } - const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); - const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + const index_t block_m_id = + __builtin_amdgcn_readfirstlane(block_work_idx[Number{}]); + const index_t block_n_id = + __builtin_amdgcn_readfirstlane(block_work_idx[Number{}]); // BScale struct (Empty) using BScale = typename BlockwiseGemmPipe::Empty; @@ -632,15 +635,51 @@ struct GridwiseGemm_wmma_cshuffle_v3 epilogue_args); } + template + __device__ static void Run(AsGridPointer& p_as_grid, + BsGridPointer& p_bs_grid, + DsGridPointer& p_ds_grid, + EDataType* p_e_grid, + void* p_shared, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + EpilogueArgument& epilogue_args) + { + Run(p_as_grid, + p_bs_grid, + p_ds_grid, + p_e_grid, + p_shared, + problem, + DefaultBlock2CTileMap(problem), + a_element_op, + b_element_op, + cde_element_op, + epilogue_args); + } + // Wrapper function to have __global__ function in common // between gemm_universal, b_scale, ab_scale, etc. template + typename Block2CTileMap, + typename EpilogueArgument, + int BlockMapMBlockIndex = 0, + int BlockMapNBlockIndex = 1> __device__ static void Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg, + const Block2CTileMap& block_2_ctile_map, EpilogueArgument& epilogue_args) { // shift A matrices pointer for splitk @@ -659,17 +698,47 @@ struct GridwiseGemm_wmma_cshuffle_v3 splitk_batch_offset.b_k_split_offset[i]; }); - Run( - p_as_grid_splitk, - p_bs_grid_splitk, - karg.p_ds_grid, - karg.p_e_grid + splitk_batch_offset.c_reduce_offset, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.cde_element_op, - epilogue_args); + Run(p_as_grid_splitk, + p_bs_grid_splitk, + karg.p_ds_grid, + karg.p_e_grid + splitk_batch_offset.c_reduce_offset, + p_shared, + karg, + block_2_ctile_map, + karg.a_element_op, + karg.b_element_op, + karg.cde_element_op, + epilogue_args); + } + + // Wrapper function to have __global__ function in common + // between gemm_universal, b_scale, ab_scale, etc. + template + __device__ static void Run(void* p_shared, + const SplitKBatchOffset& splitk_batch_offset, + Argument& karg, + EpilogueArgument& epilogue_args) + { + Run( + p_shared, splitk_batch_offset, karg, DefaultBlock2CTileMap(karg), epilogue_args); + } + + __device__ static auto DefaultBlock2CTileMap(const Problem& problem) + { + return Block2CTileMap{problem.M, problem.N, 4}; } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp index ca4646a1c1..2b2cf1ca34 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index 020d0110cf..4de3a35b3e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -729,6 +729,13 @@ struct GridwiseGemm_wmma_cshuffle_v3_base auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value too low for combination of AK1/BK1/KBatch. AK1: " + << AK1Number << ", BK1: " << BK1Number << ", KBatch: " << karg.KBatch + << ", K: " << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp index 60ad4651b6..9339916d6f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp index 3e91f120b1..3b5610865c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp index 4cd1a587e9..716fe6f41d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp index ccba4d4a94..34f0a97586 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp index 11b75a6541..93f4059f0a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp index 6ce2f63e3a..258ab40b9d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp index 36141bc96f..777a622e2c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp index 35c8c6c3b4..4e46d52496 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp index f775f99a65..e7e24b148a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp index 8119cace3b..bbaafb6b56 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp index f2f1530599..5fe77f2b71 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp index bf7ae1c6e8..b02d9c7f77 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp index 1a356c372d..a41f096abb 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp index 7c5bd606b2..cf8c718273 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp index e4152e0427..2e562f6538 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp index 67f18de12f..2903b219b4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index 422b9afa61..6fd6529fbb 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp index cf3040d1ae..b19c5d8fb6 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp index c5f60a7413..05de26b2ae 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp index a040409a6d..d4ea7a2149 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp index d2418c0913..bc5e31013e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp index a9a463e2c1..3c1e1ca27b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp index d8f22b682d..5869caf19b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp index 7d5a8da60f..8c9b7492cc 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp index 7c559d1f85..ec3e7abf21 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp index 83f8773a08..f7570adfc1 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index b9c0d671db..97f6dc3084 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp index 5f5e24fb9f..0a565bf17e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp index 9066decc0a..bfe13987ac 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp index 6854e64092..34c65b4626 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp index 1c471fb873..2075f7e8d5 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp index fa9b5fb2ce..7d0147ff4c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp index c41eef8c45..701925dc8f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp index 41352fabeb..98a8266b88 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_set_multiple_buffer_value.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_set_multiple_buffer_value.hpp index 0ad36b418a..1264b7bf04 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_set_multiple_buffer_value.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_set_multiple_buffer_value.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp index 5f56ac6fc4..96e13ac55c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp index 287b4e5421..5036f4ae7f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm_builtins.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm_builtins.hpp index 7c3e372765..a97114d48d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm_builtins.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm_builtins.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp index 9761cc6a68..bc47f31a7f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp index e399499cc8..e9e801cec5 100644 --- a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp index 21248e3a0a..8230735745 100644 --- a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp index 1bee1b93f9..6c42dc33f4 100644 --- a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp index 8157b4fbc3..b980dd2e12 100644 --- a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp index 80e9a84f96..f76a6c600f 100644 --- a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp index 9e380f9638..9deb5a5f48 100644 --- a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp index 15b412fba4..d57871f331 100644 --- a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp b/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp index c6eecc067d..f4a2bc399d 100644 --- a/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp +++ b/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp index 44730d551c..06aca9c922 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp index e97aa433a6..2896375636 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_THREADWISE_GEMM_DLOPS_V3_HPP #define CK_THREADWISE_GEMM_DLOPS_V3_HPP diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp index 6774a35bcb..fddaf2b648 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 5da9722a4b..afd1e67bd0 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp index 168f028e2a..5035fe23d0 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp @@ -1,7 +1,6 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once - namespace ck { // Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp index 4a6ed62c0e..610d03ca10 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp index 8af6a2148b..6eb4b21e21 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp index 8574fd055c..2077eeebd7 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp index 9383e3f829..56ae553f2f 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp index 6a6c1f2ac5..74a964ddd8 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp index 4e9c188115..bce2d453dc 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp index 644877d393..2e255e2500 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp index 88ed217547..43d4148dab 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp index cf2c7a2aee..f036bc4312 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp index b5847e51b4..7d53c1ac0d 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp index db7dee2199..f7949da594 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp index 4b277e4383..87cecc7574 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp index 7e9870bf91..058cf58ec9 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp index 65e63993a6..fe975f4e36 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp index eb6715e8eb..3b510742a6 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp b/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp index 409bb9f674..dd81b074a5 100644 --- a/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp index a436afd395..0a7b8d91d6 100644 --- a/include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp index 55ede990af..1b60e6cb53 100644 --- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 0817cf9856..67712be483 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp index ea27a40ce3..ae21d5fa87 100644 --- a/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp b/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp index 56181d38c8..9fb81012ea 100644 --- a/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp index b989d63e0e..03dc0efeb5 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp index efc7f20cdc..266ffb5fae 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp @@ -1,6 +1,5 @@ - +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp index e410f06190..96482b1412 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp index d8d33a9d5e..ab2821d989 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp @@ -1,6 +1,5 @@ - +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp index 0f28fe8169..945fbbc8a9 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/amd_address_space.hpp b/include/ck/utility/amd_address_space.hpp index d54f70e750..00687da226 100644 --- a/include/ck/utility/amd_address_space.hpp +++ b/include/ck/utility/amd_address_space.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index 783fc661ce..f9404e00b7 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "data_type.hpp" diff --git a/include/ck/utility/amd_buffer_addressing_builtins.hpp b/include/ck/utility/amd_buffer_addressing_builtins.hpp index f642e06050..cddb8b7e5c 100644 --- a/include/ck/utility/amd_buffer_addressing_builtins.hpp +++ b/include/ck/utility/amd_buffer_addressing_builtins.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "data_type.hpp" diff --git a/include/ck/utility/amd_ck_fp8.hpp b/include/ck/utility/amd_ck_fp8.hpp index c5525d5ff8..05d688b59d 100644 --- a/include/ck/utility/amd_ck_fp8.hpp +++ b/include/ck/utility/amd_ck_fp8.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/amd_gemm_dpp.hpp b/include/ck/utility/amd_gemm_dpp.hpp index a28292dade..428a9fe915 100644 --- a/include/ck/utility/amd_gemm_dpp.hpp +++ b/include/ck/utility/amd_gemm_dpp.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/amd_inline_asm.hpp b/include/ck/utility/amd_inline_asm.hpp index 79efd77edb..6c080b4654 100644 --- a/include/ck/utility/amd_inline_asm.hpp +++ b/include/ck/utility/amd_inline_asm.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_AMD_INLINE_ASM_HPP #define CK_AMD_INLINE_ASM_HPP diff --git a/include/ck/utility/amd_lds.hpp b/include/ck/utility/amd_lds.hpp index c218fded96..fa78cfb95a 100644 --- a/include/ck/utility/amd_lds.hpp +++ b/include/ck/utility/amd_lds.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/amd_smfmac.hpp b/include/ck/utility/amd_smfmac.hpp index 8b6b094ff2..4a09163038 100644 --- a/include/ck/utility/amd_smfmac.hpp +++ b/include/ck/utility/amd_smfmac.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #pragma once diff --git a/include/ck/utility/amd_transpose_load.hpp b/include/ck/utility/amd_transpose_load.hpp index 6ef17b18da..17de26a401 100644 --- a/include/ck/utility/amd_transpose_load.hpp +++ b/include/ck/utility/amd_transpose_load.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "data_type.hpp" diff --git a/include/ck/utility/amd_wave_read_first_lane.hpp b/include/ck/utility/amd_wave_read_first_lane.hpp index 3604712837..44259f0601 100644 --- a/include/ck/utility/amd_wave_read_first_lane.hpp +++ b/include/ck/utility/amd_wave_read_first_lane.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/amd_wmma.hpp b/include/ck/utility/amd_wmma.hpp index 09a462d016..35389bda37 100644 --- a/include/ck/utility/amd_wmma.hpp +++ b/include/ck/utility/amd_wmma.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_AMD_WMMA_HPP #define CK_AMD_WMMA_HPP diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index b9d171dbea..f8b1736801 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/utility/dtype_fp64.hpp" diff --git a/include/ck/utility/array.hpp b/include/ck/utility/array.hpp index 2afad00d49..2b249884b6 100644 --- a/include/ck/utility/array.hpp +++ b/include/ck/utility/array.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_ARRAY_HPP #define CK_ARRAY_HPP diff --git a/include/ck/utility/array_multi_index.hpp b/include/ck/utility/array_multi_index.hpp index c0c1ea65fc..32925a6946 100644 --- a/include/ck/utility/array_multi_index.hpp +++ b/include/ck/utility/array_multi_index.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_ARRAY_MULTI_INDEX_HPP #define CK_ARRAY_MULTI_INDEX_HPP diff --git a/include/ck/utility/blkgemmpipe_scheduler.hpp b/include/ck/utility/blkgemmpipe_scheduler.hpp index 63466a36f2..3b285123ba 100644 --- a/include/ck/utility/blkgemmpipe_scheduler.hpp +++ b/include/ck/utility/blkgemmpipe_scheduler.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/c_style_pointer_cast.hpp b/include/ck/utility/c_style_pointer_cast.hpp index 610e393a77..b7426123cb 100644 --- a/include/ck/utility/c_style_pointer_cast.hpp +++ b/include/ck/utility/c_style_pointer_cast.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_C_STYLE_POINTER_CAST_HPP #define CK_C_STYLE_POINTER_CAST_HPP diff --git a/include/ck/utility/common_header.hpp b/include/ck/utility/common_header.hpp index 69420a6465..78c3b78de1 100644 --- a/include/ck/utility/common_header.hpp +++ b/include/ck/utility/common_header.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/container_element_picker.hpp b/include/ck/utility/container_element_picker.hpp index 838147e420..9de2466e71 100644 --- a/include/ck/utility/container_element_picker.hpp +++ b/include/ck/utility/container_element_picker.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_CONTAINER_ELEMENT_PICKER_HPP #define CK_CONTAINER_ELEMENT_PICKER_HPP diff --git a/include/ck/utility/container_helper.hpp b/include/ck/utility/container_helper.hpp index d6524283db..8f2fe45796 100644 --- a/include/ck/utility/container_helper.hpp +++ b/include/ck/utility/container_helper.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_CONTAINER_HELPER_HPP #define CK_CONTAINER_HELPER_HPP diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 574269b94a..8e6f875c39 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include diff --git a/include/ck/utility/debug.hpp b/include/ck/utility/debug.hpp index 1b86b33777..c96026b217 100644 --- a/include/ck/utility/debug.hpp +++ b/include/ck/utility/debug.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #ifndef UTILITY_DEBUG_HPP #define UTILITY_DEBUG_HPP diff --git a/include/ck/utility/dtype_fp64.hpp b/include/ck/utility/dtype_fp64.hpp index 3c63d083ad..e854cff260 100644 --- a/include/ck/utility/dtype_fp64.hpp +++ b/include/ck/utility/dtype_fp64.hpp @@ -1,5 +1,6 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + namespace ck { // fp64 using double2_t = typename vector_type::type; diff --git a/include/ck/utility/dtype_vector.hpp b/include/ck/utility/dtype_vector.hpp index 084240f84b..ebdbbb107d 100644 --- a/include/ck/utility/dtype_vector.hpp +++ b/include/ck/utility/dtype_vector.hpp @@ -1,5 +1,6 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// // // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + #pragma once #include "ck/utility/data_type.hpp" diff --git a/include/ck/utility/dynamic_buffer.hpp b/include/ck/utility/dynamic_buffer.hpp index 66166e11e3..4e477eed26 100644 --- a/include/ck/utility/dynamic_buffer.hpp +++ b/include/ck/utility/dynamic_buffer.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/e8m0.hpp b/include/ck/utility/e8m0.hpp index ac2a114593..391727371a 100644 --- a/include/ck/utility/e8m0.hpp +++ b/include/ck/utility/e8m0.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/enable_if.hpp b/include/ck/utility/enable_if.hpp index 9d5403ceb2..8d1a5ebe3b 100644 --- a/include/ck/utility/enable_if.hpp +++ b/include/ck/utility/enable_if.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/env.hpp b/include/ck/utility/env.hpp index 2f5b804d16..0cb0b4caf8 100644 --- a/include/ck/utility/env.hpp +++ b/include/ck/utility/env.hpp @@ -1,7 +1,6 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once - #ifndef CK_CODE_GEN_RTC #include diff --git a/include/ck/utility/f8_utils.hpp b/include/ck/utility/f8_utils.hpp index 94c2f84c8c..1aa9e182b0 100644 --- a/include/ck/utility/f8_utils.hpp +++ b/include/ck/utility/f8_utils.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/filter_tuple.hpp b/include/ck/utility/filter_tuple.hpp index c2e378b879..2d9c79dbe6 100644 --- a/include/ck/utility/filter_tuple.hpp +++ b/include/ck/utility/filter_tuple.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/flush_icache.hpp b/include/ck/utility/flush_icache.hpp index 7378ba5c26..ca62e7a175 100644 --- a/include/ck/utility/flush_icache.hpp +++ b/include/ck/utility/flush_icache.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/functional.hpp b/include/ck/utility/functional.hpp index cd48ed1747..b7f4243a96 100644 --- a/include/ck/utility/functional.hpp +++ b/include/ck/utility/functional.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/functional2.hpp b/include/ck/utility/functional2.hpp index ef8b5a435c..888fedc0a9 100644 --- a/include/ck/utility/functional2.hpp +++ b/include/ck/utility/functional2.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/functional3.hpp b/include/ck/utility/functional3.hpp index 97605a7ade..ffec810760 100644 --- a/include/ck/utility/functional3.hpp +++ b/include/ck/utility/functional3.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/functional4.hpp b/include/ck/utility/functional4.hpp index 8e86a296dc..16673743b3 100644 --- a/include/ck/utility/functional4.hpp +++ b/include/ck/utility/functional4.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_FUNCTIONAL4_HPP #define CK_FUNCTIONAL4_HPP diff --git a/include/ck/utility/generic_memory_space_atomic.hpp b/include/ck/utility/generic_memory_space_atomic.hpp index 011491ffc6..210b354504 100644 --- a/include/ck/utility/generic_memory_space_atomic.hpp +++ b/include/ck/utility/generic_memory_space_atomic.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "data_type.hpp" diff --git a/include/ck/utility/get_id.hpp b/include/ck/utility/get_id.hpp index c96a6c3aef..6995d01238 100644 --- a/include/ck/utility/get_id.hpp +++ b/include/ck/utility/get_id.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/get_shift.hpp b/include/ck/utility/get_shift.hpp index 0a93081cfd..87361c1b48 100644 --- a/include/ck/utility/get_shift.hpp +++ b/include/ck/utility/get_shift.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/ignore.hpp b/include/ck/utility/ignore.hpp index f70a182fd4..b93b04af37 100644 --- a/include/ck/utility/ignore.hpp +++ b/include/ck/utility/ignore.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/inner_product.hpp b/include/ck/utility/inner_product.hpp index 65efaf388a..22789c4b71 100644 --- a/include/ck/utility/inner_product.hpp +++ b/include/ck/utility/inner_product.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "data_type.hpp" diff --git a/include/ck/utility/inner_product_dpp8.hpp b/include/ck/utility/inner_product_dpp8.hpp index f079e2ca64..7fc06ead37 100644 --- a/include/ck/utility/inner_product_dpp8.hpp +++ b/include/ck/utility/inner_product_dpp8.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/integral_constant.hpp b/include/ck/utility/integral_constant.hpp index a7fa64d710..83015f8880 100644 --- a/include/ck/utility/integral_constant.hpp +++ b/include/ck/utility/integral_constant.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/is_detected.hpp b/include/ck/utility/is_detected.hpp index 8cb37b68b2..41e43678e9 100644 --- a/include/ck/utility/is_detected.hpp +++ b/include/ck/utility/is_detected.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/is_known_at_compile_time.hpp b/include/ck/utility/is_known_at_compile_time.hpp index 0916e4604e..e5c48008c7 100644 --- a/include/ck/utility/is_known_at_compile_time.hpp +++ b/include/ck/utility/is_known_at_compile_time.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/loop_scheduler.hpp b/include/ck/utility/loop_scheduler.hpp index cbbce85007..f186d0fea9 100644 --- a/include/ck/utility/loop_scheduler.hpp +++ b/include/ck/utility/loop_scheduler.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/magic_division.hpp b/include/ck/utility/magic_division.hpp index 7227cee754..8fc32246aa 100644 --- a/include/ck/utility/magic_division.hpp +++ b/include/ck/utility/magic_division.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/math.hpp b/include/ck/utility/math.hpp index 7efbb3e63a..b2ebf4b371 100644 --- a/include/ck/utility/math.hpp +++ b/include/ck/utility/math.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/math_v2.hpp b/include/ck/utility/math_v2.hpp index e235f51c93..f11c98974a 100644 --- a/include/ck/utility/math_v2.hpp +++ b/include/ck/utility/math_v2.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/multi_index.hpp b/include/ck/utility/multi_index.hpp index 9f7ba8bff6..98bae78fb1 100644 --- a/include/ck/utility/multi_index.hpp +++ b/include/ck/utility/multi_index.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/mxf4_utils.hpp b/include/ck/utility/mxf4_utils.hpp index 53edb6e182..da1f3f5df7 100644 --- a/include/ck/utility/mxf4_utils.hpp +++ b/include/ck/utility/mxf4_utils.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_CODE_GEN_RTC #pragma once diff --git a/include/ck/utility/mxf6_utils.hpp b/include/ck/utility/mxf6_utils.hpp index a840c520a9..2ae42de63b 100644 --- a/include/ck/utility/mxf6_utils.hpp +++ b/include/ck/utility/mxf6_utils.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_CODE_GEN_RTC #pragma once diff --git a/include/ck/utility/mxf8_utils.hpp b/include/ck/utility/mxf8_utils.hpp index 565e1b27dc..81cf64fd60 100644 --- a/include/ck/utility/mxf8_utils.hpp +++ b/include/ck/utility/mxf8_utils.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/numeric_limits.hpp" #include "ck/utility/mxfp_utils.hpp" diff --git a/include/ck/utility/mxfp_utils.hpp b/include/ck/utility/mxfp_utils.hpp index ebed85f5fd..533712d6c5 100644 --- a/include/ck/utility/mxfp_utils.hpp +++ b/include/ck/utility/mxfp_utils.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/number.hpp b/include/ck/utility/number.hpp index d29afd31a7..a8e2dcbddb 100644 --- a/include/ck/utility/number.hpp +++ b/include/ck/utility/number.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_NUMBER_HPP #define CK_NUMBER_HPP diff --git a/include/ck/utility/numeric_limits.hpp b/include/ck/utility/numeric_limits.hpp index b8d6280acc..42a0ff33aa 100644 --- a/include/ck/utility/numeric_limits.hpp +++ b/include/ck/utility/numeric_limits.hpp @@ -1,5 +1,6 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + #pragma once #include "ck/utility/data_type.hpp" diff --git a/include/ck/utility/numeric_utils.hpp b/include/ck/utility/numeric_utils.hpp index ab84bd765f..2bbf27b1c9 100644 --- a/include/ck/utility/numeric_utils.hpp +++ b/include/ck/utility/numeric_utils.hpp @@ -1,5 +1,6 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + #pragma once #include "ck/utility/data_type.hpp" diff --git a/include/ck/utility/random_gen.hpp b/include/ck/utility/random_gen.hpp index dd2662b6d9..fb9aee6e14 100644 --- a/include/ck/utility/random_gen.hpp +++ b/include/ck/utility/random_gen.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include diff --git a/include/ck/utility/reduction_common.hpp b/include/ck/utility/reduction_common.hpp index 75fdd85825..849888a678 100644 --- a/include/ck/utility/reduction_common.hpp +++ b/include/ck/utility/reduction_common.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/reduction_enums.hpp b/include/ck/utility/reduction_enums.hpp index 23b7149f8e..623dac32ea 100644 --- a/include/ck/utility/reduction_enums.hpp +++ b/include/ck/utility/reduction_enums.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/reduction_functions_accumulate.hpp b/include/ck/utility/reduction_functions_accumulate.hpp index b9765ff0d2..f0912c65ee 100644 --- a/include/ck/utility/reduction_functions_accumulate.hpp +++ b/include/ck/utility/reduction_functions_accumulate.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/reduction_operator.hpp b/include/ck/utility/reduction_operator.hpp index c472c3f409..d281b53dc7 100644 --- a/include/ck/utility/reduction_operator.hpp +++ b/include/ck/utility/reduction_operator.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/scaled_type_convert.hpp b/include/ck/utility/scaled_type_convert.hpp index 7de84d974c..6ed04b8c17 100644 --- a/include/ck/utility/scaled_type_convert.hpp +++ b/include/ck/utility/scaled_type_convert.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/sequence.hpp b/include/ck/utility/sequence.hpp index 75f0c92c58..9f97d44a4a 100644 --- a/include/ck/utility/sequence.hpp +++ b/include/ck/utility/sequence.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/sequence_helper.hpp b/include/ck/utility/sequence_helper.hpp index 8c493a2822..35a6a48632 100644 --- a/include/ck/utility/sequence_helper.hpp +++ b/include/ck/utility/sequence_helper.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/span.hpp b/include/ck/utility/span.hpp index 5e7567a847..c0e68c95f4 100644 --- a/include/ck/utility/span.hpp +++ b/include/ck/utility/span.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/static_buffer.hpp b/include/ck/utility/static_buffer.hpp index 602e76abdb..d49817eb8f 100644 --- a/include/ck/utility/static_buffer.hpp +++ b/include/ck/utility/static_buffer.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/statically_indexed_array.hpp b/include/ck/utility/statically_indexed_array.hpp index a2d70045a4..d0735a32f6 100644 --- a/include/ck/utility/statically_indexed_array.hpp +++ b/include/ck/utility/statically_indexed_array.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_STATICALLY_INDEXED_ARRAY_HPP #define CK_STATICALLY_INDEXED_ARRAY_HPP diff --git a/include/ck/utility/statically_indexed_array_multi_index.hpp b/include/ck/utility/statically_indexed_array_multi_index.hpp index fd11e5a150..381c70db61 100644 --- a/include/ck/utility/statically_indexed_array_multi_index.hpp +++ b/include/ck/utility/statically_indexed_array_multi_index.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP #define CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP diff --git a/include/ck/utility/synchronization.hpp b/include/ck/utility/synchronization.hpp index 54391e7e86..3bd07bb59a 100644 --- a/include/ck/utility/synchronization.hpp +++ b/include/ck/utility/synchronization.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/thread_group.hpp b/include/ck/utility/thread_group.hpp index 1cd6b2f3ce..3ea834066a 100644 --- a/include/ck/utility/thread_group.hpp +++ b/include/ck/utility/thread_group.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/transpose_vectors.hpp b/include/ck/utility/transpose_vectors.hpp index e73ec03de4..de20674ef2 100644 --- a/include/ck/utility/transpose_vectors.hpp +++ b/include/ck/utility/transpose_vectors.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/tuple.hpp b/include/ck/utility/tuple.hpp index 4bd2f08944..78931407d8 100644 --- a/include/ck/utility/tuple.hpp +++ b/include/ck/utility/tuple.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/tuple_helper.hpp b/include/ck/utility/tuple_helper.hpp index ec055fb2a2..294d7e7c7d 100644 --- a/include/ck/utility/tuple_helper.hpp +++ b/include/ck/utility/tuple_helper.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/type.hpp b/include/ck/utility/type.hpp index bde9c179ce..74e07bd580 100644 --- a/include/ck/utility/type.hpp +++ b/include/ck/utility/type.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 701b2686c7..b3e399609e 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/workgroup_barrier.hpp b/include/ck/utility/workgroup_barrier.hpp index ec9151fd1b..0e440799be 100644 --- a/include/ck/utility/workgroup_barrier.hpp +++ b/include/ck/utility/workgroup_barrier.hpp @@ -1,3 +1,6 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + #pragma once #include #include diff --git a/include/ck/utility/workgroup_synchronization.hpp b/include/ck/utility/workgroup_synchronization.hpp index af5b0808fb..37b9b96863 100644 --- a/include/ck/utility/workgroup_synchronization.hpp +++ b/include/ck/utility/workgroup_synchronization.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/host_utility/hip_check_error.hpp" diff --git a/include/ck/wrapper/layout.hpp b/include/ck/wrapper/layout.hpp index 5cd1f614e6..334d5851db 100644 --- a/include/ck/wrapper/layout.hpp +++ b/include/ck/wrapper/layout.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/wrapper/operations/copy.hpp b/include/ck/wrapper/operations/copy.hpp index e8a919fdda..57d6832efb 100644 --- a/include/ck/wrapper/operations/copy.hpp +++ b/include/ck/wrapper/operations/copy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/wrapper/operations/gemm.hpp b/include/ck/wrapper/operations/gemm.hpp index 42a70239ad..d328ac7d42 100644 --- a/include/ck/wrapper/operations/gemm.hpp +++ b/include/ck/wrapper/operations/gemm.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/wrapper/tensor.hpp b/include/ck/wrapper/tensor.hpp index 26cfcaa2f0..9f8278a357 100644 --- a/include/ck/wrapper/tensor.hpp +++ b/include/ck/wrapper/tensor.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp b/include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp index e5d848c404..de9e864a74 100644 --- a/include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp +++ b/include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/wrapper/utils/kernel_utils.hpp b/include/ck/wrapper/utils/kernel_utils.hpp index e5a31f6aa4..9f0cbd4d42 100644 --- a/include/ck/wrapper/utils/kernel_utils.hpp +++ b/include/ck/wrapper/utils/kernel_utils.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/wrapper/utils/layout_utils.hpp b/include/ck/wrapper/utils/layout_utils.hpp index 296ae6a2e8..8dd111b872 100644 --- a/include/ck/wrapper/utils/layout_utils.hpp +++ b/include/ck/wrapper/utils/layout_utils.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/wrapper/utils/tensor_partition.hpp b/include/ck/wrapper/utils/tensor_partition.hpp index 69fd502d63..5099f35cda 100644 --- a/include/ck/wrapper/utils/tensor_partition.hpp +++ b/include/ck/wrapper/utils/tensor_partition.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/wrapper/utils/tensor_utils.hpp b/include/ck/wrapper/utils/tensor_utils.hpp index ccab99fac3..11694753c5 100644 --- a/include/ck/wrapper/utils/tensor_utils.hpp +++ b/include/ck/wrapper/utils/tensor_utils.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/algorithm/cluster_descriptor.hpp b/include/ck_tile/core/algorithm/cluster_descriptor.hpp index c59a7c1fa1..0d3fee04ec 100644 --- a/include/ck_tile/core/algorithm/cluster_descriptor.hpp +++ b/include/ck_tile/core/algorithm/cluster_descriptor.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/algorithm/coordinate_transform.hpp b/include/ck_tile/core/algorithm/coordinate_transform.hpp index 7511413bba..81eea60c2f 100644 --- a/include/ck_tile/core/algorithm/coordinate_transform.hpp +++ b/include/ck_tile/core/algorithm/coordinate_transform.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/algorithm/indexing_adaptor.hpp b/include/ck_tile/core/algorithm/indexing_adaptor.hpp index ef59abdc99..c5d9434630 100644 --- a/include/ck_tile/core/algorithm/indexing_adaptor.hpp +++ b/include/ck_tile/core/algorithm/indexing_adaptor.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/algorithm/space_filling_curve.hpp b/include/ck_tile/core/algorithm/space_filling_curve.hpp index 648a1251be..1838b143db 100644 --- a/include/ck_tile/core/algorithm/space_filling_curve.hpp +++ b/include/ck_tile/core/algorithm/space_filling_curve.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/algorithm/static_encoding_pattern.hpp b/include/ck_tile/core/algorithm/static_encoding_pattern.hpp index c6a1ee0155..94f8075c7f 100644 --- a/include/ck_tile/core/algorithm/static_encoding_pattern.hpp +++ b/include/ck_tile/core/algorithm/static_encoding_pattern.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. /** * @file diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 90137331f6..ba9201135c 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index 3e4f6f35be..6d7de749c9 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/arch/amd_transpose_load_encoding.hpp b/include/ck_tile/core/arch/amd_transpose_load_encoding.hpp index 665be1b167..db1d097b24 100644 --- a/include/ck_tile/core/arch/amd_transpose_load_encoding.hpp +++ b/include/ck_tile/core/arch/amd_transpose_load_encoding.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 70338e1185..a162195390 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -1,5 +1,5 @@ -// Copyright © Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT #pragma once diff --git a/include/ck_tile/core/arch/generic_memory_space_atomic.hpp b/include/ck_tile/core/arch/generic_memory_space_atomic.hpp index e56bcadcba..2fcd76c5e7 100644 --- a/include/ck_tile/core/arch/generic_memory_space_atomic.hpp +++ b/include/ck_tile/core/arch/generic_memory_space_atomic.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core/numeric/vector_type.hpp" @@ -102,6 +102,9 @@ CK_TILE_DEVICE void atomic_add(X* p_dst, const X& x); template <> CK_TILE_DEVICE void atomic_add(bf16x2_t* p_dst, const bf16x2_t& x) { +#if HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN + __builtin_amdgcn_global_atomic_fadd_v2bf16(c_style_pointer_cast(p_dst), x); +#else union U32BF162_ADDR { uint32_t* u32_a; @@ -128,6 +131,7 @@ CK_TILE_DEVICE void atomic_add(bf16x2_t* p_dst, const bf16x2_t& x) new_v = new_.u32; cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v); } while(cur_v.u32 != old_v); +#endif } template <> diff --git a/include/ck_tile/core/arch/mma/amdgcn_mma.hpp b/include/ck_tile/core/arch/mma/amdgcn_mma.hpp index 88cf189667..4c9ef7d6ba 100644 --- a/include/ck_tile/core/arch/mma/amdgcn_mma.hpp +++ b/include/ck_tile/core/arch/mma/amdgcn_mma.hpp @@ -1,5 +1,5 @@ -// Copyright © Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT #pragma once diff --git a/include/ck_tile/core/arch/mma/mfma/mfma.hpp b/include/ck_tile/core/arch/mma/mfma/mfma.hpp index 34c3b11d2f..55817b5ba8 100644 --- a/include/ck_tile/core/arch/mma/mfma/mfma.hpp +++ b/include/ck_tile/core/arch/mma/mfma/mfma.hpp @@ -1,5 +1,5 @@ -// Copyright © Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT #pragma once diff --git a/include/ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp b/include/ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp index 94e429d385..c7375c5e12 100644 --- a/include/ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp +++ b/include/ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp @@ -1,5 +1,5 @@ -// Copyright © Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT #pragma once diff --git a/include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp b/include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp index b45da8a509..5c87419d0c 100644 --- a/include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp +++ b/include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp @@ -1,5 +1,5 @@ -// Copyright © Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT #pragma once diff --git a/include/ck_tile/core/arch/mma/mfma/mfma_traits.hpp b/include/ck_tile/core/arch/mma/mfma/mfma_traits.hpp index b023118ab0..170e06f08c 100644 --- a/include/ck_tile/core/arch/mma/mfma/mfma_traits.hpp +++ b/include/ck_tile/core/arch/mma/mfma/mfma_traits.hpp @@ -1,5 +1,5 @@ -// Copyright © Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT #pragma once diff --git a/include/ck_tile/core/arch/mma/mfma/mfma_transforms.hpp b/include/ck_tile/core/arch/mma/mfma/mfma_transforms.hpp index 589e6e049c..b20d243618 100644 --- a/include/ck_tile/core/arch/mma/mfma/mfma_transforms.hpp +++ b/include/ck_tile/core/arch/mma/mfma/mfma_transforms.hpp @@ -1,5 +1,5 @@ -// Copyright © Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT #pragma once diff --git a/include/ck_tile/core/arch/mma/mma.hpp b/include/ck_tile/core/arch/mma/mma.hpp index 032261eb52..2a5de37550 100644 --- a/include/ck_tile/core/arch/mma/mma.hpp +++ b/include/ck_tile/core/arch/mma/mma.hpp @@ -1,7 +1,6 @@ -// Copyright © Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT #pragma once - #include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/numeric/vector_type.hpp" diff --git a/include/ck_tile/core/arch/mma/mma_selector.hpp b/include/ck_tile/core/arch/mma/mma_selector.hpp index b2845e9bb2..070189373d 100644 --- a/include/ck_tile/core/arch/mma/mma_selector.hpp +++ b/include/ck_tile/core/arch/mma/mma_selector.hpp @@ -1,5 +1,5 @@ -// Copyright © Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT #pragma once namespace ck_tile::core::arch::mma { diff --git a/include/ck_tile/core/arch/mma/mma_traits.hpp b/include/ck_tile/core/arch/mma/mma_traits.hpp index 29b7e106cb..8a9092f9cb 100644 --- a/include/ck_tile/core/arch/mma/mma_traits.hpp +++ b/include/ck_tile/core/arch/mma/mma_traits.hpp @@ -1,7 +1,6 @@ -// Copyright © Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT #pragma once - #include "amdgcn_mma.hpp" #include "mfma/mfma_traits.hpp" #include "wmma/wmma_traits.hpp" diff --git a/include/ck_tile/core/arch/mma/mma_transforms.hpp b/include/ck_tile/core/arch/mma/mma_transforms.hpp index bbb0050084..4131daaced 100644 --- a/include/ck_tile/core/arch/mma/mma_transforms.hpp +++ b/include/ck_tile/core/arch/mma/mma_transforms.hpp @@ -1,7 +1,6 @@ -// Copyright © Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT #pragma once - namespace ck_tile::core::arch::mma { /** diff --git a/include/ck_tile/core/arch/mma/wmma/wmma.hpp b/include/ck_tile/core/arch/mma/wmma/wmma.hpp index 8f79478b38..ae5269dcb8 100644 --- a/include/ck_tile/core/arch/mma/wmma/wmma.hpp +++ b/include/ck_tile/core/arch/mma/wmma/wmma.hpp @@ -1,5 +1,5 @@ -// Copyright © Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT #pragma once diff --git a/include/ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp b/include/ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp index 355fe6c957..58fe51abb7 100644 --- a/include/ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp +++ b/include/ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp @@ -1,5 +1,5 @@ -// Copyright © Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT #pragma once diff --git a/include/ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp b/include/ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp index c41224b995..71e518d4a3 100644 --- a/include/ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp +++ b/include/ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp @@ -1,5 +1,5 @@ -// Copyright © Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT #pragma once diff --git a/include/ck_tile/core/arch/mma/wmma/wmma_selector.hpp b/include/ck_tile/core/arch/mma/wmma/wmma_selector.hpp index 401d672126..e758ad9a5f 100644 --- a/include/ck_tile/core/arch/mma/wmma/wmma_selector.hpp +++ b/include/ck_tile/core/arch/mma/wmma/wmma_selector.hpp @@ -1,5 +1,5 @@ -// Copyright © Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT #pragma once diff --git a/include/ck_tile/core/arch/mma/wmma/wmma_traits.hpp b/include/ck_tile/core/arch/mma/wmma/wmma_traits.hpp index 9e2e42a9d7..69913e67f3 100644 --- a/include/ck_tile/core/arch/mma/wmma/wmma_traits.hpp +++ b/include/ck_tile/core/arch/mma/wmma/wmma_traits.hpp @@ -1,5 +1,5 @@ -// Copyright © Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT #pragma once diff --git a/include/ck_tile/core/arch/mma/wmma/wmma_transforms.hpp b/include/ck_tile/core/arch/mma/wmma/wmma_transforms.hpp index 2877e8f1f8..eb87c38e87 100644 --- a/include/ck_tile/core/arch/mma/wmma/wmma_transforms.hpp +++ b/include/ck_tile/core/arch/mma/wmma/wmma_transforms.hpp @@ -1,5 +1,5 @@ -// Copyright © Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT #pragma once diff --git a/include/ck_tile/core/arch/utility.hpp b/include/ck_tile/core/arch/utility.hpp index 93008f8525..647f5b4435 100644 --- a/include/ck_tile/core/arch/utility.hpp +++ b/include/ck_tile/core/arch/utility.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/arch/workgroup_barrier.hpp b/include/ck_tile/core/arch/workgroup_barrier.hpp index 827a490fcb..8b18dac06b 100644 --- a/include/ck_tile/core/arch/workgroup_barrier.hpp +++ b/include/ck_tile/core/arch/workgroup_barrier.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index 91e6134ac8..de97b46336 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -21,7 +21,7 @@ #endif #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \ defined(__gfx1103__) || defined(__gfx1150__) || defined(__gfx1151__) || \ - defined(__gfx1152__) || defined(__gfx11_generic__) + defined(__gfx1152__) || defined(__gfx1153__) || defined(__gfx11_generic__) #define __gfx11__ #endif #if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__) diff --git a/include/ck_tile/core/container/array.hpp b/include/ck_tile/core/container/array.hpp index 352c645325..8b273b691b 100644 --- a/include/ck_tile/core/container/array.hpp +++ b/include/ck_tile/core/container/array.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/container/container_helper.hpp b/include/ck_tile/core/container/container_helper.hpp index 1a631bd95e..90579c0034 100644 --- a/include/ck_tile/core/container/container_helper.hpp +++ b/include/ck_tile/core/container/container_helper.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/container/map.hpp b/include/ck_tile/core/container/map.hpp index 7697995c92..d342235b38 100644 --- a/include/ck_tile/core/container/map.hpp +++ b/include/ck_tile/core/container/map.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/container/meta_data_buffer.hpp b/include/ck_tile/core/container/meta_data_buffer.hpp index eba60fac75..e8663d466e 100644 --- a/include/ck_tile/core/container/meta_data_buffer.hpp +++ b/include/ck_tile/core/container/meta_data_buffer.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/container/multi_index.hpp b/include/ck_tile/core/container/multi_index.hpp index 921c590df8..667589a828 100644 --- a/include/ck_tile/core/container/multi_index.hpp +++ b/include/ck_tile/core/container/multi_index.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/container/sequence.hpp b/include/ck_tile/core/container/sequence.hpp index 1a88a98cbf..44b120cd5e 100644 --- a/include/ck_tile/core/container/sequence.hpp +++ b/include/ck_tile/core/container/sequence.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/container/span.hpp b/include/ck_tile/core/container/span.hpp index eeb1f226a9..4cce87eb6f 100644 --- a/include/ck_tile/core/container/span.hpp +++ b/include/ck_tile/core/container/span.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/container/statically_indexed_array.hpp b/include/ck_tile/core/container/statically_indexed_array.hpp index d6da50b627..d35934ab04 100644 --- a/include/ck_tile/core/container/statically_indexed_array.hpp +++ b/include/ck_tile/core/container/statically_indexed_array.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/container/thread_buffer.hpp b/include/ck_tile/core/container/thread_buffer.hpp index d67581e7d2..8785a301fb 100644 --- a/include/ck_tile/core/container/thread_buffer.hpp +++ b/include/ck_tile/core/container/thread_buffer.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/container/tuple.hpp b/include/ck_tile/core/container/tuple.hpp index 4c48b3d477..7f8176d5ec 100644 --- a/include/ck_tile/core/container/tuple.hpp +++ b/include/ck_tile/core/container/tuple.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/numeric/bfloat16.hpp b/include/ck_tile/core/numeric/bfloat16.hpp index 5caee28e2e..e193c58915 100644 --- a/include/ck_tile/core/numeric/bfloat16.hpp +++ b/include/ck_tile/core/numeric/bfloat16.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck_tile/core/config.hpp" #include "ck_tile/core/utility/bit_cast.hpp" diff --git a/include/ck_tile/core/numeric/e8m0.hpp b/include/ck_tile/core/numeric/e8m0.hpp index ba122b7f66..41aeb8ffab 100644 --- a/include/ck_tile/core/numeric/e8m0.hpp +++ b/include/ck_tile/core/numeric/e8m0.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/numeric/float8.hpp b/include/ck_tile/core/numeric/float8.hpp index 890e507894..ba0a1c48a6 100644 --- a/include/ck_tile/core/numeric/float8.hpp +++ b/include/ck_tile/core/numeric/float8.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck_tile/core/config.hpp" #include "ck_tile/core/utility/bit_cast.hpp" diff --git a/include/ck_tile/core/numeric/half.hpp b/include/ck_tile/core/numeric/half.hpp index 128befe90f..b6a7e86d3c 100644 --- a/include/ck_tile/core/numeric/half.hpp +++ b/include/ck_tile/core/numeric/half.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck_tile/core/config.hpp" #include "ck_tile/core/utility/bit_cast.hpp" diff --git a/include/ck_tile/core/numeric/int8.hpp b/include/ck_tile/core/numeric/int8.hpp index 34d9a1c4b9..aa9f820c17 100644 --- a/include/ck_tile/core/numeric/int8.hpp +++ b/include/ck_tile/core/numeric/int8.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck_tile/core/config.hpp" #include "ck_tile/core/numeric/half.hpp" diff --git a/include/ck_tile/core/numeric/integer.hpp b/include/ck_tile/core/numeric/integer.hpp index 502026c231..da7201995c 100644 --- a/include/ck_tile/core/numeric/integer.hpp +++ b/include/ck_tile/core/numeric/integer.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include diff --git a/include/ck_tile/core/numeric/integral_constant.hpp b/include/ck_tile/core/numeric/integral_constant.hpp index c22fad07f4..f5cd7d2d84 100644 --- a/include/ck_tile/core/numeric/integral_constant.hpp +++ b/include/ck_tile/core/numeric/integral_constant.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/numeric/math.hpp b/include/ck_tile/core/numeric/math.hpp index b8a31ba8fc..57f3953514 100644 --- a/include/ck_tile/core/numeric/math.hpp +++ b/include/ck_tile/core/numeric/math.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/numeric/mxfp_convert.hpp b/include/ck_tile/core/numeric/mxfp_convert.hpp index 9b378933d0..b3632ef4d6 100644 --- a/include/ck_tile/core/numeric/mxfp_convert.hpp +++ b/include/ck_tile/core/numeric/mxfp_convert.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/numeric/null_type.hpp b/include/ck_tile/core/numeric/null_type.hpp index 8799c0560e..20f0964588 100644 --- a/include/ck_tile/core/numeric/null_type.hpp +++ b/include/ck_tile/core/numeric/null_type.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include diff --git a/include/ck_tile/core/numeric/numeric.hpp b/include/ck_tile/core/numeric/numeric.hpp index 6b61e3f99c..b2bd628685 100644 --- a/include/ck_tile/core/numeric/numeric.hpp +++ b/include/ck_tile/core/numeric/numeric.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/numeric/pk_fp4.hpp b/include/ck_tile/core/numeric/pk_fp4.hpp index 4f662095db..cc23ce71a8 100644 --- a/include/ck_tile/core/numeric/pk_fp4.hpp +++ b/include/ck_tile/core/numeric/pk_fp4.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/numeric/pk_int4.hpp b/include/ck_tile/core/numeric/pk_int4.hpp index fc1caf13ff..13a43f8b5c 100644 --- a/include/ck_tile/core/numeric/pk_int4.hpp +++ b/include/ck_tile/core/numeric/pk_int4.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck_tile/core/config.hpp" #include "ck_tile/core/numeric/half.hpp" diff --git a/include/ck_tile/core/numeric/type_convert.hpp b/include/ck_tile/core/numeric/type_convert.hpp index 3fee3ef96c..deaa9e0bd9 100644 --- a/include/ck_tile/core/numeric/type_convert.hpp +++ b/include/ck_tile/core/numeric/type_convert.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/numeric/vector_type.hpp b/include/ck_tile/core/numeric/vector_type.hpp index 23786c41a1..6921210b34 100644 --- a/include/ck_tile/core/numeric/vector_type.hpp +++ b/include/ck_tile/core/numeric/vector_type.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index 3729a0de5c..f3aeed6e61 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index 6b6cad299a..af0f81e832 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/tensor/load_tile_transpose.hpp b/include/ck_tile/core/tensor/load_tile_transpose.hpp index fb645f89e9..0ac2ded5f6 100644 --- a/include/ck_tile/core/tensor/load_tile_transpose.hpp +++ b/include/ck_tile/core/tensor/load_tile_transpose.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/tensor/null_tensor.hpp b/include/ck_tile/core/tensor/null_tensor.hpp index 565ff87dff..f5cabbef5a 100644 --- a/include/ck_tile/core/tensor/null_tensor.hpp +++ b/include/ck_tile/core/tensor/null_tensor.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/tensor/null_tile_window.hpp b/include/ck_tile/core/tensor/null_tile_window.hpp index f7eca73afb..73997d8ad7 100644 --- a/include/ck_tile/core/tensor/null_tile_window.hpp +++ b/include/ck_tile/core/tensor/null_tile_window.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/tensor/shuffle_tile.hpp b/include/ck_tile/core/tensor/shuffle_tile.hpp index 84c2b7d2fa..a48d4ca5ad 100644 --- a/include/ck_tile/core/tensor/shuffle_tile.hpp +++ b/include/ck_tile/core/tensor/shuffle_tile.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/tensor/slice_tile.hpp b/include/ck_tile/core/tensor/slice_tile.hpp index 3b696d8cc8..d68da1a98d 100644 --- a/include/ck_tile/core/tensor/slice_tile.hpp +++ b/include/ck_tile/core/tensor/slice_tile.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/tensor/static_distributed_tensor.hpp b/include/ck_tile/core/tensor/static_distributed_tensor.hpp index 5228ad978a..ac8b5eccab 100644 --- a/include/ck_tile/core/tensor/static_distributed_tensor.hpp +++ b/include/ck_tile/core/tensor/static_distributed_tensor.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/tensor/store_tile.hpp b/include/ck_tile/core/tensor/store_tile.hpp index b535b40534..78fdb9c071 100644 --- a/include/ck_tile/core/tensor/store_tile.hpp +++ b/include/ck_tile/core/tensor/store_tile.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/tensor/sweep_tile.hpp b/include/ck_tile/core/tensor/sweep_tile.hpp index 6ee1fa54f4..ffd48686b6 100644 --- a/include/ck_tile/core/tensor/sweep_tile.hpp +++ b/include/ck_tile/core/tensor/sweep_tile.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/tensor/tensor_adaptor.hpp b/include/ck_tile/core/tensor/tensor_adaptor.hpp index eb226debfd..0edf246927 100644 --- a/include/ck_tile/core/tensor/tensor_adaptor.hpp +++ b/include/ck_tile/core/tensor/tensor_adaptor.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp b/include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp index d7b9a466ef..2ea76a3814 100644 --- a/include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp +++ b/include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/tensor/tensor_coordinate.hpp b/include/ck_tile/core/tensor/tensor_coordinate.hpp index a51da9f844..866fb30d20 100644 --- a/include/ck_tile/core/tensor/tensor_coordinate.hpp +++ b/include/ck_tile/core/tensor/tensor_coordinate.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/tensor/tensor_descriptor.hpp b/include/ck_tile/core/tensor/tensor_descriptor.hpp index 8ee87ff9d0..57b7f75775 100644 --- a/include/ck_tile/core/tensor/tensor_descriptor.hpp +++ b/include/ck_tile/core/tensor/tensor_descriptor.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/tensor/tensor_view.hpp b/include/ck_tile/core/tensor/tensor_view.hpp index 3cdc4ff1cf..837f2b87a6 100644 --- a/include/ck_tile/core/tensor/tensor_view.hpp +++ b/include/ck_tile/core/tensor/tensor_view.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/tensor/tile_distribution.hpp b/include/ck_tile/core/tensor/tile_distribution.hpp index 52a5281cbe..426bb84e9c 100644 --- a/include/ck_tile/core/tensor/tile_distribution.hpp +++ b/include/ck_tile/core/tensor/tile_distribution.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/tensor/tile_distribution_encoding.hpp b/include/ck_tile/core/tensor/tile_distribution_encoding.hpp index 90d1a2ccb2..77c4b965aa 100644 --- a/include/ck_tile/core/tensor/tile_distribution_encoding.hpp +++ b/include/ck_tile/core/tensor/tile_distribution_encoding.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/tensor/tile_elementwise.hpp b/include/ck_tile/core/tensor/tile_elementwise.hpp index 4ab4c78884..bc6d7d2f5a 100644 --- a/include/ck_tile/core/tensor/tile_elementwise.hpp +++ b/include/ck_tile/core/tensor/tile_elementwise.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -360,10 +360,12 @@ CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor) (SrcTensor::get_thread_buffer_size() % 2 == 0)) return impl::cast_tile_pkrtz_fp16_fp32(src_tensor); #endif +#if 0 // currently it causes extra spills in qr_async_vr pipeline of fmha_fwd else if constexpr((std::is_same_v || std::is_same_v) && std::is_same_v && (SrcTensor::get_thread_buffer_size() % 2 == 0)) return impl::cast_tile_pk_fp16bf16_fp32(src_tensor); +#endif #if CK_TILE_USE_SUBDWORD_TILE_CAST else if constexpr(sizeof(DstType) < 4 || sizeof(typename SrcTensor::DataType) < 4) return impl::cast_tile_opt_subdword(src_tensor); diff --git a/include/ck_tile/core/tensor/tile_scatter_gather.hpp b/include/ck_tile/core/tensor/tile_scatter_gather.hpp index e77ca805bb..97a44f38e8 100644 --- a/include/ck_tile/core/tensor/tile_scatter_gather.hpp +++ b/include/ck_tile/core/tensor/tile_scatter_gather.hpp @@ -1,6 +1,5 @@ - +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 89a0cc0f53..e80267faec 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -293,6 +293,15 @@ struct tile_window_with_static_distribution 0, dst_tensor, number{}, bool_constant{}); } + template + CK_TILE_DEVICE constexpr auto get_load_offset(offset_t = {}) const + { + constexpr auto bottom_tensor_idx_off = to_multi_index(offset_t{}); + const auto bottom_tensor_coord_off = make_tensor_coordinate( + this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_idx_off); + return amd_wave_read_first_lane(bottom_tensor_coord_off.get_offset()); + } + template ) return offset_t::value; else - { - auto bottom_tensor_idx_off = to_multi_index(offset_t{}); - auto bottom_tensor_coord_off = make_tensor_coordinate( - this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_idx_off); - return bottom_tensor_coord_off.get_offset(); - } + return get_load_offset(offset_t{}); }(); // loop over thread tensor space [y0, y1, ...] static_for<0, NumCoord, 1>{}([&](auto iCoord) { diff --git a/include/ck_tile/core/tensor/tile_window_base.hpp b/include/ck_tile/core/tensor/tile_window_base.hpp index 89a928a53c..2f5eaf40b1 100644 --- a/include/ck_tile/core/tensor/tile_window_base.hpp +++ b/include/ck_tile/core/tensor/tile_window_base.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/tensor/tile_window_linear.hpp b/include/ck_tile/core/tensor/tile_window_linear.hpp index b5a89e5f51..815c1bf158 100644 --- a/include/ck_tile/core/tensor/tile_window_linear.hpp +++ b/include/ck_tile/core/tensor/tile_window_linear.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core/arch/arch.hpp" diff --git a/include/ck_tile/core/tensor/tile_window_utils.hpp b/include/ck_tile/core/tensor/tile_window_utils.hpp index f8b232a7af..7a05d30574 100644 --- a/include/ck_tile/core/tensor/tile_window_utils.hpp +++ b/include/ck_tile/core/tensor/tile_window_utils.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/utility.hpp" diff --git a/include/ck_tile/core/tensor/transpose_tile.hpp b/include/ck_tile/core/tensor/transpose_tile.hpp index d917cd5bac..e5a0664ec9 100644 --- a/include/ck_tile/core/tensor/transpose_tile.hpp +++ b/include/ck_tile/core/tensor/transpose_tile.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/tensor/update_tile.hpp b/include/ck_tile/core/tensor/update_tile.hpp index 570abde189..90e2939d4d 100644 --- a/include/ck_tile/core/tensor/update_tile.hpp +++ b/include/ck_tile/core/tensor/update_tile.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/utility/bit_cast.hpp b/include/ck_tile/core/utility/bit_cast.hpp index 2cb91b7d47..a943665554 100644 --- a/include/ck_tile/core/utility/bit_cast.hpp +++ b/include/ck_tile/core/utility/bit_cast.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/utility/debug.hpp b/include/ck_tile/core/utility/debug.hpp index 581b095383..74675d44cf 100644 --- a/include/ck_tile/core/utility/debug.hpp +++ b/include/ck_tile/core/utility/debug.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include diff --git a/include/ck_tile/core/utility/env.hpp b/include/ck_tile/core/utility/env.hpp index 9b148b3e0b..45de2880a0 100644 --- a/include/ck_tile/core/utility/env.hpp +++ b/include/ck_tile/core/utility/env.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/utility/functional.hpp b/include/ck_tile/core/utility/functional.hpp index fd0252d3ca..90740dcbe3 100644 --- a/include/ck_tile/core/utility/functional.hpp +++ b/include/ck_tile/core/utility/functional.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/utility/functional_with_tuple.hpp b/include/ck_tile/core/utility/functional_with_tuple.hpp index 4b40403190..0324ef3eb8 100644 --- a/include/ck_tile/core/utility/functional_with_tuple.hpp +++ b/include/ck_tile/core/utility/functional_with_tuple.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/utility/gemm_validation.hpp b/include/ck_tile/core/utility/gemm_validation.hpp index bacf7e1e2b..735b617674 100644 --- a/include/ck_tile/core/utility/gemm_validation.hpp +++ b/include/ck_tile/core/utility/gemm_validation.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/utility/ignore.hpp b/include/ck_tile/core/utility/ignore.hpp index b15a19aa2e..1ba26f450c 100644 --- a/include/ck_tile/core/utility/ignore.hpp +++ b/include/ck_tile/core/utility/ignore.hpp @@ -1,5 +1,5 @@ -// Copyright © Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT #pragma once diff --git a/include/ck_tile/core/utility/literals.hpp b/include/ck_tile/core/utility/literals.hpp index 6f64f09f40..317d47fc6f 100644 --- a/include/ck_tile/core/utility/literals.hpp +++ b/include/ck_tile/core/utility/literals.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/utility/magic_div.hpp b/include/ck_tile/core/utility/magic_div.hpp index 1715983c09..342432e80c 100644 --- a/include/ck_tile/core/utility/magic_div.hpp +++ b/include/ck_tile/core/utility/magic_div.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/utility/philox_rand.hpp b/include/ck_tile/core/utility/philox_rand.hpp index 52b1489543..04b333a696 100644 --- a/include/ck_tile/core/utility/philox_rand.hpp +++ b/include/ck_tile/core/utility/philox_rand.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/utility/print.hpp b/include/ck_tile/core/utility/print.hpp index b7279a1ef2..adadca379e 100644 --- a/include/ck_tile/core/utility/print.hpp +++ b/include/ck_tile/core/utility/print.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/utility/random.hpp b/include/ck_tile/core/utility/random.hpp index 6a38ad3bde..2e5771b519 100644 --- a/include/ck_tile/core/utility/random.hpp +++ b/include/ck_tile/core/utility/random.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/utility/reduce_operator.hpp b/include/ck_tile/core/utility/reduce_operator.hpp index 69449711e0..2820c53101 100644 --- a/include/ck_tile/core/utility/reduce_operator.hpp +++ b/include/ck_tile/core/utility/reduce_operator.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/utility/reduce_operator_accumulate.hpp b/include/ck_tile/core/utility/reduce_operator_accumulate.hpp index b49ff41ee0..a989138428 100644 --- a/include/ck_tile/core/utility/reduce_operator_accumulate.hpp +++ b/include/ck_tile/core/utility/reduce_operator_accumulate.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/utility/static_counter.hpp b/include/ck_tile/core/utility/static_counter.hpp index 4828e2e010..e165270624 100644 --- a/include/ck_tile/core/utility/static_counter.hpp +++ b/include/ck_tile/core/utility/static_counter.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/utility/to_sequence.hpp b/include/ck_tile/core/utility/to_sequence.hpp index 2276ab68b7..6490eb11df 100644 --- a/include/ck_tile/core/utility/to_sequence.hpp +++ b/include/ck_tile/core/utility/to_sequence.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core/container/sequence.hpp" diff --git a/include/ck_tile/core/utility/transpose_vectors.hpp b/include/ck_tile/core/utility/transpose_vectors.hpp index f24b976b4c..5b96c38ef1 100644 --- a/include/ck_tile/core/utility/transpose_vectors.hpp +++ b/include/ck_tile/core/utility/transpose_vectors.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/utility/type_traits.hpp b/include/ck_tile/core/utility/type_traits.hpp index 5ed49b7249..f07e25e19c 100644 --- a/include/ck_tile/core/utility/type_traits.hpp +++ b/include/ck_tile/core/utility/type_traits.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/core/utility/unary_element_function.hpp b/include/ck_tile/core/utility/unary_element_function.hpp index 6bd6e33bd3..b195275bdc 100644 --- a/include/ck_tile/core/utility/unary_element_function.hpp +++ b/include/ck_tile/core/utility/unary_element_function.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/arg_parser.hpp b/include/ck_tile/host/arg_parser.hpp index df309f312a..8c45d2b175 100644 --- a/include/ck_tile/host/arg_parser.hpp +++ b/include/ck_tile/host/arg_parser.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/check_err.hpp b/include/ck_tile/host/check_err.hpp index 1ef6b040eb..ac388992d1 100644 --- a/include/ck_tile/host/check_err.hpp +++ b/include/ck_tile/host/check_err.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/concat.hpp b/include/ck_tile/host/concat.hpp index e9ba9a7d7b..7c19e274d0 100644 --- a/include/ck_tile/host/concat.hpp +++ b/include/ck_tile/host/concat.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/convolution_host_tensor_descriptor_helper.hpp b/include/ck_tile/host/convolution_host_tensor_descriptor_helper.hpp index 33a85b0d4b..ffd5bcc3d9 100644 --- a/include/ck_tile/host/convolution_host_tensor_descriptor_helper.hpp +++ b/include/ck_tile/host/convolution_host_tensor_descriptor_helper.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/convolution_parameter.hpp b/include/ck_tile/host/convolution_parameter.hpp index 81ea51a94f..e54ed70745 100644 --- a/include/ck_tile/host/convolution_parameter.hpp +++ b/include/ck_tile/host/convolution_parameter.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/device_memory.hpp b/include/ck_tile/host/device_memory.hpp index 587f38987e..a4d883bc10 100644 --- a/include/ck_tile/host/device_memory.hpp +++ b/include/ck_tile/host/device_memory.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/device_prop.hpp b/include/ck_tile/host/device_prop.hpp index f86e4b889a..2d7dc7dd18 100644 --- a/include/ck_tile/host/device_prop.hpp +++ b/include/ck_tile/host/device_prop.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -57,7 +57,7 @@ inline bool is_gfx11_supported() return get_device_name() == "gfx1100" || get_device_name() == "gfx1101" || get_device_name() == "gfx1102" || get_device_name() == "gfx1103" || get_device_name() == "gfx1150" || get_device_name() == "gfx1151" || - get_device_name() == "gfx1152"; + get_device_name() == "gfx1152" || get_device_name() == "gfx1153"; } inline bool is_gfx12_supported() diff --git a/include/ck_tile/host/fill.hpp b/include/ck_tile/host/fill.hpp index 817a46a8ea..12f43ebc5e 100644 --- a/include/ck_tile/host/fill.hpp +++ b/include/ck_tile/host/fill.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/flush_icache.hpp b/include/ck_tile/host/flush_icache.hpp index f4852252be..4eb1405fe2 100644 --- a/include/ck_tile/host/flush_icache.hpp +++ b/include/ck_tile/host/flush_icache.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/hip_check_error.hpp b/include/ck_tile/host/hip_check_error.hpp index 3acdb4d874..fe731799ff 100644 --- a/include/ck_tile/host/hip_check_error.hpp +++ b/include/ck_tile/host/hip_check_error.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/host_tensor.hpp b/include/ck_tile/host/host_tensor.hpp index 59510c8b93..d26686ec37 100644 --- a/include/ck_tile/host/host_tensor.hpp +++ b/include/ck_tile/host/host_tensor.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -408,7 +408,6 @@ struct HostTensor return sizeof(T) * get_element_space_size(); } - // void SetZero() { ck_tile::ranges::fill(mData, 0); } void SetZero() { if constexpr(std::is_same_v) diff --git a/include/ck_tile/host/joinable_thread.hpp b/include/ck_tile/host/joinable_thread.hpp index a42b567fb4..bf84858ee2 100644 --- a/include/ck_tile/host/joinable_thread.hpp +++ b/include/ck_tile/host/joinable_thread.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/kernel_launch.hpp b/include/ck_tile/host/kernel_launch.hpp index be38e92b1a..ac9e00b668 100644 --- a/include/ck_tile/host/kernel_launch.hpp +++ b/include/ck_tile/host/kernel_launch.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/permute_pk_int4.hpp b/include/ck_tile/host/permute_pk_int4.hpp index b770edddca..45a571b248 100644 --- a/include/ck_tile/host/permute_pk_int4.hpp +++ b/include/ck_tile/host/permute_pk_int4.hpp @@ -1,7 +1,6 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c), Advanced Micro Devices, Inc. All rights reserved. #pragma once - #include "ck_tile/core/utility/bit_cast.hpp" namespace ck_tile { diff --git a/include/ck_tile/host/ranges.hpp b/include/ck_tile/host/ranges.hpp index f6dbf5bdaa..528d056b61 100644 --- a/include/ck_tile/host/ranges.hpp +++ b/include/ck_tile/host/ranges.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/reference/reference_batched_contraction.hpp b/include/ck_tile/host/reference/reference_batched_contraction.hpp index a86accc778..cc42d77d43 100644 --- a/include/ck_tile/host/reference/reference_batched_contraction.hpp +++ b/include/ck_tile/host/reference/reference_batched_contraction.hpp @@ -1,11 +1,9 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include -#include -#include #include #include "ck_tile/core.hpp" @@ -13,110 +11,259 @@ namespace ck_tile { +// Helper to apply elementwise operation with variable number of D tensors +template +struct ApplyCDEElementWise +{ + template + CK_TILE_HOST_DEVICE static void apply(EDataType& result, + AccDataType sum, + const CDEElementWise& cde_elementwise, + DValues... d_vals) + { + if constexpr(sizeof...(DValues) == 0) + { + result = static_cast(sum); + } + else + { + cde_elementwise( + result, ck_tile::type_convert(sum), ck_tile::type_convert(d_vals)...); + } + } +}; + +// Helper to extract D values at a given offset using index sequence +template > +struct ExtractDValues; + +template +struct ExtractDValues> +{ + template + CK_TILE_HOST static void + apply_at_offsets(EDataType& result, + AccDataType sum, + const CDEElementWise& cde_elementwise, + const std::array, NumDTensor>& ds_tensors, + const std::array& d_offsets) + { + ApplyCDEElementWise::apply( + result, sum, cde_elementwise, ds_tensors[Is].mData[d_offsets[Is]]...); + } +}; + template + typename CDEElementWise, + ck_tile::index_t NumDTensor> -void calculate_reference_flat_indexing( +void compute_reference_batched_contraction( const ck_tile::HostTensor& a_full_dims, const ck_tile::HostTensor& b_full_dims, - const std::vector>& ds_full_dims_host, + const std::array, NumDTensor>& ds_full_dims_host, ck_tile::HostTensor& e_full_dims_host_ref, ck_tile::index_t G_total, ck_tile::index_t M_total, ck_tile::index_t N_total, ck_tile::index_t K_total, - const CDEElementWise& cde_elementwise) + const CDEElementWise& cde_elementwise, + const std::vector& G_dims, + const std::vector& M_dims, + const std::vector& N_dims, + const std::vector& K_dims) { - std::cout << "Calculating reference using optimized flat indexing with parallel processing..." + std::cout << "Calculating reference using stride-aware indexing with parallel processing..." << std::endl; - // Parallel computation over G and M dimensions using pattern from reference_batched_gemm.hpp + // Extract stride information from tensor descriptors + const auto a_strides = a_full_dims.get_strides(); + const auto b_strides = b_full_dims.get_strides(); + const auto e_strides = e_full_dims_host_ref.get_strides(); + + // Extract D tensor strides + std::array, NumDTensor> ds_strides; + for(ck_tile::index_t d = 0; d < NumDTensor; ++d) + { + ds_strides[d] = ds_full_dims_host[d].get_strides(); + } + + const ck_tile::index_t num_g_dims = G_dims.size(); + const ck_tile::index_t num_m_dims = M_dims.size(); + const ck_tile::index_t num_n_dims = N_dims.size(); + const ck_tile::index_t num_k_dims = K_dims.size(); + + // Helper lambda to compute linear index from flat indices using strides + auto compute_a_offset = [&](ck_tile::index_t g_flat, + ck_tile::index_t m_flat, + ck_tile::index_t k_flat) -> std::size_t { + std::size_t offset = 0; + + // Decode G dimensions + ck_tile::index_t temp = g_flat; + for(int i = num_g_dims - 1; i >= 0; --i) + { + offset += (temp % G_dims[i]) * a_strides[i]; + temp /= G_dims[i]; + } + + // Decode M dimensions + temp = m_flat; + for(int i = num_m_dims - 1; i >= 0; --i) + { + offset += (temp % M_dims[i]) * a_strides[num_g_dims + i]; + temp /= M_dims[i]; + } + + // Decode K dimensions + temp = k_flat; + for(int i = num_k_dims - 1; i >= 0; --i) + { + offset += (temp % K_dims[i]) * a_strides[num_g_dims + num_m_dims + i]; + temp /= K_dims[i]; + } + + return offset; + }; + + auto compute_b_offset = [&](ck_tile::index_t g_flat, + ck_tile::index_t n_flat, + ck_tile::index_t k_flat) -> std::size_t { + std::size_t offset = 0; + + // Decode G dimensions + ck_tile::index_t temp = g_flat; + for(int i = num_g_dims - 1; i >= 0; --i) + { + offset += (temp % G_dims[i]) * b_strides[i]; + temp /= G_dims[i]; + } + + // Decode N dimensions + temp = n_flat; + for(int i = num_n_dims - 1; i >= 0; --i) + { + offset += (temp % N_dims[i]) * b_strides[num_g_dims + i]; + temp /= N_dims[i]; + } + + // Decode K dimensions + temp = k_flat; + for(int i = num_k_dims - 1; i >= 0; --i) + { + offset += (temp % K_dims[i]) * b_strides[num_g_dims + num_n_dims + i]; + temp /= K_dims[i]; + } + + return offset; + }; + + auto compute_e_offset = [&](ck_tile::index_t g_flat, + ck_tile::index_t m_flat, + ck_tile::index_t n_flat) -> std::size_t { + std::size_t offset = 0; + + // Decode G dimensions + ck_tile::index_t temp = g_flat; + for(int i = num_g_dims - 1; i >= 0; --i) + { + offset += (temp % G_dims[i]) * e_strides[i]; + temp /= G_dims[i]; + } + + // Decode M dimensions + temp = m_flat; + for(int i = num_m_dims - 1; i >= 0; --i) + { + offset += (temp % M_dims[i]) * e_strides[num_g_dims + i]; + temp /= M_dims[i]; + } + + // Decode N dimensions + temp = n_flat; + for(int i = num_n_dims - 1; i >= 0; --i) + { + offset += (temp % N_dims[i]) * e_strides[num_g_dims + num_m_dims + i]; + temp /= N_dims[i]; + } + + return offset; + }; + + // Helper to compute D tensor offset (D tensors have same shape as E: [G, M, N]) + auto compute_d_offset = [&](ck_tile::index_t g_flat, + ck_tile::index_t m_flat, + ck_tile::index_t n_flat, + ck_tile::index_t d_idx) -> std::size_t { + std::size_t offset = 0; + const auto& d_strides = ds_strides[d_idx]; + + // Decode G dimensions + ck_tile::index_t temp = g_flat; + for(int i = num_g_dims - 1; i >= 0; --i) + { + offset += (temp % G_dims[i]) * d_strides[i]; + temp /= G_dims[i]; + } + + // Decode M dimensions + temp = m_flat; + for(int i = num_m_dims - 1; i >= 0; --i) + { + offset += (temp % M_dims[i]) * d_strides[num_g_dims + i]; + temp /= M_dims[i]; + } + + // Decode N dimensions + temp = n_flat; + for(int i = num_n_dims - 1; i >= 0; --i) + { + offset += (temp % N_dims[i]) * d_strides[num_g_dims + num_m_dims + i]; + temp /= N_dims[i]; + } + + return offset; + }; + + // Parallel computation over G and M dimensions auto f_gm = [&](auto g_flat, auto m_flat) { for(ck_tile::index_t n_flat = 0; n_flat < N_total; ++n_flat) { AccDataType sum = 0; - // Compute dot product over K dimension + // Compute dot product over K dimension using stride-aware indexing for(ck_tile::index_t k_flat = 0; k_flat < K_total; ++k_flat) { - auto a_val = - a_full_dims.mData[g_flat * M_total * K_total + m_flat * K_total + k_flat]; - auto b_val = - b_full_dims.mData[g_flat * N_total * K_total + n_flat * K_total + k_flat]; + const std::size_t a_offset = compute_a_offset(g_flat, m_flat, k_flat); + const std::size_t b_offset = compute_b_offset(g_flat, n_flat, k_flat); + + auto a_val = a_full_dims.mData[a_offset]; + auto b_val = b_full_dims.mData[b_offset]; sum += static_cast(a_val) * static_cast(b_val); } - // Apply elementwise operation with D tensors - EDataType result = static_cast(sum); - if(ds_full_dims_host.size() == 0) + // Compute output offset using strides + const std::size_t e_offset = compute_e_offset(g_flat, m_flat, n_flat); + + // Compute individual D tensor offsets using their respective strides + std::array d_offsets; + for(ck_tile::index_t d = 0; d < NumDTensor; ++d) { - ; - } - else if(ds_full_dims_host.size() == 1) - { - cde_elementwise(result, - ck_tile::type_convert(sum), - ck_tile::type_convert( - ds_full_dims_host[0].mData[g_flat * M_total * N_total + - m_flat * N_total + n_flat])); - } - else if(ds_full_dims_host.size() == 2) - { - cde_elementwise( - result, - ck_tile::type_convert(sum), - ck_tile::type_convert( - ds_full_dims_host[0] - .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]), - ck_tile::type_convert( - ds_full_dims_host[1] - .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat])); - } - else if(ds_full_dims_host.size() == 3) - { - cde_elementwise( - result, - ck_tile::type_convert(sum), - ck_tile::type_convert( - ds_full_dims_host[0] - .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]), - ck_tile::type_convert( - ds_full_dims_host[1] - .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]), - ck_tile::type_convert( - ds_full_dims_host[2] - .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat])); - } - else if(ds_full_dims_host.size() == 4) - { - cde_elementwise( - result, - ck_tile::type_convert(sum), - ck_tile::type_convert( - ds_full_dims_host[0] - .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]), - ck_tile::type_convert( - ds_full_dims_host[1] - .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]), - ck_tile::type_convert( - ds_full_dims_host[2] - .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]), - ck_tile::type_convert( - ds_full_dims_host[3] - .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat])); - } - else - { - throw std::runtime_error("Unsupported NumDTensor for reference calculation"); + d_offsets[d] = compute_d_offset(g_flat, m_flat, n_flat, d); } - // Store result - e_full_dims_host_ref.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat] = - static_cast(result); + // Apply elementwise operation with D tensors using compile-time dispatch + EDataType result = static_cast(sum); + ExtractDValues::apply_at_offsets( + result, sum, cde_elementwise, ds_full_dims_host, d_offsets); + + // Store result using stride-aware indexing + e_full_dims_host_ref.mData[e_offset] = static_cast(result); } }; @@ -125,147 +272,4 @@ void calculate_reference_flat_indexing( make_ParallelTensorFunctor(f_gm, G_total, M_total)(std::thread::hardware_concurrency()); } -template -void calculate_reference_multi_dimensional( - const HostTensor& a_full_dims, - const HostTensor& b_full_dims, - const std::vector>& ds_full_dims_host, - HostTensor& e_full_dims_host_ref, - const std::vector& G_dims, - const std::vector& M_dims, - const std::vector& N_dims, - const std::vector& K_dims, - const std::vector& A_dims, - const std::vector& B_dims, - const std::vector& E_dims, - const CDEElementWise& cde_elementwise) -{ - std::cout << "Calculating reference using multi-dimensional indexing..." << std::endl; - - std::vector g_idx(G_dims.size()); - std::vector m_idx(M_dims.size()); - std::vector n_idx(N_dims.size()); - std::vector k_idx(K_dims.size()); - std::vector a_idx, b_idx, e_idx; - - a_idx.reserve(A_dims.size()); - b_idx.reserve(B_dims.size()); - e_idx.reserve(E_dims.size()); - - auto calculate_total_elements = [](const std::vector& dims) { - return std::accumulate(dims.begin(), dims.end(), 1, std::multiplies()); - }; - - for(ck_tile::index_t g_flat = 0; g_flat < calculate_total_elements(G_dims); ++g_flat) - { - ck_tile::index_t temp = g_flat; - for(int i = G_dims.size() - 1; i >= 0; --i) - { - g_idx[i] = temp % G_dims[i]; - temp /= G_dims[i]; - } - - for(ck_tile::index_t m_flat = 0; m_flat < calculate_total_elements(M_dims); ++m_flat) - { - temp = m_flat; - for(int i = M_dims.size() - 1; i >= 0; --i) - { - m_idx[i] = temp % M_dims[i]; - temp /= M_dims[i]; - } - - for(ck_tile::index_t n_flat = 0; n_flat < calculate_total_elements(N_dims); ++n_flat) - { - temp = n_flat; - for(int i = N_dims.size() - 1; i >= 0; --i) - { - n_idx[i] = temp % N_dims[i]; - temp /= N_dims[i]; - } - - AccDataType sum = 0; - - for(ck_tile::index_t k_flat = 0; k_flat < calculate_total_elements(K_dims); - ++k_flat) - { - temp = k_flat; - for(int i = K_dims.size() - 1; i >= 0; --i) - { - k_idx[i] = temp % K_dims[i]; - temp /= K_dims[i]; - } - - a_idx.clear(); - b_idx.clear(); - - a_idx.insert(a_idx.end(), g_idx.begin(), g_idx.end()); - a_idx.insert(a_idx.end(), m_idx.begin(), m_idx.end()); - a_idx.insert(a_idx.end(), k_idx.begin(), k_idx.end()); - - b_idx.insert(b_idx.end(), g_idx.begin(), g_idx.end()); - b_idx.insert(b_idx.end(), n_idx.begin(), n_idx.end()); - b_idx.insert(b_idx.end(), k_idx.begin(), k_idx.end()); - - auto a_val = a_full_dims(a_idx); - auto b_val = b_full_dims(b_idx); - - sum += static_cast(a_val) * static_cast(b_val); - } - - e_idx.clear(); - e_idx.insert(e_idx.end(), g_idx.begin(), g_idx.end()); - e_idx.insert(e_idx.end(), m_idx.begin(), m_idx.end()); - e_idx.insert(e_idx.end(), n_idx.begin(), n_idx.end()); - - EDataType result = static_cast(sum); - if(ds_full_dims_host.size() == 0) - { - ; - } - else if(ds_full_dims_host.size() == 1) - { - cde_elementwise(result, - ck_tile::type_convert(sum), - ck_tile::type_convert(ds_full_dims_host[0](e_idx))); - } - else if(ds_full_dims_host.size() == 2) - { - cde_elementwise(result, - ck_tile::type_convert(sum), - ck_tile::type_convert(ds_full_dims_host[0](e_idx)), - ck_tile::type_convert(ds_full_dims_host[1](e_idx))); - } - else if(ds_full_dims_host.size() == 3) - { - cde_elementwise(result, - ck_tile::type_convert(sum), - ck_tile::type_convert(ds_full_dims_host[0](e_idx)), - ck_tile::type_convert(ds_full_dims_host[1](e_idx)), - ck_tile::type_convert(ds_full_dims_host[2](e_idx))); - } - else if(ds_full_dims_host.size() == 4) - { - cde_elementwise(result, - ck_tile::type_convert(sum), - ck_tile::type_convert(ds_full_dims_host[0](e_idx)), - ck_tile::type_convert(ds_full_dims_host[1](e_idx)), - ck_tile::type_convert(ds_full_dims_host[2](e_idx)), - ck_tile::type_convert(ds_full_dims_host[3](e_idx))); - } - else - { - throw std::runtime_error("Unsupported NumDTensor for reference calculation"); - } - - e_full_dims_host_ref(e_idx) = static_cast(result); - } - } - } -} - } // namespace ck_tile diff --git a/include/ck_tile/host/reference/reference_batched_dropout.hpp b/include/ck_tile/host/reference/reference_batched_dropout.hpp index 242101bf4d..a92c6f2ab9 100644 --- a/include/ck_tile/host/reference/reference_batched_dropout.hpp +++ b/include/ck_tile/host/reference/reference_batched_dropout.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/reference/reference_batched_dropout_randval.hpp b/include/ck_tile/host/reference/reference_batched_dropout_randval.hpp index ec6c6009b7..3a96d6f8bc 100644 --- a/include/ck_tile/host/reference/reference_batched_dropout_randval.hpp +++ b/include/ck_tile/host/reference/reference_batched_dropout_randval.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/reference/reference_batched_elementwise.hpp b/include/ck_tile/host/reference/reference_batched_elementwise.hpp index abd987c157..31fbd3676f 100644 --- a/include/ck_tile/host/reference/reference_batched_elementwise.hpp +++ b/include/ck_tile/host/reference/reference_batched_elementwise.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/reference/reference_batched_gemm.hpp b/include/ck_tile/host/reference/reference_batched_gemm.hpp index 826358de30..63f13b1b16 100644 --- a/include/ck_tile/host/reference/reference_batched_gemm.hpp +++ b/include/ck_tile/host/reference/reference_batched_gemm.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/reference/reference_batched_masking.hpp b/include/ck_tile/host/reference/reference_batched_masking.hpp index eece7fc3a8..c2dd8abe23 100644 --- a/include/ck_tile/host/reference/reference_batched_masking.hpp +++ b/include/ck_tile/host/reference/reference_batched_masking.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp b/include/ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp index 858144c8ba..0aa2519fe8 100644 --- a/include/ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp +++ b/include/ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/reference/reference_batched_softmax.hpp b/include/ck_tile/host/reference/reference_batched_softmax.hpp index 10222cb766..7ef3565d72 100644 --- a/include/ck_tile/host/reference/reference_batched_softmax.hpp +++ b/include/ck_tile/host/reference/reference_batched_softmax.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/reference/reference_batched_transpose.hpp b/include/ck_tile/host/reference/reference_batched_transpose.hpp index 454ab42e32..53545d5c33 100644 --- a/include/ck_tile/host/reference/reference_batched_transpose.hpp +++ b/include/ck_tile/host/reference/reference_batched_transpose.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/reference/reference_elementwise.hpp b/include/ck_tile/host/reference/reference_elementwise.hpp index 3e174bf870..f55882b9d5 100644 --- a/include/ck_tile/host/reference/reference_elementwise.hpp +++ b/include/ck_tile/host/reference/reference_elementwise.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/reference/reference_fused_moe.hpp b/include/ck_tile/host/reference/reference_fused_moe.hpp index 4b4687d3d0..ede38d38e7 100644 --- a/include/ck_tile/host/reference/reference_fused_moe.hpp +++ b/include/ck_tile/host/reference/reference_fused_moe.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index 4d0f92f3e0..883b08fcaa 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/reference/reference_grouped_conv_bwd_data.hpp b/include/ck_tile/host/reference/reference_grouped_conv_bwd_data.hpp index c8264800c9..e141d842dd 100644 --- a/include/ck_tile/host/reference/reference_grouped_conv_bwd_data.hpp +++ b/include/ck_tile/host/reference/reference_grouped_conv_bwd_data.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/reference/reference_grouped_conv_bwd_weight.hpp b/include/ck_tile/host/reference/reference_grouped_conv_bwd_weight.hpp index 346a03d1e8..e2282ebb85 100644 --- a/include/ck_tile/host/reference/reference_grouped_conv_bwd_weight.hpp +++ b/include/ck_tile/host/reference/reference_grouped_conv_bwd_weight.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/reference/reference_grouped_conv_fwd.hpp b/include/ck_tile/host/reference/reference_grouped_conv_fwd.hpp index 0b4995da3b..831e8b6882 100644 --- a/include/ck_tile/host/reference/reference_grouped_conv_fwd.hpp +++ b/include/ck_tile/host/reference/reference_grouped_conv_fwd.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/reference/reference_im2col.hpp b/include/ck_tile/host/reference/reference_im2col.hpp index 392d6abd47..b660def881 100644 --- a/include/ck_tile/host/reference/reference_im2col.hpp +++ b/include/ck_tile/host/reference/reference_im2col.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/reference/reference_layernorm2d_fwd.hpp b/include/ck_tile/host/reference/reference_layernorm2d_fwd.hpp index 62cd26b6ab..7b65abd6e2 100644 --- a/include/ck_tile/host/reference/reference_layernorm2d_fwd.hpp +++ b/include/ck_tile/host/reference/reference_layernorm2d_fwd.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/reference/reference_moe_gemm.hpp b/include/ck_tile/host/reference/reference_moe_gemm.hpp index 13203b8f7c..18d8e94751 100644 --- a/include/ck_tile/host/reference/reference_moe_gemm.hpp +++ b/include/ck_tile/host/reference/reference_moe_gemm.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/reference/reference_moe_sorting.hpp b/include/ck_tile/host/reference/reference_moe_sorting.hpp index b7615d0478..83ffae4468 100644 --- a/include/ck_tile/host/reference/reference_moe_sorting.hpp +++ b/include/ck_tile/host/reference/reference_moe_sorting.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/reference/reference_permute.hpp b/include/ck_tile/host/reference/reference_permute.hpp index 4e0f1a877e..b049c1baf3 100644 --- a/include/ck_tile/host/reference/reference_permute.hpp +++ b/include/ck_tile/host/reference/reference_permute.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/reference/reference_pool.hpp b/include/ck_tile/host/reference/reference_pool.hpp index 7a2848def5..874a67a1ab 100644 --- a/include/ck_tile/host/reference/reference_pool.hpp +++ b/include/ck_tile/host/reference/reference_pool.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/reference/reference_reduce.hpp b/include/ck_tile/host/reference/reference_reduce.hpp index 9952b7b009..07834a920e 100644 --- a/include/ck_tile/host/reference/reference_reduce.hpp +++ b/include/ck_tile/host/reference/reference_reduce.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp b/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp index 424fff4470..db4ad34f24 100644 --- a/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp +++ b/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/reference/reference_rowwise_quantization2d.hpp b/include/ck_tile/host/reference/reference_rowwise_quantization2d.hpp index aff5e78ff0..5dbf856670 100644 --- a/include/ck_tile/host/reference/reference_rowwise_quantization2d.hpp +++ b/include/ck_tile/host/reference/reference_rowwise_quantization2d.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/reference/reference_softmax.hpp b/include/ck_tile/host/reference/reference_softmax.hpp index 4e729c437d..5333b9af64 100644 --- a/include/ck_tile/host/reference/reference_softmax.hpp +++ b/include/ck_tile/host/reference/reference_softmax.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/reference/reference_topk.hpp b/include/ck_tile/host/reference/reference_topk.hpp index 0fc99a983a..8c9c87e3ee 100644 --- a/include/ck_tile/host/reference/reference_topk.hpp +++ b/include/ck_tile/host/reference/reference_topk.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/reference/reference_transpose.hpp b/include/ck_tile/host/reference/reference_transpose.hpp index 45d3dc9efa..0847cf1f0c 100644 --- a/include/ck_tile/host/reference/reference_transpose.hpp +++ b/include/ck_tile/host/reference/reference_transpose.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/rotating_buffers.hpp b/include/ck_tile/host/rotating_buffers.hpp index 601b8f2378..baec4b45e8 100644 --- a/include/ck_tile/host/rotating_buffers.hpp +++ b/include/ck_tile/host/rotating_buffers.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/stream_config.hpp b/include/ck_tile/host/stream_config.hpp index acb861b2e7..4f1280c440 100644 --- a/include/ck_tile/host/stream_config.hpp +++ b/include/ck_tile/host/stream_config.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/stream_utils.hpp b/include/ck_tile/host/stream_utils.hpp index 25faba9bfc..d5eab7461e 100644 --- a/include/ck_tile/host/stream_utils.hpp +++ b/include/ck_tile/host/stream_utils.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/tensor_shuffle_utils.hpp b/include/ck_tile/host/tensor_shuffle_utils.hpp index 8be32fa910..fb47d38478 100644 --- a/include/ck_tile/host/tensor_shuffle_utils.hpp +++ b/include/ck_tile/host/tensor_shuffle_utils.hpp @@ -1,3 +1,6 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + #pragma once #include diff --git a/include/ck_tile/host/timer.hpp b/include/ck_tile/host/timer.hpp index e5519643bf..1d641d1812 100644 --- a/include/ck_tile/host/timer.hpp +++ b/include/ck_tile/host/timer.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp index b6eac45285..b9651995d5 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_default_policy.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_default_policy.hpp index 0b9bae4e9e..c691b9e32a 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_default_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp index 64e5224780..d111ea921b 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp index 2e64060038..de25055cee 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp index 052ee4ae62..f31f27cdfe 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp b/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp index 6d8f9f3f0e..968d5d6ac2 100644 --- a/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp +++ b/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp @@ -1,10 +1,11 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" #include "ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp" +#include "ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp" #include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp" /** @@ -60,23 +61,19 @@ * Rather than implementing tensor contraction from scratch, this kernel leverages the highly * optimized `UniversalGemmKernel` as its computational backend. * - * @subsection current_limitations Current Kernel Limitations + * @subsection implementation_features Implementation Features * - * **Layout Restrictions:** - * - **Row-Major Only**: All tensors must use row-major memory layout - * - **Packed Tensors**: Only contiguous/packed tensor layouts supported - * - **Hardcoded Strides**: stride_A = K_total, stride_B = K_total, stride_E = N_total - * - **D Tensor Layout**: All D tensors must match E tensor layout (stride_Ds = N_total) + * **Stride Support:** + * - Supports arbitrary multi-dimensional stride patterns + * - Handles non-contiguous and padded tensor layouts + * - Independent strides for each auxiliary D tensor + * - Descriptor-based architecture with vectorization * - * **Implementation Constraints:** - * - **Fixed Stride Calculation**: Strides are automatically calculated and cannot be customized - * - **No Column-Major**: Column-major or custom stride patterns not supported - * - **No Strided Access**: Non-contiguous tensor slicing not supported - * - * **Future Enhancements:** - * - Support for arbitrary stride patterns - * - Column-major and mixed layout support - * - Non-contiguous tensor operation support + * **Architecture:** + * - Uses TensorDescriptorUtils for stride-aware descriptor creation + * - Custom RunGemm implementation with descriptor-based tensor views + * - Reuses GemmPipeline and EpiloguePipeline for computation + * - Split-K support via UniversalGemmKernel utilities */ namespace ck_tile { @@ -184,7 +181,10 @@ template + ck_tile::index_t NumDTensor = 0, + ck_tile::index_t VectorSizeA = 1, + ck_tile::index_t VectorSizeB = 1, + ck_tile::index_t VectorSizeE = 1> struct BatchedContractionKernelArgs { const void* a_ptr; ///< Pointer to input tensor A @@ -210,11 +210,46 @@ struct BatchedContractionKernelArgs ck_tile::index_t N_total; ///< Total N dimension: N0 * N1 * ... * N_{NumDimN-1} ck_tile::index_t K_total; ///< Total K dimension: K0 * K1 * ... * K_{NumDimK-1} - ck_tile::index_t stride_A; ///< Leading dimension stride for tensor A (row-major: K_total) - ck_tile::index_t stride_B; ///< Leading dimension stride for tensor B (row-major: K_total) + ck_tile::index_t + stride_A; ///< Leading dimension stride for tensor A (for backward compatibility) + ck_tile::index_t + stride_B; ///< Leading dimension stride for tensor B (for backward compatibility) std::array - stride_Ds; ///< Leading dimension strides for D tensors (row-major: N_total) - ck_tile::index_t stride_E; ///< Leading dimension stride for tensor E (row-major: N_total) + stride_Ds; ///< Leading dimension strides for D tensors (for backward compatibility) + ck_tile::index_t + stride_E; ///< Leading dimension stride for tensor E (for backward compatibility) + + // Tensor descriptors (encode full multi-dimensional stride information with vectorization) + using AGridDesc_M_K_ = + decltype(TensorDescriptorUtils::Make_A_GridDescriptor_M_K({}, {})); + using BGridDesc_N_K_ = + decltype(TensorDescriptorUtils::Make_B_GridDescriptor_N_K({}, {})); + using EGridDesc_M_N_ = + decltype(TensorDescriptorUtils::Make_E_GridDescriptor_M_N({}, {})); + + AGridDesc_M_K_ a_grid_desc_m_k; ///< Tensor descriptor for A[M, K] with actual strides + BGridDesc_N_K_ b_grid_desc_n_k; ///< Tensor descriptor for B[N, K] with actual strides + EGridDesc_M_N_ e_grid_desc_m_n; ///< Tensor descriptor for E[M, N] with actual strides + std::array + ds_grid_desc_m_n; ///< Descriptors for D tensors (same shape as E, independent strides) }; /// @brief GPU kernel for batched tensor contraction operations. @@ -274,10 +309,24 @@ struct BatchedContractionKernel static constexpr ck_tile::index_t kBlockSize = UniversalGemmKernel::kBlockSize; ///< GPU block size inherited from GEMM kernel - using KernelArgs = - BatchedContractionKernelArgs; ///< Kernel - ///< argument - ///< structure + // Tensor descriptor utilities with vectorization support + using DescriptorUtils = TensorDescriptorUtils; + + // Kernel arguments with vectorization support + using KernelArgs = BatchedContractionKernelArgs; /// @brief Returns the kernel name for debugging and profiling purposes. /// @return Constant string identifier for this kernel @@ -326,6 +375,104 @@ struct BatchedContractionKernel TilePartitioner::GridSize(kargs.M_total, kargs.N_total), kargs.G_total, kargs.k_batch); } + /// @brief Executes GEMM computation with descriptor-based tensor views for arbitrary stride + /// support + /// + /// @details This function performs the core GEMM computation using tensor descriptors to handle + /// arbitrary multi-dimensional stride patterns. It creates tensor views from + /// pre-computed descriptors (stored in kargs), applies padding, creates tile windows, + /// and executes the GemmPipeline and EpiloguePipeline. + /// + /// @param a_ptr Pointer to input tensor A data (after batch and split-K offsets applied) + /// @param b_ptr Pointer to input tensor B data (after batch and split-K offsets applied) + /// @param ds_ptr Array of pointers to auxiliary D tensor data + /// @param e_ptr Pointer to output tensor E data (after batch offset applied) + /// @param smem_ptr Pointer to shared memory for tile operations + /// @param kargs Kernel arguments containing tensor descriptors and dimension information + /// @param k_size Size of K dimension for this split (for split-K support) + /// @param i_m Starting M index for this block's tile + /// @param i_n Starting N index for this block's tile + CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr, + const BDataType* b_ptr, + const std::array& ds_ptr, + EDataType* e_ptr, + void* smem_ptr, + const KernelArgs& kargs, + const index_t k_size, + const index_t i_m, + const index_t i_n) + { + // Create tensor views from descriptors (supports arbitrary stride patterns) + auto a_tensor_view = + make_tensor_view(a_ptr, kargs.a_grid_desc_m_k); + auto b_tensor_view = + make_tensor_view(b_ptr, kargs.b_grid_desc_n_k); + auto e_tensor_view = + make_tensor_view(e_ptr, kargs.e_grid_desc_m_n); + + // Pad views for boundary handling and optimization (like UniversalGemmKernel) + auto a_pad_view = pad_tensor_view( + a_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + + auto b_pad_view = pad_tensor_view( + b_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + + auto e_pad_view = pad_tensor_view( + e_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + + // Create tile windows from PADDED views + auto a_block_window = make_tile_window( + a_pad_view, + make_tuple(number{}, number{}), + {i_m, 0}); + + auto b_block_window = make_tile_window( + b_pad_view, + make_tuple(number{}, number{}), + {i_n, 0}); + + auto e_block_window = make_tile_window( + e_pad_view, + make_tuple(number{}, number{}), + {i_m, i_n}); + + // Calculate number of K loops + const index_t num_loop = + __builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(k_size)); + + // Run GEMM Pipeline (same as UniversalGemmKernel, but with descriptor-based windows) + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; + + const auto& c_block_tile = GemmPipeline{}( + a_block_window, AElementWise{}, b_block_window, BElementWise{}, num_loop, smem_ptr); + + // Create D windows from descriptors (for each D tensor) + auto ds_block_windows = generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + const DDataType* d_ptr = static_cast(ds_ptr[i]); + + auto d_tensor_view = + make_tensor_view(d_ptr, kargs.ds_grid_desc_m_n[i]); + + return make_tile_window(d_tensor_view, + make_tuple(number{}, + number{}), + {i_m, i_n}); + }, + number{}); + + // Run Epilogue Pipeline with descriptor-based D windows + EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_windows, smem_ptr); + } + CK_TILE_HOST static constexpr KernelArgs MakeKernelArgs(const BatchedContractionHostArgs& host_args) { @@ -435,6 +582,22 @@ struct BatchedContractionKernel kargs.K_total *= kargs.K_dims[i]; } + // Create tensor descriptors on host using actual dims and strides + kargs.a_grid_desc_m_k = + DescriptorUtils::Make_A_GridDescriptor_M_K(host_args.A_dims, host_args.A_strides); + kargs.b_grid_desc_n_k = + DescriptorUtils::Make_B_GridDescriptor_N_K(host_args.B_dims, host_args.B_strides); + kargs.e_grid_desc_m_n = + DescriptorUtils::Make_E_GridDescriptor_M_N(host_args.E_dims, host_args.E_strides); + + // Create D descriptors with their own strides (same shape as E, independent strides) + for(ck_tile::index_t d = 0; d < NumDTensor; ++d) + { + kargs.ds_grid_desc_m_n[d] = DescriptorUtils::Make_E_GridDescriptor_M_N( + host_args.Ds_dims[d], host_args.Ds_strides[d]); + } + + // Keep simple strides for backward compatibility kargs.stride_A = kargs.K_total; kargs.stride_B = kargs.K_total; kargs.stride_E = kargs.N_total; @@ -468,8 +631,8 @@ struct BatchedContractionKernel const ck_tile::index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); - const auto i_batch_flat = __builtin_amdgcn_readfirstlane(blockIdx.y); - const auto i_splitk = __builtin_amdgcn_readfirstlane(blockIdx.z); + const auto i_batch_flat = __builtin_amdgcn_readfirstlane(blockIdx.y); + [[maybe_unused]] const auto i_splitk = __builtin_amdgcn_readfirstlane(blockIdx.z); // Calculate batch offsets for each tensor const auto batch_offset_A = i_batch_flat * kargs.batch_stride_A; @@ -487,6 +650,10 @@ struct BatchedContractionKernel ds_batch_ptr[i] = static_cast(kargs.ds_ptr[i]) + batch_offset_D; }); + // Allocate shared memory + __shared__ char smem_ptr[GetSmemSize()]; + + // Use UniversalGemmKernel's SplitKBatchOffset for split-K calculation typename UniversalGemmKernel::KernelArgs gemm_kargs{{a_ptr}, {b_ptr}, ds_batch_ptr, @@ -503,19 +670,19 @@ struct BatchedContractionKernel const typename UniversalGemmKernel::SplitKBatchOffset splitk_batch_offset(gemm_kargs, i_splitk); - const ADataType* a_ptr_final = a_ptr + splitk_batch_offset.as_k_split_offset[0]; - const BDataType* b_ptr_final = b_ptr + splitk_batch_offset.bs_k_split_offset[0]; - __shared__ char smem_ptr[GetSmemSize()]; + // Apply K-split offsets and run descriptor-based RunGemm + const ADataType* a_ptr_split = a_ptr + splitk_batch_offset.as_k_split_offset[0]; + const BDataType* b_ptr_split = b_ptr + splitk_batch_offset.bs_k_split_offset[0]; - UniversalGemmKernel::RunGemm({a_ptr_final}, - {b_ptr_final}, - ds_batch_ptr, - e_ptr, - smem_ptr, - gemm_kargs, - splitk_batch_offset, - i_m, - i_n); + RunGemm(a_ptr_split, + b_ptr_split, + ds_batch_ptr, + e_ptr, + smem_ptr, + kargs, + splitk_batch_offset.splitted_k, + i_m, + i_n); } }; diff --git a/include/ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp b/include/ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp index 9ebaae3c97..8c2d018967 100644 --- a/include/ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp +++ b/include/ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp @@ -1,7 +1,6 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once - #include "ck_tile/core.hpp" namespace ck_tile { diff --git a/include/ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp b/include/ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp index 6d3286ce09..4767a430ac 100644 --- a/include/ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp +++ b/include/ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -13,8 +13,8 @@ * dimensions for GEMM operations. These functions transform multi-dimensional tensors into * 2D matrix descriptors by removing batch dimensions and flattening the remaining dimensions. * - * These utilities are currently not used in the main batched contraction kernel but are preserved - * for future implementations that may require explicit tensor descriptor creation. + * These utilities are used by BatchedContractionKernel to create stride-aware descriptors + * that support arbitrary multi-dimensional non-contiguous tensor layouts. */ namespace ck_tile { @@ -30,7 +30,10 @@ namespace ck_tile { template + ck_tile::index_t NumDimK, + ck_tile::index_t VectorSizeA, + ck_tile::index_t VectorSizeB, + ck_tile::index_t VectorSizeE> struct TensorDescriptorUtils { /// @brief Creates a tensor descriptor for input tensor A with batch dimensions removed. @@ -62,9 +65,9 @@ struct TensorDescriptorUtils const auto dims_M = get_container_subset(A_dims_M_K, A_dims_M_ids); const auto dims_K = get_container_subset(A_dims_M_K, A_dims_K_ids); - // naive tensor A[M0, M1, M2, ..., K0, K1, K2...] Discriptor - const auto A_grid_desc_Ms_Ks = - ck_tile::make_naive_tensor_descriptor(A_dims_M_K, A_strides_M_K); + // naive tensor A[M0, M1, M2, ..., K0, K1, K2...] Descriptor with vector size + const auto A_grid_desc_Ms_Ks = ck_tile::make_naive_tensor_descriptor( + A_dims_M_K, A_strides_M_K, number{}, number<1>{}); // transformed tensor to flatten M and K dimensions [M_total = M0 * M1 * M2 * ... , K_total // = K0 * K1 * K2 * ...] @@ -106,9 +109,9 @@ struct TensorDescriptorUtils const auto dims_N = get_container_subset(B_dims_N_K, B_dims_N_ids); const auto dims_K = get_container_subset(B_dims_N_K, B_dims_K_ids); - // naive tensor B[N0, N1, N2, ..., K0, K1, K2...] Discriptor - const auto B_grid_desc_Ns_Ks = - ck_tile::make_naive_tensor_descriptor(B_dims_N_K, B_strides_N_K); + // naive tensor B[N0, N1, N2, ..., K0, K1, K2...] Descriptor with vector size + const auto B_grid_desc_Ns_Ks = ck_tile::make_naive_tensor_descriptor( + B_dims_N_K, B_strides_N_K, number{}, number<1>{}); // transformed tensor to flatten N and K dimensions [N_total = N0 * N1 * N2 * ... , K_total // = K0 * K1 * K2 * ...] @@ -150,9 +153,9 @@ struct TensorDescriptorUtils const auto dims_M = get_container_subset(E_dims_M_N, E_dims_M_ids); const auto dims_N = get_container_subset(E_dims_M_N, E_dims_N_ids); - // naive tensor E[M0, M1, M2, ..., N0, N1, N2...] Discriptor - const auto E_grid_desc_Ms_Ns = - ck_tile::make_naive_tensor_descriptor(E_dims_M_N, E_strides_M_N); + // naive tensor E[M0, M1, M2, ..., N0, N1, N2...] Descriptor with vector size + const auto E_grid_desc_Ms_Ns = ck_tile::make_naive_tensor_descriptor( + E_dims_M_N, E_strides_M_N, number{}, number<1>{}); // transformed tensor to flatten M and N dimensions [M_total = M0 * M1 * M2 * ... , // N_total = N0 * N1 * N2 * ...] diff --git a/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp b/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp index c99571562d..3d17933fb7 100644 --- a/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp +++ b/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_common_policy.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_common_policy.hpp index 9e2a67f940..e872ef77ce 100644 --- a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_common_policy.hpp +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_common_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_pipeline.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_pipeline.hpp index ef0b7fa229..44029f84d0 100644 --- a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_pipeline.hpp +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_pipeline.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_policy.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_policy.hpp index 77c3db9c06..163c4010ef 100644 --- a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_policy.hpp +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_problem.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_problem.hpp index b791bf9727..66854941b8 100644 --- a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_problem.hpp +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_problem.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp index 633827f3c3..d689bae79c 100644 --- a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp index 137584c3e8..463863d490 100644 --- a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp index 2be979723b..584ac8d350 100644 --- a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/common/generic_2d_block_shape.hpp b/include/ck_tile/ops/common/generic_2d_block_shape.hpp index 9c5d99efc3..4c11209a9b 100644 --- a/include/ck_tile/ops/common/generic_2d_block_shape.hpp +++ b/include/ck_tile/ops/common/generic_2d_block_shape.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/common/load_interleaved_pk_type.hpp b/include/ck_tile/ops/common/load_interleaved_pk_type.hpp index 91fa61763a..10c2a1e4df 100644 --- a/include/ck_tile/ops/common/load_interleaved_pk_type.hpp +++ b/include/ck_tile/ops/common/load_interleaved_pk_type.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -48,7 +48,7 @@ CK_TILE_DEVICE void load_int4_tile(WarpTile& dst, const WarpWindow& src) } else { - dst = load_tile(src); + load_tile(dst, src); } } diff --git a/include/ck_tile/ops/common/streamk_common.hpp b/include/ck_tile/ops/common/streamk_common.hpp index 5dbe6223c4..c97282a8be 100644 --- a/include/ck_tile/ops/common/streamk_common.hpp +++ b/include/ck_tile/ops/common/streamk_common.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/common/tensor_layout.hpp b/include/ck_tile/ops/common/tensor_layout.hpp index bb905e6ab9..6f30b48f53 100644 --- a/include/ck_tile/ops/common/tensor_layout.hpp +++ b/include/ck_tile/ops/common/tensor_layout.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/common/utils.hpp b/include/ck_tile/ops/common/utils.hpp index d1e6813b01..318a1a5860 100644 --- a/include/ck_tile/ops/common/utils.hpp +++ b/include/ck_tile/ops/common/utils.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -11,15 +11,17 @@ namespace ck_tile { // clang-format off -template struct typeToStr; -template <> struct typeToStr { static constexpr const char * name = "fp32"; }; -template <> struct typeToStr { static constexpr const char * name = "fp16"; }; -template <> struct typeToStr { static constexpr const char * name = "bf16"; }; -template <> struct typeToStr { static constexpr const char * name = "fp8"; }; -template <> struct typeToStr { static constexpr const char * name = "bf8"; }; -template <> struct typeToStr { static constexpr const char * name = "int8"; }; -template <> struct typeToStr { static constexpr const char * name = "pk_int4"; }; -template <> struct typeToStr { static constexpr const char * name = "pk_fp4"; }; +template struct DataTypeTraits; +template <> struct DataTypeTraits { static constexpr const char * name = "fp32"; }; +template <> struct DataTypeTraits { static constexpr const char * name = "fp64"; }; +template <> struct DataTypeTraits { static constexpr const char * name = "int32"; }; +template <> struct DataTypeTraits { static constexpr const char * name = "fp16"; }; +template <> struct DataTypeTraits { static constexpr const char * name = "bf16"; }; +template <> struct DataTypeTraits { static constexpr const char * name = "fp8"; }; +template <> struct DataTypeTraits { static constexpr const char * name = "bf8"; }; +template <> struct DataTypeTraits { static constexpr const char * name = "int8"; }; +template <> struct DataTypeTraits { static constexpr const char * name = "pk_int4"; }; +template <> struct DataTypeTraits { static constexpr const char * name = "pk_fp4"; }; template struct memOpToStr; template <> struct memOpToStr { static constexpr const char * name = "set"; }; @@ -31,10 +33,10 @@ template <> struct memOpToStr { static constexpr con template std::string gemm_prec_str() { - std::string base_str = std::string(typeToStr::name); + std::string base_str = std::string(DataTypeTraits::name); if(!std::is_same_v) { - base_str += "_" + std::string(typeToStr::name); + base_str += "_" + std::string(DataTypeTraits::name); } return base_str; } diff --git a/include/ck_tile/ops/elementwise/binary_elementwise_operation.hpp b/include/ck_tile/ops/elementwise/binary_elementwise_operation.hpp index bf56a36b1e..44eeeadfb0 100644 --- a/include/ck_tile/ops/elementwise/binary_elementwise_operation.hpp +++ b/include/ck_tile/ops/elementwise/binary_elementwise_operation.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp b/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp index b1e5e01777..2078a69546 100644 --- a/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp +++ b/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/elementwise/pipeline/elementwise_pipeline_default_policy.hpp b/include/ck_tile/ops/elementwise/pipeline/elementwise_pipeline_default_policy.hpp index 9cba43d350..f719fd8182 100644 --- a/include/ck_tile/ops/elementwise/pipeline/elementwise_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/elementwise/pipeline/elementwise_pipeline_default_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp b/include/ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp index a5d00ee1d0..a4edd95970 100644 --- a/include/ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp +++ b/include/ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp b/include/ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp index aaad6407d4..5393a9eb27 100644 --- a/include/ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp +++ b/include/ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp index f8f8059469..2f8d3c6053 100644 --- a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp +++ b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -214,22 +214,27 @@ CK_TILE_DEVICE fp8x8_t amd_assembly_i4_to_fp8x8(int a) uint32_t tmp_pos, tmp_neg, tmp_res_even, tmp_res_odd, final_sel; + // ---- Lower 4 int4 values (even positions) ---- + // Extract dictionary indices: low 3 bits of each byte (values 0..7). uint32_t dict_sel = a & 0x07070707; - uint32_t sign = a >> 1; - asm volatile("v_and_or_b32 %0, %1, %2, %3" - : "=v"(final_sel) - : "v"(sign), "v"(0x04040404), "v"(0x03020100)); - - tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel); - tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel); + // sign bit is bit[2] of each nibble after bias; shift to isolate per-byte sign. + uint32_t sign = a >> 1; + // Build final selector: + // - bit 2 of each byte (0x04) selects negative vs positive table + // - 0x03020100 selects byte lanes [0,1,2,3] in order + final_sel = (sign & 0x04040404) | 0x03020100; + // Lookup positive and negative fp8 codes from the small register tables. + tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel); + tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel); + // Select per-lane between tmp_pos and tmp_neg using the sign-derived selector. tmp_res_even = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel); + // ---- Upper 4 int4 values (odd positions) ---- + // Shift to bring the high-nibble int4s into place and repeat the process. a >>= 4; - dict_sel = a & 0x07070707; - sign = a >> 1; - asm volatile("v_and_or_b32 %0, %1, %2, %3" - : "=v"(final_sel) - : "v"(sign), "v"(0x04040404), "v"(0x03020100)); + dict_sel = a & 0x07070707; + sign = a >> 1; + final_sel = (sign & 0x04040404) | 0x03020100; tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel); tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel); @@ -306,22 +311,29 @@ CK_TILE_DEVICE bf8x8_t amd_assembly_i4_to_bf8x8(uint32_t a) uint32_t tmp_pos, tmp_neg, tmp_res_even, tmp_res_odd, final_sel; + // ---- Lower 4 int4 values (even positions) ---- + // Extract dictionary indices: low 3 bits of each byte (values 0..7). uint32_t dict_sel = a & 0x07070707; - uint32_t sign = a >> 1; - asm volatile("v_and_or_b32 %0, %1, %2, %3" - : "=v"(final_sel) - : "v"(sign), "v"(0x04040404), "v"(0x03020100)); - tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel); - tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel); + // sign bit is bit[2] of each nibble after bias; shift to isolate per-byte sign. + uint32_t sign = a >> 1; + // Build final selector: + // - bit 2 of each byte (0x04) selects negative vs positive table + // - 0x03020100 selects byte lanes [0,1,2,3] in order + final_sel = (sign & 0x04040404) | 0x03020100; + + // Lookup positive and negative fp8 codes from the small register tables. + tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel); + tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel); + // Select per-lane between tmp_pos and tmp_neg using the sign-derived selector. tmp_res_even = __builtin_amdgcn_perm(tmp_neg, tmp_pos, final_sel); + // ---- Upper 4 int4 values (odd positions) ---- + // Shift to bring the high-nibble int4s into place and repeat the process. a >>= 4; - dict_sel = a & 0x07070707; - sign = a >> 1; - asm volatile("v_and_or_b32 %0, %1, %2, %3" - : "=v"(final_sel) - : "v"(sign), "v"(0x04040404), "v"(0x03020100)); + dict_sel = a & 0x07070707; + sign = a >> 1; + final_sel = (sign & 0x04040404) | 0x03020100; tmp_pos = __builtin_amdgcn_perm(reg1, reg0, dict_sel); tmp_neg = __builtin_amdgcn_perm(reg3, reg2, dict_sel); diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 8a84f7e9bf..9a7876f6a5 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp b/include/ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp index 6c5a2ac149..dca1dfcf6f 100644 --- a/include/ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp index 8cf47c46e7..cc2303582e 100644 --- a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp b/include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp index c8168a1eed..b5ece33ee2 100644 --- a/include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp index 85494b3a76..2b8e9e4b1a 100644 --- a/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_custom_policy.hpp b/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_custom_policy.hpp index d5b062a1b3..946908f297 100644 --- a/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_custom_policy.hpp +++ b/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_custom_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp b/include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp index 21ca470222..24d4b8ca18 100644 --- a/include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp +++ b/include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp b/include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp index 037bb7688c..386730e45c 100644 --- a/include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp +++ b/include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp b/include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp index d645d99c9f..c040629bac 100644 --- a/include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp +++ b/include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/flatmm/block/flatmm_uk_config.hpp b/include/ck_tile/ops/flatmm/block/flatmm_uk_config.hpp index 003335c0e7..ac4cce21ec 100644 --- a/include/ck_tile/ops/flatmm/block/flatmm_uk_config.hpp +++ b/include/ck_tile/ops/flatmm/block/flatmm_uk_config.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc b/include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc index d78f266bfd..dd57bbc1de 100644 --- a/include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc +++ b/include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc @@ -1,7 +1,6 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. // clang-format off - // define the CK_TILE_** macro before include this file to change kernel variation // we will undef everything defined in this file diff --git a/include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16_itl.inc b/include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16_itl.inc index 733afdbe94..7724f2be75 100644 --- a/include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16_itl.inc +++ b/include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16_itl.inc @@ -1,7 +1,6 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. // clang-format off - // define the CK_TILE_** macro before include this file to change kernel variation // we will undef everything defined in this file diff --git a/include/ck_tile/ops/flatmm/block/uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc b/include/ck_tile/ops/flatmm/block/uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc index 10531b7a26..84fb8a9b21 100644 --- a/include/ck_tile/ops/flatmm/block/uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc +++ b/include/ck_tile/ops/flatmm/block/uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc @@ -1,7 +1,6 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. // clang-format off - // define the CK_TILE_** macro before include this file to change kernel variation // we will undef everything defined in this file, so it's safe diff --git a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp index d3ecbefd91..09204aa7ed 100644 --- a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/flatmm/kernel/grouped_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/grouped_flatmm_kernel.hpp index c47c31dd8d..58d053d5ae 100644 --- a/include/ck_tile/ops/flatmm/kernel/grouped_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/grouped_flatmm_kernel.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp index 2ee78e1fc1..05d50666a5 100644 --- a/include/ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp index 8a9aa3cdd3..b3b34a6da0 100644 --- a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -623,7 +623,7 @@ struct MoeFlatmmKernel { return make_naive_tensor_view( e_ptr, - make_tuple(IsInputGemm ? kargs.NumTokens * kargs.TopK : kargs.NumToken, + make_tuple(IsInputGemm ? kargs.NumTokens * kargs.TopK : kargs.NumTokens, IsGateUp ? kargs.N / 2 : kargs.N), make_tuple(1, kargs.stride_C), number<1>{}, @@ -1250,6 +1250,8 @@ struct MoeFlatmmKernel constexpr int MPerThread = TileEncodingPattern::Y2; statically_indexed_array, NumMEpiTile> c_scatter_offsets; + statically_indexed_array, NumMEpiTile> + c_scatter_valids; auto c_coord = dram_tile_distribution.calculate_index(); static_for<0, NumMEpiTile, 1>{}([&](auto mIter) { static_for<0, MPerThread, 1>{}([&](auto m0) { @@ -1262,6 +1264,7 @@ struct MoeFlatmmKernel scatter_token_id = scatter_token_id * kargs.TopK + (fused_token >> token_id_offset); c_scatter_offsets[mIter][m0] = scatter_token_id * kargs.stride_C; + c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens); }); }); @@ -1302,7 +1305,8 @@ struct MoeFlatmmKernel c_block_window.get_window_lengths(), c_block_window.get_window_origin(), dram_tile_distribution, - c_scatter_offsets[mIter]); + c_scatter_offsets[mIter], + c_scatter_valids[mIter]); if constexpr(!IsInputGemm || EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add) diff --git a/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp index 3f2560587a..d9fb144176 100644 --- a/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp index e6ff17952b..79b36adec4 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index 57e58b24f0..76d191a40c 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 17c88e4f08..8ec23b7570 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index f34c682b0f..ea67d80e37 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/flatmm/pipeline/moe_flatmm_pipeline_agmem_bgmem_creg.hpp b/include/ck_tile/ops/flatmm/pipeline/moe_flatmm_pipeline_agmem_bgmem_creg.hpp index 0987971a72..fe6d3ec830 100644 --- a/include/ck_tile/ops/flatmm/pipeline/moe_flatmm_pipeline_agmem_bgmem_creg.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/moe_flatmm_pipeline_agmem_bgmem_creg.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 95b4cfeaca..ff799cb0fc 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -46,8 +46,8 @@ struct MXFlatmmPipelineProblem : FlatmmPipelineProblem -struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1 +template +struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1 { using Underlying = FlatmmPipelineAGmemBGmemCRegV1; @@ -470,17 +470,39 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1(); } + template + CK_TILE_DEVICE auto operator()(Args&&... args) const + { + auto c_warp_tensors = Run_(std::forward(args)...); + + // Block GEMM Acc register tile + using CWarpDstr = typename WG::CWarpDstr; + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + auto c_block_tile = BlockFlatmm{}.MakeCBlockTile(); + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensors(mIter)(nIter).get_thread_buffer()); + }); + }); + return c_block_tile; + } + template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_copy_dram_window_tmp, - const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, - const ScaleADramBlockWindowTmp& scale_a_window, - const ScaleBDramBlockWindowTmp& scale_b_window, - index_t num_loop, - void* __restrict__ p_smem_ping, - void* __restrict__ p_smem_pong) const + CK_TILE_DEVICE auto Run_(const ADramBlockWindowTmp& a_copy_dram_window_tmp, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + const ScaleADramBlockWindowTmp& scale_a_window, + const ScaleBDramBlockWindowTmp& scale_b_window, + index_t num_loop, + void* __restrict__ p_smem_ping, + void* __restrict__ p_smem_pong) const { #ifndef __gfx950__ static_assert(false, "Only gfx950 is supported for MXFP4 flatmm pipeline now."); @@ -497,19 +519,14 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}; - auto a_dram_window = - make_tile_window(PipelinePolicy::template MakeMXFP4_AAsyncLoadDramDescriptor( + make_tile_window(PipelinePolicy::template MakeMX_AAsyncLoadDramDescriptor( a_copy_dram_window_tmp.get_bottom_tensor_view()), a_copy_dram_window_tmp.get_window_lengths(), a_copy_dram_window_tmp.get_window_origin(), - PipelinePolicy::template MakeMXFP4_ADramTileDistribution()); + PipelinePolicy::template MakeMX_ADramTileDistribution()); __builtin_amdgcn_sched_barrier(0); @@ -518,7 +535,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1(p_smem_pong); constexpr auto a_lds_block_desc = - PipelinePolicy::template MakeMXFP4_ALdsBlockDescriptor(); + PipelinePolicy::template MakeMX_ALdsBlockDescriptor(); auto a_lds_block_ping = make_tensor_view(p_a_lds_ping, a_lds_block_desc); @@ -535,39 +552,34 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}, number{}), {0, 0}, - PipelinePolicy::template MakeMXF4_ALDS_TileDistribution()); + PipelinePolicy::template MakeMX_ALDS_TileDistribution()); auto a_warp_window_pong = make_tile_window(a_lds_block_pong, make_tuple(number{}, number{}), {0, 0}, - PipelinePolicy::template MakeMXF4_ALDS_TileDistribution()); - - // Block GEMM - auto block_flatmm = BlockFlatmm(); - // Acc register tile - auto c_block_tile = block_flatmm.MakeCBlockTile(); + PipelinePolicy::template MakeMX_ALDS_TileDistribution()); // B flat DRAM window for load // pingpong buffer for B - auto b_flat_dram_windows = generate_tuple( + auto b_flat_dram_window = + make_tile_window(b_flat_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_flat_dram_block_window_tmp.get_window_origin(), + PipelinePolicy::template MakeMX_BFlatDramTileDistribution()); + auto b_flat_dram_offsets = generate_tuple( [&](auto nIter) { constexpr auto packed_n_idx = nIter / number{}; constexpr auto packed_n_rank = nIter % number{}; - auto window_i = make_tile_window( - b_flat_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_flat_dram_block_window_tmp.get_window_origin(), - PipelinePolicy::template MakeMXFP4_BFlatDramTileDistribution()); - move_tile_window( - window_i, - {number{}, - number<0>{}}); - return window_i; + return b_flat_dram_window.get_load_offset( + tuple, + number<0>>{}) + + b_flat_dram_window.get_load_offset( + tuple, number<0>>{}); }, number{}); statically_indexed_array< - statically_indexed_array, + statically_indexed_array, NIterPerWarp> b_warp_tensor_ping, b_warp_tensor_pong; @@ -576,41 +588,37 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}, number<64 / WG::kM>{}), scale_a_window.get_window_origin(), - PipelinePolicy::template MakeMXFP4_ScaleA_FlatDramTileDistribution()); + PipelinePolicy::template MakeMX_ScaleA_FlatDramTileDistribution()); + const auto scale_a_dram_step_m = amd_wave_read_first_lane( + scale_a_dram_window.get_load_offset(tuple, number<0>>{})); + const auto scale_a_dram_step_k = amd_wave_read_first_lane( + scale_a_dram_window.get_load_offset(tuple, number<64 / WG::kM>>{})); auto scale_b_dram_window = make_tile_window( scale_b_window.get_bottom_tensor_view(), make_tuple(number{}, number<64 / WG::kN>{}), scale_b_window.get_window_origin(), - PipelinePolicy::template MakeMXFP4_ScaleB_DramTileDistribution()); + PipelinePolicy::template MakeMX_ScaleB_DramTileDistribution()); + const auto scale_b_dram_step_n = amd_wave_read_first_lane( + scale_b_dram_window.get_load_offset(tuple, number<0>>{})); + const auto scale_b_dram_step_k = amd_wave_read_first_lane( + scale_b_dram_window.get_load_offset(tuple, number<64 / WG::kN>>{})); + + constexpr index_t MPackIterPerWarp = MIterPerWarp / MXdlPack; + constexpr index_t NPackIterPerWarp = NIterPerWarp / NXdlPack; + constexpr index_t KPackIterPerWarp = KIterPerWarp / KXdlPack; // ping pong buffer for scale A statically_indexed_array< - statically_indexed_array, - MIterPerWarp / MXdlPack> - scale_a_dram_windows; - statically_indexed_array, - MIterPerWarp / MXdlPack> - scale_a_tile_tensor_ping; - statically_indexed_array, - MIterPerWarp / MXdlPack> - scale_a_tile_tensor_pong; + statically_indexed_array, + MPackIterPerWarp> + scale_a_tile_tensor_ping, scale_a_tile_tensor_pong; // ping pong buffer for scale B statically_indexed_array< - statically_indexed_array, - NIterPerWarp / NXdlPack> - scale_b_dram_windows; - statically_indexed_array, - NIterPerWarp / NXdlPack> - scale_b_tile_tensor_ping; - statically_indexed_array, - NIterPerWarp / NXdlPack> - scale_b_tile_tensor_pong; + statically_indexed_array, + NPackIterPerWarp> + scale_b_tile_tensor_ping, scale_b_tile_tensor_pong; auto async_load_tile_ = [](auto lds, auto dram) { async_load_tile(lds, dram, number<-1>{}, true_type{}, false_type{}); @@ -625,35 +633,31 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto nIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset( - b_flat_dram_windows(nIter), number{}); + b_flat_dram_window, b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter); }); // move B window to next flat K - move_tile_window(b_flat_dram_windows(nIter), {0, KIterPerWarp * KFlatPerBlockPerIter}); + b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset( + tuple, number>{}); }); // prefetch Scale A - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window; - move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack), - {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)}); + static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) { + static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { + scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) = load_tile_with_offset( + scale_a_dram_window, - scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) = - load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack)); + mIter_pack * scale_a_dram_step_m + kIter_pack * scale_a_dram_step_k); }); }); // move Scale A window to next K move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)}); // prefetch Scale B - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window; - move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack), - {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)}); - - scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) = - load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack)); + static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) { + static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { + scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) = load_tile_with_offset( + scale_b_dram_window, + nIter_pack * scale_b_dram_step_n + kIter_pack * scale_b_dram_step_k); }); }); // move Scale B window to next K @@ -667,7 +671,12 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1, MIterPerWarp> + c_warp_tensors; + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, NIterPerWarp, 1>{}( + [&](auto nIter) { clear_tile(c_warp_tensors(mIter)(nIter)); }); + }); statically_indexed_array a_warp_tensor; @@ -688,40 +697,37 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto kIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset( - b_flat_dram_windows(nIter), number{}); + b_flat_dram_window, + b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter); + + // move B window to next flat K if constexpr(kIter == KIterPerWarp - 1) - move_tile_window(b_flat_dram_windows(nIter), - {0, BlockGemmShape::flatKPerBlock}); + b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset( + tuple, number>{}); }); }); // prefetch Scale A and Scale B (2i+1) - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window; - move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack), - {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)}); - - scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) = - load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack)); + static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { + static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) { + scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) = load_tile_with_offset( + scale_a_dram_window, + mIter_pack * scale_a_dram_step_m + kIter_pack * scale_a_dram_step_k); }); }); - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window; - move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack), - {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)}); - - scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) = - load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack)); + static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { + static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) { + scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) = load_tile_with_offset( + scale_b_dram_window, + nIter_pack * scale_b_dram_step_n + kIter_pack * scale_b_dram_step_k); }); }); // GEMM 2i - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { + static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { + static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) { + static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) { static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { static_for<0, MXdlPack, 1>{}([&](auto imxdl) { constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; @@ -729,39 +735,22 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto inxdl) { constexpr auto n_iter = nIter_pack * NXdlPack + inxdl; - - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = - c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM WG{}.template operator()( - c_warp_tensor, + c_warp_tensors(number{})(number{}), a_warp_tensor(number{}), - b_warp_tensor_ping(nIter_pack * number{} + inxdl)( - kIter_pack * number{} + ikxdl), + b_warp_tensor_ping(number{})(number{}), scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) .get_thread_buffer()[0], scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) .get_thread_buffer()[0]); - - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); }); // preload next A from lds constexpr auto addr = m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload; if constexpr(addr < (KIterPerWarp * MIterPerWarp) && - (nIter_pack == NIterPerWarp / NXdlPack - 1)) + (nIter_pack == NPackIterPerWarp - 1)) { constexpr auto AmIter = addr % 2 + addr / 4 * 2; constexpr auto AkIter = addr / 2 % 2; @@ -802,81 +791,60 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto kIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset( - b_flat_dram_windows(nIter), number{}); + b_flat_dram_window, + b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter); + + // move B window to next flat K if constexpr(kIter == KIterPerWarp - 1) - move_tile_window(b_flat_dram_windows(nIter), - {0, BlockGemmShape::flatKPerBlock}); + b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset( + tuple, number>{}); }); }); // prefetch Scale A and Scale B (2i+2) - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window; - move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack), - {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)}); - - scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) = - load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack)); + static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { + static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) { + scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) = load_tile_with_offset( + scale_a_dram_window, + mIter_pack * scale_a_dram_step_m + kIter_pack * scale_a_dram_step_k); }); }); - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window; - move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack), - {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)}); - - scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) = - load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack)); + static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { + static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) { + scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) = load_tile_with_offset( + scale_b_dram_window, + nIter_pack * scale_b_dram_step_n + kIter_pack * scale_b_dram_step_k); }); }); // GEMM 2i+1 - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { + static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { + static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) { + static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) { static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { static_for<0, MXdlPack, 1>{}([&](auto imxdl) { constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; + constexpr auto m_iter = mIter_pack * MXdlPack + imxdl; + constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl; static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = - c_block_tile.get_y_sliced_thread_data( - merge_sequences( - sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - + constexpr auto n_iter = nIter_pack * NXdlPack + inxdl; // warp GEMM WG{}.template operator()( - c_warp_tensor, + c_warp_tensors(number{})(number{}), a_warp_tensor(number{}), - b_warp_tensor_pong(nIter_pack * number{} + inxdl)( - kIter_pack * number{} + ikxdl), + b_warp_tensor_pong(number{})(number{}), scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) .get_thread_buffer()[0], // scale A scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) .get_thread_buffer()[0]); // scale B - - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); }); // preload next A from lds - constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 + - (kIter_pack * KXdlPack + ikxdl) * 2 + - (mIter_pack * MXdlPack + imxdl) / 2 * 4 + - m_preload; + constexpr auto addr = + m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload; if constexpr(addr < (KIterPerWarp * MIterPerWarp) && - (nIter_pack == NIterPerWarp / NXdlPack - 1)) + (nIter_pack == NPackIterPerWarp - 1)) { constexpr auto AmIter = addr % 2 + addr / 4 * 2; constexpr auto AkIter = addr / 2 % 2; @@ -928,78 +896,54 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto kIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset( - b_flat_dram_windows(nIter), - make_tuple(number<0>{}, number{})); + b_flat_dram_window, + b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter); }); }); // prefetch Scale A and Scale B (2i+1) - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window; - move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack), - {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)}); - - scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) = - load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack)); + static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) { + static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { + scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) = load_tile_with_offset( + scale_a_dram_window, + mIter_pack * scale_a_dram_step_m + kIter_pack * scale_a_dram_step_k); }); }); - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window; - move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack), - {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)}); - - scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) = - load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack)); + static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) { + static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { + scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) = load_tile_with_offset( + scale_b_dram_window, + nIter_pack * scale_b_dram_step_n + kIter_pack * scale_b_dram_step_k); }); }); // GEMM loopK-1 - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { + static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { + static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) { + static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) { static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { static_for<0, MXdlPack, 1>{}([&](auto imxdl) { constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; + constexpr auto m_iter = mIter_pack * MXdlPack + imxdl; + constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl; static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = - c_block_tile.get_y_sliced_thread_data( - merge_sequences( - sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - + constexpr auto n_iter = nIter_pack * NXdlPack + inxdl; // warp GEMM WG{}.template operator()( - c_warp_tensor, + c_warp_tensors(number{})(number{}), a_warp_tensor(number{}), - b_warp_tensor_ping(nIter_pack * number{} + inxdl)( - kIter_pack * number{} + ikxdl), + b_warp_tensor_ping(number{})(number{}), scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) .get_thread_buffer()[0], // scale A scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) .get_thread_buffer()[0]); // scale B - - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); }); // preload next A from lds - constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 + - (kIter_pack * KXdlPack + ikxdl) * 2 + - (mIter_pack * MXdlPack + imxdl) / 2 * 4 + - m_preload; + constexpr auto addr = + m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload; if constexpr(addr < (KIterPerWarp * MIterPerWarp) && - (nIter_pack == NIterPerWarp / NXdlPack - 1)) + (nIter_pack == NPackIterPerWarp - 1)) { constexpr auto AmIter = addr % 2 + addr / 4 * 2; constexpr auto AkIter = addr / 2 % 2; @@ -1028,50 +972,32 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto kIter_pack) { - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { + static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { + static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) { + static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) { static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { static_for<0, MXdlPack, 1>{}([&](auto imxdl) { constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; + constexpr auto m_iter = mIter_pack * MXdlPack + imxdl; + constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl; static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = - c_block_tile.get_y_sliced_thread_data( - merge_sequences( - sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - + constexpr auto n_iter = nIter_pack * NXdlPack + inxdl; // warp GEMM WG{}.template operator()( - c_warp_tensor, + c_warp_tensors(number{})(number{}), a_warp_tensor(number{}), - b_warp_tensor_pong(nIter_pack * number{} + inxdl)( - kIter_pack * number{} + ikxdl), + b_warp_tensor_pong(number{})(number{}), scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) .get_thread_buffer()[0], // scale A scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) .get_thread_buffer()[0]); // scale B - - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); }); // preload next A from lds - constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 + - (kIter_pack * KXdlPack + ikxdl) * 2 + - (mIter_pack * MXdlPack + imxdl) / 2 * 4 + - m_preload; + constexpr auto addr = + m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload; if constexpr(addr < (KIterPerWarp * MIterPerWarp) && - (nIter_pack == NIterPerWarp / NXdlPack - 1)) + (nIter_pack == NPackIterPerWarp - 1)) { constexpr auto AmIter = addr % 2 + addr / 4 * 2; constexpr auto AkIter = addr / 2 % 2; @@ -1089,50 +1015,32 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto kIter_pack) { - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { + static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { + static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) { + static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) { static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { static_for<0, MXdlPack, 1>{}([&](auto imxdl) { constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; + constexpr auto m_iter = mIter_pack * MXdlPack + imxdl; + constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl; static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = - c_block_tile.get_y_sliced_thread_data( - merge_sequences( - sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - + constexpr auto n_iter = nIter_pack * NXdlPack + inxdl; // warp GEMM WG{}.template operator()( - c_warp_tensor, + c_warp_tensors(number{})(number{}), a_warp_tensor(number{}), - b_warp_tensor_ping(nIter_pack * number{} + inxdl)( - kIter_pack * number{} + ikxdl), + b_warp_tensor_ping(number{})(number{}), scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) .get_thread_buffer()[0], // scale A scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) .get_thread_buffer()[0]); // scale B - - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); }); // preload next A from lds - constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 + - (kIter_pack * KXdlPack + ikxdl) * 2 + - (mIter_pack * MXdlPack + imxdl) / 2 * 4 + - m_preload; + constexpr auto addr = + m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload; if constexpr(addr < (KIterPerWarp * MIterPerWarp) && - (nIter_pack == NIterPerWarp / NXdlPack - 1)) + (nIter_pack == NPackIterPerWarp - 1)) { constexpr auto AmIter = addr % 2 + addr / 4 * 2; constexpr auto AkIter = addr / 2 % 2; @@ -1151,7 +1059,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}; static constexpr auto I1 = number<1>{}; @@ -58,7 +58,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy template CK_TILE_DEVICE static constexpr auto - MakeMXFP4_AAsyncLoadDramDescriptor(const TensorView& naive_view) + MakeMX_AAsyncLoadDramDescriptor(const TensorView& naive_view) { using ADataType = remove_cvref_t; using ALayout = remove_cvref_t; @@ -107,7 +107,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy } template - CK_TILE_DEVICE static constexpr auto MakeMXFP4_ADramTileDistribution() + CK_TILE_DEVICE static constexpr auto MakeMX_ADramTileDistribution() { using ADataType = remove_cvref_t; @@ -140,7 +140,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy } template - CK_TILE_DEVICE static constexpr auto MakeMXFP4_ALdsBlockDescriptor() + CK_TILE_DEVICE static constexpr auto MakeMX_ALdsBlockDescriptor() { using ADataType = remove_cvref_t; using ALayout = remove_cvref_t; @@ -218,7 +218,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeMXF4_ALDS_TileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ALDS_TileDistribution() { using TileShape = typename Problem::BlockGemmShape; @@ -255,7 +255,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_BFlatDramTileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatDramTileDistribution() { using TileShape = typename Problem::BlockGemmShape; @@ -298,7 +298,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleA_DramTileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_DramTileDistribution() { using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape @@ -335,7 +335,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleB_DramTileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleB_DramTileDistribution() { using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape @@ -372,7 +372,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleA_FlatDramTileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_FlatDramTileDistribution() { using TileShape = typename Problem::BlockGemmShape; @@ -394,7 +394,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleB_FlatDramTileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleB_FlatDramTileDistribution() { using TileShape = typename Problem::BlockGemmShape; @@ -420,8 +420,8 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy { using ADataType = remove_cvref_t; constexpr index_t APackedSize = numeric_traits::PackedSize; - return sizeof(ADataType) * - MakeMXFP4_ALdsBlockDescriptor().get_element_space_size() / APackedSize; + return sizeof(ADataType) * MakeMX_ALdsBlockDescriptor().get_element_space_size() / + APackedSize; } template diff --git a/include/ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp b/include/ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp index 0e98078d53..75aff55043 100644 --- a/include/ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/block/block_attention_bias_enum.hpp b/include/ck_tile/ops/fmha/block/block_attention_bias_enum.hpp index e5be21e048..067e531fc0 100644 --- a/include/ck_tile/ops/fmha/block/block_attention_bias_enum.hpp +++ b/include/ck_tile/ops/fmha/block/block_attention_bias_enum.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp b/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp index 4d80443f35..3755a2bc71 100644 --- a/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp +++ b/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/block/block_dropout.hpp b/include/ck_tile/ops/fmha/block/block_dropout.hpp index 1512c6ae34..6e01ea5dda 100644 --- a/include/ck_tile/ops/fmha/block/block_dropout.hpp +++ b/include/ck_tile/ops/fmha/block/block_dropout.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/block/block_masking.hpp b/include/ck_tile/ops/fmha/block/block_masking.hpp index 2c45945fac..1a79aebef5 100644 --- a/include/ck_tile/ops/fmha/block/block_masking.hpp +++ b/include/ck_tile/ops/fmha/block/block_masking.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/block/block_position_encoding.hpp b/include/ck_tile/ops/fmha/block/block_position_encoding.hpp index 703ec0967a..dbe188611e 100644 --- a/include/ck_tile/ops/fmha/block/block_position_encoding.hpp +++ b/include/ck_tile/ops/fmha/block/block_position_encoding.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/block/block_rotary_embedding.hpp b/include/ck_tile/ops/fmha/block/block_rotary_embedding.hpp index 5173279299..05aa582628 100644 --- a/include/ck_tile/ops/fmha/block/block_rotary_embedding.hpp +++ b/include/ck_tile/ops/fmha/block/block_rotary_embedding.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/block/page_block_navigator.hpp b/include/ck_tile/ops/fmha/block/page_block_navigator.hpp index f1e6101d1d..204ed26f02 100644 --- a/include/ck_tile/ops/fmha/block/page_block_navigator.hpp +++ b/include/ck_tile/ops/fmha/block/page_block_navigator.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/block/variants.hpp b/include/ck_tile/ops/fmha/block/variants.hpp index d8b0cdbb86..29d9cf2a8e 100644 --- a/include/ck_tile/ops/fmha/block/variants.hpp +++ b/include/ck_tile/ops/fmha/block/variants.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index c6fbd6945f..e63ad8252b 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index b5bd4c74ef..5b491465b3 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp index 02296513d8..8c346f69e3 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp index 97c9b960c2..074296f294 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 93b415b4ce..38830ee6fe 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp index a2e6f08361..677ead91ad 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp index 99a301f620..1ce707996b 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index a6e44c7293..19592e8bf4 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp index e9115b14df..df17bdd879 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 59fa9139bf..27776453f6 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp index 02731ca8f8..33e6ad006a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp index 3da1104169..e6d7c622f7 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp index c38779d1d2..f01d681002 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp index ea024a0257..854e45c432 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp index 6393f227a2..95c9a7ad19 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp index abe024ced1..254c461d6d 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp index 5cdb4fe1d7..9aeabaa8c2 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp index 3d5bfcc76a..3d21928ced 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index d9e19a0c7e..e04e08258e 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp index 38aff07093..a67d727077 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp index 30c2c26416..457c4bc488 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp index 5c5dbb3a96..37aadc63bc 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp index cf3f7466e7..bd53d41f34 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp index 7a8e9a1d47..693f81d08a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp index 13ef642b1b..6217f8475a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp index f2b524fa3d..1385ffe104 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp index 0747f4e6e1..b9c8337e2c 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp index 4d1c38e079..0b30077a29 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp index 9d8f6bc99f..c5af751cd5 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index fe5e0bc345..6be6a64b1c 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp index ccc4f23817..b92549ccc4 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp index 5e2a4e898b..8bf24be386 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp index ac04f54adf..ce097b6741 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline_default_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp index 45a1c8f4b8..da0fa16ee1 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 864d155750..b90b760a0d 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 9ec82617b1..9e1eb3bdec 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index da48802c76..e07516cc27 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp index e92ba58b37..1a1c80c2dd 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp index dce9583fc1..5d224a6adf 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp index 575f9f106e..cf3657a88d 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp index e905037398..22d515cd6f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp index 74e91aac56..85526c9e24 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp index 1283782d06..8114bb96c4 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp index 0b9986e083..3f015a1c1a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp index 8309c1ec2a..d2d8bb2c7e 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp index 7505dbb172..5eb5b48abb 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index 692a1cfa13..4acd5d7250 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp index ca82519e72..ee5238869f 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index 1cc41deeb7..b9e18de1e5 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp index c69c15a2b0..2aca7527be 100644 --- a/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp index 92f6a48648..802ae2f607 100644 --- a/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp index 381edb650d..565bb873c7 100644 --- a/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index f6189c7495..3445f063f5 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp index ea218b9c25..a6c7fd4f8d 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp index 17c38a2632..46a7acbb12 100644 --- a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp +++ b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp index dbd6913cdb..c04e45782f 100644 --- a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp +++ b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp index 3f0dbfb340..1a0abe83e4 100644 --- a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp +++ b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp index 6089c2558f..55b8a96611 100644 --- a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp index 3fb82bc099..82664458d6 100644 --- a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp +++ b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp b/include/ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp index bbd47352d4..f70f4ddacc 100644 --- a/include/ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp +++ b/include/ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp b/include/ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp index f5218a93e2..a7e7c2848d 100644 --- a/include/ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp +++ b/include/ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp index d50179c1a1..d9b75649a3 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp index 0a17b05353..dcd6e76ec8 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp index 4652e5f20f..35b6025594 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp index 9d494c2831..3d585beae4 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp index b849c48daf..44709d67c9 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp index 8313693d3a..960a685792 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2_custom_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2_custom_policy.hpp index c2cfbc083b..a938bbb239 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2_custom_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2_custom_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp index b99466b1ea..3302d149ca 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp index 98e5538c0a..14d59ff373 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp index f998c67c95..094a6eebbb 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp index 9b10d435b6..2280f6f875 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp index 0181c0eec8..0aa7509b1e 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp index 20dcf2c270..398a835c15 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp index e90500c28c..b8290c95d8 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2r1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2r1.hpp index b8708a91fb..2ba01d91c5 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2r1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2r1.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp index d28aa9e787..b1223f8755 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp index 5a17578f69..afbe5efa93 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp index cd16f09c37..29022e764f 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp index d6fee879b1..6eedfabaf8 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp index 2436457ec1..91ace17499 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp index f336fc7470..0af69ff1a5 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_gemm_problem.hpp b/include/ck_tile/ops/gemm/block/block_gemm_problem.hpp index fd5211a59a..7ef6407870 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_problem.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_problem.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index 75a424e31e..8541ffa3a9 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp index 9036d48b08..49c26fab6c 100644 --- a/include/ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp b/include/ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp index 55a2fbc34c..eefad71640 100644 --- a/include/ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp index eb7e3bcf94..8adbfb9723 100644 --- a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index d632b1596c..d113336a3e 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp index 7b73b89ede..9fc8ef83c3 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp index b0b2905cb4..6360e868e5 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp index 998431f165..ac7a2966aa 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. /** * @file diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index a72b1ba544..63993c5eb6 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp index 9dfed16bc9..91f1358321 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp @@ -1,5 +1,5 @@ -// Copyright © Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT #pragma once @@ -28,8 +28,7 @@ struct StreamKHostArgs : public ck_tile::UniversalGemmHostArgs<> index_t K_, index_t stride_A_, index_t stride_B_, - index_t stride_C_, - StreamKReductionStrategy reduction_strategy_) + index_t stride_C_) : UniversalGemmHostArgs<>({a_ptr_}, {b_ptr_}, {/*ds_ptr*/}, @@ -41,12 +40,9 @@ struct StreamKHostArgs : public ck_tile::UniversalGemmHostArgs<> {stride_A_}, {stride_B_}, {/*stride_Ds_*/}, - stride_C_), - reduction_strategy{reduction_strategy_} + stride_C_) { } - - ck_tile::StreamKReductionStrategy reduction_strategy; }; /** @@ -133,7 +129,6 @@ struct StreamKKernel host_args.stride_Ds, host_args.stride_E, host_args.k_batch}, - reduction_strategy{host_args.reduction_strategy}, // The workspace pointer is set to nullptr because we must first // instantiate the TilePartitioner to get the necessary size workspace_ptr{nullptr}, @@ -141,10 +136,6 @@ struct StreamKKernel { } - /** - * @brief The strategy used by work groups to compute final results in C tensor. - */ - StreamKReductionStrategy reduction_strategy; /** * @brief A pointer to a buffer in device memory for accumulating partial via reduction * strategy. diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp index 8ee1ebc51a..9ab75fbdbf 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp @@ -1,5 +1,5 @@ -// Copyright © Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT #pragma once diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp index 626f440119..acc1860f1f 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp @@ -1,5 +1,6 @@ -// Copyright © Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + #pragma once #include "streamk_gemm_tile_partitioner.hpp" namespace ck_tile { diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index 2aac894a46..4b28ac3f12 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index a05e07bbc4..f39d41a653 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -26,18 +26,32 @@ struct GemmPipelineAgBgCrImplBase static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t KPerBlock = BlockGemmShape::kK; #if defined(__gfx950__) - // The combination of pk_int4_t and transposed loading causes numerical errors. + // The combination of pk_int4_t and transposed loading causes compilation errors. // Therefore do not use transposed loading in this case. + // Also, transpose load (ds_read_tr) requires specific tile distribution patterns + // that only work for certain K warp tile sizes based on data type size: + // - For 1-byte types (fp8/bf8): K warp tile <= 64 + // - For 2-byte types (fp16/bf16): K warp tile <= 32 static constexpr bool is_a_load_tr = []() { + using WarpTile = typename BlockGemmShape::WarpTile; + constexpr index_t kKWarpTile = WarpTile::at(number<2>{}); + constexpr index_t kMaxKWarpTile = (sizeof(ADataType) == 1) ? 64 : 32; if constexpr(std::is_same_v) return false; + else if constexpr(kKWarpTile > kMaxKWarpTile) + return false; else return std::is_same_v; }(); static constexpr bool is_b_load_tr = []() { + using WarpTile = typename BlockGemmShape::WarpTile; + constexpr index_t kKWarpTile = WarpTile::at(number<2>{}); + constexpr index_t kMaxKWarpTile = (sizeof(BDataType) == 1) ? 64 : 32; if constexpr(std::is_same_v) return false; + else if constexpr(kKWarpTile > kMaxKWarpTile) + return false; else return std::is_same_v; }(); @@ -93,19 +107,21 @@ struct GemmPipelineAgBgCrImplBase load_tile(dst_block_tile, lds_tile_window); } + template CK_TILE_DEVICE auto GetABLdsTensorViews(void* p_smem) const { // A tile in LDS - ADataType* __restrict__ p_a_lds = static_cast(p_smem); - constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor(); + OverrideADataType* __restrict__ p_a_lds = static_cast(p_smem); + constexpr auto a_lds_block_desc = + Policy::template MakeALdsBlockDescriptor(); auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); // TODO: LDS alignment should come from Policy! constexpr index_t a_lds_block_space_size_aligned = integer_least_multiple( - sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16); + sizeof(OverrideADataType) * a_lds_block_desc.get_element_space_size(), 16); // B tile in LDS - BDataType* __restrict__ p_b_lds = static_cast( + OverrideBDataType* __restrict__ p_b_lds = static_cast( static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp index b293097d89..d27f937435 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp @@ -1,5 +1,6 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + #pragma once #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp index b55835ab46..ffe889af41 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -18,7 +18,8 @@ struct GemmPipelineAgBgCrCompAsyncDefaultPolicy static constexpr auto ATileAccessPattern = tile_distribution_pattern::warp_raked; static constexpr auto BTileAccessPattern = tile_distribution_pattern::warp_raked; - template + template > CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() { constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index a1bbcbe990..f83462391c 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -30,7 +30,7 @@ struct BaseGemmPipelineAgBgCrCompV3 { if(BlockHasHotloop(num_loop)) { - return TailNumber::Full; + return TailNumber::Odd; } else { @@ -52,23 +52,27 @@ struct BaseGemmPipelineAgBgCrCompV3 // Handle all the valid cases. if(has_hot_loop) { - if(tail_number == TailNumber::Full) + if(tail_number == ck_tile::TailNumber::Odd) { - return run_func(bool_constant{}, - integral_constant{}); + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } } else { - if(tail_number == TailNumber::Odd) + + if(tail_number == ck_tile::TailNumber::Odd) { - return run_func(bool_constant{}, - integral_constant{}); + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } - else if(tail_number == TailNumber::Even) + else if(tail_number == ck_tile::TailNumber::Even) { - return run_func(bool_constant{}, - integral_constant{}); + return run_func( + ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } } #if defined(__HIP_DEVICE_COMPILE__) @@ -76,16 +80,8 @@ struct BaseGemmPipelineAgBgCrCompV3 __builtin_unreachable(); #else // If execution reaches here, it's an invalid combination of arguments. - if(has_hot_loop) - { - throw std::logic_error("Invalid TailNumber: If has_hot_loop is true, tail_number must " - "be TailNumber::Full."); - } - else - { - throw std::logic_error("Invalid TailNumber: If has_hot_loop is false, tail_number must " - "be TailNumber::Odd or TailNumber::Even."); - } + throw std::logic_error("Invalid TailNumber value: must be " + "TailNumber::Odd or TailNumber::Even"); #endif } }; @@ -588,7 +584,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 } while(i < (num_loop - 1)); } // tail - if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd)) + if constexpr(TailNum == TailNumber::Odd) { // Leak last MFMA block to epilogue region, cover the potential lds-shuffle // latency diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp index 238b4e2389..d448cdbb93 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp @@ -1,5 +1,6 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + #pragma once #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" @@ -22,7 +23,8 @@ struct BaseGemmPipelineAgBgCrCompV4 CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop) { - return num_loop > PrefetchStages; + constexpr index_t HotLoopGlobalReads = 2; + return num_loop >= (HotLoopGlobalReads + PrefetchStages); } CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp index 3164b41cc7..777537a83a 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp index 6343ff9872..1d6ac207eb 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp @@ -1,4 +1,4 @@ -// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #include "ck_tile/core.hpp" diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp index 7065e55e6d..e8eb53a601 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6.hpp index 5b57560f6e..0b846d3116 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6.hpp @@ -1,5 +1,6 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + #pragma once #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6_default_policy.hpp index 6ac702d38b..a1daf0f0f5 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6_default_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp index ba71e3b6cb..b7e5642bd1 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp index b18bf603a9..957cf7ab8f 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index 8a4fb59b51..16ed8de22f 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp old mode 100755 new mode 100644 index 712a6b6ac3..d7ce08a720 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp index 32217e0024..5dbcde80a6 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp index 7dad55d6b9..f2bd58e1af 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index ed2fd4a5cb..79fe02cb93 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipelines.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipelines.hpp index 9b948626f6..b8ba584ef8 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipelines.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipelines.hpp @@ -1,5 +1,5 @@ -// Copyright © Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT #pragma once diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index 6cca15c1d8..d843916f5e 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -1,9 +1,11 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" @@ -35,11 +37,22 @@ struct UniversalGemmBasePolicy #if defined(__gfx950__) // The combination of pk_int4_t and transposed loading causes numerical errors. // Therefore do not use transposed loading in this case. + // Also, transpose load (ds_read_tr) requires specific tile distribution patterns + // that only work for certain K warp tile sizes based on data type size: + // - For 1-byte types (fp8/bf8): K warp tile <= 64 + // - For 2-byte types (fp16/bf16): K warp tile <= 32 template static constexpr bool is_a_load_tr = []() { - using BDataType = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + constexpr index_t kKWarpTile = WarpTile::at(number<2>{}); + // Max K warp tile for transpose load based on data type size + constexpr index_t kMaxKWarpTile = (sizeof(ADataType) == 1) ? 64 : 32; if constexpr(std::is_same_v) return false; + else if constexpr(kKWarpTile > kMaxKWarpTile) + return false; else return std::is_same_v, tensor_layout::gemm::ColumnMajor>; @@ -47,9 +60,15 @@ struct UniversalGemmBasePolicy template static constexpr bool is_b_load_tr = []() { - using BDataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + constexpr index_t kKWarpTile = WarpTile::at(number<2>{}); + // Max K warp tile for transpose load based on data type size + constexpr index_t kMaxKWarpTile = (sizeof(BDataType) == 1) ? 64 : 32; if constexpr(std::is_same_v) return false; + else if constexpr(kKWarpTile > kMaxKWarpTile) + return false; else return std::is_same_v, tensor_layout::gemm::RowMajor>; @@ -85,13 +104,12 @@ struct UniversalGemmBasePolicy return DefaultBTileAccessPattern; } - template + template > CK_TILE_DEVICE static constexpr auto MakeALdsBlockDescriptor() { - using ALayout = remove_cvref_t; - using ADataType = remove_cvref_t; - - using ADataType = remove_cvref_t; + using ALayout = remove_cvref_t; + using ADataType = OverrideADataType; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPack = GetSmemPackA(); diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp index 25cd20ae27..8029f6a2c7 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp index 96203b2cd2..d76fd6dc0f 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp index f1c8f2ec9b..019a828ec0 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp index cae2bd0e9f..977cdbae5c 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 0d41461038..a2c320f3e6 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp index 896bb31b42..3c7944a427 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -115,7 +115,7 @@ struct WarpGemmAttributeMfma const BVecType& b_vec, const int32_t& b_scale) const { - auto c_vec = Impl{}.template operator()(a_vec, a_scale, b_vec, b_scale); + return Impl{}.template operator()(a_vec, a_scale, b_vec, b_scale); } }; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index 1ddf0c0cf8..bd65f53383 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp index 84cdf17d66..72cbf37206 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp index cd6cd3a399..d45abae887 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp index dd2931f6b7..ff2ba501fe 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp index 751ada07af..0464ffbce4 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_16bit_traits.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_16bit_traits.hpp index 7e834d9add..992f0a8783 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_16bit_traits.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_16bit_traits.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp index 81ff5af2fe..34c4dbe551 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_base_traits.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_base_traits.hpp index ed5f0eb0a6..524215ddfa 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_base_traits.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_base_traits.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once namespace ck_tile { diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index fe9a611b55..9d928a7cfa 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp index c38175d345..ca7c32b6af 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp index 9e028ddab0..3d64e148c4 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp index dcbcb95492..a80fb0f765 100644 --- a/include/ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp @@ -1,4 +1,4 @@ -// Copyright © Advanced Micro Devices, Inc. or its affiliates. +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #pragma once diff --git a/include/ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp b/include/ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp index d695888b88..8a708828bd 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp index 392dc46f72..b54a93614a 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index eb59f89a69..5100de58ac 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -373,8 +373,8 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase { // Need to multiply aquant with accumulated C // - // The accumulated C tile has the standard distribution. For example - // lane 0 holds elements [0,0], [1,0], [2,0], [3,0], [8,0], [9,0], + // The accumulated C tile has the standard distribution. For example, a + // 32x32 C lane 0 holds elements [0,0], [1,0], [2,0], [3,0], [8,0], [9,0], // [10,0], [11,0], [16,0], [17,0], [18,0], [19,0], [24,0], [25,0], // [26,0], [27,0]. // @@ -388,35 +388,31 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase // // These scales can be obtained using __builtin_amdgcn_ds_bpermute. - // MIters per warp - constexpr index_t mIters_per_warp = get_warp_size() / WarpGemm::kM; - // Reg block offset based on mIter - constexpr index_t reg_block_offset = - ((mIter / mIters_per_warp) * Traits::AQPerBlock); - - constexpr index_t lane_base_offset = - (mIter % mIters_per_warp) * WarpGemm::kM; - - // Scale tensor offset along K - constexpr index_t src_reg_offset = reg_block_offset + kQScale; - // Directly index into thread buffer corresponding to - // desired row coefficient + // Each thread stores AQPerBlock scale values per M iteration. + constexpr index_t reg_block_offset = mIter * Traits::AQPerBlock; + constexpr index_t src_reg_offset = reg_block_offset + kQScale; auto& scale_reg = aq_block_tensor.get_thread_buffer()[src_reg_offset]; - constexpr uint32_t kTileRows = (get_warp_size() == 64) ? 4 : 8; - ; - constexpr uint32_t kTiledCMsPerWarp = WarpGemm::kCMLane * kTileRows; - constexpr uint32_t reg_offset_for_row_data = c_row * WarpGemm::kCMLane; - // Multiply by 4 because output is stored in tiles of 4 - // x CNLane - constexpr uint32_t row_base = - ((reg_offset_for_row_data / kTiledCMsPerWarp) * kTiledCMsPerWarp) + - ((reg_offset_for_row_data % kTiledCMsPerWarp) / WarpGemm::kCMLane); + // Divide M dimension of C Warp tile into groups of + // (WarpGemm::kCMLane * WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane) + // m_base_offset_of_c_row indicates which group the current c_row belongs + // to. + constexpr index_t m_base_offset_of_c_row = + (c_row / WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane) * + (WarpGemm::kCMLane * WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane); + + // M offset of each thread within its group (see comment above) + index_t m_base_offset_of_lane = + (get_lane_id() / WarpGemm::kN * + WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane); + + // M offset wrt. c_row in the subgroup of kCM1PerLane + constexpr index_t m_offset_of_c_row = + c_row & (WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane - 1); - // Lane index to source scale from uint32_t src_lane_idx = - lane_base_offset + row_base + (__lane_id() / WarpGemm::kN * kTileRows); + m_base_offset_of_c_row + m_base_offset_of_lane + m_offset_of_c_row; return exchange_quant_value_across_lanes(scale_reg, src_lane_idx); } @@ -439,12 +435,22 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase ALdsTile a_warp_tile_; BLdsTile b_warp_tile_; - template + template CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, - const BSmemBlockWindow& b_block_window) + const BSmemBlockWindow& b_block_window, + bool_constant = {}, + bool_constant = {}) { - load_int4_tile(a_warp_tile_, a_block_window); - load_int4_tile(b_warp_tile_, b_block_window); + // while ADatatype might not be the same as BDataType at the time of problem + // initialization, we can safely use BDataType here because when A would be int4 we will + // ensure A is converted to BDataType prior to loading + load_int4_tile( + a_warp_tile_, a_block_window); + load_int4_tile( + b_warp_tile_, b_block_window); } // C += A * B @@ -526,11 +532,16 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase MakeCBlockTile(); } - template + template CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, - const BSmemBlockWindow& b_block_window) + const BSmemBlockWindow& b_block_window, + bool_constant a_load_tr = {}, + bool_constant b_load_tr = {}) { - block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window); + block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr); } // C += A * B diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index b7c0eb2198..d97145cbc3 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 245ce4fa89..f6cf4ce9be 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -414,7 +414,6 @@ struct QuantGemmKernel if constexpr(kQuantType == QuantType::AQuantGrouped) { - static_assert(std::is_same_v); if(kargs.QK_A % GemmPipeline::GetVectorSizeAQ() != 0) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) @@ -655,13 +654,24 @@ struct QuantGemmKernel } else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant) { - static_assert(std::is_same_v); - return make_naive_tensor_view( - aq_ptr, - make_tuple(kargs.M, kargs.QK_A), - make_tuple(kargs.stride_AQ, 1), - number{}, - number<1>{}); + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + aq_ptr, + make_tuple(kargs.M, kargs.QK_A), + make_tuple(kargs.stride_AQ, 1), + number{}, + number<1>{}); + } + else // Column major AQ + { + return make_naive_tensor_view( + aq_ptr, + make_tuple(kargs.QK_A, kargs.M), // Swapped dimensions + make_tuple(kargs.stride_AQ, 1), // Same stride pattern + number{}, + number<1>{}); + } } else if constexpr(kQuantType == QuantType::RowColQuant) { @@ -786,8 +796,8 @@ struct QuantGemmKernel using QuantGroupSize = remove_cvref_t; return make_naive_tensor_view( bq_ptr, - make_tuple(kargs.QK_B, integer_divide_ceil(kargs.N, QuantGroupSize::kN)), - make_tuple(1, kargs.stride_BQ), + make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B), + make_tuple(kargs.stride_BQ, 1), number{}, number<1>{}); } @@ -946,14 +956,21 @@ struct QuantGemmKernel } else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant) { - static_assert(std::is_same_v); - using QuantGroupSize = remove_cvref_t; - constexpr auto block_m = TilePartitioner::MPerBlock; - constexpr auto block_k = TilePartitioner::KPerBlock; - return make_tile_window( - aq_pad_view, - make_tuple(number{}, number{}), - {i_m, 0}); + using QuantGroupSize = remove_cvref_t; + constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; + constexpr auto block_m = TilePartitioner::MPerBlock; + if constexpr(std::is_same_v) + { + return make_tile_window(aq_pad_view, + make_tuple(number{}, number{}), + {i_m, 0}); + } + else // Column major AQ + { + return make_tile_window(aq_pad_view, + make_tuple(number{}, number{}), + {0, i_m}); + } } else if constexpr(kQuantType == QuantType::RowColQuant) { @@ -1030,9 +1047,9 @@ struct QuantGemmKernel using QuantGroupSize = remove_cvref_t; return make_tile_window( bq_pad_view, - make_tuple(number{}, - number{}), - {0, i_n / QuantGroupSize::kN}); + make_tuple(number{}, + number{}), + {i_n / QuantGroupSize::kN, 0}); } } else diff --git a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp index 32f1279e93..caa6aad363 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,6 +7,7 @@ #include "ck_tile/core/utility/literals.hpp" #include "ck_tile/core/utility/type_traits.hpp" #include "ck_tile/host/stream_utils.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp" diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp index 9bed22ba9f..e3ad883440 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -20,8 +20,6 @@ struct GemmAQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase; - using AQLayout = remove_cvref_t; - static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t KPerBlock = BlockGemmShape::kK; @@ -36,8 +34,6 @@ struct GemmAQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase); - auto aq_copy_dram_window = make_tile_window(aq_dram_block_window_tmp.get_bottom_tensor_view(), aq_dram_block_window_tmp.get_window_lengths(), diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp index ca8598a03f..f3c8b7a1a3 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -8,6 +8,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp" @@ -15,68 +16,9 @@ namespace ck_tile { -template -struct BaseAQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem -{ - CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) - { - if(num_loop % BaseGemmPipelineAgBgCrCompV3::PrefetchStages == 0) - { - return TailNumber::Even; - } - else - { - return TailNumber::Odd; - } - } - template - CK_TILE_HOST_DEVICE static auto - TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number) - { - if(has_hot_loop) - { - if(tail_number == ck_tile::TailNumber::Odd) - { - return run_func( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_number == ck_tile::TailNumber::Even) - { - return run_func( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - throw std::runtime_error("Unsupported tail number for this operation !!!"); - } - } - else - { - - if(tail_number == ck_tile::TailNumber::Odd) - { - return run_func( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_number == ck_tile::TailNumber::Even) - { - return run_func( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - throw std::runtime_error("Unsupported tail number for this operation !!!"); - } - } - } -}; - +// ToDo: Change the Pipeline to actual memory pipeline. template -struct AQuantGemmPipelineAgBgCrMem : public BaseAQuantGemmPipelineAgBgCrMem +struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem { using Base = BaseGemmPipelineAgBgCrMem; using PipelineImplBase = GemmAQuantPipelineAgBgCrImplBase; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp index 20ebd8f3e7..9681156e1a 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -18,13 +18,11 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC template CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeAQ() { - using AQLayout = remove_cvref_t; using AQDataType = remove_cvref_t; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlockAQ = KPerBlock / Problem::QuantGroupSize::kK; - static_assert(std::is_same_v); return GetABQGlobalVectorLoadSize(); } @@ -49,7 +47,6 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC WarpTile::at(I2), Problem::TransposeC>; - static_assert(std::is_same_v); if constexpr(PreshuffleQuant) { using TileEncodingPattern = tile_distribution_encoding_pattern_aq< @@ -68,6 +65,8 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC { if constexpr(Problem::TransposeC) { + static_assert(std::is_same_v, + "TransposeC currently only supports RowMajor layout"); using TileEncodingPatternTransposeC = tile_distribution_encoding_pattern_aq_transposed_c; + // !Problem::TransposeC + if constexpr(std::is_same_v) + { + using TileEncodingPattern = + tile_distribution_encoding_pattern_aq; - return TileEncodingPattern::make_2d_static_tile_distribution(); + return TileEncodingPattern::make_2d_static_tile_distribution(); + } + else + { + using TileEncodingPattern = + tile_distribution_encoding_pattern_aq; + return TileEncodingPattern::make_2d_static_tile_distribution_transposed(); + } } } } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp index 41052cb485..30b9d70eb8 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,74 +14,8 @@ namespace ck_tile { -// Compute optimized pipeline -// GlobalPrefetchStages: 2 -// LocalPreFillStages: 1 -// LocalPreFetchStages: 1 -// LocalSharedMemoryBuffer: 1 - -template -struct BaseAQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 -{ - template - CK_TILE_HOST_DEVICE static auto - TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number) - { - if(has_hot_loop) - { - if(tail_number == ck_tile::TailNumber::Full) - { - return run_func( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_number == ck_tile::TailNumber::Odd) - { - return run_func( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_number == ck_tile::TailNumber::Even) - { - return run_func( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - throw std::runtime_error("Unsupported tail number for this operation !!!"); - } - } - else - { - if(tail_number == ck_tile::TailNumber::Full) - { - return run_func( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_number == ck_tile::TailNumber::Odd) - { - return run_func( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_number == ck_tile::TailNumber::Even) - { - return run_func( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - throw std::runtime_error("Unsupported tail number for this operation !!!"); - } - } - } -}; - template -struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV3 +struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 { using Base = BaseGemmPipelineAgBgCrCompV3; using PipelineImplBase = GemmAQuantPipelineAgBgCrImplBase; @@ -143,6 +77,9 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV static constexpr auto TailNum = Problem::TailNum; static constexpr auto Scheduler = Problem::Scheduler; + static constexpr auto is_a_load_tr_v = bool_constant{}; + static constexpr auto is_b_load_tr_v = bool_constant{}; + using Base::PrefetchStages; [[nodiscard]] CK_TILE_HOST static const std::string GetName() @@ -227,6 +164,16 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV { using Base = PipelineImplBase; + template + CK_TILE_DEVICE static void LoadAndConvertATile(ABlockTile_& a_block_tile, + const ADramWindow& a_dram_window) + { + using DestDataType = typename ABlockTile_::DataType; + using SrcDataType = typename ADramWindow::Base::TileWindowBase::DataType; + constexpr index_t UnaryOpSize = 8; + load_int4_tile(a_block_tile, a_dram_window); + } + template > && std::is_same_v; constexpr bool is_b_row_major = std::is_same_v; - static_assert(!is_aq_col_major, "Aq must be row major (col major not supported yet)"); - static_assert(is_a_col_major ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) @@ -277,7 +223,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; using AQDramTileWindowStep = typename AQDramBlockWindowTmp::BottomTensorIndex; - auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem); + auto&& [a_lds_block, b_lds_block] = + Base::template GetABLdsTensorViews(p_smem); constexpr auto a_lds_load_tile_distr = make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); @@ -294,8 +241,11 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); using AQBlockTileDistr = decltype(aq_copy_dram_window.get_tile_distribution()); + // while ADatatype might not be the same as BDataType at the time of problem + // initialization, we can safely use BDataType here because when A would be int4 we will + // ensure A is converted to BDataType prior to loading using ABlockTile = - decltype(make_static_distributed_tensor(ABlockTileDistr{})); + decltype(make_static_distributed_tensor(ABlockTileDistr{})); using BBlockTile = decltype(make_static_distributed_tensor(BBlockTileDistr{})); using AQBlockTile = @@ -317,23 +267,25 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV // only row_major for AQ const AQDramTileWindowStep aq_dram_tile_window_step = - PreshuffleQuant ? make_array(ck_tile::integer_least_multiple(m, MPerBlock) / - BlockGemm::WarpGemm::kM, - 0) - : make_array(0, KPerBlockAQ); + PreshuffleQuant + ? make_array(ck_tile::integer_least_multiple(m, MPerBlock) / + BlockGemm::WarpGemm::kM, + 0) + : (is_aq_col_major ? make_array(KPerBlockAQ, 0) : make_array(0, KPerBlockAQ)); // DRAM prefetch (global read 0) - Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + LoadAndConvertATile(a_block_tile, a_copy_dram_window); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); Base::GlobalPrefetch( aq_block_tile[currIdx], aq_copy_dram_window, aq_dram_tile_window_step); tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { - auto a_shuffle_tmp = make_static_distributed_tensor( - Policy::template make_shuffled_2d_static_tile_distribution()); + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tile); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); } @@ -342,10 +294,10 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( - Policy::template make_shuffled_2d_static_tile_distribution()); + Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tile); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); } @@ -354,12 +306,14 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); } - Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + LoadAndConvertATile(a_block_tile, a_copy_dram_window); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); __builtin_amdgcn_sched_barrier(0); @@ -370,9 +324,9 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV { block_sync_lds(); - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { - auto a_shuffle_tmp = make_static_distributed_tensor( + auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tile); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); @@ -381,7 +335,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV { Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); @@ -393,7 +347,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); } - Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + LoadAndConvertATile(a_block_tile, a_copy_dram_window); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); Base::GlobalPrefetch(aq_block_tile[(currIdx + 1) % 2], aq_copy_dram_window, @@ -406,7 +361,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); __builtin_amdgcn_sched_barrier(0); i += 1; @@ -429,9 +385,9 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV currIdx = (currIdx + 1) % 2; - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { - auto a_shuffle_tmp = make_static_distributed_tensor( + auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tile); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); @@ -440,7 +396,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV { Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); @@ -452,7 +408,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); } block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); block_gemm( c_block_tile, aq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window); } @@ -471,7 +428,10 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV { return PipelineImpl{}.template operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + // Note: a_element_func takes BDataType (not ADataType) because A tiles are + // converted from ADataType (e.g., pk_int4_t) to BDataType (e.g., fp8) in + // LoadAndConvertATile before the element function is applied. + [](const BDataType& a) { return a; }, b_dram_block_window_tmp, [](const BDataType& b) { return b; }, aq_dram_block_window_tmp, diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp index bb6bd7fd1f..4cd343e640 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp index 4f792e9de8..870326cb9d 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -71,8 +71,8 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC tile_distribution_encoding_pattern_bq; return TileEncodingPattern::make_2d_static_tile_distribution(); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp index 584a15571c..4883a30f57 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -20,68 +20,8 @@ namespace ck_tile { // LocalPreFetchStages: 1 // LocalSharedMemoryBuffer: 1 -template -struct BaseBQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 -{ - template - CK_TILE_HOST_DEVICE static auto - TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number) - { - if(has_hot_loop) - { - if(tail_number == ck_tile::TailNumber::Full) - { - return run_func( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_number == ck_tile::TailNumber::Odd) - { - return run_func( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_number == ck_tile::TailNumber::Even) - { - return run_func( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - throw std::runtime_error("Unsupported tail number for this operation !!!"); - } - } - else - { - if(tail_number == ck_tile::TailNumber::Full) - { - return run_func( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_number == ck_tile::TailNumber::Odd) - { - return run_func( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_number == ck_tile::TailNumber::Even) - { - return run_func( - ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - throw std::runtime_error("Unsupported tail number for this operation !!!"); - } - } - } -}; - template -struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV3 +struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 { using Base = BaseGemmPipelineAgBgCrCompV3; using PipelineImplBase = GemmBQuantPipelineAgBgCrImplBase; @@ -318,8 +258,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV (PreshuffleQuant) ? make_array(ck_tile::integer_least_multiple(n, NPerBlock) / BlockGemmShape::WarpTile::at(number<1>{}), 0) - : is_bq_col_major ? make_array(KPerBlockBQ, 0) - : make_array(0, KPerBlockBQ); + : is_bq_col_major ? make_array(0, KPerBlockBQ) + : make_array(KPerBlockBQ, 0); // DRAM prefetch (global read 0) Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); @@ -332,7 +272,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV if constexpr(is_a_col_major) { auto a_shuffle_tmp = make_static_distributed_tensor( - Policy::template make_shuffled_2d_static_tile_distribution()); + Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tile); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); } @@ -344,7 +284,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV if constexpr(is_b_row_major) { auto b_shuffle_tmp = make_static_distributed_tensor( - Policy::template make_shuffled_2d_static_tile_distribution()); + Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tile); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp index 6cd8dc3e0f..b51dee752d 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -94,23 +94,43 @@ struct tile_distribution_encoding_pattern_aq : public tile_distribution_encoding // # of elements per thread constexpr index_t X = XPerTile; - constexpr index_t Y0 = 1; - constexpr index_t Y1 = MIterPerWarp ? MIterPerWarp : 1; - constexpr index_t Y2 = MWarps; - constexpr index_t Y3 = WarpGemm::kM; - static_assert(Y3 >= WarpGemm::kM, + constexpr index_t YR = 1; + constexpr index_t Y0 = MIterPerWarp ? MIterPerWarp : 1; + constexpr index_t Y1 = MWarps; + constexpr index_t Y2 = WarpGemm::kM; + static_assert(Y2 >= WarpGemm::kM, "Scales for all rows must be available within the warp."); - static_assert(Y0 * Y1 * Y2 * Y3 == YPerTile, - "Y0, Y1, Y2, Y3 must cover the blocktile along Y."); + static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover the blocktile along Y."); return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 1>>, - tuple, sequence<0, 3>>, + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0, 1>>, + tuple, sequence<1, 2>>, sequence<1, 2>, - sequence<1, 0>>{}); + sequence<0, 0>>{}); } } + CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution_transposed() + { + + constexpr index_t Y0 = YPerTile; + constexpr index_t X0 = 1; + constexpr index_t X1 = MIterPerWarp ? MIterPerWarp : 1; + constexpr index_t X2 = MWarps; + constexpr index_t X3 = WarpGemm::kM; + + static_assert(X3 >= WarpGemm::kM, "Scales for all rows must be available within the warp."); + static_assert(X0 * X1 * X2 * X3 == XPerTile, + "X0, X1, X2, X3 must cover the blocktile along X."); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 2>>, + tuple, sequence<0, 3>>, + sequence<2, 1>, + sequence<1, 0>>{}); + } }; template struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern { @@ -231,39 +251,39 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding } else { - if constexpr(XPerQ < WarpGemm::kN) + if constexpr(YPerQ < WarpGemm::kN) { // Case 1: Fine-grained - multiple quantization scales within a single warp - constexpr index_t Y = YPerTile; // Full Y dimension of tile - constexpr index_t YR = 1; // No Y replication needed - constexpr index_t X0 = NIterPerWarp; // Iterations per warp in N-dim - constexpr index_t X1 = NWarps; // Number of warps in N-dim - constexpr index_t X2 = WarpGemm::kN / XPerQ; // Number of scales per warp - constexpr index_t XR = XPerQ; // Elements per quantization group + constexpr index_t X = XPerTile; // Full X dimension of tile + constexpr index_t XR = 1; // No Y replication needed + constexpr index_t Y0 = NIterPerWarp; // Iterations per warp in N-dim + constexpr index_t Y1 = NWarps; // Number of warps in N-dim + constexpr index_t Y2 = WarpGemm::kN / YPerQ; // Number of scales per warp + constexpr index_t YR = YPerQ; // Elements per quantization group - static_assert(X0 * X1 * X2 == XPerTile, - "X0, X1, X2 must cover the blocktile along X."); + static_assert(Y0 * Y1 * Y2 == YPerTile, + "Y0, Y1, Y2 must cover the blocktile along Y."); return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<0, 2, 0>>, + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0, 1, 0>>, tuple, sequence<1, 2, 2>>, - sequence<2, 1>, + sequence<1, 2>, sequence<0, 0>>{}); } - else if constexpr(XPerQ <= WarpGemm::kN * NWarps) + else if constexpr(YPerQ <= WarpGemm::kN * NWarps) { // Case 2: Medium-grained - one quantization scale per warp - constexpr auto XR = XPerQ / WarpGemm::kN; // Scale replication factor - constexpr auto X1 = NWarps / XR; // Warps per unique scale - constexpr auto X0 = XPerTile / X1; // Iterations to cover X dimension + constexpr auto YR = YPerQ / WarpGemm::kN; // Scale replication factor + constexpr auto Y1 = NWarps / YR; // Warps per unique scale + constexpr auto Y0 = YPerTile / Y1; // Iterations to cover X dimension return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<0>>, + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0>>, tuple, sequence<2>>, - sequence<2, 1>, + sequence<1, 2>, sequence<0, 0>>{}); } else // XPerQ > WarpGemm::kN * NWarps diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp index 4637f7ba72..0005eab52f 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp index fd71add13c..28a06f8b3d 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp @@ -1,8 +1,9 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp" #include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp" diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp index b92c9ee1fd..59a5b0df4e 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -9,6 +9,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/host/concat.hpp" @@ -280,7 +281,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV } else { - move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0}); + move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ}); } // Prefill A0 auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); @@ -338,7 +339,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV } else { - move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0}); + move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ}); } // Prefill A(2i+1) @@ -390,7 +391,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV } else { - move_tile_window(bq_copy_dram_window, {KPerBlockBQ, 0}); + move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ}); } // Prefill A(2i+2) diff --git a/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp b/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp index 3a5b86382d..69a39f344b 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp index 309860810c..e172e732fa 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp index 7942d5e6e3..6ef1d84a6e 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -643,8 +643,6 @@ struct GroupedConvolutionBackwardWeightKernel CK_TILE_ERROR("ConvG must be a multiple of NumGroupsToMerge!"); return false; } - - // TODO: Should we also check that GemmM <= MPerBlock and GemmN <= NPerBlock? } return true; diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index 6d97f7b758..72ba17c5a5 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -28,7 +28,6 @@ namespace ck_tile { template struct GroupedConvFwdKernelArgs { - using ConvToGemmFwdTransformer = TransformConvFwdToGemm(); - group_stride_a = args.C_; - group_stride_b = args.K_ * args.C_ * + NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge; + group_stride_a = args.C_ * NumGroupsToMerge; + group_stride_b = args.K_ * args.C_ * NumGroupsToMerge * std::accumulate(args.filter_spatial_lengths_.begin(), args.filter_spatial_lengths_.end(), 1, std::multiplies()); - group_stride_c = args.K_; + group_stride_c = args.K_ * NumGroupsToMerge; // Initialize Split-N support fields for 1D convolution (NWGC layout) // Get the actual split N from transformer @@ -121,8 +120,20 @@ struct GroupedConvFwdKernelArgs input_batch_stride = args.G_ * args.C_ * args.input_spatial_lengths_[0]; output_batch_stride = args.G_ * args.K_ * args.output_spatial_lengths_[0]; - // Update GemmM to use split N (not original N) - GemmM = n_per_split * args.output_spatial_lengths_[0]; + GemmM = a_grid_desc_m_k.get_length(number<0>{}); + GemmN = b_grid_desc_n_k.get_length(number<0>{}); + GemmK = a_grid_desc_m_k.get_length(number<1>{}); + GemmBatch = integer_divide_ceil(args.G_, NumGroupsToMerge); + + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK + << ", GemmBatch: " << GemmBatch << ", N per split: " << n_per_split + << ", number of N splits: " << n_splits + << ", input_batch_stride: " << input_batch_stride + << ", output_batch_stride: " << output_batch_stride + << ", NumGroupsToMerge: " << NumGroupsToMerge << std::endl; + } } template < @@ -163,11 +174,6 @@ struct GroupedConvFwdKernelArgs k_batch = args.k_batch; - // Note: GemmM will be set after Split-N calculation - GemmN = args.K_; - GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1]; - GemmBatch = args.G_; - in_ptr = args.in_ptr; wei_ptr = args.wei_ptr; for(index_t d = 0; d < NumDTensor; d++) @@ -192,13 +198,14 @@ struct GroupedConvFwdKernelArgs c_grid_desc_m_n = transformer_.template MakeCDescriptor_M_N(); - group_stride_a = args.C_; - group_stride_b = args.K_ * args.C_ * + NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge; + group_stride_a = args.C_ * NumGroupsToMerge; + group_stride_b = args.K_ * args.C_ * NumGroupsToMerge * std::accumulate(args.filter_spatial_lengths_.begin(), args.filter_spatial_lengths_.end(), 1, std::multiplies()); - group_stride_c = args.K_; + group_stride_c = args.K_ * NumGroupsToMerge; // Initialize Split-N support fields for 2D convolution (NHWGC layout) // Get the actual split N from transformer @@ -213,8 +220,20 @@ struct GroupedConvFwdKernelArgs output_batch_stride = args.G_ * args.K_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1]; - // Update GemmM to use split N (not original N) - GemmM = n_per_split * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1]; + GemmM = a_grid_desc_m_k.get_length(number<0>{}); + GemmN = b_grid_desc_n_k.get_length(number<0>{}); + GemmK = a_grid_desc_m_k.get_length(number<1>{}); + GemmBatch = integer_divide_ceil(args.G_, NumGroupsToMerge); + + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK + << ", GemmBatch: " << GemmBatch << ", N per split: " << n_per_split + << ", number of N splits: " << n_splits + << ", input_batch_stride: " << input_batch_stride + << ", output_batch_stride: " << output_batch_stride + << ", NumGroupsToMerge: " << NumGroupsToMerge << std::endl; + } } template < @@ -262,12 +281,6 @@ struct GroupedConvFwdKernelArgs k_batch = args.k_batch; - // Note: GemmM will be set after Split-N calculation - GemmN = args.K_; - GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1] * - args.filter_spatial_lengths_[2]; - GemmBatch = args.G_; - in_ptr = args.in_ptr; wei_ptr = args.wei_ptr; for(index_t d = 0; d < NumDTensor; d++) @@ -292,13 +305,14 @@ struct GroupedConvFwdKernelArgs c_grid_desc_m_n = transformer_.template MakeCDescriptor_M_N(); - group_stride_a = args.C_; - group_stride_b = args.K_ * args.C_ * + NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge; + group_stride_a = args.C_ * NumGroupsToMerge; + group_stride_b = args.K_ * args.C_ * NumGroupsToMerge * std::accumulate(args.filter_spatial_lengths_.begin(), args.filter_spatial_lengths_.end(), 1, std::multiplies()); - group_stride_c = args.K_; + group_stride_c = args.K_ * NumGroupsToMerge; // Initialize Split-N support fields for 3D convolution (NDHWGC layout) // Get the actual split N from transformer @@ -313,11 +327,21 @@ struct GroupedConvFwdKernelArgs output_batch_stride = args.G_ * args.K_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1] * args.output_spatial_lengths_[2]; - // Update GemmM to use split N (not original N) - GemmM = n_per_split * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1] * - args.output_spatial_lengths_[2]; - } + GemmM = a_grid_desc_m_k.get_length(number<0>{}); + GemmN = b_grid_desc_n_k.get_length(number<0>{}); + GemmK = a_grid_desc_m_k.get_length(number<1>{}); + GemmBatch = integer_divide_ceil(args.G_, NumGroupsToMerge); + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK + << ", GemmBatch: " << GemmBatch << ", N per split: " << n_per_split + << ", number of N splits: " << n_splits + << ", input_batch_stride: " << input_batch_stride + << ", output_batch_stride: " << output_batch_stride + << ", NumGroupsToMerge: " << NumGroupsToMerge << std::endl; + } + } using AGridDescMK = remove_cvref_t< decltype(ConvToGemmFwdTransformer{} .template MakeADescriptor_M_K())>; @@ -343,6 +367,7 @@ struct GroupedConvFwdKernelArgs index_t GemmN; index_t GemmK; index_t GemmBatch; + index_t NumGroupsToMerge; const void* in_ptr; const void* wei_ptr; @@ -567,13 +592,25 @@ struct GroupedConvolutionForwardKernel [[nodiscard]] CK_TILE_HOST static const std::string GetName() { + constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge; // clang-format off - return concat('_', "grouped_convolution_forward", - gemm_prec_str(), - "gemm", - GemmPipeline::GetName(), - "epilogue", - EpiloguePipeline::GetName()); + if (NumGroupsToMerge > 1) { + return concat('_', "grouped_convolution_forward", + gemm_prec_str(), + "gemm", + GemmPipeline::GetName(), + "epilogue", + EpiloguePipeline::GetName(), + "merge", + NumGroupsToMerge); + } else { + return concat('_', "grouped_convolution_forward", + gemm_prec_str(), + "gemm", + GemmPipeline::GetName(), + "epilogue", + EpiloguePipeline::GetName()); + } // clang-format on } @@ -742,6 +779,16 @@ struct GroupedConvolutionForwardKernel return false; } + if constexpr(GroupedConvTraitsType_::NumGroupsToMerge > 1) + { + const index_t ConvG = kargs.wei_g_k_c_xs_lengths[number<0>{}]; + if(ConvG % GroupedConvTraitsType_::NumGroupsToMerge != 0) + { + CK_TILE_ERROR("ConvG must be a multiple of NumGroupsToMerge!"); + return false; + } + } + return true; } diff --git a/include/ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp b/include/ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp index 4cbc5c506a..83eadd496a 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp index 27349a0978..71739c9083 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_data_to_gemm.hpp b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_data_to_gemm.hpp index a00ea37979..deb4dcb3db 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_data_to_gemm.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_data_to_gemm.hpp @@ -1,6 +1,5 @@ - +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" @@ -9,7 +8,7 @@ namespace ck_tile { template 1 - return make_naive_tensor_descriptor(make_tuple(N_, Wo_, K_), - make_tuple(NStride, WoStride, KStride), - number{}, - I1); + if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor( + make_tuple(N_ * Wo_, K_), make_tuple(WoStride, KStride), number{}, I1); + } + else + { + return make_naive_tensor_descriptor(make_tuple(N_, Wo_, K_), + make_tuple(NStride, WoStride, KStride), + number{}, + I1); + } } template ::type = false> CK_TILE_HOST auto make_wei_grid_desc() const { // GKXC - return make_naive_tensor_descriptor( - make_tuple(K_, X_, C_), make_tuple(X_ * C_, C_, I1), number{}, I1); + if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor( + make_tuple(K_, C_), make_tuple(C_, I1), number{}, I1); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(K_, X_, C_), make_tuple(X_ * C_, C_, I1), number{}, I1); + } } template ::type = false> @@ -492,14 +507,22 @@ struct TransformConvBwdDataToGemm { // NWGC const index_t NStride = Wi_ * G_ * C_; - const index_t WiStride = G_ * C_; // GC? + const index_t WiStride = G_ * C_; constexpr auto CStride = I1; // TODO Add support for NumGroupsToMerge > 1 - return make_naive_tensor_descriptor(make_tuple(N_, Wi_, C_), - make_tuple(NStride, WiStride, CStride), - number{}, - I1); + if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor( + make_tuple(N_ * Wi_, C_), make_tuple(WiStride, CStride), number{}, I1); + } + else + { + return make_naive_tensor_descriptor(make_tuple(N_, Wi_, C_), + make_tuple(NStride, WiStride, CStride), + number{}, + I1); + } } template ::type = false> @@ -513,10 +536,20 @@ struct TransformConvBwdDataToGemm // TODO Add support for NumGroupsToMerge > 1 - return make_naive_tensor_descriptor(make_tuple(N_, Ho_, Wo_, K_), - make_tuple(NStride, HoStride, WoStride, KStride), - number{}, - I1); + if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_), + make_tuple(WoStride, KStride), + number{}, + I1); + } + else + { + return make_naive_tensor_descriptor(make_tuple(N_, Ho_, Wo_, K_), + make_tuple(NStride, HoStride, WoStride, KStride), + number{}, + I1); + } } template ::type = false> @@ -529,20 +562,38 @@ struct TransformConvBwdDataToGemm constexpr auto CStride = I1; // TODO Add support for NumGroupsToMerge > 1 - return make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_, C_), - make_tuple(NStride, HiStride, WiStride, CStride), - number{}, - I1); + if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor(make_tuple(N_ * Hi_ * Wi_, C_), + make_tuple(WiStride, CStride), + number{}, + I1); + } + else + { + return make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_, C_), + make_tuple(NStride, HiStride, WiStride, CStride), + number{}, + I1); + } } template ::type = false> CK_TILE_HOST auto make_wei_grid_desc() const { // GKYXC - return make_naive_tensor_descriptor(make_tuple(K_, Y_, X_, C_), - make_tuple(C_ * X_ * Y_, C_ * X_, C_, I1), - number{}, - I1); + if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor( + make_tuple(K_, C_), make_tuple(C_, I1), number{}, I1); + } + else + { + return make_naive_tensor_descriptor(make_tuple(K_, Y_, X_, C_), + make_tuple(C_ * X_ * Y_, C_ * X_, C_, I1), + number{}, + I1); + } } template ::type = false> @@ -556,11 +607,21 @@ struct TransformConvBwdDataToGemm constexpr auto KStride = I1; // TODO Add support for NumGroupsToMerge > 1 - return make_naive_tensor_descriptor( - make_tuple(N_, Do_, Ho_, Wo_, K_), - make_tuple(NStride, DoStride, HoStride, WoStride, KStride), - number{}, - I1); + if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_), + make_tuple(WoStride, KStride), + number{}, + I1); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N_, Do_, Ho_, Wo_, K_), + make_tuple(NStride, DoStride, HoStride, WoStride, KStride), + number{}, + I1); + } } template ::type = false> @@ -613,103 +674,111 @@ struct TransformConvBwdDataToGemm const auto in_grid_desc = make_in_grid_desc(); const auto wei_grid_desc = make_wei_grid_desc(); - // A: output tensor comes in K_M - const auto out_n_wop_k_grid_desc = - transform_tensor_descriptor(out_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_pad_transform(Wo_, I0, I0), - make_pass_through_transform(K_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + return make_tuple(out_grid_desc, wei_grid_desc, in_grid_desc); + } + else + { + // A: output tensor comes in K_M + const auto out_n_wop_k_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Wo_, I0, I0), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); - const auto out_n_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( - out_n_wop_k_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_embed_transform(make_tuple(XDot_, WTilde_), - make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)), - make_pass_through_transform(K_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + const auto out_n_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( + out_n_wop_k_grid_desc, + make_tuple( + make_pass_through_transform(N_), + make_embed_transform(make_tuple(XDot_, WTilde_), + make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); - const auto out_n_xdotslice_wtildeslice_k_grid_desc = transform_tensor_descriptor( - out_n_xdot_wtilde_k_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_slice_transform(XDot_, I0, XDotSlice), - make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), - make_pass_through_transform(K_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + const auto out_n_xdotslice_wtildeslice_k_grid_desc = transform_tensor_descriptor( + out_n_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_slice_transform(XDot_, I0, XDotSlice), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); - const auto out_gemmm_gemmkraw_grid_desc = transform_tensor_descriptor( - out_n_xdotslice_wtildeslice_k_grid_desc, - make_tuple(make_merge_transform(make_tuple(XDotSlice, K_)), - make_merge_transform(make_tuple(N_, WTildeSlice))), - make_tuple(sequence<1, 3>{}, sequence<0, 2>{}), - make_tuple(sequence<1>{}, sequence<0>{})); + const auto out_gemmm_gemmkraw_grid_desc = transform_tensor_descriptor( + out_n_xdotslice_wtildeslice_k_grid_desc, + make_tuple(make_merge_transform(make_tuple(XDotSlice, K_)), + make_merge_transform(make_tuple(N_, WTildeSlice))), + make_tuple(sequence<1, 3>{}, sequence<0, 2>{}), + make_tuple(sequence<1>{}, sequence<0>{})); - // B: weight tensor comes in K_N - const auto wei_k_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( - wei_grid_desc, - make_tuple(make_pass_through_transform(K_), - make_embed_transform(make_tuple(XDot_, XTilde_), - make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + // B: weight tensor comes in K_N + const auto wei_k_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( + wei_grid_desc, + make_tuple(make_pass_through_transform(K_), + make_embed_transform(make_tuple(XDot_, XTilde_), + make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); - const auto wei_k_xdotslice_c_grid_desc = transform_tensor_descriptor( - wei_k_xdot_xtilde_c_grid_desc, - make_tuple(make_pass_through_transform(K_), - make_slice_transform(XDot_, I0, XDotSlice), - make_freeze_transform(IdxXTilde_), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<>{}, sequence<2>{})); + const auto wei_k_xdotslice_c_grid_desc = transform_tensor_descriptor( + wei_k_xdot_xtilde_c_grid_desc, + make_tuple(make_pass_through_transform(K_), + make_slice_transform(XDot_, I0, XDotSlice), + make_freeze_transform(IdxXTilde_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<>{}, sequence<2>{})); - const auto wei_gemmn_gemmkraw_grid_desc = - transform_tensor_descriptor(wei_k_xdotslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(XDotSlice, K_)), - make_pass_through_transform(C_)), - make_tuple(sequence<1, 0>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + const auto wei_gemmn_gemmkraw_grid_desc = transform_tensor_descriptor( + wei_k_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(XDotSlice, K_)), + make_pass_through_transform(C_)), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); - // c: input - const auto in_n_wip_c_grid_desc = transform_tensor_descriptor( - in_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + // c: input + const auto in_n_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); - const auto in_n_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( - in_n_wip_c_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_embed_transform(make_tuple(XTilde_, WTilde_), - make_tuple(ConvDilationW_, ConvStrideW_)), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + const auto in_n_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( + in_n_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(XTilde_, WTilde_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); - const auto in_n_wtildeslice_c_grid_desc = transform_tensor_descriptor( - in_n_xtilde_wtilde_c_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_freeze_transform(IdxXTilde_), - make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<>{}, sequence<1>{}, sequence<2>{})); + const auto in_n_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_freeze_transform(IdxXTilde_), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<>{}, sequence<1>{}, sequence<2>{})); - const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( - in_n_wtildeslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(N_, WTildeSlice)), - make_pass_through_transform(C_)), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_n_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N_, WTildeSlice)), + make_pass_through_transform(C_)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); - return make_tuple(out_gemmm_gemmkraw_grid_desc, - wei_gemmn_gemmkraw_grid_desc, - in_gemmmraw_gemmnraw_grid_desc); + return make_tuple(out_gemmm_gemmkraw_grid_desc, + wei_gemmn_gemmkraw_grid_desc, + in_gemmmraw_gemmnraw_grid_desc); + } } template ::type = false> @@ -735,39 +804,135 @@ struct TransformConvBwdDataToGemm const auto XDotSlice = integer_divide_ceil(X_ - IdxXTilde_, XTilde_); const auto out_grid_desc = make_out_grid_desc(); - const auto in_grid_desc = make_in_grid_desc(); const auto wei_grid_desc = make_wei_grid_desc(); + const auto in_grid_desc = make_in_grid_desc(); - // A: output tensor comes in K_M - const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( - out_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_pad_transform(Ho_, I0, I0), - make_pad_transform(Wo_, I0, I0), - make_pass_through_transform(K_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); - - const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( - out_n_hop_wop_k_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_embed_transform(make_tuple(YDot_, HTilde_), - make_tuple(-ConvDilationH_ / GcdStrideDilationH_, I1)), - make_embed_transform(make_tuple(XDot_, WTilde_), - make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)), - make_pass_through_transform(K_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{})); - - const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc = - transform_tensor_descriptor( - out_n_ydot_htilde_xdot_wtilde_k_grid_desc, + if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + return make_tuple(out_grid_desc, wei_grid_desc, in_grid_desc); + } + else + { + // A: output tensor comes in K_M + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_grid_desc, make_tuple(make_pass_through_transform(N_), - make_slice_transform(YDot_, I0, YDotSlice), - make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), - make_slice_transform(XDot_, I0, XDotSlice), - make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pad_transform(Ho_, I0, I0), + make_pad_transform(Wo_, I0, I0), make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + + const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple( + make_pass_through_transform(N_), + make_embed_transform(make_tuple(YDot_, HTilde_), + make_tuple(-ConvDilationH_ / GcdStrideDilationH_, I1)), + make_embed_transform(make_tuple(XDot_, WTilde_), + make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{})); + + const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc = + transform_tensor_descriptor( + out_n_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_slice_transform(YDot_, I0, YDotSlice), + make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot_, I0, XDotSlice), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{})); + + const auto out_gemmm_gemmkraw_grid_desc = transform_tensor_descriptor( + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)), + make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice))), + make_tuple(sequence<1, 3, 5>{}, sequence<0, 2, 4>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + // B: weight tensor comes in K_N + const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( + wei_grid_desc, + make_tuple(make_pass_through_transform(K_), + make_embed_transform(make_tuple(YDot_, YTilde_), + make_tuple(ConvStrideH_ / GcdStrideDilationH_, I1)), + make_embed_transform(make_tuple(XDot_, XTilde_), + make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{})); + + const auto wei_k_ydotslice_xdotslice_c_grid_desc = + transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, + make_tuple(make_pass_through_transform(K_), + make_slice_transform(YDot_, I0, YDotSlice), + make_slice_transform(XDot_, I0, XDotSlice), + make_freeze_transform(IdxYTilde_), + make_freeze_transform(IdxXTilde_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<3>{}, + sequence<2>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<>{}, + sequence<>{}, + sequence<3>{})); + + const auto wei_gemmn_gemmkraw_grid_desc = transform_tensor_descriptor( + wei_k_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)), + make_pass_through_transform(C_)), + make_tuple(sequence<1, 2, 0>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + // c: input + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + + const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(YTilde_, HTilde_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(XTilde_, WTilde_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{})); + + const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_freeze_transform(IdxYTilde_), + make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(IdxXTilde_), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C_)), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, @@ -775,111 +940,23 @@ struct TransformConvBwdDataToGemm sequence<4>{}, sequence<5>{}), make_tuple(sequence<0>{}, + sequence<>{}, sequence<1>{}, + sequence<>{}, sequence<2>{}, - sequence<3>{}, - sequence<4>{}, - sequence<5>{})); + sequence<3>{})); - const auto out_gemmm_gemmkraw_grid_desc = transform_tensor_descriptor( - out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc, - make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)), - make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice))), - make_tuple(sequence<1, 3, 5>{}, sequence<0, 2, 4>{}), - make_tuple(sequence<1>{}, sequence<0>{})); + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_n_htildeslice_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice)), + make_pass_through_transform(C_)), + make_tuple(sequence<0, 1, 2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); - // B: weight tensor comes in K_N - const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( - wei_grid_desc, - make_tuple(make_pass_through_transform(K_), - make_embed_transform(make_tuple(YDot_, YTilde_), - make_tuple(ConvStrideH_ / GcdStrideDilationH_, I1)), - make_embed_transform(make_tuple(XDot_, XTilde_), - make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{})); - - const auto wei_k_ydotslice_xdotslice_c_grid_desc = - transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, - make_tuple(make_pass_through_transform(K_), - make_slice_transform(YDot_, I0, YDotSlice), - make_slice_transform(XDot_, I0, XDotSlice), - make_freeze_transform(IdxYTilde_), - make_freeze_transform(IdxXTilde_), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<3>{}, - sequence<2>{}, - sequence<4>{}, - sequence<5>{}), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2>{}, - sequence<>{}, - sequence<>{}, - sequence<3>{})); - - const auto wei_gemmn_gemmkraw_grid_desc = transform_tensor_descriptor( - wei_k_ydotslice_xdotslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)), - make_pass_through_transform(C_)), - make_tuple(sequence<1, 2, 0>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - // c: input - const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), - make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); - - const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_c_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_embed_transform(make_tuple(YTilde_, HTilde_), - make_tuple(ConvDilationH_, ConvStrideH_)), - make_embed_transform(make_tuple(XTilde_, WTilde_), - make_tuple(ConvDilationW_, ConvStrideW_)), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{})); - - const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( - in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_freeze_transform(IdxYTilde_), - make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), - make_freeze_transform(IdxXTilde_), - make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2>{}, - sequence<3>{}, - sequence<4>{}, - sequence<5>{}), - make_tuple(sequence<0>{}, - sequence<>{}, - sequence<1>{}, - sequence<>{}, - sequence<2>{}, - sequence<3>{})); - - const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( - in_n_htildeslice_wtildeslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice)), - make_pass_through_transform(C_)), - make_tuple(sequence<0, 1, 2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return make_tuple(out_gemmm_gemmkraw_grid_desc, - wei_gemmn_gemmkraw_grid_desc, - in_gemmmraw_gemmnraw_grid_desc); + return make_tuple(out_gemmm_gemmkraw_grid_desc, + wei_gemmn_gemmkraw_grid_desc, + in_gemmmraw_gemmnraw_grid_desc); + } } template ::type = false> @@ -915,45 +992,174 @@ struct TransformConvBwdDataToGemm const auto in_grid_desc = make_in_grid_desc(); const auto wei_grid_desc = make_wei_grid_desc(); - // A: output tensor comes in K_M - const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( - out_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_pad_transform(Do_, I0, I0), - make_pad_transform(Ho_, I0, I0), - make_pad_transform(Wo_, I0, I0), - make_pass_through_transform(K_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{})); - - const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( - out_n_hop_wop_k_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_embed_transform(make_tuple(ZDot_, DTilde_), - make_tuple(-ConvDilationD_ / GcdStrideDilationD_, I1)), - make_embed_transform(make_tuple(YDot_, HTilde_), - make_tuple(-ConvDilationH_ / GcdStrideDilationH_, I1)), - make_embed_transform(make_tuple(XDot_, WTilde_), - make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)), - make_pass_through_transform(K_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), - make_tuple(sequence<0>{}, - sequence<1, 2>{}, - sequence<3, 4>{}, - sequence<5, 6>{}, - sequence<7>{})); - - const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc = - transform_tensor_descriptor( - out_n_ydot_htilde_xdot_wtilde_k_grid_desc, + if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + return make_tuple(out_grid_desc, wei_grid_desc, in_grid_desc); + } + else + { + // A: output tensor comes in K_M + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_grid_desc, make_tuple(make_pass_through_transform(N_), - make_slice_transform(ZDot_, I0, ZDotSlice), - make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice), - make_slice_transform(YDot_, I0, YDotSlice), - make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), - make_slice_transform(XDot_, I0, XDotSlice), - make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pad_transform(Do_, I0, I0), + make_pad_transform(Ho_, I0, I0), + make_pad_transform(Wo_, I0, I0), make_pass_through_transform(K_)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{})); + + const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple( + make_pass_through_transform(N_), + make_embed_transform(make_tuple(ZDot_, DTilde_), + make_tuple(-ConvDilationD_ / GcdStrideDilationD_, I1)), + make_embed_transform(make_tuple(YDot_, HTilde_), + make_tuple(-ConvDilationH_ / GcdStrideDilationH_, I1)), + make_embed_transform(make_tuple(XDot_, WTilde_), + make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(K_)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, + sequence<1, 2>{}, + sequence<3, 4>{}, + sequence<5, 6>{}, + sequence<7>{})); + + const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc = + transform_tensor_descriptor( + out_n_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_slice_transform(ZDot_, I0, ZDotSlice), + make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice), + make_slice_transform(YDot_, I0, YDotSlice), + make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot_, I0, XDotSlice), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}, + sequence<6>{}, + sequence<7>{}), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}, + sequence<6>{}, + sequence<7>{})); + + const auto out_gemmm_gemmkraw_grid_desc = transform_tensor_descriptor( + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc, + make_tuple( + make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)), + make_merge_transform(make_tuple(N_, DTildeSlice, HTildeSlice, WTildeSlice))), + make_tuple(sequence<1, 3, 5, 7>{}, sequence<0, 2, 4, 6>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + // B: weight tensor comes in K_N + const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( + wei_grid_desc, + make_tuple(make_pass_through_transform(K_), + make_embed_transform(make_tuple(ZDot_, ZTilde_), + make_tuple(ConvStrideD_ / GcdStrideDilationD_, I1)), + make_embed_transform(make_tuple(YDot_, YTilde_), + make_tuple(ConvStrideH_ / GcdStrideDilationH_, I1)), + make_embed_transform(make_tuple(XDot_, XTilde_), + make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(C_)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, + sequence<1, 2>{}, + sequence<3, 4>{}, + sequence<5, 6>{}, + sequence<7>{})); + + const auto wei_k_ydotslice_xdotslice_c_grid_desc = + transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, + make_tuple(make_pass_through_transform(K_), + make_slice_transform(ZDot_, I0, ZDotSlice), + make_slice_transform(YDot_, I0, YDotSlice), + make_slice_transform(XDot_, I0, XDotSlice), + make_freeze_transform(IdxZTilde_), + make_freeze_transform(IdxYTilde_), + make_freeze_transform(IdxXTilde_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<3>{}, + sequence<5>{}, + sequence<2>{}, + sequence<4>{}, + sequence<6>{}, + sequence<7>{}), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<>{}, + sequence<>{}, + sequence<>{}, + sequence<4>{})); + + const auto wei_gemmn_gemmkraw_grid_desc = transform_tensor_descriptor( + wei_k_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)), + make_pass_through_transform(C_)), + make_tuple(sequence<1, 2, 3, 0>{}, sequence<4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + // c: input + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Di_, InLeftPadD_, InRightPadD_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{})); + + const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(ZTilde_, DTilde_), + make_tuple(ConvDilationD_, ConvStrideD_)), + make_embed_transform(make_tuple(YTilde_, HTilde_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(XTilde_, WTilde_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, + sequence<1, 2>{}, + sequence<3, 4>{}, + sequence<5, 6>{}, + sequence<7>{})); + + const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_freeze_transform(IdxZTilde_), + make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice), + make_freeze_transform(IdxYTilde_), + make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(IdxXTilde_), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C_)), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, @@ -963,138 +1169,26 @@ struct TransformConvBwdDataToGemm sequence<6>{}, sequence<7>{}), make_tuple(sequence<0>{}, + sequence<>{}, sequence<1>{}, + sequence<>{}, sequence<2>{}, + sequence<>{}, sequence<3>{}, - sequence<4>{}, - sequence<5>{}, - sequence<6>{}, - sequence<7>{})); + sequence<4>{})); - const auto out_gemmm_gemmkraw_grid_desc = transform_tensor_descriptor( - out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc, - make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)), - make_merge_transform(make_tuple(N_, DTildeSlice, HTildeSlice, WTildeSlice))), - make_tuple(sequence<1, 3, 5, 7>{}, sequence<0, 2, 4, 6>{}), - make_tuple(sequence<1>{}, sequence<0>{})); + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_n_htildeslice_wtildeslice_c_grid_desc, + make_tuple( + make_merge_transform(make_tuple(N_, DTildeSlice, HTildeSlice, WTildeSlice)), + make_pass_through_transform(C_)), + make_tuple(sequence<0, 1, 2, 3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); - // B: weight tensor comes in K_N - const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( - wei_grid_desc, - make_tuple(make_pass_through_transform(K_), - make_embed_transform(make_tuple(ZDot_, ZTilde_), - make_tuple(ConvStrideD_ / GcdStrideDilationD_, I1)), - make_embed_transform(make_tuple(YDot_, YTilde_), - make_tuple(ConvStrideH_ / GcdStrideDilationH_, I1)), - make_embed_transform(make_tuple(XDot_, XTilde_), - make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), - make_tuple(sequence<0>{}, - sequence<1, 2>{}, - sequence<3, 4>{}, - sequence<5, 6>{}, - sequence<7>{})); - - const auto wei_k_ydotslice_xdotslice_c_grid_desc = - transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, - make_tuple(make_pass_through_transform(K_), - make_slice_transform(ZDot_, I0, ZDotSlice), - make_slice_transform(YDot_, I0, YDotSlice), - make_slice_transform(XDot_, I0, XDotSlice), - make_freeze_transform(IdxZTilde_), - make_freeze_transform(IdxYTilde_), - make_freeze_transform(IdxXTilde_), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<3>{}, - sequence<5>{}, - sequence<2>{}, - sequence<4>{}, - sequence<6>{}, - sequence<7>{}), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2>{}, - sequence<3>{}, - sequence<>{}, - sequence<>{}, - sequence<>{}, - sequence<4>{})); - - const auto wei_gemmn_gemmkraw_grid_desc = transform_tensor_descriptor( - wei_k_ydotslice_xdotslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)), - make_pass_through_transform(C_)), - make_tuple(sequence<1, 2, 3, 0>{}, sequence<4>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - // c: input - const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_pad_transform(Di_, InLeftPadD_, InRightPadD_), - make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), - make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{})); - - const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_c_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_embed_transform(make_tuple(ZTilde_, DTilde_), - make_tuple(ConvDilationD_, ConvStrideD_)), - make_embed_transform(make_tuple(YTilde_, HTilde_), - make_tuple(ConvDilationH_, ConvStrideH_)), - make_embed_transform(make_tuple(XTilde_, WTilde_), - make_tuple(ConvDilationW_, ConvStrideW_)), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), - make_tuple(sequence<0>{}, - sequence<1, 2>{}, - sequence<3, 4>{}, - sequence<5, 6>{}, - sequence<7>{})); - - const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( - in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_freeze_transform(IdxZTilde_), - make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice), - make_freeze_transform(IdxYTilde_), - make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), - make_freeze_transform(IdxXTilde_), - make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2>{}, - sequence<3>{}, - sequence<4>{}, - sequence<5>{}, - sequence<6>{}, - sequence<7>{}), - make_tuple(sequence<0>{}, - sequence<>{}, - sequence<1>{}, - sequence<>{}, - sequence<2>{}, - sequence<>{}, - sequence<3>{}, - sequence<4>{})); - - const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( - in_n_htildeslice_wtildeslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(N_, DTildeSlice, HTildeSlice, WTildeSlice)), - make_pass_through_transform(C_)), - make_tuple(sequence<0, 1, 2, 3>{}, sequence<4>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return make_tuple(out_gemmm_gemmkraw_grid_desc, - wei_gemmn_gemmkraw_grid_desc, - in_gemmmraw_gemmnraw_grid_desc); + return make_tuple(out_gemmm_gemmkraw_grid_desc, + wei_gemmn_gemmkraw_grid_desc, + in_gemmmraw_gemmnraw_grid_desc); + } } IndexType G_, N_, original_N_; diff --git a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp index 04024a588d..0b4744a3a1 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp @@ -1,6 +1,5 @@ - +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" diff --git a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp index 38857c13cb..8bea7f653c 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp @@ -1,6 +1,5 @@ - +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" @@ -471,10 +470,10 @@ struct TransformConvFwdToGemm bool>::type = false> CK_TILE_HOST auto MakeADescriptor_M_K() const { + IndexType NStrideTensorA_ = Wi_ * G_ * C_; IndexType WiStride_ = G_ * C_; - IndexType CStrideTensorA_ = 1; - IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_; IndexType GStrideTensorA_ = C_; + IndexType CStrideTensorA_ = 1; if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0) { @@ -702,11 +701,11 @@ struct TransformConvFwdToGemm CK_TILE_HOST auto MakeADescriptor_M_K() const { + IndexType NStrideTensorA_ = Hi_ * Wi_ * G_ * C_; IndexType HiStride_ = Wi_ * G_ * C_; IndexType WiStride_ = G_ * C_; - IndexType CStrideTensorA_ = 1; - IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_; IndexType GStrideTensorA_ = C_; + IndexType CStrideTensorA_ = 1; if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0) { @@ -961,12 +960,12 @@ struct TransformConvFwdToGemm CK_TILE_HOST auto MakeADescriptor_M_K() const { + IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_; IndexType DiStride_ = Hi_ * Wi_ * G_ * C_; IndexType HiStride_ = Wi_ * G_ * C_; IndexType WiStride_ = G_ * C_; - IndexType CStrideTensorA_ = 1; - IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_; IndexType GStrideTensorA_ = C_; + IndexType CStrideTensorA_ = 1; if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0) { @@ -1290,9 +1289,9 @@ struct TransformConvFwdToGemm bool>::type = false> CK_TILE_HOST auto MakeBDescriptor_N_K() const { - IndexType CStrideTensorB_ = 1; - IndexType KStrideTensorB_ = Z_ * Y_ * X_ * C_; IndexType GStrideTensorB_ = K_ * Z_ * Y_ * X_ * C_; + IndexType KStrideTensorB_ = Z_ * Y_ * X_ * C_; + IndexType CStrideTensorB_ = 1; if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter3x3) { @@ -1357,10 +1356,10 @@ struct TransformConvFwdToGemm bool>::type = false> CK_TILE_HOST auto MakeCDescriptor_M_N() const { + IndexType NStrideTensorC_ = Wo_ * G_ * K_; IndexType WoStride_ = G_ * K_; - IndexType KStrideTensorC_ = 1; - IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_; IndexType GStrideTensorC_ = K_; + IndexType KStrideTensorC_ = 1; const IndexType NDoHoWo = N_ * Wo_; if constexpr(NumGroupsToMerge == 1) @@ -1418,11 +1417,11 @@ struct TransformConvFwdToGemm bool>::type = false> CK_TILE_HOST auto MakeCDescriptor_M_N() const { + IndexType NStrideTensorC_ = Ho_ * Wo_ * G_ * K_; IndexType HoStride_ = Wo_ * G_ * K_; IndexType WoStride_ = G_ * K_; - IndexType KStrideTensorC_ = 1; - IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_; IndexType GStrideTensorC_ = K_; + IndexType KStrideTensorC_ = 1; const IndexType NDoHoWo = N_ * Ho_ * Wo_; if constexpr(NumGroupsToMerge == 1) @@ -1483,12 +1482,12 @@ struct TransformConvFwdToGemm bool>::type = false> CK_TILE_HOST auto MakeCDescriptor_M_N() const { + IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_; IndexType DoStride_ = Ho_ * Wo_ * G_ * K_; IndexType HoStride_ = Wo_ * G_ * K_; IndexType WoStride_ = G_ * K_; - IndexType KStrideTensorC_ = 1; - IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_; IndexType GStrideTensorC_ = K_; + IndexType KStrideTensorC_ = 1; const IndexType NDoHoWo = N_ * Do_ * Ho_ * Wo_; if constexpr(NumGroupsToMerge == 1) diff --git a/include/ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp b/include/ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp index bc20057e7a..2fc2779edc 100644 --- a/include/ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp +++ b/include/ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp b/include/ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp index 8d50ffde6d..21f1045a5d 100644 --- a/include/ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp +++ b/include/ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp b/include/ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp index 05490ac3ed..544c2b1747 100644 --- a/include/ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp +++ b/include/ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp index 0181a3291f..ea959a5cef 100644 --- a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp +++ b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp index 37f87b4fe0..0d4a8b8179 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp index 788d507bf5..f83e42f3ac 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp index 7fae9dc435..f895946ee3 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp index 422950b143..5fcae5e4c4 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp index 189ca9fe80..65982939ca 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/moe_flatmm.hpp b/include/ck_tile/ops/moe_flatmm.hpp index 484e3ca11d..40441d70f3 100644 --- a/include/ck_tile/ops/moe_flatmm.hpp +++ b/include/ck_tile/ops/moe_flatmm.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp b/include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp index 88da6be86e..07d97ec4ff 100644 --- a/include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp +++ b/include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp b/include/ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp index 53f5bfc6ff..44f9cfb52f 100644 --- a/include/ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp +++ b/include/ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/norm_reduce/thread/thread_welford.hpp b/include/ck_tile/ops/norm_reduce/thread/thread_welford.hpp index 52b253e5f7..ef83c07e44 100644 --- a/include/ck_tile/ops/norm_reduce/thread/thread_welford.hpp +++ b/include/ck_tile/ops/norm_reduce/thread/thread_welford.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/permute/kernel/generic_permute_kernel.hpp b/include/ck_tile/ops/permute/kernel/generic_permute_kernel.hpp index 3578e3b375..286e918296 100644 --- a/include/ck_tile/ops/permute/kernel/generic_permute_kernel.hpp +++ b/include/ck_tile/ops/permute/kernel/generic_permute_kernel.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp b/include/ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp index 17f18acb5e..a3abd13170 100644 --- a/include/ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp +++ b/include/ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/pooling/kernel/pool_kernel.hpp b/include/ck_tile/ops/pooling/kernel/pool_kernel.hpp index b91fe514e8..91be63b803 100644 --- a/include/ck_tile/ops/pooling/kernel/pool_kernel.hpp +++ b/include/ck_tile/ops/pooling/kernel/pool_kernel.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -436,12 +436,14 @@ struct PoolKernel // Main reduction loop - with index tracking for(int k_tile = amd_wave_read_first_lane(0); k_tile < num_k_tiles; ++k_tile) { - const auto x_tile = load_tile(x_window); + const auto x_tile = load_tile(x_window); + const auto& in_tensor_padded_ref = + in_tensor_padded; // structured bindings cannot be captured prior to cpp20 auto index_calculator = [&](const auto& x_indices) { // Get global coordinates in the 2D matrix space (M, N) const auto global_M = x_indices.at(number<0>{}) + iM; const auto global_N = (k_tile * S::Block_N) + x_indices.at(number<1>{}); - return in_tensor_padded.get_tensor_descriptor().calculate_offset( + return in_tensor_padded_ref.get_tensor_descriptor().calculate_offset( make_tuple(global_M, global_N)); }; diff --git a/include/ck_tile/ops/pooling/pipeline/pool_default_policy.hpp b/include/ck_tile/ops/pooling/pipeline/pool_default_policy.hpp index e08cc42e58..a6a53c970a 100644 --- a/include/ck_tile/ops/pooling/pipeline/pool_default_policy.hpp +++ b/include/ck_tile/ops/pooling/pipeline/pool_default_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/pooling/pipeline/pool_problem.hpp b/include/ck_tile/ops/pooling/pipeline/pool_problem.hpp index 53071b1772..85e9bf8962 100644 --- a/include/ck_tile/ops/pooling/pipeline/pool_problem.hpp +++ b/include/ck_tile/ops/pooling/pipeline/pool_problem.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/pooling/pipeline/pool_shape.hpp b/include/ck_tile/ops/pooling/pipeline/pool_shape.hpp index 5879fe593e..a53365f927 100644 --- a/include/ck_tile/ops/pooling/pipeline/pool_shape.hpp +++ b/include/ck_tile/ops/pooling/pipeline/pool_shape.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/reduce/block/block_reduce.hpp b/include/ck_tile/ops/reduce/block/block_reduce.hpp index 2fd8a48eee..5517a3fd6b 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp index c666608bfd..cbf4afefb2 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/reduce/block/block_reduce2d_problem.hpp b/include/ck_tile/ops/reduce/block/block_reduce2d_problem.hpp index 33cc660541..d778f6db57 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce2d_problem.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce2d_problem.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp b/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp index 83a22aaded..1503b2b18b 100644 --- a/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp +++ b/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/reduce/pipeline/reduce2d_default_policy.hpp b/include/ck_tile/ops/reduce/pipeline/reduce2d_default_policy.hpp index 273a764f01..8692f0b678 100644 --- a/include/ck_tile/ops/reduce/pipeline/reduce2d_default_policy.hpp +++ b/include/ck_tile/ops/reduce/pipeline/reduce2d_default_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp b/include/ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp index 1570b44271..1298bff274 100644 --- a/include/ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp +++ b/include/ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/reduce/pipeline/reduce2d_shape.hpp b/include/ck_tile/ops/reduce/pipeline/reduce2d_shape.hpp index 0499fe370b..267db73b24 100644 --- a/include/ck_tile/ops/reduce/pipeline/reduce2d_shape.hpp +++ b/include/ck_tile/ops/reduce/pipeline/reduce2d_shape.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp b/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp index 32586a6343..aa8081eaf5 100644 --- a/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp +++ b/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp index 356a2e12ca..e056a74ca5 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp index b05197b653..de27b15952 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp index 39d7c65d3e..1c1ce1c9ec 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp index 773df4f0f4..02dc234077 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp index ca3cdc37c4..e22343ee55 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp index b91f17ffdd..fe1b2f6d60 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp b/include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp index f6c7c0753a..8d3f200471 100644 --- a/include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp +++ b/include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp b/include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp index e0ea9692c5..7db8225fb0 100644 --- a/include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp +++ b/include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_default_policy.hpp b/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_default_policy.hpp index 4945b46071..11b7e4b5fa 100644 --- a/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_default_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_one_pass.hpp b/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_one_pass.hpp index f45afe3d2a..bbddb8e2c7 100644 --- a/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_one_pass.hpp +++ b/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_one_pass.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_problem.hpp b/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_problem.hpp index 1669fdd36d..b6fb4aebe8 100644 --- a/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_problem.hpp +++ b/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_problem.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp b/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp index 8b0a7274ed..3d8ae12364 100644 --- a/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp +++ b/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/softmax/block/block_softmax_2d.hpp b/include/ck_tile/ops/softmax/block/block_softmax_2d.hpp index 607ec7eb53..abb95934ff 100644 --- a/include/ck_tile/ops/softmax/block/block_softmax_2d.hpp +++ b/include/ck_tile/ops/softmax/block/block_softmax_2d.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp b/include/ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp index 82b9a5a486..426c55b60f 100644 --- a/include/ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp +++ b/include/ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/topk/block/block_topk_stream_2d.hpp b/include/ck_tile/ops/topk/block/block_topk_stream_2d.hpp index 164685f980..ee1e19dee8 100644 --- a/include/ck_tile/ops/topk/block/block_topk_stream_2d.hpp +++ b/include/ck_tile/ops/topk/block/block_topk_stream_2d.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp b/include/ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp index d47188d862..eccd7488a1 100644 --- a/include/ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp +++ b/include/ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp b/include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp index 019e940a33..1b34892bb9 100644 --- a/include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp +++ b/include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp b/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp index 677263229b..26c3902aaa 100644 --- a/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp +++ b/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp b/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp index a6e886bd39..3d0bccab22 100644 --- a/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp +++ b/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp b/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp index 1dc7e9335e..3bb3ac2897 100644 --- a/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp +++ b/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ref/naive_attention.hpp b/include/ck_tile/ref/naive_attention.hpp index 50e963bd72..fd7a4b31cb 100644 --- a/include/ck_tile/ref/naive_attention.hpp +++ b/include/ck_tile/ref/naive_attention.hpp @@ -1,5 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/remod.py b/include/ck_tile/remod.py index 2ff707e9d3..affa6d987b 100644 --- a/include/ck_tile/remod.py +++ b/include/ck_tile/remod.py @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + from datetime import datetime import pathlib from pathlib import Path diff --git a/include/ck_tile/utility/json_dump.hpp b/include/ck_tile/utility/json_dump.hpp index ed6373ae66..b5bab28cac 100644 --- a/include/ck_tile/utility/json_dump.hpp +++ b/include/ck_tile/utility/json_dump.hpp @@ -1,4 +1,4 @@ -// Copyright © Advanced Micro Devices, Inc. or its affiliates. +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #ifdef CK_ENABLE_JSON_DUMP diff --git a/library/CMakeLists.txt b/library/CMakeLists.txt index 90873fdd14..083d2e4b1e 100644 --- a/library/CMakeLists.txt +++ b/library/CMakeLists.txt @@ -1,2 +1,5 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_subdirectory(src/tensor_operation_instance/gpu) add_subdirectory(src/utility) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp index bf8536d268..4d9c09f597 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp @@ -15,6 +15,142 @@ namespace tensor_operation { namespace device { namespace instance { +#if defined(CK_USE_WMMA) +#if defined(CK_ENABLE_FP16) +void add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_kn_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_nk_mn_instances( + std::vector>>& instances); +#endif // CK_ENABLE_FP16 +#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8) && defined(__gfx12__) +void add_device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instances( + std::vector>>& instances); +#endif +#if defined(CK_ENABLE_BF16) +void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_instances( + std::vector>>& instances); +#endif // CK_ENABLE_BF16 +#endif // CK_USE_WMMA + #if defined(CK_USE_XDL) #if defined(CK_ENABLE_FP16) void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( @@ -409,6 +545,81 @@ struct DeviceOperationInstanceFactory> op_ptrs; +#if defined(CK_USE_WMMA) +#if defined(CK_ENABLE_FP16) + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_nk_mn_instances(op_ptrs); + } + } +#endif // CK_ENABLE_FP16 +#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8) && defined(__gfx12__) + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instances(op_ptrs); + } + } +#endif +#if defined(CK_ENABLE_BF16) + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_instances(op_ptrs); + } + } +#endif // CK_ENABLE_BF16 +#endif // CK_USE_WMMA + #if defined(CK_USE_XDL) #if defined(CK_ENABLE_FP16) if constexpr(is_same_v && is_same_v && diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp new file mode 100644 index 0000000000..6d5da9208b --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp @@ -0,0 +1,205 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/utility/loop_scheduler.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using Empty_Tuple = ck::Tuple<>; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AccDataType = F32; +using DsDataType = Empty_Tuple; + +using DsLayout = Empty_Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto PipelineV1 = BlockGemmPipelineVersion::v1; +static constexpr auto PipelineV3 = BlockGemmPipelineVersion::v3; +static constexpr auto IntrawaveScheduler = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto InterwaveScheduler = BlockGemmPipelineScheduler::Interwave; +static constexpr auto GemmMNKPadding = device::GemmSpecialization::MNKPadding; +static constexpr auto GemmDefault = device::GemmSpecialization::Default; + +// Instances for 2 byte datatypes in CRR layout with ADataType = BDataType = EDataType +template = false> +using device_grouped_gemm_wmma_universal_km_kn_mn_instances = + std::tuple< + // clang-format off + //##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_CShuffleV3< Col, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_CShuffleV3< Col, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_CShuffleV3< Col, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> + // clang`-format on + >; + +// Instances for 2 byte datatypes in CCR layout with ADataType = BDataType = EDataType +template = false> +using device_grouped_gemm_wmma_universal_km_nk_mn_instances = std::tuple< + // clang-format off + //##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_CShuffleV3< Col, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_CShuffleV3< Col, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_CShuffleV3< Col, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> + // clang-format on + >; + +// Instances for 2 byte datatypes in RRR layout with ADataType = BDataType = EDataType +template = false> +using device_grouped_gemm_wmma_universal_mk_kn_mn_instances = + std::tuple< + // clang-format off + //##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> + // clang-format on + >; + +// Instances for 2 byte datatypes in RCR layout with ADataType = BDataType = EDataType +template = false> +using device_grouped_gemm_wmma_universal_mk_nk_mn_instances = + std::tuple< + // clang-format off + //##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_CShuffleV3< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_CShuffleV3< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_CShuffleV3< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> + // clang-format on + >; + +// Helper function to add a list of layout instances with specific A/B/E datatypes for all supported +// padding/scheduler/pipeline version combinations +template + typename LayoutInstances, + typename ADataType, // NOTE: type parameters as last so that they can be inferred from the + typename BDataType, // vector argument + typename EDataType> +void add_device_grouped_gemm_wmma_universal_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + LayoutInstances{}); + add_device_operation_instances(instances, + LayoutInstances{}); + add_device_operation_instances(instances, + LayoutInstances{}); + add_device_operation_instances( + instances, LayoutInstances{}); + add_device_operation_instances( + instances, LayoutInstances{}); + add_device_operation_instances( + instances, LayoutInstances{}); +} + +// Helper function to add a list of layout instances for instances with matching A/B/E data types +// for all supported padding/scheduler/pipeline version combinations +template + typename LayoutInstances> +void add_device_grouped_gemm_wmma_universal_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, LayoutInstances{}); + add_device_operation_instances( + instances, LayoutInstances{}); + add_device_operation_instances( + instances, LayoutInstances{}); + add_device_operation_instances( + instances, LayoutInstances{}); + add_device_operation_instances( + instances, LayoutInstances{}); + add_device_operation_instances( + instances, LayoutInstances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index 6f171191ca..eeaf269394 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + function(add_instance_library INSTANCE_NAME) message(DEBUG "adding instance ${INSTANCE_NAME}") set(result 1) @@ -54,7 +57,7 @@ function(add_instance_library INSTANCE_NAME) list(REMOVE_ITEM ARGN "${source}") endif() # Do not build XDL instances if gfx9 targets are not on the target list - if(NOT INST_TARGETS MATCHES "gfx9" AND NOT INST_TARGETS MATCHES "gfx11" AND NOT INST_TARGETS MATCHES "gfx12" AND source_name MATCHES "_xdl") + if(((NOT INST_TARGETS MATCHES "gfx9" AND NOT INST_TARGETS MATCHES "gfx11" AND NOT INST_TARGETS MATCHES "gfx12") OR FORCE_DISABLE_XDL) AND source_name MATCHES "_xdl") message(DEBUG "removing xdl instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() @@ -64,7 +67,7 @@ function(add_instance_library INSTANCE_NAME) list(REMOVE_ITEM ARGN "${source}") endif() # Do not build WMMA instances if gfx11 targets are not on the target list - if(NOT INST_TARGETS MATCHES "gfx11" AND NOT INST_TARGETS MATCHES "gfx12" AND source_name MATCHES "_wmma") + if(((NOT INST_TARGETS MATCHES "gfx11" AND NOT INST_TARGETS MATCHES "gfx12") OR FORCE_DISABLE_WMMA) AND source_name MATCHES "_wmma") message(DEBUG "removing wmma instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() @@ -79,23 +82,27 @@ function(add_instance_library INSTANCE_NAME) message(DEBUG "removing gemm_multiply_multiply_f8 instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() - if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx95" AND NOT INST_TARGETS MATCHES "gfx12" AND source_name MATCHES "gemm_xdl_universal" AND source_name MATCHES "_f8_") + if(NOT INST_TARGETS MATCHES "gfx94|gfx95|gfx12" AND source_name MATCHES "gemm_xdl_universal" AND source_name MATCHES "_f8_") message(DEBUG "removing gemm_universal_f8 instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() + if(NOT INST_TARGETS MATCHES "gfx94|gfx95|gfx12" AND source_name MATCHES "gemm_blockscale" AND source_name MATCHES "_f8_") + message(DEBUG "removing gemm_blockscale_f8 instance ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() endif() # Do not build WMMA gemm_universal_f8 for any targets except gfx12+ - if(NOT INST_TARGETS MATCHES "gfx12" AND source_name MATCHES "gemm_wmma_universal" AND source_name MATCHES "_f8_") - message(DEBUG "removing gemm_universal_f8 instance ${source} ") + if((NOT INST_TARGETS MATCHES "gfx12" OR FORCE_DISABLE_WMMA) AND source_name MATCHES "gemm_wmma_universal" AND source_name MATCHES "_f8_") + message(DEBUG "removing gemm_wmma_universal_f8 instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() # Do not build gemm_universal_preshuffle_f8 for any targets except gfx94, gfx95 and gfx12 - if(NOT (INST_TARGETS MATCHES "gfx942" OR INST_TARGETS MATCHES "gfx950" OR INST_TARGETS MATCHES "gfx12") AND (source_name MATCHES "gemm_universal_preshuffle" OR source_name MATCHES "gemm_xdl_universal_preshuffle") AND (source_name MATCHES "_f8_f8_f16" OR source_name MATCHES "_f8_f8_bf16")) + if(NOT (INST_TARGETS MATCHES "gfx94|gfx95|gfx12") AND (source_name MATCHES "gemm_universal_preshuffle") AND source_name MATCHES "_f8_" ) message(DEBUG "removing gemm_universal_preshuffle_f8 instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() # Only build tf32 instances for gfx942 & gfx950 - if(NOT (INST_TARGETS MATCHES "gfx942" OR INST_TARGETS MATCHES "gfx950") AND source_name MATCHES "_tf32_") + if(NOT (INST_TARGETS MATCHES "gfx942|gfx950") AND source_name MATCHES "_tf32_") message(DEBUG "removing tf32 instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() @@ -115,39 +122,21 @@ function(add_instance_library INSTANCE_NAME) elseif(source_name MATCHES "_wmma") list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950) elseif(source_name MATCHES "mha") - list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1153 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) endif() if(source_name MATCHES "_mx") - list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1153 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) endif() #only build the fp8 gemm instances for gfx90a if the build argument is set, otherwise only build for gfx942/gfx950 and gfx1200/gfx1201 if(NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH) - if(source_name MATCHES "gemm_xdl_universal" AND source_name MATCHES "f8") - list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx10-3-generic gfx11-generic) - endif() - if(source_name MATCHES "gemm_multiply_multiply" AND source_name MATCHES "f8") - list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx10-3-generic gfx11-generic) - endif() - if(source_name MATCHES "gemm_universal_preshuffle" AND source_name MATCHES "f8") - list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx10-3-generic gfx11-generic) - endif() - if(source_name MATCHES "gemm_xdl_universal_preshuffle" AND source_name MATCHES "f8") - list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx10-3-generic gfx11-generic) + if(source_name MATCHES "gemm_xdl_universal|gemm_multiply_multiply|gemm_universal_preshuffle|gemm_blockscale" AND source_name MATCHES "_f8_") + list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1153 gfx10-3-generic gfx11-generic) endif() else() - if(source_name MATCHES "gemm_xdl_universal" AND source_name MATCHES "f8") - list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx10-3-generic gfx11-generic) - endif() - if(source_name MATCHES "gemm_multiply_multiply" AND source_name MATCHES "f8") - list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx10-3-generic gfx11-generic) - endif() - if(source_name MATCHES "gemm_universal_preshuffle" AND source_name MATCHES "f8") - list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx10-3-generic gfx11-generic) - endif() - if(source_name MATCHES "gemm_xdl_universal_preshuffle" AND source_name MATCHES "f8") - list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx10-3-generic gfx11-generic) + if(source_name MATCHES "gemm_xdl_universal|gemm_multiply_multiply|gemm_universal_preshuffle|gemm_blockscale" AND source_name MATCHES "_f8_") + list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1153 gfx10-3-generic gfx11-generic) endif() endif() if(source_name MATCHES "gemm_wmma_universal" AND source_name MATCHES "f8") @@ -271,7 +260,7 @@ FOREACH(subdir_path ${dir_list}) message(DEBUG "Found only dl instances, but DL_KERNELS is not set. Skipping.") set(add_inst 0) endif() - if(("${cmake_instance}" MATCHES "ONLY XDL_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx9|gfx11|gfx12")) + if(("${cmake_instance}" MATCHES "ONLY XDL_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx9|gfx11|gfx12" OR FORCE_DISABLE_XDL)) message(DEBUG "Found only xdl instances, but gfx9 is not on the targets list. Skipping.") set(add_inst 0) endif() @@ -279,7 +268,7 @@ FOREACH(subdir_path ${dir_list}) message(DEBUG "Found only MX instances, but gfx950 is not on the targets list. Skipping.") set(add_inst 0) endif() - if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx12")) + if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (((NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx12")) OR FORCE_DISABLE_WMMA)) message(DEBUG "Found only wmma instances, but gfx11 is not on the targets list. Skipping.") set(add_inst 0) endif() @@ -287,7 +276,7 @@ FOREACH(subdir_path ${dir_list}) message(DEBUG "Found only xdl and dl instances, but gfx9 is not on the targets listand DL_KERNELS is not set. Skipping.") set(add_inst 0) endif() - if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx9|gfx11|gfx12")) + if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND ((NOT INST_TARGETS MATCHES "gfx9|gfx11|gfx12") OR (FORCE_DISABLE_XDL AND FORCE_DISABLE_WMMA))) message(DEBUG "Found only xdl and wmma instances, but gfx11 and gfx9 are not on the targets list. Skipping.") set(add_inst 0) endif() @@ -302,12 +291,8 @@ FOREACH(subdir_path ${dir_list}) message(DEBUG "Found gemm_multiply_multiply instances, but gfx94/gfx95/gfx11/gfx12 not on the target list. Skipping. ${cmake_instance}") set(add_inst 0) endif() - if(("${cmake_instance}" MATCHES "gemm_universal_preshuffle" AND "${cmake_instance}" MATCHES "_f8_" ) AND (NOT INST_TARGETS MATCHES "gfx94|gfx95|gfx12") AND (NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH)) - message(DEBUG "Found gemm_universal_preshuffle_f8 instances, but gfx94/gfx95 not on the target list. Skipping.") - set(add_inst 0) - endif() - if(("${cmake_instance}" MATCHES "gemm_xdl_universal_preshuffle" AND "${cmake_instance}" MATCHES "_f8_" ) AND (NOT INST_TARGETS MATCHES "gfx94|gfx95|gfx12") AND (NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH)) - message(DEBUG "Found gemm_xdl_universal_preshuffle_f8_f8_bf16 instances, but gfx94/gfx95 not on the target list. Skipping.") + if(("${cmake_instance}" MATCHES "gemm_universal_preshuffle|gemm_blockscale" AND "${cmake_instance}" MATCHES "_f8_" ) AND (NOT INST_TARGETS MATCHES "gfx94|gfx95|gfx12") AND (NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH)) + message(DEBUG "Found gemm_f8 instances, but gfx94/gfx95 not on the target list. Skipping.") set(add_inst 0) endif() if ("${cmake_instance}" MATCHES "gemm_bilinear") @@ -330,20 +315,22 @@ FOREACH(subdir_path ${dir_list}) if((add_inst EQUAL 1)) get_filename_component(target_dir ${subdir_path} NAME) add_subdirectory(${target_dir}) - if("${cmake_instance}" MATCHES "gemm") - list(APPEND CK_DEVICE_GEMM_INSTANCES $) - elseif("${cmake_instance}" MATCHES "conv") - list(APPEND CK_DEVICE_CONV_INSTANCES $) - elseif("${cmake_instance}" MATCHES "mha") - list(APPEND CK_DEVICE_MHA_INSTANCES $) - elseif("${cmake_instance}" MATCHES "contr") - list(APPEND CK_DEVICE_CONTRACTION_INSTANCES $) - elseif("${cmake_instance}" MATCHES "reduce") - list(APPEND CK_DEVICE_REDUCTION_INSTANCES $) - else() - list(APPEND CK_DEVICE_OTHER_INSTANCES $) - endif() - message(DEBUG "add_instance_directory ${subdir_path}") + if (TARGET device_${target_dir}_instance) + if("${cmake_instance}" MATCHES "gemm") + list(APPEND CK_DEVICE_GEMM_INSTANCES $) + elseif("${cmake_instance}" MATCHES "conv") + list(APPEND CK_DEVICE_CONV_INSTANCES $) + elseif("${cmake_instance}" MATCHES "mha") + list(APPEND CK_DEVICE_MHA_INSTANCES $) + elseif("${cmake_instance}" MATCHES "contr") + list(APPEND CK_DEVICE_CONTRACTION_INSTANCES $) + elseif("${cmake_instance}" MATCHES "reduce") + list(APPEND CK_DEVICE_REDUCTION_INSTANCES $) + else() + list(APPEND CK_DEVICE_OTHER_INSTANCES $) + endif() + message(DEBUG "add_instance_directory ${subdir_path}") + endif() else() message(DEBUG "skip_instance_directory ${subdir_path}") endif() diff --git a/library/src/tensor_operation_instance/gpu/avg_pool2d_bwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/avg_pool2d_bwd/CMakeLists.txt index ad69023465..76d1afe84f 100644 --- a/library/src/tensor_operation_instance/gpu/avg_pool2d_bwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/avg_pool2d_bwd/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + set(DEVICE_AVGPOOL_2D_BWD_INSTANCES) list(APPEND DEVICE_AVGPOOL_2D_BWD_INSTANCES device_avg_pool2d_bwd_nhwc_bf16_instance.cpp device_avg_pool2d_bwd_nhwc_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/avg_pool3d_bwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/avg_pool3d_bwd/CMakeLists.txt index 084714b707..352636838f 100644 --- a/library/src/tensor_operation_instance/gpu/avg_pool3d_bwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/avg_pool3d_bwd/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + set(DEVICE_AVGPOOL_BWD_INSTANCES) list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_f16_instance.cpp device_avg_pool3d_bwd_ndhwc_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt index 519d549a3d..e2eaf8382e 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_AND_WMMA_KERNELS set(BATCHED_GEMM_INSTANCES) list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gkn_gmn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt index 5c8470f7cb..5d830bb2fe 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS add_instance_library(device_batched_gemm_add_relu_gemm_add_instance device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_b_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_b_scale/CMakeLists.txt index 77295ed151..940d4a2f15 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_b_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_b_scale/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_AND_WMMA_KERNELS set(BATCHED_GEMM_B_SCALE_INSTANCES) diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/CMakeLists.txt index 8082a8c275..a4f66fdd4d 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS add_instance_library(device_batched_gemm_bias_permute_instance device_batched_gemm_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt index 527c40fcd9..6aa0aeed12 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_batched_gemm_gemm_instance device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/CMakeLists.txt index b874bc50ee..2023c0b208 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY DL_KERNELS set(BATCHED_GEMM_MULTID_INSTANCES) list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instance.cpp) diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt index 51bbdf1d7c..a098a0a7e5 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS add_instance_library(device_batched_gemm_reduce_instance device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/CMakeLists.txt index e43eb07fb6..135dd9c484 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS add_instance_library(device_batched_gemm_softmax_gemm_instance device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt index f1fb0646e4..29dcce4bde 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS set(DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES) list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES diff --git a/library/src/tensor_operation_instance/gpu/batchnorm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batchnorm/CMakeLists.txt index 19a3cc8cd1..d596217722 100644 --- a/library/src/tensor_operation_instance/gpu/batchnorm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batchnorm/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_instance_library(device_batchnorm_instance device_batchnorm_forward_f16_instance.cpp device_batchnorm_forward_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/column_to_image/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/column_to_image/CMakeLists.txt index 50855babb5..0e97f31712 100644 --- a/library/src/tensor_operation_instance/gpu/column_to_image/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/column_to_image/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_instance_library(device_column_to_image_instance device_column_to_image_gnwc_1d_instance.cpp device_column_to_image_gnhwc_2d_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt index 70e4bbfe57..9850882c55 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS set(DEVICE_CONTRACTION_BILINEAR_INSTANCES) diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt index dd36f88c43..a45bea6460 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS set(DEVICE_CONTRACTION_SCALE_INSTANCES) diff --git a/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/CMakeLists.txt index 796a9b2402..753cbb8735 100644 --- a/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS add_instance_library(device_conv1d_bwd_data_instance device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt index 2da5155117..0fa115ed09 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_AND_DL_KERNELS set(CONV2D_BWD_DATA_INSTANCES) list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt index 04b313d075..70b5675859 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS set(DEVICE_CONV2D_FWD_INSTANCES) list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/CMakeLists.txt index 4304d8996c..e8662f4465 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS add_instance_library(device_conv2d_fwd_bias_relu_instance device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/CMakeLists.txt index 40a6b1ff09..2e562fb940 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS add_instance_library(device_conv2d_fwd_bias_relu_add_instance device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/CMakeLists.txt index ec4a8a2864..1524eeb93b 100644 --- a/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS add_instance_library(device_conv3d_bwd_data_instance device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/elementwise/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/elementwise/CMakeLists.txt index 47516b4162..07c200139b 100644 --- a/library/src/tensor_operation_instance/gpu/elementwise/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/elementwise/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_instance_library(device_elementwise_instance device_normalize_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/elementwise_normalization/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/elementwise_normalization/CMakeLists.txt index 0c7cc2cd31..4e5fa8048c 100644 --- a/library/src/tensor_operation_instance/gpu/elementwise_normalization/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/elementwise_normalization/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_instance_library(device_elementwise_normalization_instance device_elementwise_normalization_f16_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt index b8ecb4557e..668bce5836 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + set(GEMM_INSTANCES) list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/CMakeLists.txt index 4f3c2f1ff5..a315db8bdd 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS set(GEMM_AB_SCALE_INSTANCES) diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt index 478e9a8ab8..1f7e4b6f4c 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_gemm_add_instance device_gemm_add_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt index ab8023d1ba..eb3d1fa28e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_gemm_add_add_fastgelu_instance device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt index 46f0c3b9c6..912757b82b 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_gemm_add_fastgelu_instance device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_multiply/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_multiply/CMakeLists.txt index 2e6bdca234..1b25b3891d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_multiply/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_multiply/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_gemm_add_multiply_instance device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt index 1bdf611907..024532969b 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_gemm_add_relu_instance device_gemm_add_relu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/CMakeLists.txt index 87b414faf7..6a9bfa8873 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_gemm_add_relu_add_layernorm_instance device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_silu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_silu/CMakeLists.txt index 565096dd61..d8297b7307 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_silu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_silu/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_gemm_add_silu_instance device_gemm_add_silu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_b_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_b_scale/CMakeLists.txt index 34f51f5f58..a207ede8c7 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_b_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_b_scale/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_AND_WMMA_KERNELS set(GEMM_B_SCALE_INSTANCES) diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt index f29943d93b..a82e95d8d1 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS add_instance_library(device_gemm_bias_add_reduce_instance device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt index 39e83495d4..5dfa827991 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_gemm_bilinear_instance device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/CMakeLists.txt index 0ffe5f95b2..b37a22d895 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/CMakeLists.txt @@ -1,23 +1,28 @@ -# ONLY XDL_KERNELS -set(GEMM_BLOCKSCALE_WP_INSTANCES) +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT -list(APPEND GEMM_BLOCKSCALE_WP_INSTANCES - device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp - device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp - device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp - device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp - ) -check_cxx_compiler_flag("-mllvm --misched-bottomup=1" HAS_MISCHED_BOTTOMUP) -check_cxx_compiler_flag("-mllvm --misched-prera-direction=bottomup" HAS_MISCHED_PRERA_DIRECTION) -if(HAS_MISCHED_BOTTOMUP) - set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-mllvm;--misched-bottomup=1") - set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-mllvm;--misched-bottomup=1") - set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-mllvm;--misched-bottomup=1") - set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-mllvm;--misched-bottomup=1") -elseif(HAS_MISCHED_PRERA_DIRECTION) - set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-mllvm;--misched-prera-direction=bottomup") - set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-mllvm;--misched-prera-direction=bottomup") - set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-mllvm;--misched-prera-direction=bottomup") - set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-mllvm;--misched-prera-direction=bottomup") +# ONLY XDL_KERNELS +if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]|gfx12") + set(GEMM_BLOCKSCALE_WP_INSTANCES) + + list(APPEND GEMM_BLOCKSCALE_WP_INSTANCES + device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp + device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp + device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp + device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp + ) + check_cxx_compiler_flag("-mllvm --misched-bottomup=1" HAS_MISCHED_BOTTOMUP) + check_cxx_compiler_flag("-mllvm --misched-prera-direction=bottomup" HAS_MISCHED_PRERA_DIRECTION) + if(HAS_MISCHED_BOTTOMUP) + set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-mllvm;--misched-bottomup=1") + set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-mllvm;--misched-bottomup=1") + set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-mllvm;--misched-bottomup=1") + set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-mllvm;--misched-bottomup=1") + elseif(HAS_MISCHED_PRERA_DIRECTION) + set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-mllvm;--misched-prera-direction=bottomup") + set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-mllvm;--misched-prera-direction=bottomup") + set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-mllvm;--misched-prera-direction=bottomup") + set_source_files_properties(device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1;-mllvm;--misched-prera-direction=bottomup") + endif() + add_instance_library(device_gemm_blockscale_wp_instance ${GEMM_BLOCKSCALE_WP_INSTANCES}) endif() -add_instance_library(device_gemm_blockscale_wp_instance ${GEMM_BLOCKSCALE_WP_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/CMakeLists.txt index f3273fb8ed..37c7eb869e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_gemm_fastgelu_instance device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/CMakeLists.txt index 5ce585ad81..24b4524063 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_AND_WMMA_KERNELS set(GEMM_MULTI_ABD_INSTANCES) diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt index 3a27e43dd6..5ccb83be18 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_gemm_multiply_add_instance device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/CMakeLists.txt index 0e52eac0bf..a68bea98a9 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_AND_WMMA_KERNELS set(GEMM_MULTIPLY_MULTIPLY_INSTANCES) diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/CMakeLists.txt index 37233ac5b4..dd8c47e190 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS set(GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES) diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_mx/CMakeLists.txt index 67805a86b1..68b84ae73c 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_mx/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY MX_KERNELS set(GEMM_MX_INSTANCES) diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt index 12d1026ea1..eb15849256 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_gemm_reduce_instance device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt index dac86d7707..73bd6f0e37 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS set(GEMM_SPLITK_INSTANCES) diff --git a/library/src/tensor_operation_instance/gpu/gemm_streamk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_streamk/CMakeLists.txt index c854b16eeb..b7d7ed4628 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_streamk/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_streamk/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS add_instance_library(device_gemm_streamk_instance # device_gemm_xdl_streamk_f32_f32_f32_mk_kn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_universal/CMakeLists.txt index c8d56f46be..012e7273a7 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_AND_WMMA_KERNELS set(GEMM_UNIVERSAL_INSTANCES) diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_batched/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_universal_batched/CMakeLists.txt index 1affa12bb3..43fd55b738 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_batched/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_batched/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS set(GEMM_UNIVERSAL_BATCHED_INSTANCES) diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/CMakeLists.txt index 5967258789..a022b746ac 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS set(GEMM_UNIVERSAL_INSTANCES) diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/CMakeLists.txt index 142ace2e42..a657c2f322 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_reduce/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_AND_WMMA_KERNELS set(GEMM_UNIVERSAL_REDUCE_INSTANCES) diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/CMakeLists.txt index b7391d3446..b5ec87593d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS set(GEMM_UNIVERSAL_STREAMK_INSTANCES) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt index b057e0c8d2..4ef6722ab5 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_AND_DL_KERNELS set(GROUPED_CONV1D_BWD_WEIGHT xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/CMakeLists.txt index ca4ea515bb..f4cba07b83 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS add_instance_library(device_grouped_conv1d_fwd_instance xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt index a686643fb5..9da738480b 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_AND_WMMA_KERNELS add_instance_library( device_grouped_conv2d_bwd_data_instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt index f042e09e69..7e9a26c092 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_AND_DL_KERNELS set(GROUPED_CONV2D_BWD_WEIGHT xdl/gnhwc_gkyxc_gnhwk/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt index 561073e6dc..704079f181 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # XDL_DL_WMMA_KERNELS set(GROUPED_CONV2D_FWD #xdl diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt index 2c79315799..cf1eaf0e12 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS set(GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP) include(ShardInstantiation) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/CMakeLists.txt index a06268573d..59312849c3 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS add_instance_library(device_grouped_conv2d_fwd_bias_clamp_instance xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/CMakeLists.txt index d4efa4aaa1..69823c3246 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS add_instance_library(device_grouped_conv2d_fwd_clamp_instance xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_dynamic_op/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_dynamic_op/CMakeLists.txt index 92735fcaeb..dd5c69a7c2 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_dynamic_op/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_dynamic_op/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS set(GROUPED_CONV2D_FWD_DYNAMIC_OP xdl/device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt index f5b2f0d021..a2a792e745 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_AND_WMMA_KERNELS set(GROUPED_CONV3D_BWD_DATA xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/CMakeLists.txt index ffcd1ea1f7..69ea0c5ccf 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS set(GROUPED_CONV3D_BWD_DATA_BILINEAR xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/CMakeLists.txt index 0d2432e30c..a3837c51b9 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS set(GROUPED_CONV3D_BWD_DATA_BILINEAR xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt index f9922b1f37..24c608f4ba 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # XDL_DL_WMMA_KERNELS set(GROUPED_CONV3D_BWD_WEIGHT xdl/gndhwc_gkzyxc_gndhwk/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/CMakeLists.txt index b8621e73aa..f2187485a9 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_bilinear/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS set(GROUPED_CONV3D_BWD_WEIGHT_BILINEAR xdl/device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/CMakeLists.txt index 5277b04ed4..bce32f3bdb 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight_scale/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS set(GROUPED_CONV3D_BWD_WEIGHT_SCALE xdl/device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index 849c99583a..7b5138ad9e 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_AND_WMMA_KERNELS set(GROUPED_CONV3D_FWD xdl/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt index e05477e2d9..9796c561c0 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS set(GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP) include(ShardInstantiation) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt index 1adf0fbb43..11dac1620a 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS set(GROUPED_CONV3D_FWD xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt index 6a4637d6e1..bd143bc0b9 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS set(GROUPED_CONV3D_FWD_BILINEAR xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt index 3ec28d78af..84b59b4849 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS set(GROUPED_CONV3D_FWD xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convinvscale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convinvscale/CMakeLists.txt index bbbe18bea6..6b284512be 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convinvscale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convinvscale/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS set(GROUPED_CONV3D_FWD_CONVINVSCALE xdl/device_grouped_conv3d_fwd_xdl_convinvscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/CMakeLists.txt index e20e3f49ed..90ddaacbca 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS set(GROUPED_CONV3D_FWD_CONVSCALE xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_add/CMakeLists.txt index 5f7062be97..e148b19839 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_add/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_add/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS set(GROUPED_CONV3D_FWD_CONVSCALE_ADD xdl/device_grouped_conv3d_fwd_xdl_convscale_add_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/CMakeLists.txt index 8ba52adcb8..e79da12b1a 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale_relu/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS set(GROUPED_CONV3D_FWD_CONVSCALE_RELU xdl/device_grouped_conv3d_fwd_xdl_convscale_relu_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_dynamic_op/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_dynamic_op/CMakeLists.txt index 3b8ebbffd1..715ce6630a 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_dynamic_op/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_dynamic_op/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS set(GROUPED_CONV3D_FWD_DYNAMIC_OP xdl/device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt index 47fc2655bb..0622b121b5 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS set(GROUPED_CONV3D_FWD_BILINEAR xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt index 1076249447..21ec8ecc6e 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS set(GROUPED_CONV3D_FWD_SCALEADD_AB xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/CMakeLists.txt index 1be1db7d1d..3495d88637 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS set(GROUPED_CONV3D_FWD_scaleadd_scaleadd_RELU xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/CMakeLists.txt index 65d92e3c2c..f909fe0356 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS set(GROUPED_CONVND_EXP_BWD_WEIGHT # Explicit instances are common for 2d and 3d diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt index 4a3e1a4ada..ba54c6ffb3 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt @@ -1,4 +1,7 @@ -# ONLY XDL_KERNELS +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_grouped_gemm_instance device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp @@ -33,4 +36,17 @@ add_instance_library(device_grouped_gemm_instance device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_bf16_bf16_mk_nk_mn_instance.cpp device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instance.cpp device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_i8_bf16_mk_nk_mn_instance.cpp + + device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instance.cpp + device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instance.cpp + + device_grouped_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_instance.cpp + device_grouped_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_instance.cpp + device_grouped_gemm_wmma_universal_f16_f16_f16_km_kn_mn_instance.cpp + device_grouped_gemm_wmma_universal_f16_f16_f16_km_nk_mn_instance.cpp + + device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_instance.cpp + device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_instance.cpp + device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_instance.cpp + device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..6f8b31e663 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_universal_instances< + BF16, + Col, + Row, + device_grouped_gemm_wmma_universal_km_kn_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..2839890dcf --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_universal_instances< + BF16, + Col, + Col, + device_grouped_gemm_wmma_universal_km_nk_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..c41dbdfc7b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_universal_instances< + BF16, + Row, + Row, + device_grouped_gemm_wmma_universal_mk_kn_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..55d1163900 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_universal_instances< + BF16, + Row, + Col, + device_grouped_gemm_wmma_universal_mk_nk_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..ea7eb0d615 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_kn_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_universal_instances< + F16, + Col, + Row, + device_grouped_gemm_wmma_universal_km_kn_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..816188c7ff --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_nk_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_universal_instances< + F16, + Col, + Col, + device_grouped_gemm_wmma_universal_km_nk_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..6680002d47 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& instances) +{ + + add_device_grouped_gemm_wmma_universal_instances< + F16, + Row, + Row, + device_grouped_gemm_wmma_universal_mk_kn_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..3e82899834 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& instances) +{ + + add_device_grouped_gemm_wmma_universal_instances< + F16, + Row, + Col, + device_grouped_gemm_wmma_universal_mk_nk_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..e93e9dff4a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,57 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using ADataType = F16; +using BDataType = F8; +using EDataType = F16; + +template +using device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instances = + std::tuple< + // clang-format off + //##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> + // clang-format on + >; + +void add_device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instances( + std::vector>>& instances) +{ + + add_device_grouped_gemm_wmma_universal_instances< + Row, + Row, + device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..e8f043d1f8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,57 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using ADataType = F8; +using BDataType = F16; +using EDataType = F16; + +template +using device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instances = + std::tuple< + // clang-format off + //##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> + // clang-format on + >; + +void add_device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instances( + std::vector>>& instances) +{ + + add_device_grouped_gemm_wmma_universal_instances< + Row, + Row, + device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/CMakeLists.txt index 167dfa9a6f..950784eb46 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS add_instance_library(device_grouped_gemm_bias_instance device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt index 8e9693e691..1997427462 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS add_instance_library(device_grouped_gemm_fastgelu_instance device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt index bc9c711d3a..e56df524cf 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS set(GROUPED_GEMM_FIXED_NK_INSTANCES) diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/CMakeLists.txt index e38c82d396..9d9a0e691c 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk_multi_abd/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS set(GROUPED_GEMM_FIXED_NK_MULTI_ABD_INSTANCES) diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/CMakeLists.txt index 0ba84c5cdc..76156968d6 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_KERNELS set(GROUPED_GEMM_TILE_LOOP_INSTANCES) diff --git a/library/src/tensor_operation_instance/gpu/image_to_column/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/image_to_column/CMakeLists.txt index 9e52a8157f..f7851c3c20 100644 --- a/library/src/tensor_operation_instance/gpu/image_to_column/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/image_to_column/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_instance_library(device_image_to_column_instance device_image_to_column_gnwc_1d_instance.cpp device_image_to_column_gnhwc_2d_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/max_pool_bwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/max_pool_bwd/CMakeLists.txt index 6925e800b2..fee531e0fa 100644 --- a/library/src/tensor_operation_instance/gpu/max_pool_bwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/max_pool_bwd/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + set(DEVICE_MAXPOOL_BWD_INSTANCES) list(APPEND DEVICE_MAXPOOL_BWD_INSTANCES device_max_pool_bwd_f16_instance.cpp device_max_pool_bwd_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/mha/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/mha/CMakeLists.txt index eba234cf3f..646c2cce1a 100644 --- a/library/src/tensor_operation_instance/gpu/mha/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/mha/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + set(FMHA_CPP_FOLDER ${CMAKE_CURRENT_BINARY_DIR}) set(FMHA_SRC_FOLDER ${CMAKE_SOURCE_DIR}/example/ck_tile/01_fmha/) set(CK_TILE_SRC_FOLDER ${CMAKE_SOURCE_DIR}/include/ck_tile/) diff --git a/library/src/tensor_operation_instance/gpu/normalization_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/normalization_bwd_data/CMakeLists.txt index 9f3dd9d94c..3ae305e802 100644 --- a/library/src/tensor_operation_instance/gpu/normalization_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/normalization_bwd_data/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + set(DEVICE_NORMALIZATION_bwd_data_INSTANCES) list(APPEND DEVICE_NORMALIZATION_bwd_data_INSTANCES diff --git a/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/CMakeLists.txt index 686fb5e665..c5cc9ee6c5 100644 --- a/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + set(DEVICE_NORMALIZATION_BWD_GAMMA_BETA_INSTANCES) list(APPEND DEVICE_NORMALIZATION_BWD_GAMMA_BETA_INSTANCES diff --git a/library/src/tensor_operation_instance/gpu/normalization_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/normalization_fwd/CMakeLists.txt index ce4c80943e..f6934d3b2e 100644 --- a/library/src/tensor_operation_instance/gpu/normalization_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/normalization_fwd/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + set(DEVICE_NORMALIZATION_FWD_INSTANCES) list(APPEND DEVICE_NORMALIZATION_FWD_INSTANCES diff --git a/library/src/tensor_operation_instance/gpu/permute_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/permute_scale/CMakeLists.txt index 427bf54ca1..8db355fbae 100644 --- a/library/src/tensor_operation_instance/gpu/permute_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/permute_scale/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_instance_library(device_permute_scale_instance device_permute_scale_1d_fp16_instances.cpp device_permute_scale_2d_fp16_instances.cpp diff --git a/library/src/tensor_operation_instance/gpu/pool2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/pool2d_fwd/CMakeLists.txt index 7372a56d9f..ec30684b9b 100644 --- a/library/src/tensor_operation_instance/gpu/pool2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/pool2d_fwd/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + set(DEVICE_POOL2D_FWD_INSTANCES) list(APPEND DEVICE_POOL2D_FWD_INSTANCES device_avg_pool2d_fwd_nhwc_f16_instance.cpp device_max_pool2d_fwd_nhwc_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/pool3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/pool3d_fwd/CMakeLists.txt index a40663bf75..a8bee5c4a7 100644 --- a/library/src/tensor_operation_instance/gpu/pool3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/pool3d_fwd/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + set(DEVICE_POOL3D_FWD_INSTANCES) list(APPEND DEVICE_POOL3D_FWD_INSTANCES device_avg_pool3d_fwd_ndhwc_f16_instance.cpp device_max_pool3d_fwd_ndhwc_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt index a5b4fb5df4..f433fcf76d 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ONLY XDL_AND_DL_KERNELS set(CONV2D_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp) set(CONV2D_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp) diff --git a/library/src/tensor_operation_instance/gpu/reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/reduce/CMakeLists.txt index 31ae7226f4..d034d000ef 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/reduce/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_instance_library(device_reduce_instance device_reduce_instance_blockwise_f16_f16_f16_min.cpp device_reduce_instance_blockwise_f16_f16_f16_max.cpp diff --git a/library/src/tensor_operation_instance/gpu/softmax/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/softmax/CMakeLists.txt index 6daaec738a..8826e208b1 100644 --- a/library/src/tensor_operation_instance/gpu/softmax/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/softmax/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_instance_library(device_softmax_instance device_softmax_f16_f16_instance_rank3_reduce1.cpp device_softmax_f16_f16_instance_rank3_reduce2.cpp diff --git a/library/src/tensor_operation_instance/gpu/transpose/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/transpose/CMakeLists.txt index 69e85a9c3d..cfb61a4e3f 100644 --- a/library/src/tensor_operation_instance/gpu/transpose/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/transpose/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_instance_library(device_transpose_instance device_transpose_instances_3d.cpp ) diff --git a/library/src/utility/CMakeLists.txt b/library/src/utility/CMakeLists.txt index 28883efef0..f7bc1acf59 100644 --- a/library/src/utility/CMakeLists.txt +++ b/library/src/utility/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_library(utility device_memory.cpp host_tensor.cpp diff --git a/profiler/CMakeLists.txt b/profiler/CMakeLists.txt index bdd7125ac1..15ed3c8c67 100644 --- a/profiler/CMakeLists.txt +++ b/profiler/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + include_directories(BEFORE ${CMAKE_CURRENT_LIST_DIR}/include ) diff --git a/profiler/include/profiler/profile_grouped_gemm_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_impl.hpp index 03a2ed3186..0ee0ee4c2e 100644 --- a/profiler/include/profiler/profile_grouped_gemm_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_impl.hpp @@ -42,10 +42,11 @@ bool profile_grouped_gemm_impl(int do_verification, const std::vector& StrideAs, const std::vector& StrideBs, const std::vector& StrideCs, - const std::vector& kbatches = {}, - int n_warmup = 1, - int n_iter = 10, - int instance_index = -1) + const std::vector& kbatches = {}, + int n_warmup = 1, + int n_iter = 10, + int instance_index = -1, + bool fail_if_no_supported_instance = false) { bool pass = true; // TODO: Fixme - we do not pass compute data type here but need it @@ -225,6 +226,7 @@ bool profile_grouped_gemm_impl(int do_verification, } } // profile device GEMM instances + int instances_supporting_all_batch_sizes = 0; for(auto& gemm_ptr : op_ptrs) { auto argument_ptr = @@ -268,6 +270,7 @@ bool profile_grouped_gemm_impl(int do_verification, kbatch_list = kbatches; } + bool all_batch_sizes_supported = true; for(std::size_t j = 0; j < kbatch_list.size(); j++) { auto kbatch_curr = kbatch_list[j]; @@ -367,10 +370,30 @@ bool profile_grouped_gemm_impl(int do_verification, } else { + all_batch_sizes_supported = false; std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem" << std::endl; } } + + // If all batch sizes were supported by this instance, the instance can be marked as + // 'supported' for this problem + if(all_batch_sizes_supported) + { + ++instances_supporting_all_batch_sizes; + } + } + + // Warn if not a single instance was supported + if(instances_supporting_all_batch_sizes == 0) + { + std::cout << "Warning! No instance found that supported all of the batch sizes." + << std::endl; + + if(fail_if_no_supported_instance) + { + return false; + } } if(time_kernel) @@ -384,6 +407,7 @@ bool profile_grouped_gemm_impl(int do_verification, std::cout << "grouped_gemm_instance (" << instance_index << "/" << num_kernel << "): Passed" << std::endl; } + return pass; } diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index b9f82af29d..71f1637653 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ckProfiler set(CK_PROFILER_OP_FILTER "" CACHE STRING "Filter for the operators to be profiled. Default is to include all") set(CK_PROFILER_INSTANCE_FILTER "" CACHE STRING "Filter for the kernels instances to be profiled. Default is to be the same as the operator filter") @@ -42,7 +45,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") endif() endif() -if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]") +if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]") if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND PROFILER_OPS profile_gemm_reduce.cpp) list(APPEND PROFILER_OPS profile_batched_gemm_add_relu_gemm_add.cpp) @@ -56,7 +59,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[1 list(APPEND PROFILER_OPS profile_grouped_gemm_tile_loop.cpp) list(APPEND PROFILER_OPS profile_grouped_gemm_multiply_tile_loop.cpp) endif() - if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") + if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]|gfx12") list(APPEND PROFILER_OPS profile_gemm_multiply_multiply_wp.cpp) list(APPEND PROFILER_OPS profile_gemm_ab_scale.cpp) list(APPEND PROFILER_OPS profile_gemm_blockscale_wp.cpp) @@ -87,7 +90,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx(9[45]|1[12])") list(APPEND PROFILER_OPS profile_gemm_multiply_multiply.cpp) endif() -if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]") +if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]") list(APPEND PROFILER_OPS profile_gemm_universal.cpp) list(APPEND PROFILER_OPS profile_batched_gemm.cpp) list(APPEND PROFILER_OPS profile_batched_gemm_b_scale.cpp) @@ -161,7 +164,7 @@ list(APPEND DEVICE_INSTANCES device_column_to_image_instance) list(APPEND DEVICE_INSTANCES device_transpose_instance) list(APPEND DEVICE_INSTANCES device_permute_scale_instance) -if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]") +if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]") if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) list(APPEND DEVICE_INSTANCES device_contraction_bilinear_instance) list(APPEND DEVICE_INSTANCES device_contraction_scale_instance) @@ -181,11 +184,11 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[1 list(APPEND DEVICE_INSTANCES device_grouped_gemm_tile_loop_instance) endif() list(APPEND DEVICE_INSTANCES device_batched_gemm_reduce_instance) - if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") + if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]|gfx12") list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_wp_instance) list(APPEND DEVICE_INSTANCES device_gemm_universal_preshuffle_instance) endif() - if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]") + if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]|gfx1[12]") list(APPEND DEVICE_INSTANCES device_gemm_ab_scale_instance) list(APPEND DEVICE_INSTANCES device_gemm_blockscale_wp_instance) endif() @@ -225,7 +228,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx(9[45]|1[12])") list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_instance) endif() -if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]") +if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]") list(APPEND DEVICE_INSTANCES device_gemm_universal_instance) list(APPEND DEVICE_INSTANCES device_batched_gemm_instance) list(APPEND DEVICE_INSTANCES device_gemm_b_scale_instance) diff --git a/script/infra_helper/capture_build_trace.js b/script/infra_helper/capture_build_trace.js new file mode 100644 index 0000000000..e484a815cc --- /dev/null +++ b/script/infra_helper/capture_build_trace.js @@ -0,0 +1,53 @@ +const puppeteer = require('puppeteer'); + +(async () => { + try { + // Launch the browser + const browser = await puppeteer.launch({ + args: [ + '--no-sandbox', + '--headless', + '--disable-gpu', + '--window-size=1920x1080' + ]}); + const page = await browser.newPage(); + await page.setViewport({ width: 1920, height: 1080 }); + await page.goto('https://ui.perfetto.dev'); + // Wait for the home page to be visible + console.log('Waiting for page to load...'); + await page.waitForSelector('.pf-home-page', { visible: true, timeout: 30000 }); + // Locate and click the Open trace button + const elements = await page.$$('li'); + let element = null; + for (const el of elements) { + const text = await el.evaluate(node => node.textContent); + if (text && text.includes('Open trace file')) { + element = el; + break; + } + } + if (element) { + const [fileChooser] = await Promise.all([ + page.waitForFileChooser(), + element.click() + ]); + await fileChooser.accept(['/workspace/ck_build_trace.json']); + } else { + throw new Error('Element not found'); + } + console.log('Waiting for data to load...'); + // Wait for the timeline element to be visible + await page.waitForSelector('.pf-track', { timeout: 30000 }); + // Wait for the data to finish loading + await page.waitForFunction(() => { + return !document.body.textContent.includes('Loading...'); + }, { timeout: 30000 }); + console.log('Capturing screenshot...'); + await page.screenshot({path: '/workspace/perfetto_snapshot_build.png'}); + console.log('Done capturing screenshot...'); + await browser.close(); + } catch (err) { + console.error(err); + process.exit(1); + } +})(); \ No newline at end of file diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 84c2ea090b..f8498c6c03 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + include_directories(BEFORE ${PROJECT_SOURCE_DIR}/ ${PROJECT_SOURCE_DIR}/profiler/include @@ -117,7 +120,7 @@ function(add_test_executable TEST_NAME) elseif(source_name_list MATCHES "_wmma") list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950) elseif(source_name_list MATCHES "_smfmac") - list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx908 gfx90a gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1153 gfx908 gfx90a gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) endif() set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP) add_executable(${TEST_NAME} ${ARGN}) @@ -209,9 +212,9 @@ function(add_gtest_executable TEST_NAME) elseif(source_name_list MATCHES "_wmma") list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950) elseif(source_name_list MATCHES "_smfmac") - list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx908 gfx90a gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1153 gfx908 gfx90a gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) elseif(source_name_list MATCHES "_mx") #only build mx example for gfx950 - list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1153 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) endif() set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP) add_executable(${TEST_NAME} ${ARGN}) diff --git a/test/batched_gemm/CMakeLists.txt b/test/batched_gemm/CMakeLists.txt index 4c325b2872..926fafcc97 100644 --- a/test/batched_gemm/CMakeLists.txt +++ b/test/batched_gemm/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_gtest_executable(test_batched_gemm_xdl test_batched_gemm_xdl.cpp) if(result EQUAL 0) target_link_libraries(test_batched_gemm_xdl PRIVATE utility device_batched_gemm_instance) diff --git a/test/batched_gemm_b_scale/CMakeLists.txt b/test/batched_gemm_b_scale/CMakeLists.txt index abc3d14ee1..6d8a7a5946 100644 --- a/test/batched_gemm_b_scale/CMakeLists.txt +++ b/test/batched_gemm_b_scale/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_gtest_executable(test_batched_gemm_b_scale_wmma test_batched_gemm_b_scale_wmma.cpp) if(result EQUAL 0) diff --git a/test/batched_gemm_gemm/CMakeLists.txt b/test/batched_gemm_gemm/CMakeLists.txt index 70d7420992..a12d5c3435 100644 --- a/test/batched_gemm_gemm/CMakeLists.txt +++ b/test/batched_gemm_gemm/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_gtest_executable(test_batched_gemm_gemm_fp16_xdl test_batched_gemm_gemm_fp16_xdl.cpp) if(result EQUAL 0) target_link_libraries(test_batched_gemm_gemm_fp16_xdl PRIVATE utility device_batched_gemm_gemm_instance) diff --git a/test/batched_gemm_multi_d/CMakeLists.txt b/test/batched_gemm_multi_d/CMakeLists.txt index d5e4c4fbe8..cf218b79e2 100644 --- a/test/batched_gemm_multi_d/CMakeLists.txt +++ b/test/batched_gemm_multi_d/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_gtest_executable(test_batched_gemm_multi_d test_batched_gemm_multi_d_dl.cpp) if(result EQUAL 0) target_link_libraries(test_batched_gemm_multi_d PRIVATE utility device_batched_gemm_multi_d_instance) diff --git a/test/batched_gemm_reduce/CMakeLists.txt b/test/batched_gemm_reduce/CMakeLists.txt index c5868e4d7a..4348c4b536 100644 --- a/test/batched_gemm_reduce/CMakeLists.txt +++ b/test/batched_gemm_reduce/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_test_executable(test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16_xdl.cpp) if(result EQUAL 0) target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE utility device_batched_gemm_reduce_instance) diff --git a/test/batched_gemm_softmax_gemm/CMakeLists.txt b/test/batched_gemm_softmax_gemm/CMakeLists.txt index c042d7e000..d982ae92e6 100644 --- a/test/batched_gemm_softmax_gemm/CMakeLists.txt +++ b/test/batched_gemm_softmax_gemm/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_gtest_executable(test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16_xdl.cpp) if(result EQUAL 0) add_custom_target(test_batched_gemm_softmax_gemm) diff --git a/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt b/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt index 2e09073540..d759d13456 100644 --- a/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt +++ b/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_custom_target(test_batched_gemm_softmax_gemm_permute) add_gtest_executable(test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16_xdl.cpp) if(result EQUAL 0) diff --git a/test/batchnorm/CMakeLists.txt b/test/batchnorm/CMakeLists.txt index 2a528f9c37..3d1a7f777f 100644 --- a/test/batchnorm/CMakeLists.txt +++ b/test/batchnorm/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_gtest_executable(test_batchnorm_fwd_rank_4 batchnorm_fwd_rank_4.cpp) add_gtest_executable(test_batchnorm_bwd_rank_4 batchnorm_bwd_rank_4.cpp) add_gtest_executable(test_batchnorm_infer_rank_4 batchnorm_infer_rank_4.cpp) diff --git a/test/block_to_ctile_map/CMakeLists.txt b/test/block_to_ctile_map/CMakeLists.txt index 97dfbb2b55..8507115efa 100644 --- a/test/block_to_ctile_map/CMakeLists.txt +++ b/test/block_to_ctile_map/CMakeLists.txt @@ -1 +1,4 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_gtest_executable(test_block_to_ctile_map test_block_to_ctile_map.cpp) \ No newline at end of file diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index d4cef34ce0..6378bb8e43 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_subdirectory(image_to_column) add_subdirectory(gemm) add_subdirectory(gemm_weight_preshuffle) @@ -27,6 +30,7 @@ add_subdirectory(add_rmsnorm2d_rdquant) # add_subdirectory(rmsnorm2d) add_subdirectory(gemm_block_scale) add_subdirectory(utility) +add_subdirectory(warp_gemm) add_subdirectory(reduce) add_subdirectory(core) add_subdirectory(epilogue) diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/CMakeLists.txt b/test/ck_tile/add_rmsnorm2d_rdquant/CMakeLists.txt index 64672e200b..2a05e4a87c 100644 --- a/test/ck_tile/add_rmsnorm2d_rdquant/CMakeLists.txt +++ b/test/ck_tile/add_rmsnorm2d_rdquant/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + function(create_tile_add_rmsnorm2d_rdquant_fwd SUFFIX) set(TILE_ADD_RMSNORM2D_RDQUANT_FWD "test_ck_tile_add_rmsnorm2d_rdquant_fwd_${SUFFIX}") message(DEBUG "adding ${TILE_ADD_RMSNORM2D_RDQUANT_FWD}") diff --git a/test/ck_tile/atomic_add_op/CMakeLists.txt b/test/ck_tile/atomic_add_op/CMakeLists.txt index 5dfb4d9db3..15343cd99e 100644 --- a/test/ck_tile/atomic_add_op/CMakeLists.txt +++ b/test/ck_tile/atomic_add_op/CMakeLists.txt @@ -1,2 +1,5 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_gtest_executable(test_atomic test_atomic.cpp) set(CTEST_OUTPUT_ON_FAILURE ON) diff --git a/test/ck_tile/batched_gemm/CMakeLists.txt b/test/ck_tile/batched_gemm/CMakeLists.txt index 9bcbc7352e..6f29225291 100644 --- a/test/ck_tile/batched_gemm/CMakeLists.txt +++ b/test/ck_tile/batched_gemm/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_gtest_executable(test_ck_tile_batched_gemm test_batched_gemm.cpp) endif() diff --git a/test/ck_tile/batched_transpose/CMakeLists.txt b/test/ck_tile/batched_transpose/CMakeLists.txt index 32a22a508a..c5cd9a3a5d 100644 --- a/test/ck_tile/batched_transpose/CMakeLists.txt +++ b/test/ck_tile/batched_transpose/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if(GPU_TARGETS MATCHES "gfx9") add_gtest_executable(test_ck_tile_batched_transpose test_batched_transpose.cpp) set_property(TARGET test_ck_tile_batched_transpose PROPERTY CXX_STANDARD 20) diff --git a/test/ck_tile/container/CMakeLists.txt b/test/ck_tile/container/CMakeLists.txt index f13f0dbedf..19b4cc451e 100644 --- a/test/ck_tile/container/CMakeLists.txt +++ b/test/ck_tile/container/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_gtest_executable(test_ck_tile_tuple_apply test_tuple_apply.cpp) if(result EQUAL 0) diff --git a/test/ck_tile/core/CMakeLists.txt b/test/ck_tile/core/CMakeLists.txt index a0479470dd..a46c2b1d41 100644 --- a/test/ck_tile/core/CMakeLists.txt +++ b/test/ck_tile/core/CMakeLists.txt @@ -1 +1,4 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_subdirectory(arch) diff --git a/test/ck_tile/core/arch/CMakeLists.txt b/test/ck_tile/core/arch/CMakeLists.txt index 9e7aa0e197..76b0ffa3b1 100644 --- a/test/ck_tile/core/arch/CMakeLists.txt +++ b/test/ck_tile/core/arch/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_subdirectory(mma) set(EXAMPLE_GEMM_COMPILE_OPTIONS) diff --git a/test/ck_tile/core/arch/mma/CMakeLists.txt b/test/ck_tile/core/arch/mma/CMakeLists.txt index 07eccdcd90..f5ecbf7f8b 100644 --- a/test/ck_tile/core/arch/mma/CMakeLists.txt +++ b/test/ck_tile/core/arch/mma/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # Currently ck_tile_gemm is only built on gfx94/gfx95 set(EXAMPLE_GEMM_COMPILE_OPTIONS) if(CK_USE_OCP_FP8) diff --git a/test/ck_tile/data_type/CMakeLists.txt b/test/ck_tile/data_type/CMakeLists.txt index a5713ac55c..17df115b9d 100644 --- a/test/ck_tile/data_type/CMakeLists.txt +++ b/test/ck_tile/data_type/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_gtest_executable(test_ck_tile_pk_int4 test_pk_int4.cpp) endif() diff --git a/test/ck_tile/elementwise/CMakeLists.txt b/test/ck_tile/elementwise/CMakeLists.txt index 860a23a62a..0f9d2b695e 100644 --- a/test/ck_tile/elementwise/CMakeLists.txt +++ b/test/ck_tile/elementwise/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_gtest_executable(test_ck_tile_elementwise_1d test_elementwise_1d.cpp) endif() diff --git a/test/ck_tile/epilogue/CMakeLists.txt b/test/ck_tile/epilogue/CMakeLists.txt index 4103b9d4db..2b3ffe33cc 100644 --- a/test/ck_tile/epilogue/CMakeLists.txt +++ b/test/ck_tile/epilogue/CMakeLists.txt @@ -1 +1,4 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_gtest_executable(test_ck_tile_cshuffle_epilogue test_cshuffle_epilogue.cpp) diff --git a/test/ck_tile/fmha/CMakeLists.txt b/test/ck_tile/fmha/CMakeLists.txt index 52accaf812..e591d5066f 100644 --- a/test/ck_tile/fmha/CMakeLists.txt +++ b/test/ck_tile/fmha/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # Keep in sync with example/ck_tile/01_fmha/CMakeLists.txt if(NOT SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx12") return() diff --git a/test/ck_tile/gemm/CMakeLists.txt b/test/ck_tile/gemm/CMakeLists.txt index 8365b9ff45..ee23ad2f63 100644 --- a/test/ck_tile/gemm/CMakeLists.txt +++ b/test/ck_tile/gemm/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # Currently ck_tile_gemm is only built on gfx94/gfx95 set(EXAMPLE_GEMM_COMPILE_OPTIONS) if(CK_USE_OCP_FP8) diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt index 1c4a25c8bd..8309b14f0a 100644 --- a/test/ck_tile/gemm_block_scale/CMakeLists.txt +++ b/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + set(TEST_GEMM_COMPILE_OPTIONS) if(CK_USE_OCP_FP8) list(APPEND TEST_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) @@ -6,11 +9,37 @@ endif() list(APPEND TEST_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") - # Typed Test Suite for GEMM Quantization - add_gtest_executable(test_tile_gemm_quant_typed - test_gemm_quant_typed.cpp + # Typed Test Suite for GEMM Quantization - split into multiple files to reduce compile time + + # AQuant tests + add_gtest_executable(test_tile_gemm_quant_aquant + test_gemm_quant_aquant.cpp ) - target_compile_options(test_tile_gemm_quant_typed PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + target_compile_options(test_tile_gemm_quant_aquant PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + # BQuant tests (without PreshuffleB) + add_gtest_executable(test_tile_gemm_quant_bquant + test_gemm_quant_bquant.cpp + ) + target_compile_options(test_tile_gemm_quant_bquant PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + # BQuant tests (with PreshuffleB) + add_gtest_executable(test_tile_gemm_quant_bquant_preshuffle + test_gemm_quant_bquant_preshuffle.cpp + ) + target_compile_options(test_tile_gemm_quant_bquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + # RowColQuant tests + add_gtest_executable(test_tile_gemm_quant_rowcol + test_gemm_quant_rowcol.cpp + ) + target_compile_options(test_tile_gemm_quant_rowcol PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + # TensorQuant tests + add_gtest_executable(test_tile_gemm_quant_tensor + test_gemm_quant_tensor.cpp + ) + target_compile_options(test_tile_gemm_quant_tensor PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) else() message(DEBUG "Skipping ck_tile quant gemm tests for current target") endif() diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant.cpp new file mode 100644 index 0000000000..9ba0b9c804 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant.cpp @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using AQuantGrouped = std::integral_constant; +using BQuantGrouped = std::integral_constant; +using RowColQuant = std::integral_constant; +using TensorQuant = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; +using GroupSize64 = ck_tile::QuantGroupShape>; + +// 2d block sizes for BQuant +using GroupSize2D8N = ck_tile::QuantGroupShape>; +using GroupSize2D16N = ck_tile::QuantGroupShape>; +using GroupSize2D32N = ck_tile::QuantGroupShape>; +using GroupSize2D64N = ck_tile::QuantGroupShape>; +using GroupSize2D128N = ck_tile::QuantGroupShape>; + +// Type combinations for AQuant tests +// Tuple format: +// clang-format off +using AQuantTypes = ::testing::Types< + // PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ) + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + // RRR layout (RowMajor A, RowMajor B, RowMajor C with RowMajor AQ) + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + // CRR layout (ColumnMajor A, RowMajor B, RowMajor C with RowMajor AQ) + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + // CCR layout (ColumnMajor A, ColumnMajor B, RowMajor C with ColumnMajor AQ) - NEW layout support + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + // RCR layout - with the Prefill BlockTile Config. + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + // PreshuffleQuant = false && TransposeC = true (with RowMajor AQ) + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + // PreshuffleQuant = true && TransposeC = false (with RowMajor AQ - PreshuffleQuant only supports RowMajor) + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + // PreshuffleQuant = true && TransposeC = true (with RowMajor AQ - PreshuffleQuant only supports RowMajor) + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for AQuant +TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmAQuant, AQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp index 08232f81be..38bd59b882 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp @@ -29,13 +29,14 @@ class TestCkTileGemmQuantBase : public ::testing::Test using ALayout = std::tuple_element_t<0, Tuple>; using BLayout = std::tuple_element_t<1, Tuple>; using CLayout = std::tuple_element_t<2, Tuple>; - using ADataType = std::tuple_element_t<3, Tuple>; - using BDataType = std::tuple_element_t<4, Tuple>; - using QDataType = std::tuple_element_t<5, Tuple>; - using CDataType = std::tuple_element_t<6, Tuple>; - static constexpr auto QuantType = std::tuple_element_t<7, Tuple>::value; - using GemmConfig = std::tuple_element_t<8, Tuple>; - using QuantGroupSize = std::tuple_element_t<9, Tuple>; + using AQLayout = std::tuple_element_t<3, Tuple>; + using ADataType = std::tuple_element_t<4, Tuple>; + using BDataType = std::tuple_element_t<5, Tuple>; + using QDataType = std::tuple_element_t<6, Tuple>; + using CDataType = std::tuple_element_t<7, Tuple>; + static constexpr auto QuantType = std::tuple_element_t<8, Tuple>::value; + using GemmConfig = std::tuple_element_t<9, Tuple>; + using QuantGroupSize = std::tuple_element_t<10, Tuple>; using AccDataType = float; // accumulate always in float // Get the quant-type specific data types from traits @@ -85,6 +86,9 @@ class TestCkTileGemmQuantBase : public ::testing::Test using TilePartitioner = ck_tile::GemmTile1DPartitioner; + // BQLayout is always ColumnMajor for BQuant + using BQLayout = ck_tile::tensor_layout::gemm::ColumnMajor; + using CodegenGemmTraits = ck_tile::TileGemmQuantTraits +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using BQuantGrouped = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; +using GroupSize64 = ck_tile::QuantGroupShape>; + +// 2d block sizes for BQuant +using GroupSize2D8N = ck_tile::QuantGroupShape>; +using GroupSize2D16N = ck_tile::QuantGroupShape>; +using GroupSize2D32N = ck_tile::QuantGroupShape>; +using GroupSize2D64N = ck_tile::QuantGroupShape>; +using GroupSize2D128N = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant tests (without PreshuffleB) +// Tuple format: +// clang-format off +using BQuantTypes = ::testing::Types< + // 1d cases with grouping only on k axis (AQLayout is always RowMajor for BQuant) + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + // 2d cases with grouping also on the n axis + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant (without PreshuffleB) +TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuantTypes); + +// BQuant tests +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp new file mode 100644 index 0000000000..59b267842f --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using BQuantGrouped = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant tests with PreshuffleB +// Tuple format: +// clang-format off +using BPreshuffleBQuantTypes = ::testing::Types< + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant with PreshuffleB +TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshuffleBQuantTypes); + +// BQuant PreshuffleB tests +TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantPreshuffleTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 5e610cb76b..3b62d8073e 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -53,6 +53,13 @@ struct GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; +struct GemmConfigPrefill : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; +}; + struct GemmConfigPreshuffleQuant : public GemmConfigBase { static constexpr bool PreshuffleQuant = true; @@ -128,6 +135,7 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBaseis_row_major(ALayout{})); + const ck_tile::index_t stride_B = + ck_tile::get_default_stride(K, N, 0, this->is_row_major(BLayout{})); + const ck_tile::index_t stride_C = + ck_tile::get_default_stride(M, N, 0, this->is_row_major(CLayout{})); // AQuant uses grouped quantization for A matrix const ck_tile::index_t AQK = ck_tile::integer_divide_ceil(K, QuantGroupSize::kK); + // AQLayout is parameterized in the test tuple (can be RowMajor or ColumnMajor for AQuant) const ck_tile::index_t stride_AQ = - ck_tile::get_default_stride(M, AQK, 0, this->is_row_major(ALayout{})); + ck_tile::get_default_stride(M, AQK, 0, this->is_row_major(AQLayout{})); // Generate test data ck_tile::HostTensor a_m_k( ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{}))); + // AQLayout is independently specified for each test case ck_tile::HostTensor aq_m_aqk( - ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, this->is_row_major(ALayout{}))); + ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, this->is_row_major(AQLayout{}))); ck_tile::HostTensor b_k_n( ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{}))); @@ -400,8 +413,9 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBaseis_row_major(ALayout{}))); ck_tile::HostTensor b_k_n( ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{}))); + // BQ is always ColumnMajor ck_tile::HostTensor bq_bqk_bqn( - ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, this->is_row_major(BLayout{}))); + ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, ck_tile::bool_constant{})); // Initialize data with random values ck_tile::FillUniformDistribution{-0.5f, 0.5f}(a_m_k); diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_rowcol.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_rowcol.cpp new file mode 100644 index 0000000000..5a58ed886a --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_rowcol.cpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using RowColQuant = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for RowColQuant tests +// Tuple format: +// clang-format off +using RowColQuantTypes = ::testing::Types< + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for RowColQuant +TYPED_TEST_SUITE(TestCkTileGemmRowColQuant, RowColQuantTypes); + +// RowColQuant tests +TYPED_TEST(TestCkTileGemmRowColQuant, RowColQuantTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_tensor.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_tensor.cpp new file mode 100644 index 0000000000..0fa4048dab --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_tensor.cpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using TensorQuant = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for TensorQuant tests +// Tuple format: +// clang-format off +using TensorQuantTypes = ::testing::Types< + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for TensorQuant +TYPED_TEST_SUITE(TestCkTileGemmTensorQuant, TensorQuantTypes); + +// TensorQuant tests +TYPED_TEST(TestCkTileGemmTensorQuant, TensorQuantTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp deleted file mode 100644 index 34bdf4ea38..0000000000 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp +++ /dev/null @@ -1,144 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "ck_tile/host.hpp" -#include "ck_tile/ops/gemm.hpp" - -#include -#include - -#include "test_gemm_quant_fixtures.hpp" - -// Type aliases for readability -using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; -using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; -using FP8 = ck_tile::fp8_t; -using BF8 = ck_tile::bf8_t; -using Half = ck_tile::half_t; -using PkInt4 = ck_tile::pk_int4_t; -using AQuantGrouped = std::integral_constant; -using BQuantGrouped = std::integral_constant; -using RowColQuant = std::integral_constant; -using TensorQuant = std::integral_constant; -using GroupSize = ck_tile::QuantGroupShape>; -using GroupSize64 = ck_tile::QuantGroupShape>; - -// 2d block sizes for BQuant -using GroupSize2D8N = ck_tile::QuantGroupShape>; -using GroupSize2D16N = ck_tile::QuantGroupShape>; -using GroupSize2D32N = ck_tile::QuantGroupShape>; -using GroupSize2D64N = ck_tile::QuantGroupShape>; -using GroupSize2D128N = ck_tile::QuantGroupShape>; - -// Type combinations for each quantization type -// clang-format off -using AQuantTypes = ::testing::Types< - // PreshuffleQuant = false && TransposeC = false - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - // PreshuffleQuant = false && TransposeC = true - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - // PreshuffleQuant = true && TransposeC = false - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - // PreshuffleQuant = true && TransposeC = true - std::tuple, - std::tuple, - std::tuple, - std::tuple ->; -// clang-format on - -// clang-format off -using BQuantTypes = ::testing::Types< - // 1d cases with grouping only on k axis - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - // 2d cases with grouping also on the n axis - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple ->; -// clang-format on - -// clang-format off -using BPreshuffleBQuantTypes = ::testing::Types< - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - std::tuple, - std::tuple, - std::tuple, - std::tuple ->; -// clang-format on - -// clang-format off -using RowColQuantTypes = ::testing::Types< - std::tuple, - std::tuple ->; -// clang-format on - -// clang-format off -using TensorQuantTypes = ::testing::Types< - std::tuple, - std::tuple ->; -// clang-format on - -// Test suites for each quantization type -TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantTypes); -TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuantTypes); -TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshuffleBQuantTypes); -TYPED_TEST_SUITE(TestCkTileGemmRowColQuant, RowColQuantTypes); -TYPED_TEST_SUITE(TestCkTileGemmTensorQuant, TensorQuantTypes); - -#include "test_gemm_quant_ut_cases.inc" diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_ut_cases.inc b/test/ck_tile/gemm_block_scale/test_gemm_quant_ut_cases.inc deleted file mode 100644 index a88483fe3e..0000000000 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_ut_cases.inc +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -// AQuant tests -TYPED_TEST(TestCkTileGemmAQuant, AQuantGroupedTest) -{ - this->run_test_with_validation(1024, 1024, 1024); -} - -// BQuant tests -TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest) -{ - this->run_test_with_validation(1024, 1024, 1024); -} - -// BQuant tests -TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantGroupedTest) -{ - this->run_test_with_validation(1024, 1024, 1024); -} -// RowColQuant tests -TYPED_TEST(TestCkTileGemmRowColQuant, RowColQuantTest) -{ - this->run_test_with_validation(1024, 1024, 1024); -} - -// TensorQuant tests -TYPED_TEST(TestCkTileGemmTensorQuant, TensorQuantTest) -{ - this->run_test_with_validation(1024, 1024, 1024); -} diff --git a/test/ck_tile/gemm_multi_abd/CMakeLists.txt b/test/ck_tile/gemm_multi_abd/CMakeLists.txt index 8f9b694a3b..2dccf9cd60 100644 --- a/test/ck_tile/gemm_multi_abd/CMakeLists.txt +++ b/test/ck_tile/gemm_multi_abd/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # Currently ck_tile is only built on gfx9 set(EXAMPLE_GEMM_COMPILE_OPTIONS) if(CK_USE_OCP_FP8) diff --git a/test/ck_tile/gemm_multi_d/CMakeLists.txt b/test/ck_tile/gemm_multi_d/CMakeLists.txt index 143fb9dc40..1c04127b2d 100644 --- a/test/ck_tile/gemm_multi_d/CMakeLists.txt +++ b/test/ck_tile/gemm_multi_d/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + set(EXAMPLE_GEMM_COMPILE_OPTIONS) if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) diff --git a/test/ck_tile/gemm_streamk/CMakeLists.txt b/test/ck_tile/gemm_streamk/CMakeLists.txt index 7b1bc6f4f2..d8b4ff945f 100644 --- a/test/ck_tile/gemm_streamk/CMakeLists.txt +++ b/test/ck_tile/gemm_streamk/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + set(EXAMPLE_GEMM_COMPILE_OPTIONS) if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp b/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp index 72b4c52831..213702551a 100644 --- a/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp @@ -250,8 +250,7 @@ class TestCkTileStreamK : public ::testing::Test K, stride_A, stride_B, - stride_C, - reduction_strategy}; + stride_C}; ck_tile::index_t num_accumulations_per_tile = invoke_streamk( diff --git a/test/ck_tile/gemm_tile_engine/CMakeLists.txt b/test/ck_tile/gemm_tile_engine/CMakeLists.txt index 8ad0f2af75..33effcc120 100644 --- a/test/ck_tile/gemm_tile_engine/CMakeLists.txt +++ b/test/ck_tile/gemm_tile_engine/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # ============================================================================ # GEMM Tile Engine Unit Tests # @@ -84,7 +87,7 @@ function(create_individual_gemm_test_target datatype layout config_name trait ti target_compile_options(${target_name} PRIVATE -DCK_TILE_USE_OCP_FP8) endif() - message(STATUS " Created test target: ${target_name}") + message(DEBUG " Created test target: ${target_name}") endfunction() # ============================================================================ @@ -135,11 +138,11 @@ function(build_gemm_test_targets datatype layout config_name) # Verify kernel list file was generated if(NOT EXISTS ${working_path}/gemm_kernel_list.txt) - message(STATUS "No kernels found for ${datatype}_${layout}_${config_name} (validation filtered out all combinations)") + message(DEBUG "No kernels found for ${datatype}_${layout}_${config_name} (validation filtered out all combinations)") return() endif() - message(STATUS "Building tests for ${datatype}_${layout}_${config_name}") + message(DEBUG "Building tests for ${datatype}_${layout}_${config_name}") # STEP 2a: Extract test parameters from config set(test_params_file "${working_path}/test_params.hpp") diff --git a/test/ck_tile/gemm_weight_preshuffle/CMakeLists.txt b/test/ck_tile/gemm_weight_preshuffle/CMakeLists.txt index 90803bd9d5..86db48335d 100644 --- a/test/ck_tile/gemm_weight_preshuffle/CMakeLists.txt +++ b/test/ck_tile/gemm_weight_preshuffle/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # Currently ck_tile_gemm is only built on gfx94/gfx95 set(EXAMPLE_GEMM_COMPILE_OPTIONS) if(CK_USE_OCP_FP8) diff --git a/test/ck_tile/grouped_gemm/CMakeLists.txt b/test/ck_tile/grouped_gemm/CMakeLists.txt index 4fd5c82ae9..b30dc2a867 100644 --- a/test/ck_tile/grouped_gemm/CMakeLists.txt +++ b/test/ck_tile/grouped_gemm/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # Currently ck_tile is only built on gfx9 if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_gtest_executable(test_ck_tile_grouped_gemm test_grouped_gemm.cpp) diff --git a/test/ck_tile/grouped_gemm_multi_d/CMakeLists.txt b/test/ck_tile/grouped_gemm_multi_d/CMakeLists.txt index 845da28b5d..f86da3c4d5 100644 --- a/test/ck_tile/grouped_gemm_multi_d/CMakeLists.txt +++ b/test/ck_tile/grouped_gemm_multi_d/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + set(EXAMPLE_GEMM_COMPILE_OPTIONS) if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) diff --git a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp index 5d23e73146..4397668a5d 100644 --- a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp +++ b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp @@ -124,7 +124,12 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test using GemmPipelineProblem = ck_tile::GemmPipelineProblem; - using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; + using BaseGemmPipeline = std::conditional_t< + Config::Pipeline_ == (PipelineType::Memory), + ck_tile::BaseGemmPipelineAgBgCrMem, + std::conditional_t, + ck_tile::BaseGemmPipelineAgBgCrCompV4>>; const ck_tile::index_t k_grain = gemm_descs[0].k_batch * Config::K_Tile_; const ck_tile::index_t K_split = diff --git a/test/ck_tile/grouped_gemm_preshuffle/CMakeLists.txt b/test/ck_tile/grouped_gemm_preshuffle/CMakeLists.txt index 68120efc7e..08b413aea9 100644 --- a/test/ck_tile/grouped_gemm_preshuffle/CMakeLists.txt +++ b/test/ck_tile/grouped_gemm_preshuffle/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + set(EXAMPLE_GEMM_COMPILE_OPTIONS) if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) diff --git a/test/ck_tile/grouped_gemm_quant/CMakeLists.txt b/test/ck_tile/grouped_gemm_quant/CMakeLists.txt index c9399e54dc..2bd2571993 100644 --- a/test/ck_tile/grouped_gemm_quant/CMakeLists.txt +++ b/test/ck_tile/grouped_gemm_quant/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + set(EXAMPLE_GEMM_COMPILE_OPTIONS) if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) diff --git a/test/ck_tile/image_to_column/CMakeLists.txt b/test/ck_tile/image_to_column/CMakeLists.txt index 8873a846fc..0a458acb85 100644 --- a/test/ck_tile/image_to_column/CMakeLists.txt +++ b/test/ck_tile/image_to_column/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_gtest_executable(test_tile_image_to_column test_tile_image_to_column.cpp) endif() diff --git a/test/ck_tile/layernorm2d/CMakeLists.txt b/test/ck_tile/layernorm2d/CMakeLists.txt index e924f39e7a..9314890447 100644 --- a/test/ck_tile/layernorm2d/CMakeLists.txt +++ b/test/ck_tile/layernorm2d/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + function(create_tile_layernorm2d_fwd SUFFIX) set(TEST_CK_TILE_LAYERNORM2D_FWD "test_ck_tile_layernorm2d_fwd_${SUFFIX}") diff --git a/test/ck_tile/memory_copy/CMakeLists.txt b/test/ck_tile/memory_copy/CMakeLists.txt index 5311e5060a..b754049848 100644 --- a/test/ck_tile/memory_copy/CMakeLists.txt +++ b/test/ck_tile/memory_copy/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if(GPU_TARGETS MATCHES "gfx950") add_gtest_executable(test_memory_copy test_copy.cpp) endif() diff --git a/test/ck_tile/moe_smoothquant/CMakeLists.txt b/test/ck_tile/moe_smoothquant/CMakeLists.txt index 019e87323f..1af92a26f0 100644 --- a/test/ck_tile/moe_smoothquant/CMakeLists.txt +++ b/test/ck_tile/moe_smoothquant/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") function (add_moe_smoothquant_test TARGET_NAME MAIN_SRC) message(DEBUG "adding ${TARGET_NAME}") diff --git a/test/ck_tile/moe_sorting/CMakeLists.txt b/test/ck_tile/moe_sorting/CMakeLists.txt index 48d8e1392f..525e39571f 100644 --- a/test/ck_tile/moe_sorting/CMakeLists.txt +++ b/test/ck_tile/moe_sorting/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # Currently ck_tile is only built on gfx90a, gfx942, gfx950, gfx11 and gfx12 if(GPU_TARGETS MATCHES "gfx942|gfx950|gfx90a|gfx11|gfx12") diff --git a/test/ck_tile/permute/CMakeLists.txt b/test/ck_tile/permute/CMakeLists.txt index 8574813be3..f21d160acd 100644 --- a/test/ck_tile/permute/CMakeLists.txt +++ b/test/ck_tile/permute/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") function(add_permute_test TARGET_NAME MAIN_SRC) diff --git a/test/ck_tile/pooling/CMakeLists.txt b/test/ck_tile/pooling/CMakeLists.txt index 83c36cb321..2a7d6a140d 100644 --- a/test/ck_tile/pooling/CMakeLists.txt +++ b/test/ck_tile/pooling/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_gtest_executable(test_ck_tile_pooling test_pooling.cpp) endif() diff --git a/test/ck_tile/reduce/CMakeLists.txt b/test/ck_tile/reduce/CMakeLists.txt index 0ba5974f6c..073bcd2836 100644 --- a/test/ck_tile/reduce/CMakeLists.txt +++ b/test/ck_tile/reduce/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_gtest_executable(test_ck_tile_reduce2d test_reduce2d.cpp) if(result EQUAL 0) diff --git a/test/ck_tile/rmsnorm2d/CMakeLists.txt b/test/ck_tile/rmsnorm2d/CMakeLists.txt index c60d73aafd..7bd0baed90 100644 --- a/test/ck_tile/rmsnorm2d/CMakeLists.txt +++ b/test/ck_tile/rmsnorm2d/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + function(create_tile_rmsnorm2d_fwd SUFFIX) set(TILE_RMSNORM2D_FWD "test_ck_tile_rmsnorm2d_fwd_${SUFFIX}") diff --git a/test/ck_tile/slice_tile/CMakeLists.txt b/test/ck_tile/slice_tile/CMakeLists.txt index d0d1a4ee00..0291a78a1d 100644 --- a/test/ck_tile/slice_tile/CMakeLists.txt +++ b/test/ck_tile/slice_tile/CMakeLists.txt @@ -1 +1,4 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_test_executable(test_slice_tile test_slice_tile.cpp) \ No newline at end of file diff --git a/test/ck_tile/smoothquant/CMakeLists.txt b/test/ck_tile/smoothquant/CMakeLists.txt index 381923803f..183ce5e74d 100644 --- a/test/ck_tile/smoothquant/CMakeLists.txt +++ b/test/ck_tile/smoothquant/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") function (add_smoothquant_test TARGET_NAME MAIN_SRC) message(DEBUG "adding ${TARGET_NAME}") diff --git a/test/ck_tile/topk_softmax/CMakeLists.txt b/test/ck_tile/topk_softmax/CMakeLists.txt index cd524eca01..0fd59a0c72 100644 --- a/test/ck_tile/topk_softmax/CMakeLists.txt +++ b/test/ck_tile/topk_softmax/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + function(add_tile_topk_softmax_test SUFFIX) set(TEST_NAME "test_ck_tile_topk_softmax_${SUFFIX}") add_test_executable(${TEST_NAME} test_topk_softmax_${SUFFIX}.cpp test_topk_softmax_api.cpp) diff --git a/test/ck_tile/utility/CMakeLists.txt b/test/ck_tile/utility/CMakeLists.txt index c57cafca5a..aa15293411 100644 --- a/test/ck_tile/utility/CMakeLists.txt +++ b/test/ck_tile/utility/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + message("-- Adding: test/ck_tile/utility/") # Add print tests diff --git a/test/ck_tile/utility/print/CMakeLists.txt b/test/ck_tile/utility/print/CMakeLists.txt index 888f23b4c5..c339546c8d 100644 --- a/test/ck_tile/utility/print/CMakeLists.txt +++ b/test/ck_tile/utility/print/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # Print utility tests add_gtest_executable(test_print_sequence test_print_sequence.cpp) add_gtest_executable(test_print_array test_print_array.cpp) diff --git a/test/ck_tile/warp_gemm/CMakeLists.txt b/test/ck_tile/warp_gemm/CMakeLists.txt new file mode 100644 index 0000000000..664ebc003b --- /dev/null +++ b/test/ck_tile/warp_gemm/CMakeLists.txt @@ -0,0 +1,3 @@ +if(GPU_TARGETS MATCHES "gfx95") + add_gtest_executable(test_ck_tile_wg_16x16x128_fp4 test_f32_16x16x128_fp4.cpp) +endif() diff --git a/test/ck_tile/warp_gemm/test_f32_16x16x128_fp4.cpp b/test/ck_tile/warp_gemm/test_f32_16x16x128_fp4.cpp new file mode 100644 index 0000000000..7878fda618 --- /dev/null +++ b/test/ck_tile/warp_gemm/test_f32_16x16x128_fp4.cpp @@ -0,0 +1,192 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "ck_tile/host.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" + +using namespace ck_tile; + +template +struct WGDispCase +{ + using AType = A; + using BType = B; + using AccType = Acc; + static constexpr index_t MPerWave = M; + static constexpr index_t NPerWave = N; + static constexpr index_t KPerWave = K; + static constexpr bool kTransposeC = TransposeC; + static constexpr bool kSwizzleA = SwizzleA; + static constexpr bool kUSS = UseStructuredSparsity; + static constexpr WGAttrNumAccessEnum kNA = NA; +}; + +using WGDispatcherTypesList = + ::testing::Types>; + +template +struct WarpGemmKernel +{ + static constexpr int kBlockSize = 64; + __device__ void operator()(void* A, void* B, void* C, void* ScaleA, void* ScaleB) const + { + using WarpGemm = ck_tile::WarpGemmDispatcher; + // A: [M,K] row-major (packed) + const auto a_view = ck_tile::make_naive_tensor_view( + static_cast(A), + ck_tile::make_tuple(M, K), + ck_tile::make_tuple(K, ck_tile::number<1>{}), + ck_tile::number{}, + ck_tile::number<1>{}); + // B: expose as logical [N,K] with strides (1, N) over the original row-major [K,N] buffer + const auto b_view = ck_tile::make_naive_tensor_view( + static_cast(B), + ck_tile::make_tuple(N, K), + ck_tile::make_tuple(K, ck_tile::number<1>{}), + ck_tile::number{}, + ck_tile::number<1>{}); + // C: [M,N] row-major (packed) + const auto c_view = ck_tile::make_naive_tensor_view( + static_cast(C), + ck_tile::make_tuple(M, N), + ck_tile::make_tuple(N, ck_tile::number<1>{}), + ck_tile::number{}, + ck_tile::number<1>{}); + + using AWarpTensor = typename WarpGemm::AWarpTensor; + using BWarpTensor = typename WarpGemm::BWarpTensor; + using CWarpTensor = typename WarpGemm::CWarpTensor; + + constexpr auto a_len = AWarpTensor::get_tile_distribution().get_lengths(); + constexpr auto b_len = BWarpTensor::get_tile_distribution().get_lengths(); + constexpr auto c_len = CWarpTensor::get_tile_distribution().get_lengths(); + + auto a_win = ck_tile::make_tile_window( + a_view, a_len, ck_tile::make_multi_index(0, 0), AWarpTensor::get_tile_distribution()); + auto b_win = ck_tile::make_tile_window( + b_view, b_len, ck_tile::make_multi_index(0, 0), BWarpTensor::get_tile_distribution()); + auto c_win = ck_tile::make_tile_window( + c_view, c_len, ck_tile::make_multi_index(0, 0), CWarpTensor::get_tile_distribution()); + + AWarpTensor a_tile; + BWarpTensor b_tile; + ck_tile::load_tile(a_tile, a_win); + ck_tile::load_tile(b_tile, b_win); + + auto scale_a = static_cast(static_cast(ScaleA)[0].get()); + auto scale_b = static_cast(static_cast(ScaleB)[0].get()); + + auto c_tile = WarpGemm{}.template operator()<0, 0>(a_tile, b_tile, scale_a, scale_b); + + ck_tile::store_tile(c_win, c_tile); + } +}; + +template +static void RunWarpGemmCase(const ck_tile::HostTensor& A, + const ck_tile::HostTensor& B, + const ck_tile::HostTensor& ScaleA, + const ck_tile::HostTensor& ScaleB, + ck_tile::HostTensor& C) +{ + ck_tile::DeviceMem Ad(A), Bd(B), Cd(C), SAd(ScaleA), SBd(ScaleB); + dim3 grid(1), block{64}; + + using Kernel = WarpGemmKernel; + + (void)ck_tile::launch_kernel(ck_tile::stream_config{nullptr, true, 0, 0, 1}, + ck_tile::make_kernel(Kernel{}, + grid, + block, + 0, + Ad.GetDeviceBuffer(), + Bd.GetDeviceBuffer(), + Cd.GetDeviceBuffer(), + SAd.GetDeviceBuffer(), + SBd.GetDeviceBuffer())); + + Cd.FromDevice(C.mData.data()); +} + +template +class WGRuntimeTest : public ::testing::Test +{ +}; + +TYPED_TEST_SUITE(WGRuntimeTest, WGDispatcherTypesList); + +TYPED_TEST(WGRuntimeTest, Compare_Dispatcher_MakeWG) +{ + using Case = TypeParam; + + using AType = typename Case::AType; + using BType = typename Case::BType; + using CType = typename Case::AccType; + using ck_tile::e8m0_t; + + constexpr index_t M = Case::MPerWave; + constexpr index_t N = Case::NPerWave; + constexpr index_t K = Case::KPerWave; + + auto ScaleA = e8m0_t{2.f}; + auto ScaleB = e8m0_t{4.f}; + + ck_tile::HostTensor A({M, K}); + ck_tile::HostTensor B({N, K}); + ck_tile::HostTensor C({M, N}); + ck_tile::HostTensor sA({M, 1}); + ck_tile::HostTensor sB({N, 1}); + + ck_tile::FillUniformDistribution{-5.f, 5.f}(A); + ck_tile::FillUniformDistribution{-5.f, 5.f}(B); + C.SetZero(); + ck_tile::FillConstant{ScaleA}(sA); + ck_tile::FillConstant{ScaleB}(sB); + + RunWarpGemmCase(A, B, sA, sB, C); + + ck_tile::HostTensor C_ref({M, N}); + C_ref.SetZero(); + ck_tile::reference_mx_gemm( + A, B.transpose(), C_ref, sA, sB.transpose()); + + EXPECT_TRUE(ck_tile::check_err(C, C_ref, "Warp gemm result error.")); +} diff --git a/test/contraction/CMakeLists.txt b/test/contraction/CMakeLists.txt index 3ba0d82f0e..90541b67a1 100644 --- a/test/contraction/CMakeLists.txt +++ b/test/contraction/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if((DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64") OR NOT DEFINED DTYPES) add_gtest_executable(test_contraction test_contraction_xdl.cpp) if(result EQUAL 0) diff --git a/test/conv_tensor_rearrange/CMakeLists.txt b/test/conv_tensor_rearrange/CMakeLists.txt index 05ca4a9ffb..361bb960df 100644 --- a/test/conv_tensor_rearrange/CMakeLists.txt +++ b/test/conv_tensor_rearrange/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_gtest_executable(test_conv_tensor_rearrange test_conv_tensor_rearrange.cpp) target_link_libraries(test_conv_tensor_rearrange PRIVATE utility device_image_to_column_instance device_column_to_image_instance) diff --git a/test/conv_util/CMakeLists.txt b/test/conv_util/CMakeLists.txt index 7a46039f15..a5e2f1563b 100644 --- a/test/conv_util/CMakeLists.txt +++ b/test/conv_util/CMakeLists.txt @@ -1,2 +1,5 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_gtest_executable(test_conv_util conv_util.cpp) target_link_libraries(test_conv_util PRIVATE utility) diff --git a/test/convnd_bwd_data/CMakeLists.txt b/test/convnd_bwd_data/CMakeLists.txt index e68a9b243c..3919c7c069 100644 --- a/test/convnd_bwd_data/CMakeLists.txt +++ b/test/convnd_bwd_data/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_gtest_executable(test_convnd_bwd_data convnd_bwd_data_xdl.cpp) if(result EQUAL 0) target_link_libraries(test_convnd_bwd_data PRIVATE utility device_conv1d_bwd_data_instance device_conv2d_bwd_data_instance device_conv3d_bwd_data_instance) diff --git a/test/convnd_fwd/CMakeLists.txt b/test/convnd_fwd/CMakeLists.txt index ba6d16a0d5..3da6ffd164 100644 --- a/test/convnd_fwd/CMakeLists.txt +++ b/test/convnd_fwd/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_gtest_executable(test_convnd_fwd convnd_fwd_xdl.cpp) if(result EQUAL 0) target_link_libraries(test_convnd_fwd PRIVATE utility device_conv2d_fwd_instance) diff --git a/test/data_type/CMakeLists.txt b/test/data_type/CMakeLists.txt index 32d5464e8f..69a991143c 100644 --- a/test/data_type/CMakeLists.txt +++ b/test/data_type/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # temporarily disable flaky test for all architectures add_definitions(-DCK_SKIP_FLAKY_F8_TEST) set(CK_SKIP_FLAKY_F8_TEST "ON") diff --git a/test/elementwise_normalization/CMakeLists.txt b/test/elementwise_normalization/CMakeLists.txt index aed67901b5..c7007cb03b 100644 --- a/test/elementwise_normalization/CMakeLists.txt +++ b/test/elementwise_normalization/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_custom_target(test_elementwise_normalization) add_gtest_executable(test_elementwise_layernorm_fp16 test_elementwise_layernorm_fp16.cpp) if(result EQUAL 0) diff --git a/test/gemm/CMakeLists.txt b/test/gemm/CMakeLists.txt index f88a134041..93432f3fc8 100644 --- a/test/gemm/CMakeLists.txt +++ b/test/gemm/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_test_executable(test_gemm_fp32 gemm_fp32.cpp) if(result EQUAL 0) target_link_libraries(test_gemm_fp32 PRIVATE utility device_gemm_instance) diff --git a/test/gemm_add/CMakeLists.txt b/test/gemm_add/CMakeLists.txt index fe0a08c0c9..17bfadf95d 100644 --- a/test/gemm_add/CMakeLists.txt +++ b/test/gemm_add/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + # Implements test instances for MultipleD with xdl and wmma support. add_gtest_executable(test_gemm_add_xdl test_gemm_add_xdl.cpp) diff --git a/test/gemm_b_scale/CMakeLists.txt b/test/gemm_b_scale/CMakeLists.txt index 0bf8a024ea..517e2f01f6 100644 --- a/test/gemm_b_scale/CMakeLists.txt +++ b/test/gemm_b_scale/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_gtest_executable(test_gemm_b_scale_xdl test_gemm_b_scale_xdl.cpp) if(result EQUAL 0) target_link_libraries(test_gemm_b_scale_xdl PRIVATE utility device_gemm_b_scale_instance) diff --git a/test/gemm_blockscale_wp/CMakeLists.txt b/test/gemm_blockscale_wp/CMakeLists.txt index d198db0870..a095968035 100644 --- a/test/gemm_blockscale_wp/CMakeLists.txt +++ b/test/gemm_blockscale_wp/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if(GPU_TARGETS MATCHES "gfx9[45]|gfx12") add_gtest_executable(test_gemm_blockscale_wp_xdl_fp8 test_gemm_blockscale_wp_xdl_fp8.cpp) if(result EQUAL 0) diff --git a/test/gemm_layernorm/CMakeLists.txt b/test/gemm_layernorm/CMakeLists.txt index d912ce301c..e558d586ef 100644 --- a/test/gemm_layernorm/CMakeLists.txt +++ b/test/gemm_layernorm/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_gtest_executable(test_gemm_add_relu_add_layernorm_fp16 test_gemm_add_relu_add_layernorm_fp16.cpp) if(result EQUAL 0) diff --git a/test/gemm_multi_abd/CMakeLists.txt b/test/gemm_multi_abd/CMakeLists.txt index d700414b05..9b1454ca93 100644 --- a/test/gemm_multi_abd/CMakeLists.txt +++ b/test/gemm_multi_abd/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_gtest_executable(test_gemm_multi_abd_wmma test_gemm_multi_abd_wmma.cpp) if(result EQUAL 0) target_link_libraries(test_gemm_multi_abd_wmma PRIVATE utility device_gemm_multi_abd_instance) diff --git a/test/gemm_multiply_multiply_wp/CMakeLists.txt b/test/gemm_multiply_multiply_wp/CMakeLists.txt index 4302084a6f..c672fa3059 100644 --- a/test/gemm_multiply_multiply_wp/CMakeLists.txt +++ b/test/gemm_multiply_multiply_wp/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if(GPU_TARGETS MATCHES "gfx9[45]|gfx12") add_gtest_executable(test_gemm_multiply_multiply_wp_xdl_fp8 test_gemm_multiply_multiply_wp_xdl_fp8.cpp) if(result EQUAL 0) diff --git a/test/gemm_mx/CMakeLists.txt b/test/gemm_mx/CMakeLists.txt index 7a04d5378f..986bf239e0 100644 --- a/test/gemm_mx/CMakeLists.txt +++ b/test/gemm_mx/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_gtest_executable(test_gemm_mx test_gemm_mx.cpp) if(result EQUAL 0) target_compile_options(test_gemm_mx PRIVATE -mavx512f) diff --git a/test/gemm_reduce/CMakeLists.txt b/test/gemm_reduce/CMakeLists.txt index ae2246e628..5d4b813890 100644 --- a/test/gemm_reduce/CMakeLists.txt +++ b/test/gemm_reduce/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_test_executable(test_gemm_reduce_fp16 gemm_reduce_fp16.cpp) if(result EQUAL 0) diff --git a/test/gemm_split_k/CMakeLists.txt b/test/gemm_split_k/CMakeLists.txt index 4b66dddef9..2e64c0fd8d 100644 --- a/test/gemm_split_k/CMakeLists.txt +++ b/test/gemm_split_k/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_gtest_executable(test_gemm_splitk test_gemm_splitk_xdl.cpp) if(result EQUAL 0) target_link_libraries(test_gemm_splitk PRIVATE utility device_gemm_splitk_instance) diff --git a/test/gemm_universal/CMakeLists.txt b/test/gemm_universal/CMakeLists.txt index 0a68622ebe..5be42aae90 100644 --- a/test/gemm_universal/CMakeLists.txt +++ b/test/gemm_universal/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_gtest_executable(test_gemm_universal_wmma_fp16 test_gemm_universal_wmma_fp16.cpp) if(result EQUAL 0) target_link_libraries(test_gemm_universal_wmma_fp16 PRIVATE utility device_gemm_universal_instance) diff --git a/test/gemm_universal_preshuffle/CMakeLists.txt b/test/gemm_universal_preshuffle/CMakeLists.txt index 0d8955f6a4..1abc4391bb 100644 --- a/test/gemm_universal_preshuffle/CMakeLists.txt +++ b/test/gemm_universal_preshuffle/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if(GPU_TARGETS MATCHES "gfx9[45]|gfx12") add_gtest_executable(test_gemm_universal_preshuffle_xdl_fp8 test_gemm_universal_preshuffle_xdl_fp8.cpp) if(result EQUAL 0) diff --git a/test/gemm_universal_reduce/CMakeLists.txt b/test/gemm_universal_reduce/CMakeLists.txt index dab9de44c0..39bdc07874 100644 --- a/test/gemm_universal_reduce/CMakeLists.txt +++ b/test/gemm_universal_reduce/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_gtest_executable(test_gemm_universal_reduce_bf16_wmma test_gemm_universal_reduce_bf16_wmma.cpp) if(result EQUAL 0) target_link_libraries(test_gemm_universal_reduce_bf16_wmma PRIVATE utility device_gemm_universal_reduce_instance) diff --git a/test/gemm_universal_streamk/CMakeLists.txt b/test/gemm_universal_streamk/CMakeLists.txt index 6e42bfe396..1610f4ae09 100644 --- a/test/gemm_universal_streamk/CMakeLists.txt +++ b/test/gemm_universal_streamk/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_gtest_executable(test_gemm_universal_streamk_fp16 test_gemm_universal_streamk_xdl_fp16.cpp) if(result EQUAL 0) target_link_libraries(test_gemm_universal_streamk_fp16 PRIVATE utility device_gemm_universal_streamk_instance) diff --git a/test/grouped_convnd_bwd_data/CMakeLists.txt b/test/grouped_convnd_bwd_data/CMakeLists.txt index dfd08bc42e..1da477ebb3 100644 --- a/test/grouped_convnd_bwd_data/CMakeLists.txt +++ b/test/grouped_convnd_bwd_data/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_gtest_executable(test_grouped_convnd_bwd_data_xdl test_grouped_convnd_bwd_data_xdl.cpp) if(result EQUAL 0) target_link_libraries(test_grouped_convnd_bwd_data_xdl PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance) diff --git a/test/grouped_convnd_bwd_weight/CMakeLists.txt b/test/grouped_convnd_bwd_weight/CMakeLists.txt index 7c2b208c6b..e46113bea0 100644 --- a/test/grouped_convnd_bwd_weight/CMakeLists.txt +++ b/test/grouped_convnd_bwd_weight/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp) target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance device_grouped_convnd_bwd_weight_instance) diff --git a/test/grouped_convnd_fwd/CMakeLists.txt b/test/grouped_convnd_fwd/CMakeLists.txt index 28583e82c7..ab52d12bb0 100644 --- a/test/grouped_convnd_fwd/CMakeLists.txt +++ b/test/grouped_convnd_fwd/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_gtest_executable(test_grouped_convnd_fwd test_grouped_convnd_fwd.cpp) target_link_libraries(test_grouped_convnd_fwd PRIVATE utility device_grouped_conv1d_fwd_instance device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance) diff --git a/test/grouped_convnd_fwd_activation/CMakeLists.txt b/test/grouped_convnd_fwd_activation/CMakeLists.txt index 61e101de72..18b9114300 100644 --- a/test/grouped_convnd_fwd_activation/CMakeLists.txt +++ b/test/grouped_convnd_fwd_activation/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + if(GPU_TARGETS MATCHES "gfx9|gfx12") #Fail on gfx11 CI but fail to reproduce it in local, disable it temporary add_gtest_executable(test_grouped_convnd_fwd_bias_bnorm_clamp test_grouped_convnd_fwd_bias_bnorm_clamp.cpp) diff --git a/test/grouped_gemm/CMakeLists.txt b/test/grouped_gemm/CMakeLists.txt index f47685cf91..c6b5180013 100644 --- a/test/grouped_gemm/CMakeLists.txt +++ b/test/grouped_gemm/CMakeLists.txt @@ -1,9 +1,17 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_custom_target(test_grouped_gemm) -add_gtest_executable(test_grouped_gemm_splitk test_grouped_gemm_splitk_xdl.cpp) -if(result EQUAL 0) - target_link_libraries(test_grouped_gemm_splitk PRIVATE utility device_grouped_gemm_instance) - add_dependencies(test_grouped_gemm test_grouped_gemm_splitk) +# NOTE: We test for XDL/WMMA support here instead of relying on the usual pattern matching in the parent CMakeLists. This is necessary +# as these tests are universal and dont have "xdl" or "wmma" in their name to signify their target arch. But they will fail to link +# the instance library if there's no instances present for the current arch. +if (CK_USE_XDL OR CK_USE_WMMA) + add_gtest_executable(test_grouped_gemm_splitk test_grouped_gemm_splitk.cpp) + if(result EQUAL 0) + target_link_libraries(test_grouped_gemm_splitk PRIVATE utility device_grouped_gemm_instance) + add_dependencies(test_grouped_gemm test_grouped_gemm_splitk) + endif() endif() add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp) diff --git a/test/grouped_gemm/test_grouped_gemm_interface_xdl.cpp b/test/grouped_gemm/test_grouped_gemm_interface_xdl.cpp index 1683e16323..56fb758f89 100644 --- a/test/grouped_gemm/test_grouped_gemm_interface_xdl.cpp +++ b/test/grouped_gemm/test_grouped_gemm_interface_xdl.cpp @@ -9,6 +9,7 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "test_grouped_gemm_util.hpp" +#include "test_grouped_gemm_interface_xdl.hpp" class TestGGemmSplitKInterface_MKNKMN : public ::testing::Test { diff --git a/test/grouped_gemm/test_grouped_gemm_interface_xdl.hpp b/test/grouped_gemm/test_grouped_gemm_interface_xdl.hpp new file mode 100644 index 0000000000..a04d13c1ea --- /dev/null +++ b/test/grouped_gemm/test_grouped_gemm_interface_xdl.hpp @@ -0,0 +1,205 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/stream_config.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/utility/number.hpp" +#include "profiler/profile_grouped_gemm_impl.hpp" + +namespace ck { +namespace test { + +template +struct DeviceGroupedGemmSplitkInstanceWrapper +{ + using F16 = half_t; + using F32 = float; + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + using PassThrough = tensor_operation::element_wise::PassThrough; + + using EmptyTuple = ck::Tuple<>; + + template + using S = ck::Sequence; + + template + using I = ck::Number; + + using ABlockTransferThreadClusterArrageOrder = + std::conditional_t, S<0, 2, 1, 3>, S<0, 1, 3, 2>>; + using ABlockTransferSrcAccessOrder = + std::conditional_t, S<0, 2, 1, 3>, S<0, 1, 3, 2>>; + using ABlockTransferSrcVectorDim = std::conditional_t, I<3>, I<2>>; + using ABlockTransferDstScalarPerVector_K1 = + std::conditional_t, I<8>, I<2>>; + using ABlockLdsAddExtraM = std::conditional_t, I<1>, I<0>>; + + using BBlockTransferThreadClusterArrageOrder = + std::conditional_t, S<0, 1, 3, 2>, S<0, 2, 1, 3>>; + using BBlockTransferSrcAccessOrder = + std::conditional_t, S<0, 1, 3, 2>, S<0, 2, 1, 3>>; + using BBlockTransferSrcVectorDim = std::conditional_t, I<2>, I<3>>; + using BBlockTransferDstScalarPerVector_K1 = + std::conditional_t, I<2>, I<8>>; + using BBlockLdsAddExtraM = std::conditional_t, I<0>, I<1>>; + + using DeviceGroupedGemmSplitKInstance = + tensor_operation::device::DeviceGroupedGemmXdlSplitKCShuffle< + ALayout, + BLayout, + EmptyTuple, + ELayout, + F16, + F16, + F32, + F16, + EmptyTuple, + F16, + PassThrough, + PassThrough, + PassThrough, + GemmSpec, + 1, + 128, + 128, + 128, + KPerBlock, + K1, + K1, + 16, + 16, + 8, + 4, + S<1, 4, 16, 1>, + ABlockTransferThreadClusterArrageOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim::value, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1::value, + ABlockLdsAddExtraM::value, + S<1, 4, 16, 1>, + BBlockTransferThreadClusterArrageOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim::value, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1::value, + BBlockLdsAddExtraM::value, + 1, + 1, + S<1, 16, 1, 8>, + CDEBlockTransferScalarPerVector_NPerBlock>; + + bool IsSupported(const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideCs, + int kbatch = 1) const + { + std::size_t n_groups = Ms.size(); + EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups && + StrideBs.size() == n_groups && StrideCs.size() == n_groups) + << "The number of groups is not consistent!"; + + std::vector gemm_descs; + + for(std::size_t i = 0; i < n_groups; ++i) + { + gemm_descs.push_back(tensor_operation::device::GemmDesc{ + Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}}); + } + + std::vector p_As(n_groups, nullptr); + std::vector p_Bs(n_groups, nullptr); + std::vector p_Cs(n_groups, nullptr); + auto p_Ds = std::vector>{}; + + auto ggemm_instance = DeviceGroupedGemmSplitKInstance{}; + auto argument = ggemm_instance.MakeArgument( + p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}); + if(kbatch > 1) + { + ggemm_instance.SetKBatchSize(&argument, kbatch); + } + + return ggemm_instance.IsSupportedArgument(argument); + } + + float Run(const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideCs, + int kbatch = 1) const + { + std::size_t n_groups = Ms.size(); + EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups && + StrideBs.size() == n_groups && StrideCs.size() == n_groups) + << "The number of groups is not consistent!"; + + std::vector gemm_descs; + + for(std::size_t i = 0; i < n_groups; ++i) + { + gemm_descs.push_back(tensor_operation::device::GemmDesc{ + Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}}); + } + + std::vector p_As(n_groups, nullptr); + std::vector p_Bs(n_groups, nullptr); + std::vector p_Cs(n_groups, nullptr); + auto p_Ds = std::vector>{}; + + auto ggemm_instance = DeviceGroupedGemmSplitKInstance{}; + auto argument = ggemm_instance.MakeArgument( + p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}); + if(kbatch > 1) + { + ggemm_instance.SetKBatchSize(&argument, kbatch); + } + if(kbatch > 1 && ck::is_gfx11_supported()) + { + EXPECT_FALSE(ggemm_instance.IsSupportedArgument(argument)); + return 0; + } + else + { + EXPECT_TRUE(ggemm_instance.IsSupportedArgument(argument)); + auto invoker = ggemm_instance.MakeInvoker(); + DeviceMem dev_gemm_kargs(ggemm_instance.GetDeviceKernelArgSize(&argument)); + ggemm_instance.SetDeviceKernelArgs(&argument, dev_gemm_kargs.GetDeviceBuffer()); + return invoker.Run(argument, StreamConfig{nullptr, false}); + } + } +}; + +} // namespace test +} // namespace ck diff --git a/test/grouped_gemm/test_grouped_gemm_splitk_xdl.cpp b/test/grouped_gemm/test_grouped_gemm_splitk.cpp similarity index 62% rename from test/grouped_gemm/test_grouped_gemm_splitk_xdl.cpp rename to test/grouped_gemm/test_grouped_gemm_splitk.cpp index c237fd562e..968bea2109 100644 --- a/test/grouped_gemm/test_grouped_gemm_splitk_xdl.cpp +++ b/test/grouped_gemm/test_grouped_gemm_splitk.cpp @@ -24,21 +24,48 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; template class TestGroupedGemm : public ck::test::TestGroupedGemm { + public: + void SetUp() override + { + ck::test::TestGroupedGemm::SetUp(); + +#if defined(CK_USE_WMMA) + // The old XDL tests didn't fail if instances were not supported, so we want to keep that + // behaviour When compiling WMMA instances and WMMA is supported, then we'll fail if a + // specific case is not supported + this->fail_if_no_supported_instances_ = + ck::is_gfx11_supported() || ck::is_gfx12_supported(); +#endif + } }; // clang-format off using KernelTypes = ::testing::Types< + +#if defined(CK_USE_WMMA) + // WWMA only. No reason to not have it for XDL, but the instance was not defined and it was not in the original test. + std::tuple< Col, Col, Row, BF16, BF16, BF16>, +#endif + +#if defined(CK_USE_XDL) && defined(__gfx9__) + // XDL only at the moment, instances for WMMA not defined + std::tuple< Row, Row, Row, BF16, I8, BF16>, + std::tuple< Row, Col, Row, BF16, I8, BF16>, +#endif + +#if (defined(CK_USE_XDL) && (defined(__gfx9__) || defined(__gfx12__))) || (defined(CK_USE_WMMA) && defined(__gfx12__)) + std::tuple< Row, Row, Row, F8, F16, F16>, + std::tuple< Row, Row, Row, F16, F8, F16>, +#endif + std::tuple< Row, Row, Row, F16, F16, F16>, std::tuple< Row, Col, Row, F16, F16, F16>, std::tuple< Col, Row, Row, F16, F16, F16>, std::tuple< Col, Col, Row, F16, F16, F16>, + std::tuple< Row, Row, Row, BF16, BF16, BF16>, std::tuple< Row, Col, Row, BF16, BF16, BF16>, - std::tuple< Col, Row, Row, BF16, BF16, BF16>, - std::tuple< Row, Row, Row, BF16, I8, BF16>, - std::tuple< Row, Col, Row, BF16, I8, BF16>, - std::tuple< Row, Row, Row, F16, F8, F16>, - std::tuple< Row, Row, Row, F8, F16, F16> + std::tuple< Col, Row, Row, BF16, BF16, BF16> >; // clang-format on diff --git a/test/grouped_gemm/test_grouped_gemm_ut_cases.inc b/test/grouped_gemm/test_grouped_gemm_ut_cases.inc index 16c4ad5909..84558c89f9 100644 --- a/test/grouped_gemm/test_grouped_gemm_ut_cases.inc +++ b/test/grouped_gemm/test_grouped_gemm_ut_cases.inc @@ -65,6 +65,13 @@ TYPED_TEST(TestGroupedGemm, MNKPadded) TYPED_TEST(TestGroupedGemm, TestLargeKBatch) { + // gfx11 does not support split-K due to missing atomic add for fp16/bf16 + // Technically, we could still run the tests for fp32, but we currently don't have instances for + // it so we disable it entirely + if(ck::is_gfx11_supported()) + GTEST_SKIP() << "Split-K not supported for FP16/BF16 on GFX11 due to missing atomic add " + "instructions"; + const std::vector Ms{188, 210}; constexpr int N = 768; constexpr int K = 4096; diff --git a/test/grouped_gemm/test_grouped_gemm_util.hpp b/test/grouped_gemm/test_grouped_gemm_util.hpp index 912066ee80..6ee6465cc4 100644 --- a/test/grouped_gemm/test_grouped_gemm_util.hpp +++ b/test/grouped_gemm/test_grouped_gemm_util.hpp @@ -11,16 +11,7 @@ #include #include "ck/ck.hpp" -#include "ck/stream_config.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/utility/data_type.hpp" -#include "ck/utility/sequence.hpp" -#include "ck/utility/tuple.hpp" -#include "ck/utility/number.hpp" #include "profiler/profile_grouped_gemm_impl.hpp" extern ck::index_t param_mask; @@ -41,7 +32,7 @@ std::string serialize_range(const Range& range) return std::string(str.begin(), str.end() - 2); } -template +template class TestGroupedGemm : public testing::Test { protected: @@ -62,9 +53,26 @@ class TestGroupedGemm : public testing::Test static constexpr bool bench_ = false; // measure kernel performance static constexpr int n_warmup_ = 0; static constexpr int n_iter_ = 1; + + bool fail_if_no_supported_instances_ = FailIfNoSupportedInstances; std::vector k_batches_; - void SetUp() override { k_batches_ = {1, 2, 3, 5, 8}; } + void SetUp() override + { + constexpr bool require_16bit_atomic_add = + std::is_same_v || std::is_same_v; + if(require_16bit_atomic_add && ck::is_gfx11_supported()) + { + // gfx11 does not support split-K due to missing atomic add for fp16/bf16 + // Technically, we could still use split-K for fp32, but we currently don't have + // instances for it so we disable it entirely + k_batches_ = {1}; + } + else + { + k_batches_ = {1, 2, 3, 5, 8}; + } + } private: template @@ -132,204 +140,31 @@ class TestGroupedGemm : public testing::Test const std::vector& StrideCs, const std::vector& kbatches) { - bool pass = ck::profiler::profile_grouped_gemm_impl(verify_, - init_method_, - log_, - bench_, - Ms, - Ns, - Ks, - StrideAs, - StrideBs, - StrideCs, - kbatches, - n_warmup_, - n_iter_, - instance_index); + bool pass = + ck::profiler::profile_grouped_gemm_impl(verify_, + init_method_, + log_, + bench_, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + kbatches, + n_warmup_, + n_iter_, + instance_index, + fail_if_no_supported_instances_); EXPECT_TRUE(pass); } }; -template -struct DeviceGroupedGemmSplitkInstanceWrapper -{ - using F16 = half_t; - using F32 = float; - using Row = ck::tensor_layout::gemm::RowMajor; - using Col = ck::tensor_layout::gemm::ColumnMajor; - using PassThrough = tensor_operation::element_wise::PassThrough; - - using EmptyTuple = ck::Tuple<>; - - template - using S = ck::Sequence; - - template - using I = ck::Number; - - using ABlockTransferThreadClusterArrageOrder = - std::conditional_t, S<0, 2, 1, 3>, S<0, 1, 3, 2>>; - using ABlockTransferSrcAccessOrder = - std::conditional_t, S<0, 2, 1, 3>, S<0, 1, 3, 2>>; - using ABlockTransferSrcVectorDim = std::conditional_t, I<3>, I<2>>; - using ABlockTransferDstScalarPerVector_K1 = - std::conditional_t, I<8>, I<2>>; - using ABlockLdsAddExtraM = std::conditional_t, I<1>, I<0>>; - - using BBlockTransferThreadClusterArrageOrder = - std::conditional_t, S<0, 1, 3, 2>, S<0, 2, 1, 3>>; - using BBlockTransferSrcAccessOrder = - std::conditional_t, S<0, 1, 3, 2>, S<0, 2, 1, 3>>; - using BBlockTransferSrcVectorDim = std::conditional_t, I<2>, I<3>>; - using BBlockTransferDstScalarPerVector_K1 = - std::conditional_t, I<2>, I<8>>; - using BBlockLdsAddExtraM = std::conditional_t, I<0>, I<1>>; - - using DeviceGroupedGemmSplitKInstance = - tensor_operation::device::DeviceGroupedGemmXdlSplitKCShuffle< - ALayout, - BLayout, - EmptyTuple, - ELayout, - F16, - F16, - F32, - F16, - EmptyTuple, - F16, - PassThrough, - PassThrough, - PassThrough, - GemmSpec, - 1, - 128, - 128, - 128, - KPerBlock, - K1, - K1, - 16, - 16, - 8, - 4, - S<1, 4, 16, 1>, - ABlockTransferThreadClusterArrageOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim::value, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1::value, - ABlockLdsAddExtraM::value, - S<1, 4, 16, 1>, - BBlockTransferThreadClusterArrageOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim::value, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1::value, - BBlockLdsAddExtraM::value, - 1, - 1, - S<1, 16, 1, 8>, - CDEBlockTransferScalarPerVector_NPerBlock>; - - bool IsSupported(const std::vector& Ms, - const std::vector& Ns, - const std::vector& Ks, - const std::vector& StrideAs, - const std::vector& StrideBs, - const std::vector& StrideCs, - int kbatch = 1) const - { - std::size_t n_groups = Ms.size(); - EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups && - StrideBs.size() == n_groups && StrideCs.size() == n_groups) - << "The number of groups is not consistent!"; - - std::vector gemm_descs; - - for(std::size_t i = 0; i < n_groups; ++i) - { - gemm_descs.push_back(tensor_operation::device::GemmDesc{ - Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}}); - } - - std::vector p_As(n_groups, nullptr); - std::vector p_Bs(n_groups, nullptr); - std::vector p_Cs(n_groups, nullptr); - auto p_Ds = std::vector>{}; - - auto ggemm_instance = DeviceGroupedGemmSplitKInstance{}; - auto argument = ggemm_instance.MakeArgument( - p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}); - if(kbatch > 1) - { - ggemm_instance.SetKBatchSize(&argument, kbatch); - } - - return ggemm_instance.IsSupportedArgument(argument); - } - - float Run(const std::vector& Ms, - const std::vector& Ns, - const std::vector& Ks, - const std::vector& StrideAs, - const std::vector& StrideBs, - const std::vector& StrideCs, - int kbatch = 1) const - { - std::size_t n_groups = Ms.size(); - EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups && - StrideBs.size() == n_groups && StrideCs.size() == n_groups) - << "The number of groups is not consistent!"; - - std::vector gemm_descs; - - for(std::size_t i = 0; i < n_groups; ++i) - { - gemm_descs.push_back(tensor_operation::device::GemmDesc{ - Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}}); - } - - std::vector p_As(n_groups, nullptr); - std::vector p_Bs(n_groups, nullptr); - std::vector p_Cs(n_groups, nullptr); - auto p_Ds = std::vector>{}; - - auto ggemm_instance = DeviceGroupedGemmSplitKInstance{}; - auto argument = ggemm_instance.MakeArgument( - p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}); - if(kbatch > 1) - { - ggemm_instance.SetKBatchSize(&argument, kbatch); - } - if(kbatch > 1 && ck::is_gfx11_supported()) - { - EXPECT_FALSE(ggemm_instance.IsSupportedArgument(argument)); - return 0; - } - else - { - EXPECT_TRUE(ggemm_instance.IsSupportedArgument(argument)); - auto invoker = ggemm_instance.MakeInvoker(); - DeviceMem dev_gemm_kargs(ggemm_instance.GetDeviceKernelArgSize(&argument)); - ggemm_instance.SetDeviceKernelArgs(&argument, dev_gemm_kargs.GetDeviceBuffer()); - return invoker.Run(argument, StreamConfig{nullptr, false}); - } - } -}; - } // namespace test } // namespace ck diff --git a/test/magic_number_division/CMakeLists.txt b/test/magic_number_division/CMakeLists.txt index e7fc6ee5df..916d6b84b6 100644 --- a/test/magic_number_division/CMakeLists.txt +++ b/test/magic_number_division/CMakeLists.txt @@ -1,2 +1,5 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_test_executable(test_magic_number_division magic_number_division.cpp) target_link_libraries(test_magic_number_division PRIVATE utility) diff --git a/test/mx_mfma_op/CMakeLists.txt b/test/mx_mfma_op/CMakeLists.txt index 6715265ae6..043a3df76d 100644 --- a/test/mx_mfma_op/CMakeLists.txt +++ b/test/mx_mfma_op/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_custom_target(test_mx_mfma) add_gtest_executable(test_mx_mfma_op mx_mfma_op.cpp) diff --git a/test/normalization_bwd_data/CMakeLists.txt b/test/normalization_bwd_data/CMakeLists.txt index fb7ad81e19..c78a24fd1b 100644 --- a/test/normalization_bwd_data/CMakeLists.txt +++ b/test/normalization_bwd_data/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_custom_target(test_normalization_bwd_data) add_gtest_executable(test_layernorm2d_bwd_data_fp32 test_layernorm2d_bwd_data_fp32.cpp) diff --git a/test/normalization_bwd_gamma_beta/CMakeLists.txt b/test/normalization_bwd_gamma_beta/CMakeLists.txt index 81b6d377ce..bc6b277a62 100644 --- a/test/normalization_bwd_gamma_beta/CMakeLists.txt +++ b/test/normalization_bwd_gamma_beta/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_custom_target(test_normalization_bwd_gamma_beta) add_gtest_executable(test_layernorm2d_bwd_gamma_beta_fp32 test_layernorm2d_bwd_gamma_beta_fp32.cpp) if (result EQUAL 0) diff --git a/test/normalization_fwd/CMakeLists.txt b/test/normalization_fwd/CMakeLists.txt index c309149deb..78c049c37c 100644 --- a/test/normalization_fwd/CMakeLists.txt +++ b/test/normalization_fwd/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_custom_target(test_normalization_fwd) add_gtest_executable(test_layernorm2d_fwd_fp32 test_layernorm2d_fwd_fp32.cpp) if(result EQUAL 0) diff --git a/test/permute_scale/CMakeLists.txt b/test/permute_scale/CMakeLists.txt index d63cb79910..90e39b5829 100644 --- a/test/permute_scale/CMakeLists.txt +++ b/test/permute_scale/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_custom_target(test_permute) add_gtest_executable(test_permute_scale test_permute_scale.cpp) target_link_libraries(test_permute_scale PRIVATE utility device_permute_scale_instance) diff --git a/test/pool/CMakeLists.txt b/test/pool/CMakeLists.txt index 06eb8b85ed..173f83b745 100644 --- a/test/pool/CMakeLists.txt +++ b/test/pool/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_custom_target(test_pool) add_gtest_executable(test_avg_pool3d_bwd test_avg_pool3d_bwd.cpp) diff --git a/test/position_embedding/CMakeLists.txt b/test/position_embedding/CMakeLists.txt index e7a939bebb..13ea4fdc59 100644 --- a/test/position_embedding/CMakeLists.txt +++ b/test/position_embedding/CMakeLists.txt @@ -1 +1,4 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_test_executable(test_position_embedding position_embedding.cpp) diff --git a/test/quantization/CMakeLists.txt b/test/quantization/CMakeLists.txt index 89a99f5e5d..4c26116fc8 100644 --- a/test/quantization/CMakeLists.txt +++ b/test/quantization/CMakeLists.txt @@ -1,2 +1,5 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_custom_target(test_quantization) add_subdirectory(gemm) diff --git a/test/quantization/gemm/CMakeLists.txt b/test/quantization/gemm/CMakeLists.txt index 0eb08f9a5b..6a47c599c6 100644 --- a/test/quantization/gemm/CMakeLists.txt +++ b/test/quantization/gemm/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_custom_target(test_gemm_quantization_targets) # Only build test if the quantization instance library exists diff --git a/test/reduce/CMakeLists.txt b/test/reduce/CMakeLists.txt index bf05795063..0d5acbb465 100644 --- a/test/reduce/CMakeLists.txt +++ b/test/reduce/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_gtest_executable(test_reduce_no_index reduce_no_index.cpp) add_gtest_executable(test_reduce_with_index reduce_with_index.cpp) target_link_libraries(test_reduce_no_index PRIVATE utility device_reduce_instance) diff --git a/test/reference_conv_fwd/CMakeLists.txt b/test/reference_conv_fwd/CMakeLists.txt index b40b9a1ed0..4d238aa177 100644 --- a/test/reference_conv_fwd/CMakeLists.txt +++ b/test/reference_conv_fwd/CMakeLists.txt @@ -1,2 +1,5 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_gtest_executable(test_reference_conv_fwd reference_conv_fwd.cpp) target_link_libraries(test_reference_conv_fwd PRIVATE utility) diff --git a/test/s_prefetch_op/CMakeLists.txt b/test/s_prefetch_op/CMakeLists.txt index 1b598cc952..1c55598b7d 100644 --- a/test/s_prefetch_op/CMakeLists.txt +++ b/test/s_prefetch_op/CMakeLists.txt @@ -1,2 +1,5 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_test_executable(test_s_prefetch_op s_prefetch_op.cpp) target_link_libraries(test_s_prefetch_op PRIVATE utility) diff --git a/test/scatter_gather/CMakeLists.txt b/test/scatter_gather/CMakeLists.txt index cc327d42db..78eb844a53 100644 --- a/test/scatter_gather/CMakeLists.txt +++ b/test/scatter_gather/CMakeLists.txt @@ -1,2 +1,5 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_test_executable(test_scatter_gather scatter_gather.cpp) # target_compile_options(test_scatter_gather PRIVATE -v --save-temps -Wno-gnu-line-marker) diff --git a/test/smfmac_op/CMakeLists.txt b/test/smfmac_op/CMakeLists.txt index 4ffc423f54..f994bbd62d 100644 --- a/test/smfmac_op/CMakeLists.txt +++ b/test/smfmac_op/CMakeLists.txt @@ -1,2 +1,5 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_gtest_executable(test_smfmac_op smfmac_op_xdl.cpp) target_link_libraries(test_smfmac_op PRIVATE utility) diff --git a/test/softmax/CMakeLists.txt b/test/softmax/CMakeLists.txt index 4ba4012625..a3deab35c9 100644 --- a/test/softmax/CMakeLists.txt +++ b/test/softmax/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_custom_target(test_softmax) add_gtest_executable(test_softmax_rank3 test_softmax_rank3.cpp) diff --git a/test/space_filling_curve/CMakeLists.txt b/test/space_filling_curve/CMakeLists.txt index a527268042..d964daae94 100644 --- a/test/space_filling_curve/CMakeLists.txt +++ b/test/space_filling_curve/CMakeLists.txt @@ -1 +1,4 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_test_executable(test_space_filling_curve space_filling_curve.cpp) diff --git a/test/transpose/CMakeLists.txt b/test/transpose/CMakeLists.txt index fb9379bea9..34b75163e0 100644 --- a/test/transpose/CMakeLists.txt +++ b/test/transpose/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_gtest_executable(test_transpose test_transpose_xdl.cpp) if(result EQUAL 0) target_link_libraries(test_transpose PRIVATE utility device_transpose_instance) diff --git a/test/wmma_op/CMakeLists.txt b/test/wmma_op/CMakeLists.txt index e553253c62..d5ea46055b 100644 --- a/test/wmma_op/CMakeLists.txt +++ b/test/wmma_op/CMakeLists.txt @@ -1,2 +1,5 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_test_executable(test_wmma_op wmma_op.cpp) target_link_libraries(test_wmma_op PRIVATE utility) diff --git a/test/wrapper/CMakeLists.txt b/test/wrapper/CMakeLists.txt index 1eb6c35db2..bb1a81b5e2 100644 --- a/test/wrapper/CMakeLists.txt +++ b/test/wrapper/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_custom_target(test_wrapper) add_gtest_executable(test_wrapper_layout test_wrapper_layout.cpp) diff --git a/tile_engine/CMakeLists.txt b/tile_engine/CMakeLists.txt index cd1a192a74..7f5e2fa298 100644 --- a/tile_engine/CMakeLists.txt +++ b/tile_engine/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + include_directories(BEFORE ${CMAKE_CURRENT_LIST_DIR}/include ) diff --git a/tile_engine/include/CMakeLists.txt b/tile_engine/include/CMakeLists.txt index 53d97aafae..ee27d6734a 100644 --- a/tile_engine/include/CMakeLists.txt +++ b/tile_engine/include/CMakeLists.txt @@ -1 +1,4 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + message(STATUS "Add include directory") diff --git a/tile_engine/include/utility/validation.hpp b/tile_engine/include/utility/validation.hpp new file mode 100644 index 0000000000..dc57e6cc6a --- /dev/null +++ b/tile_engine/include/utility/validation.hpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c), Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +/// @brief Function to compare the results of the device and host computations +bool compare(std::string instanceName, + ck_tile::index_t K, + ck_tile::index_t kbatch, + ck_tile::HostTensor& c_m_n_dev_result, + ck_tile::HostTensor& c_m_n_host_result) +{ + const float max_accumulated_value = + *std::max_element(c_m_n_host_result.mData.begin(), c_m_n_host_result.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + bool pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_result, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "For " << instanceName << " Relative error threshold is " + << rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold is " + << rtol_atol.at(ck_tile::number<1>{}) << std::endl; + std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl; + + return pass; +} diff --git a/tile_engine/ops/CMakeLists.txt b/tile_engine/ops/CMakeLists.txt index db100553f3..6f82e1b07a 100644 --- a/tile_engine/ops/CMakeLists.txt +++ b/tile_engine/ops/CMakeLists.txt @@ -1,3 +1,7 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + add_subdirectory(gemm) add_subdirectory(gemm_multi_d) -add_subdirectory(gemm_preshuffle) \ No newline at end of file +add_subdirectory(gemm_preshuffle) +add_subdirectory(gemm_streamk) diff --git a/tile_engine/ops/gemm/CMakeLists.txt b/tile_engine/ops/gemm/CMakeLists.txt index a72b6c40ab..ff18291c00 100644 --- a/tile_engine/ops/gemm/CMakeLists.txt +++ b/tile_engine/ops/gemm/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + set(GEMM_DATATYPE "fp8;fp16" CACHE STRING "List of datatypes for GEMM (semicolon-separated)") set(GEMM_LAYOUT "rcr;rrr;crr;ccr" CACHE STRING "List of layout for GEMM (semicolon-separated)") set(GEMM_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)") @@ -64,6 +67,7 @@ function(create_individual_gemm_target datatype layout trait tile_config config_ # Create the executable add_executable(${target_name} + # to save build time, exclude the target from "all" target of "gemm" directory and its ancestors EXCLUDE_FROM_ALL ${GEMM_SOURCE_DIR}/gemm_benchmark_single.cpp ${instance_header} diff --git a/tile_engine/ops/gemm/gemm_benchmark_single.cpp b/tile_engine/ops/gemm/gemm_benchmark_single.cpp index 6323c066a1..26f3a3928a 100644 --- a/tile_engine/ops/gemm/gemm_benchmark_single.cpp +++ b/tile_engine/ops/gemm/gemm_benchmark_single.cpp @@ -11,12 +11,12 @@ #include "ck_tile/core.hpp" #include "ck_tile/host.hpp" +#include "ck_tile/ops/common/utils.hpp" #include "gemm_profiler.hpp" #include "gemm_common.hpp" // The kernel header is included via the compile command line with -include flag // It defines SelectedKernel struct and KERNEL_NAME -// DataTypeTraits are now defined in gemm_common.hpp // Create argument parser inline auto create_args(int argc, char* argv[]) @@ -77,12 +77,12 @@ inline auto create_args(int argc, char* argv[]) void benchmark_single(const ck_tile::ArgParser& arg_parser) { - // Use DataTypeTraits to get the actual type names from the generated header + // Use ck_tile::DataTypeTraits to get the actual type names from the generated header // The generated header defines ADataType, BDataType, AccDataType, CDataType - std::string dtype_a = DataTypeTraits::name; - std::string dtype_b = DataTypeTraits::name; - std::string dtype_acc = DataTypeTraits::name; - std::string dtype_c = DataTypeTraits::name; + std::string dtype_a = ck_tile::DataTypeTraits::name; + std::string dtype_b = ck_tile::DataTypeTraits::name; + std::string dtype_acc = ck_tile::DataTypeTraits::name; + std::string dtype_c = ck_tile::DataTypeTraits::name; // Layout names from the layout types std::string layout_a = ALayout::name; diff --git a/tile_engine/ops/gemm/gemm_common.hpp b/tile_engine/ops/gemm/gemm_common.hpp index 899221547f..1fdc63b33b 100644 --- a/tile_engine/ops/gemm/gemm_common.hpp +++ b/tile_engine/ops/gemm/gemm_common.hpp @@ -9,65 +9,6 @@ #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/pk_int4.hpp" -//[TODO] This can be moved to commons -// DataTypeTraits for all supported types -template -struct DataTypeTraits; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp64"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "int8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "int32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "pk_int4_t"; -}; - // Helper function to determine if a layout is row-major template constexpr auto is_row_major(Layout) diff --git a/tile_engine/ops/gemm_multi_d/CMakeLists.txt b/tile_engine/ops/gemm_multi_d/CMakeLists.txt index 8d9c087e24..43164cd73c 100644 --- a/tile_engine/ops/gemm_multi_d/CMakeLists.txt +++ b/tile_engine/ops/gemm_multi_d/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + set(GEMM_MULTI_D_DATATYPE "fp16" CACHE STRING "List of datatypes for GEMM Multi D (semicolon-separated)") set(GEMM_MULTI_D_LAYOUT "rcrr;rrrr;crrr;ccrr" CACHE STRING "List of layout for GEMM Multi D (semicolon-separated)") set(GEMM_MULTI_D_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)") @@ -67,6 +70,7 @@ function(create_individual_gemm_multi_d_target datatype layout trait tile_config # Create the executable add_executable(${target_name} + # to save build time, exclude the target from "all" target of "gemm_multi_d" directory and its ancestors EXCLUDE_FROM_ALL ${GEMM_MULTI_D_SOURCE_DIR}/gemm_multi_d_benchmark_single.cpp ${instance_header} diff --git a/tile_engine/ops/gemm_multi_d/gemm_multi_d_common.hpp b/tile_engine/ops/gemm_multi_d/gemm_multi_d_common.hpp index 899221547f..1fdc63b33b 100644 --- a/tile_engine/ops/gemm_multi_d/gemm_multi_d_common.hpp +++ b/tile_engine/ops/gemm_multi_d/gemm_multi_d_common.hpp @@ -9,65 +9,6 @@ #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/pk_int4.hpp" -//[TODO] This can be moved to commons -// DataTypeTraits for all supported types -template -struct DataTypeTraits; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp64"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "int8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "int32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "pk_int4_t"; -}; - // Helper function to determine if a layout is row-major template constexpr auto is_row_major(Layout) diff --git a/tile_engine/ops/gemm_preshuffle/CMakeLists.txt b/tile_engine/ops/gemm_preshuffle/CMakeLists.txt index e3bee6ff52..c89fe236dd 100644 --- a/tile_engine/ops/gemm_preshuffle/CMakeLists.txt +++ b/tile_engine/ops/gemm_preshuffle/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + set(GEMM_PRESHUFFLE_DATATYPE "fp16;fp8;bf16;bf8" CACHE STRING "List of datatypes for GEMM Preshuffle (semicolon-separated)") set(GEMM_PRESHUFFLE_LAYOUT "rcr" CACHE STRING "List of layout for GEMM Preshuffle (semicolon-separated)") set(GEMM_PRESHUFFLE_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)") @@ -64,6 +67,7 @@ function(create_individual_gemm_preshuffle_target datatype layout trait tile_con # Create the executable add_executable(${target_name} + # to save build time, exclude the target from "all" target of "gemm_preshuffle" directory and its ancestors EXCLUDE_FROM_ALL ${GEMM_PRESHUFFLE_SOURCE_DIR}/gemm_preshuffle_benchmark_single.cpp ${instance_header} diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark_single.cpp b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark_single.cpp index 4fbb25f0c9..0d5de02750 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark_single.cpp +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark_single.cpp @@ -11,12 +11,12 @@ #include "ck_tile/core.hpp" #include "ck_tile/host.hpp" +#include "ck_tile/ops/common/utils.hpp" #include "gemm_preshuffle_profiler.hpp" #include "gemm_preshuffle_common.hpp" // The kernel header is included via the compile command line with -include flag // It defines SelectedKernel struct and KERNEL_NAME -// DataTypeTraits are now defined in gemm_common.hpp // Create argument parser inline auto create_args(int argc, char* argv[]) @@ -77,12 +77,12 @@ inline auto create_args(int argc, char* argv[]) void benchmark_single(const ck_tile::ArgParser& arg_parser) { - // Use DataTypeTraits to get the actual type names from the generated header + // Use ck_tile::DataTypeTraits to get the actual type names from the generated header // The generated header defines ADataType, BDataType, AccDataType, CDataType - std::string dtype_a = DataTypeTraits::name; - std::string dtype_b = DataTypeTraits::name; - std::string dtype_acc = DataTypeTraits::name; - std::string dtype_c = DataTypeTraits::name; + std::string dtype_a = ck_tile::DataTypeTraits::name; + std::string dtype_b = ck_tile::DataTypeTraits::name; + std::string dtype_acc = ck_tile::DataTypeTraits::name; + std::string dtype_c = ck_tile::DataTypeTraits::name; // Layout names from the layout types std::string layout_a = ALayout::name; diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp index 1b2cfe3735..bb0b8090fa 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp @@ -9,65 +9,6 @@ #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/pk_int4.hpp" -//[TODO] This can be moved to commons -// DataTypeTraits for all supported types -template -struct DataTypeTraits; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp64"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "int8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "int32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "pk_int4_t"; -}; - // Helper function to determine if a layout is row-major template constexpr auto is_row_major(Layout) diff --git a/tile_engine/ops/gemm_streamk/CMakeLists.txt b/tile_engine/ops/gemm_streamk/CMakeLists.txt new file mode 100644 index 0000000000..acfd78edc5 --- /dev/null +++ b/tile_engine/ops/gemm_streamk/CMakeLists.txt @@ -0,0 +1,295 @@ +set(GEMM_STREAMK_DATATYPE "fp8;fp16" CACHE STRING "List of datatypes for GEMM (semicolon-separated)") +set(GEMM_STREAMK_LAYOUT "rcr" CACHE STRING "List of layout for GEMM (semicolon-separated)") +set(GEMM_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)") +option(ENABLE_CCACHE_GEMM "Enable ccache for GEMM ops compilation" OFF) + +# Store the directory path for use in functions +set(GEMM_SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR}) + +# Function to create individual GEMM targets +function(create_individual_gemm_target datatype layout trait tile_config config_json) + # Use the parent scope GEMM_GPU_TARGETS_INDIVIDUAL variable + if(NOT GEMM_GPU_TARGETS_INDIVIDUAL) + message(WARNING "Skipping individual GEMM target ${datatype}_${layout}_${trait}_${tile_config}: No supported GPU targets") + return() + endif() + + # Parse tile configuration: format is tile_mxtile_nxtile_k_warp_mxwarp_nxwarp_k_warp_tile_mxwarp_tile_nxwarp_tile_k + # First split by underscore to get three groups + string(REPLACE "_" ";" config_groups ${tile_config}) + list(GET config_groups 0 tile_dims) # e.g., 256x256x32 + list(GET config_groups 1 warp_dims) # e.g., 4x1x1 + list(GET config_groups 2 warp_tile_dims) # e.g., 16x16x16 + + # Parse tile dimensions + string(REPLACE "x" ";" tile_parts ${tile_dims}) + list(GET tile_parts 0 tile_m) + list(GET tile_parts 1 tile_n) + list(GET tile_parts 2 tile_k) + + # Parse warp dimensions + string(REPLACE "x" ";" warp_parts ${warp_dims}) + list(GET warp_parts 0 warp_m) + list(GET warp_parts 1 warp_n) + list(GET warp_parts 2 warp_k) + + # Parse warp tile dimensions + string(REPLACE "x" ";" warp_tile_parts ${warp_tile_dims}) + list(GET warp_tile_parts 0 warp_tile_m) + list(GET warp_tile_parts 1 warp_tile_n) + list(GET warp_tile_parts 2 warp_tile_k) + + set(target_name "benchmark_gemm_streamk_${datatype}_${layout}_${trait}_${tile_config}") + set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}") + + # Generate the single instance header for this kernel + set(instance_header "${working_path}/gemm_streamk_single_${datatype}_${layout}_${trait}_${tile_config}.hpp") + + # Add custom command to generate the header file at build time + add_custom_command( + OUTPUT ${instance_header} + COMMAND ${Python3_EXECUTABLE} ${GEMM_SOURCE_DIR}/gemm_streamk_instance_builder.py + --working_path ${working_path} + --datatype ${datatype} + --layout ${layout} + --config_json ${config_json} + --gen_single + --kernel_name "${datatype}_${layout}_${trait}_${tile_config}" + --tile_config "${tile_config}" + --trait_combo "${trait}" + DEPENDS ${GEMM_SOURCE_DIR}/gemm_streamk_instance_builder.py ${config_json} + COMMENT "Generating ${instance_header}" + ) + + # Create the executable + add_executable(${target_name} + ${GEMM_SOURCE_DIR}/gemm_streamk_benchmark_single.cpp + ${instance_header} + ) + + # Set GPU architectures + set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GEMM_GPU_TARGETS_INDIVIDUAL}) + + # Set compile definitions + target_compile_definitions(${target_name} PRIVATE + GEMM_SINGLE_INSTANCE_HPP="${instance_header}" + ) + + # Include directories + target_include_directories(${target_name} PRIVATE + ${GEMM_SOURCE_DIR} + ${working_path} + ) + + # Compile options + target_compile_options(${target_name} PRIVATE + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress + -include ${instance_header} + ) + + # Add to collection targets + add_dependencies(benchmark_gemm_streamk_all ${target_name}) + add_dependencies(benchmark_gemm_streamk_${datatype} ${target_name}) + add_dependencies(benchmark_gemm_streamk_${layout} ${target_name}) + add_dependencies(benchmark_gemm_streamk_${datatype}_${layout} ${target_name}) + + # Add to trait-specific targets + string(REPLACE "_" ";" trait_parts ${trait}) + list(GET trait_parts 0 pipeline) + list(GET trait_parts 1 epilogue) + list(GET trait_parts 2 scheduler) + + add_dependencies(benchmark_gemm_streamk_${pipeline} ${target_name}) + add_dependencies(benchmark_gemm_streamk_${epilogue} ${target_name}) + add_dependencies(benchmark_gemm_streamk_${scheduler} ${target_name}) +endfunction() + +# Function to build individual GEMM targets +function(build_individual_gemm_targets datatype layout) + set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}") + + # Choose config file + # Priority order: + # 1. Environment variable GEMM_CONFIG_FILE + # 2. CMake variable GEMM_CONFIG_FILE + # 3. Default based on layout + + # Check environment variable first + if(DEFINED ENV{GEMM_CONFIG_FILE} AND NOT "$ENV{GEMM_CONFIG_FILE}" STREQUAL "") + set(config_filename "$ENV{GEMM_CONFIG_FILE}") + set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${config_filename}") + message(STATUS " Using config from environment variable: ${config_filename}") + elseif(NOT "${GEMM_CONFIG_FILE}" STREQUAL "") + # Use CMake variable if set + set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${GEMM_CONFIG_FILE}") + message(STATUS " Using custom config: ${GEMM_CONFIG_FILE}") + else() + # Use default config for all layouts + set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json") + message(STATUS " Using default config for layout ${layout}") + endif() + + # Check if config file exists + if(NOT EXISTS ${json_blob}) + message(FATAL_ERROR "Config file not found: ${json_blob}") + endif() + + # Determine number of workers for parallel generation + if(DEFINED ENV{CMAKE_BUILD_PARALLEL_LEVEL}) + set(num_workers $ENV{CMAKE_BUILD_PARALLEL_LEVEL}) + else() + # Use processor count but limit to avoid memory issues + cmake_host_system_information(RESULT num_cores QUERY NUMBER_OF_LOGICAL_CORES) + math(EXPR num_workers "${num_cores}") + if(num_workers GREATER 8) + set(num_workers 8) + endif() + endif() + + # Generate individual kernel files using parallel version + message(STATUS "Generating individual kernels for ${datatype} ${layout} using ${num_workers} workers...") + message(STATUS " Working path: ${working_path}") + message(STATUS " Config file: ${json_blob}") + message(STATUS " Python executable: ${Python3_EXECUTABLE}") + message(STATUS " Script path: ${CMAKE_CURRENT_LIST_DIR}/gemm_streamk_instance_builder.py") + + # Create working directory first + file(MAKE_DIRECTORY ${working_path}) + + # First, just list the kernels (fast operation) + message(STATUS " Listing kernel configurations...") + execute_process( + COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_streamk_instance_builder.py + --working_path ${working_path} + --datatype ${datatype} + --layout ${layout} + --config_json ${json_blob} + --list_kernels + WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR} + RESULT_VARIABLE ret + OUTPUT_VARIABLE list_output + ERROR_VARIABLE list_error + ) + + if(NOT ret EQUAL 0) + message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${list_error}") + endif() + + # Read kernel count + if(EXISTS ${working_path}/gemm_kernel_count.txt) + file(READ ${working_path}/gemm_kernel_count.txt kernel_count) + string(STRIP "${kernel_count}" kernel_count) + message(STATUS " Found ${kernel_count} kernel configurations") + else() + message(FATAL_ERROR "Kernel count file not found") + endif() + + # Read kernel list and create targets + if(EXISTS ${working_path}/gemm_kernel_list.txt) + file(STRINGS ${working_path}/gemm_kernel_list.txt kernel_lines) + foreach(line IN LISTS kernel_lines) + # Parse line: kernel_name|tile_config|trait_combo + string(REPLACE "|" ";" parts "${line}") + list(GET parts 0 kernel_name) + list(GET parts 1 tile_config) + list(GET parts 2 trait_combo) + + # Create individual target + create_individual_gemm_target("${datatype}" "${layout}" "${trait_combo}" "${tile_config}" "${json_blob}") + endforeach() + else() + message(FATAL_ERROR "Kernel list file not found") + endif() +endfunction() + +# Main build logic - Only individual builds supported +message(STATUS "=== Starting Tile Engine StreamK GEMM Configuration ===") +message(STATUS "GEMM_STREAMK_DATATYPE: ${GEMM_STREAMK_DATATYPE}") +message(STATUS "GEMM_STREAMK_LAYOUT: ${GEMM_STREAMK_LAYOUT}") +message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") + +# Filter GPU targets to only gfx90a, gfx942 +set(GEMM_GPU_TARGETS_INDIVIDUAL "") +set(DESIRED_TARGETS "gfx90a;gfx942") # TODO: Add gfx950 when supported + +foreach(target IN LISTS SUPPORTED_GPU_TARGETS) + if(target IN_LIST DESIRED_TARGETS) + list(APPEND GEMM_GPU_TARGETS_INDIVIDUAL ${target}) + message(STATUS " Adding GPU target: ${target}") + endif() +endforeach() + +# Skip build if no matching targets found +if(NOT GEMM_GPU_TARGETS_INDIVIDUAL) + message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") +else() + message(STATUS "Building individual GEMM targets for GPU targets: ${GEMM_GPU_TARGETS_INDIVIDUAL}") + + # Enable parallel compilation optimizations + # Set up job pools for better parallel compilation control + set_property(GLOBAL PROPERTY JOB_POOLS + compile_heavy=4 # Limit heavy compilations to prevent OOM + compile_normal=16 # Allow more parallel normal compilations + ) + + # Enable compiler cache if available and explicitly requested + # Disabled by default due to permission issues in CI environments + if(ENABLE_CCACHE_GEMM) + find_program(CCACHE_PROGRAM ccache) + if(CCACHE_PROGRAM) + set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM}) + message(STATUS "Using ccache for faster compilation") + else() + message(WARNING "ccache requested but not found") + endif() + else() + message(STATUS "ccache disabled for GEMM ops (use -DENABLE_CCACHE_GEMM=ON to enable)") + endif() + + # Create master collection targets + add_custom_target(benchmark_gemm_streamk_all) + + # Create datatype collection targets + foreach(dt IN LISTS GEMM_STREAMK_DATATYPE) + add_custom_target(benchmark_gemm_streamk_${dt}) + endforeach() + + # Create layout collection targets + foreach(l IN LISTS GEMM_STREAMK_LAYOUT) + add_custom_target(benchmark_gemm_streamk_${l}) + endforeach() + + # Create combined collection targets + foreach(dt IN LISTS GEMM_STREAMK_DATATYPE) + foreach(l IN LISTS GEMM_STREAMK_LAYOUT) + add_custom_target(benchmark_gemm_streamk_${dt}_${l}) + endforeach() + endforeach() + + # Create trait-based collection targets + # These are common trait components used across all GEMM kernels + set(GEMM_PIPELINES "mem;compv3;compv4") + set(GEMM_EPILOGUES "default;cshuffle") + set(GEMM_SCHEDULERS "intrawave;interwave") + + foreach(pipeline IN LISTS GEMM_PIPELINES) + add_custom_target(benchmark_gemm_streamk_${pipeline}) + endforeach() + + foreach(epilogue IN LISTS GEMM_EPILOGUES) + add_custom_target(benchmark_gemm_streamk_${epilogue}) + endforeach() + + foreach(scheduler IN LISTS GEMM_SCHEDULERS) + add_custom_target(benchmark_gemm_streamk_${scheduler}) + endforeach() + + # Build individual targets for each datatype/layout combination + foreach(dt IN LISTS GEMM_STREAMK_DATATYPE) + foreach(l IN LISTS GEMM_STREAMK_LAYOUT) + build_individual_gemm_targets(${dt} ${l}) + endforeach() + endforeach() +endif() diff --git a/tile_engine/ops/gemm_streamk/configs/default_config.json b/tile_engine/ops/gemm_streamk/configs/default_config.json new file mode 100644 index 0000000000..f6b92feee3 --- /dev/null +++ b/tile_engine/ops/gemm_streamk/configs/default_config.json @@ -0,0 +1,105 @@ +{ + "problem": { + }, + "tile_config": { + "tile_m": { + "max": 256, + "min": 64, + "step": 64 + }, + "tile_n": { + "max": 256, + "min": 64, + "step": 64 + }, + "tile_k": { + "max": 256, + "min": 64, + "step": 64 + }, + "warp_m": { + "values": [ + 4, + 2, + 1 + ] + }, + "warp_n": { + "values": [ + 4, + 2, + 1 + ] + }, + "warp_k": { + "values": [ + 1 + ] + }, + "warp_tile_m": { + "values": [ + 4, + 16, + 32 + ] + }, + "warp_tile_n": { + "values": [ + 16, + 32, + 64 + ] + }, + "warp_tile_k": { + "values": [ + 8, + 16, + 32, + 64, + 128 + ] + } + }, + "trait_config": { + "pipeline": { + "values": [ + "compv3" + ] + }, + "scheduler": { + "values": [ + "intrawave" + ] + }, + "epilogue": { + "values": [ + "cshuffle" + ] + }, + "pad_m": { + "values": [ + false + ] + }, + "pad_n": { + "values": [ + false + ] + }, + "pad_k": { + "values": [ + false + ] + }, + "persistent": { + "values": [ + false, true + ] + }, + "reduction_strategy": { + "values": [ + "reduction", "atomic" + ] + } + } +} diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark.hpp b/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark.hpp new file mode 100644 index 0000000000..fa8a019be5 --- /dev/null +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark.hpp @@ -0,0 +1,201 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "gemm_streamk_common.hpp" +#include "utility/validation.hpp" + +// Data types and Layouts are defined by the generated kernel headers +// No hardcoded type definitions here to avoid conflicts + +enum class Metric +{ + LATENCY = 0, + TFLOPS = 1, + BANDWIDTH = 2 +}; + +inline constexpr auto get_metric_name(Metric m) +{ + switch(m) + { + case Metric::LATENCY: return "latency"; + case Metric::TFLOPS: return "tflops"; + case Metric::BANDWIDTH: return "bandwidth"; + default: throw std::invalid_argument("Unsupported metric type"); + } +} + +struct GemmProblem +{ + int split_k_; + int m_, n_, k_; + int stride_a_, stride_b_, stride_c_; + + std::string dtype_a_, dtype_b_, dtype_acc_, dtype_c_; + std::string layout_a_, layout_b_, layout_c_; + + bool structured_sparsity_; + + friend std::ostream& operator<<(std::ostream& os, const GemmProblem& problem) + { + os << "{\n" + << " \"split_k\":" << problem.split_k_ << ",\n" + << " \"m\":" << problem.m_ << ",\n" + << " \"n\":" << problem.n_ << ",\n" + << " \"k\":" << problem.k_ << ",\n" + << " \"stride_a\":" << problem.stride_a_ << ",\n" + << " \"stride_b\":" << problem.stride_b_ << ",\n" + << " \"stride_c\":" << problem.stride_c_ << ",\n" + << " \"dtype_a\":\"" << problem.dtype_a_ << "\",\n" + << " \"dtype_b\":\"" << problem.dtype_b_ << "\",\n" + << " \"dtype_acc\":\"" << problem.dtype_acc_ << "\",\n" + << " \"dtype_c\":\"" << problem.dtype_c_ << "\",\n" + << " \"layout_a\":\"" << problem.layout_a_ << "\",\n" + << " \"layout_b\":\"" << problem.layout_b_ << "\",\n" + << " \"layout_c\":\"" << problem.layout_c_ << "\",\n" + << " \"structured_sparsity\":" << (problem.structured_sparsity_ ? "true" : "false") + << "\n" + << "}"; + return os; + } +}; + +struct PerformanceResult +{ + double latency_; + double tflops_; + double bandwidth_; + + static bool compare(const PerformanceResult& a, const PerformanceResult& b, Metric m) + { + switch(m) + { + case Metric::LATENCY: return a.latency_ < b.latency_; + case Metric::TFLOPS: return a.tflops_ > b.tflops_; + case Metric::BANDWIDTH: return a.bandwidth_ > b.bandwidth_; + default: throw std::invalid_argument("Unsupported metric type"); + } + } + + friend std::ostream& operator<<(std::ostream& os, const PerformanceResult& result) + { + os << "{\n" + << " \"latency(ms)\": " << std::fixed << std::setprecision(2) << result.latency_ + << ",\n" + << " \"tflops(TFlops)\": " << result.tflops_ << ",\n" + << " \"bandwidth(GB/s)\": " << result.bandwidth_ << "\n" + << "}"; + return os; + } +}; + +struct KernelInstance +{ + std::string name_; + std::string dp_persistent_; + std::string reduction_strategy_; + GemmProblem problem_; + PerformanceResult perf_result_; + + static bool compare(const KernelInstance& a, const KernelInstance& b, Metric m) + { + return PerformanceResult::compare(a.perf_result_, b.perf_result_, m); + } + + friend std::ostream& operator<<(std::ostream& os, const KernelInstance& obj) + { + os << "{\n" + << " \"name\": \"" << obj.name_ << "\",\n" + << " \"dp_persistent\": \"" << obj.dp_persistent_ << "\",\n" + << " \"reduction_strategy\": \"" << obj.reduction_strategy_ << "\",\n" + << " \"problem\": " << obj.problem_ << ",\n" + << " \"perf_result\": " << obj.perf_result_ << "\n" + << "}"; + return os; + } +}; + +struct Setting +{ + int n_warmup_; + int n_repeat_; + bool is_gpu_timer_; + int verify_; + int init_method_; + bool log_; + std::string csv_filename_; + bool flush_cache_; + int rotating_count_; + bool json_output_; +}; + +inline std::string get_rocm_version() +{ + std::ifstream version_file("/opt/rocm/.info/version"); + if(version_file.is_open()) + { + std::string version; + std::getline(version_file, version); + return version; + } + return "Unknown"; +} + +/// @brief Function to get the kernel output with reference implementation on CPU/GPU +void gemm_host_reference(int verify, + ck_tile::HostTensor& a_m_k, + ck_tile::HostTensor& b_k_n, + ck_tile::HostTensor& c_m_n_host_result, + ck_tile::DeviceMem& a_m_k_dev_buf, + ck_tile::DeviceMem& b_k_n_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t stride_A, + ck_tile::index_t stride_B, + ck_tile::index_t stride_C) +{ + if(verify == 1) + { + c_m_n_host_result.SetZero(); + + ck_tile::reference_gemm( + a_m_k, b_k_n, c_m_n_host_result); + } + else if(verify == 2) + { + if constexpr(std::is_same_v) + { + // Restore input for B for gpu reference + b_k_n_dev_buf.ToDevice(b_k_n.data()); + } + + ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_host_result.get_element_space_size_in_bytes()); + c_m_n_host_result.SetZero(); + c_m_n_gpu_buf_ref.SetZero(); + + ADataType* d_A = static_cast(a_m_k_dev_buf.GetDeviceBuffer()); + BDataType* d_B = static_cast(b_k_n_dev_buf.GetDeviceBuffer()); + CDataType* d_C = static_cast(c_m_n_gpu_buf_ref.GetDeviceBuffer()); + + ck_tile::reference_gemm_gpu(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C); + + c_m_n_gpu_buf_ref.FromDevice(c_m_n_host_result.data()); + } +} diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark_single.cpp b/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark_single.cpp new file mode 100644 index 0000000000..13cadcd55a --- /dev/null +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_benchmark_single.cpp @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "gemm_streamk_profiler.hpp" +#include "gemm_streamk_common.hpp" + +// The kernel header is included via the compile command line with -include flag +// It defines SelectedKernel struct and KERNEL_NAME +// DataTypeTraits are now defined in gemm_streamk_common.hpp + +// Create argument parser +inline auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3840", "The value for m dimension.") + .insert("n", "4096", "The value for n dimension.") + .insert("k", "2048", "The value for k dimension.") + .insert("stride_a", "0", "The stride value for tensor A.") + .insert("stride_b", "0", "The stride value for tensor B.") + .insert("stride_c", "0", "The stride value for tensor C.") + .insert("split_k", "1", "The split value for k dimension.") + .insert("verify", + "0", + "The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 " + "for validation on GPU.") + .insert("log", + "false", + "Whether output kernel instance information or not. Possible values are true or " + "false.") + .insert("warmup", "50", "The number of iterations before benchmark the kernel.") + .insert("repeat", "100", "The number of iterations to benchmark the kernel.") + .insert("timer", + "true", + "Whether if the timer is gpu timer or not. Possible values are false or true. " + "") + .insert("init", + "0", + "The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 " + "for constant(1).") + .insert("flush_cache", "true", "To flush cache, possible values are true or false.") + .insert("rotating_count", "1000", "number of iterations to rotate the cache. default is 5.") + .insert("metric", + "0", + "Metric with which to measure kernel performance. Set to 0 for latency, 1 for " + "tflops, or 2 for bandwidth.") + .insert("csv_filename", + "", + "The filename of benchmark result. Default is empty (no CSV output).") + .insert("structured_sparsity", + "false", + "Whether use sparsity kernel or not. Possible values are true or false.") + .insert( + "json_output", + "false", + "Whether to output results in JSON format only. Possible values are true or false."); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +void repeat_once_if_verify(Setting& setting) +{ + // The output buffer will be reset after each run, which means the gemm result will be + // accumulated in the output buffer. So limit the repeat to 1 if verify is true. + if(setting.verify_) + { + setting.n_repeat_ = 1; + setting.n_warmup_ = 0; + } +} + +void benchmark_gemm_single(const ck_tile::ArgParser& arg_parser) +{ + // Use DataTypeTraits to get the actual type names from the generated header + // The generated header defines ADataType, BDataType, AccDataType, CDataType + std::string dtype_a = DataTypeTraits::name; + std::string dtype_b = DataTypeTraits::name; + std::string dtype_acc = DataTypeTraits::name; + std::string dtype_c = DataTypeTraits::name; + + // Layout names from the layout types + std::string layout_a = ALayout::name; + std::string layout_b = BLayout::name; + std::string layout_c = CLayout::name; + + // Create GemmProblem struct + GemmProblem gemm_problem{arg_parser.get_int("split_k"), + arg_parser.get_int("m"), + arg_parser.get_int("n"), + arg_parser.get_int("k"), + arg_parser.get_int("stride_a"), + arg_parser.get_int("stride_b"), + arg_parser.get_int("stride_c"), + dtype_a, + dtype_b, + dtype_acc, + dtype_c, + layout_a, + layout_b, + layout_c, + arg_parser.get_bool("structured_sparsity")}; + + // Create Setting struct + Setting setting{arg_parser.get_int("warmup"), + arg_parser.get_int("repeat"), + arg_parser.get_bool("timer"), + arg_parser.get_int("verify"), + arg_parser.get_int("init"), + arg_parser.get_bool("log"), + arg_parser.get_str("csv_filename"), + arg_parser.get_bool("flush_cache"), + arg_parser.get_int("rotating_count"), + arg_parser.get_bool("json_output")}; + + repeat_once_if_verify(setting); + + // Get the profiler instance + auto& profiler = GemmProfiler::instance(setting); + + try + { + // Create a lambda that wraps the kernel launch + auto kernel_func = [](const ck_tile::StreamKHostArgs& args, + const ck_tile::stream_config& stream) { + return SelectedKernel::launch(args, stream); + }; + + // Benchmark the kernel + profiler.benchmark(gemm_problem, kernel_func); + + // Select best instance based on metric + profiler.select_best_instance(static_cast(arg_parser.get_int("metric"))); + } + catch(const std::exception& e) + { + std::cerr << "Benchmark failed: " << e.what() << std::endl; + } +} + +int main(int argc, char* argv[]) +{ + try + { + auto [result, parser] = create_args(argc, argv); + if(!result) + { + parser.print(); + return EXIT_FAILURE; + } + + benchmark_gemm_single(parser); + return EXIT_SUCCESS; + } + catch(const std::exception& e) + { + std::cerr << "Error: " << e.what() << "\n"; + return EXIT_FAILURE; + } +} diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_common.hpp b/tile_engine/ops/gemm_streamk/gemm_streamk_common.hpp new file mode 100644 index 0000000000..179aeb7307 --- /dev/null +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_common.hpp @@ -0,0 +1,145 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/pk_int4.hpp" + +// DataTypeTraits for all supported types +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp64"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "pk_int4_t"; +}; + +// Helper function to determine if a layout is row-major +template +constexpr auto is_row_major(Layout) +{ + return ck_tile::bool_constant>{}; +} + +// Structure to hold kernel traits for dispatcher +struct KernelTraits +{ + std::string pipeline; // compv3, compv4, mem + std::string scheduler; // intrawave, interwave + std::string epilogue; // cshuffle, default + bool pad_m; + bool pad_n; + bool pad_k; + bool persistent; + + // Constructor with defaults + KernelTraits() + : pipeline("compv3"), + scheduler("intrawave"), + epilogue("cshuffle"), + pad_m(false), + pad_n(false), + pad_k(false), + persistent(false) + { + } +}; + +// Helper to extract traits from kernel name +inline KernelTraits extract_traits_from_name(const std::string& kernel_name) +{ + KernelTraits traits; + + // Extract pipeline + if(kernel_name.find("compv3") != std::string::npos) + { + traits.pipeline = "compv3"; + } + else if(kernel_name.find("compv4") != std::string::npos) + { + traits.pipeline = "compv4"; + } + else if(kernel_name.find("mem") != std::string::npos) + { + traits.pipeline = "mem"; + } + + // Extract scheduler + if(kernel_name.find("interwave") != std::string::npos) + { + traits.scheduler = "interwave"; + } + else + { + traits.scheduler = "intrawave"; + } + + // Extract epilogue + if(kernel_name.find("default") != std::string::npos && + kernel_name.find("default_") == std::string::npos) + { + traits.epilogue = "default"; + } + else + { + traits.epilogue = "cshuffle"; + } + + // Padding flags would need to be extracted from the kernel configuration + // For now, we'll leave them as false + + return traits; +} diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py b/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py new file mode 100644 index 0000000000..6aebc54564 --- /dev/null +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py @@ -0,0 +1,905 @@ +#!/usr/bin/env python + +import os +import json +import argparse +import itertools +import multiprocessing +import concurrent.futures +from pathlib import Path +import logging +from typing import Optional +from gemm_streamk_validation_utils import ( + is_tile_config_valid, + is_trait_combination_valid, +) + +logging.basicConfig(level=logging.INFO) + + +class GemmKernelBuilder: + def __init__(self, working_path, datatype, layout, config_json=None): + self.working_path = Path(working_path) + self.datatype = datatype + self.layout = layout + self.config_json = config_json + + # Create working directory if it doesn't exist + self.working_path.mkdir(parents=True, exist_ok=True) + + # Load configuration + if config_json and os.path.exists(config_json): + with open(config_json, "r") as f: + self.config = json.load(f) + else: + self.config = self._get_default_config() + + def _get_default_config(self): + """Return default configuration if no config file is provided""" + # Define base tile configurations that work for all layouts + base_fp16_configs = [ + { + "tile_m": 256, + "tile_n": 256, + "tile_k": 32, + "warp_m": 2, + "warp_n": 2, + "warp_k": 1, + "warp_tile_m": 32, + "warp_tile_n": 32, + "warp_tile_k": 32, + }, + { + "tile_m": 256, + "tile_n": 128, + "tile_k": 32, + "warp_m": 2, + "warp_n": 2, + "warp_k": 1, + "warp_tile_m": 32, + "warp_tile_n": 32, + "warp_tile_k": 16, + }, + ] + + base_fp8_configs = [ + { + "tile_m": 256, + "tile_n": 256, + "tile_k": 32, + "warp_m": 4, + "warp_n": 1, + "warp_k": 1, + "warp_tile_m": 32, + "warp_tile_n": 32, + "warp_tile_k": 32, + }, + { + "tile_m": 256, + "tile_n": 128, + "tile_k": 32, + "warp_m": 1, + "warp_n": 4, + "warp_k": 1, + "warp_tile_m": 16, + "warp_tile_n": 16, + "warp_tile_k": 32, + }, + ] + + # Create configurations for all supported layouts + all_layouts = ["rcr", "rrr", "ccr", "crr"] + tile_configs = {} + + for datatype, base_configs in [ + ("fp16", base_fp16_configs), + ("fp8", base_fp8_configs), + ]: + tile_configs[datatype] = {} + for layout in all_layouts: + tile_configs[datatype][layout] = base_configs + + return { + "tile_configs": tile_configs, + "traits": { + "pipelines": ["mem", "compv3", "compv4"], + "epilogues": ["default", "cshuffle"], + "schedulers": ["intrawave", "interwave"], + }, + "structured_sparsity": ["false"], + "padding": {"pad_m": ["false"], "pad_n": ["false"], "pad_k": ["false"]}, + "persistent": ["false"], + "reduction_strategy": ["reduction"], + } + + def _get_tile_configs(self, fast_mode=False): + """Get tile configurations for the current datatype and layout""" + if "tile_configs" in self.config: + # Old format + return ( + self.config["tile_configs"].get(self.datatype, {}).get(self.layout, []) + ) + elif "tile_config" in self.config: + # New format - generate combinations from individual parameter values + tile_config = self.config["tile_config"] + + # Get all possible values for each parameter + tile_m_values = tile_config.get("tile_m", {}).get("values", [256]) + tile_n_values = tile_config.get("tile_n", {}).get("values", [256]) + tile_k_values = tile_config.get("tile_k", {}).get("values", [32]) + warp_m_values = tile_config.get("warp_m", {}).get("values", [2]) + warp_n_values = tile_config.get("warp_n", {}).get("values", [2]) + warp_k_values = tile_config.get("warp_k", {}).get("values", [1]) + warp_tile_m_values = tile_config.get("warp_tile_m", {}).get("values", [32]) + warp_tile_n_values = tile_config.get("warp_tile_n", {}).get("values", [32]) + warp_tile_k_values = tile_config.get("warp_tile_k", {}).get("values", [32]) + + # Generate all combinations + configs = [] + for tile_m in tile_m_values: + for tile_n in tile_n_values: + for tile_k in tile_k_values: + for warp_m in warp_m_values: + for warp_n in warp_n_values: + for warp_k in warp_k_values: + for warp_tile_m in warp_tile_m_values: + for warp_tile_n in warp_tile_n_values: + for warp_tile_k in warp_tile_k_values: + # Validate configuration + if self._validate_tile_config( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + fast_mode=fast_mode, + ): + configs.append( + { + "tile_m": tile_m, + "tile_n": tile_n, + "tile_k": tile_k, + "warp_m": warp_m, + "warp_n": warp_n, + "warp_k": warp_k, + "warp_tile_m": warp_tile_m, + "warp_tile_n": warp_tile_n, + "warp_tile_k": warp_tile_k, + } + ) + return configs + else: + # Fallback to default + return [] + + def _validate_tile_config( + self, + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + pipeline="mem", # Default pipeline for validation + fast_mode=False, # Add fast mode option + ): + """Validate that tile configuration is reasonable""" + if fast_mode: + # Fast validation for listing - only basic sanity checks + if tile_m <= 0 or tile_n <= 0 or tile_k <= 0: + return False + if warp_m <= 0 or warp_n <= 0 or warp_k <= 0: + return False + if warp_tile_m <= 0 or warp_tile_n <= 0 or warp_tile_k <= 0: + return False + + # Basic divisibility check + if tile_m % (warp_m * warp_tile_m) != 0: + return False + if tile_n % (warp_n * warp_tile_n) != 0: + return False + if tile_k % (warp_k * warp_tile_k) != 0: + return False + + return True + else: + # Full validation for generation + # Determine data types for validation + a_datatype = self.datatype + b_datatype = self.datatype + c_datatype = self.datatype + + # Special handling for certain data types + if self.datatype in ["fp8", "bf8"]: + c_datatype = "fp16" + + # Use the comprehensive validation function + return is_tile_config_valid( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + a_datatype, + b_datatype, + c_datatype, + pipeline, + ) + + def _generate_trait_combinations(self): + """Generate all combinations of traits""" + if "trait_config" in self.config: + # New format + trait_config = self.config["trait_config"] + + pipelines = trait_config.get("pipeline", {}).get("values", ["mem"]) + epilogues = trait_config.get("epilogue", {}).get("values", ["default"]) + schedulers = trait_config.get("scheduler", {}).get("values", ["intrawave"]) + pad_m_values = trait_config.get("pad_m", {}).get("values", [False]) + pad_n_values = trait_config.get("pad_n", {}).get("values", [False]) + pad_k_values = trait_config.get("pad_k", {}).get("values", [False]) + persistent_values = trait_config.get("persistent", {}).get( + "values", [False] + ) + reduction_strategy_value = trait_config.get("reduction_strategy", {}).get( + "values", ["reduction"] + ) + + all_combinations = list( + itertools.product( + pipelines, + epilogues, + schedulers, + reduction_strategy_value, + pad_m_values, + pad_n_values, + pad_k_values, + persistent_values, + ) + ) + + # Filter out unsupported trait combinations + combinations = [] + for combo in all_combinations: + pipeline, epilogue, scheduler, reduction_strategy = combo[:4] + if is_trait_combination_valid( + pipeline, epilogue, scheduler, reduction_strategy + ): + combinations.append(combo) + else: + logging.debug( + f"Skipping unsupported trait combination: {pipeline}-{epilogue}-{scheduler}-{reduction_strategy}" + ) + else: + # Fallback to minimal default + combinations = [ + ( + "compv3", + "cshuffle", + "intrawave", + "reduction_strategy", + False, + False, + False, + False, + ) + ] + + return combinations + + def _get_dtype_string(self): + """Get C++ type string for datatype""" + dtype_map = { + "fp16": "ck_tile::fp16_t", + "fp8": "ck_tile::fp8_t", + "bf16": "ck_tile::bf16_t", + "fp32": "float", + "fp64": "double", + } + return dtype_map.get(self.datatype, "float") + + _LAYOUT_MAP = { + "r": "ck_tile::tensor_layout::gemm::RowMajor", + "c": "ck_tile::tensor_layout::gemm::ColumnMajor", + } + + def _get_abc_layouts(self, layout_code: Optional[str] = None): + """ + Return (ALayout, BLayout, CLayout) from a 3-letter code like 'rcr', 'ccr', 'crr', 'rrr'. + If layout_code is None, use self.layout. + """ + if layout_code is None: + # fall back to the instance field + layout_code = getattr(self, "layout", "") + + code = str(layout_code).strip().lower() + + if len(code) != 3 or any(ch not in self._LAYOUT_MAP for ch in code): + raise ValueError( + f"Invalid layout '{layout_code}'. " + "Use a 3-letter code with 'r'/'c' (e.g., rcr, ccr, crr, rrr)." + ) + + a_layout = self._LAYOUT_MAP[code[0]] + b_layout = self._LAYOUT_MAP[code[1]] + c_layout = self._LAYOUT_MAP[code[2]] + return a_layout, b_layout, c_layout + + def _generate_kernel_instance(self, tile_config, trait_combo, is_header=True): + """Generate a single kernel instance""" + ( + pipeline, + epilogue, + scheduler, + reduction_strategy, + pad_m, + pad_n, + pad_k, + persistent, + ) = trait_combo + + # Create kernel name with proper boolean capitalization + kernel_name = f"{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{reduction_strategy}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}" + + # Create tile configuration string + tile_str = ( + f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" + ) + tile_str += ( + f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_" + ) + tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}" + + kernel_name += f"_{tile_str}" + + # Map pipeline names to the correct pipeline implementation + pipeline_impl_map = { + "mem": "ck_tile::GemmPipelineAgBgCrMem", + "compv3": "ck_tile::GemmPipelineAgBgCrCompV3", + "compv4": "ck_tile::GemmPipelineAgBgCrCompV4", + } + + reduction_strategy_map = { + "atomic": "ck_tile::StreamKReductionStrategy::Atomic", + "reduction": "ck_tile::StreamKReductionStrategy::Reduction", + } + + # Determine accumulator type based on datatype + acc_type = "float" + if self.datatype in ["int8", "int4"]: + acc_type = "ck_tile::int32_t" + + # Determine output type + c_type = self._get_dtype_string() + if self.datatype in ["fp8", "bf8"]: + c_type = "ck_tile::fp16_t" + + # Determine layouts based on self.layout + a_layout, b_layout, c_layout = self._get_abc_layouts() + + # Generate kernel instance code using the correct API + pragma_line = "#pragma once\n" if is_header else "" + instance_code = f"""// Generated kernel instance for {kernel_name} +{pragma_line} +#include +#include +#include +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" +#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" + +using ADataType = {self._get_dtype_string()}; +using BDataType = {self._get_dtype_string()}; +using AccDataType = {acc_type}; +using CDataType = {c_type}; + +using ALayout = {a_layout}; +using BLayout = {b_layout}; +using CLayout = {c_layout}; + +// Kernel name for display +constexpr const char* KERNEL_NAME = "{kernel_name}"; + +// Wrapper for simplified launch interface +struct SelectedKernel {{ + // Tile configuration + static constexpr ck_tile::index_t BlockSize = 256; + static constexpr ck_tile::index_t TileM = {tile_config["tile_m"]}; + static constexpr ck_tile::index_t TileN = {tile_config["tile_n"]}; + static constexpr ck_tile::index_t TileK = {tile_config["tile_k"]}; + static constexpr ck_tile::index_t WarpPerBlock_M = {tile_config["warp_m"]}; + static constexpr ck_tile::index_t WarpPerBlock_N = {tile_config["warp_n"]}; + static constexpr ck_tile::index_t WarpPerBlock_K = {tile_config["warp_k"]}; + static constexpr ck_tile::index_t WarpTileM = {tile_config["warp_tile_m"]}; + static constexpr ck_tile::index_t WarpTileN = {tile_config["warp_tile_n"]}; + static constexpr ck_tile::index_t WarpTileK = {tile_config["warp_tile_k"]}; + + // Traits + static constexpr bool kPadM = {"true" if pad_m == "true" else "false"}; + static constexpr bool kPadN = {"true" if pad_n == "true" else "false"}; + static constexpr bool kPadK = {"true" if pad_k == "true" else "false"}; + static constexpr bool Preshuffle = false; + + static constexpr bool DoubleSmemBuffer = {"true" if pipeline == "compv4" else "false"}; + static constexpr int kBlockPerCu = 1; + static constexpr bool StructuredSparsity = false; + static constexpr bool NumWaveGroup = 1; + + static constexpr bool TransposeC = false; + static constexpr bool UsePersistentKernel = {"true" if str(persistent).lower() == "true" else "false"}; + static constexpr bool UseStructuredSparsity = false; + static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr ck_tile::StreamKReductionStrategy reduction_strategy = {reduction_strategy_map.get(reduction_strategy, "ck_tile::StreamKReductionStrategy::Reduction")}; + + // Tile shape + using TileShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + + // Tile partitioner + using TilePartitioner = ck_tile::StreamKTilePartitioner; + + // Traits + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + + // Pipeline problem + using GemmPipelineProblem = ck_tile::GemmPipelineProblem< + ADataType, + BDataType, + AccDataType, + TileShape, + GemmUniversalTraits>; + + static float launch(const ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& stream) {{ + const auto Run = [&](const auto memory_operation_) {{ + constexpr auto memory_operation = memory_operation_.value; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = {pipeline_impl_map.get(pipeline, "ck_tile::GemmPipelineAgBgCrCompV3")}; + + // Epilogue + using EpilogueProblem = ck_tile::CShuffleEpilogueProblem< + ADataType, + BDataType, + ck_tile::tuple<>, // DsDataType + AccDataType, + CDataType, + ck_tile::tuple<>, // DsLayout + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, // kM_ + TilePartitioner::NPerBlock, // kN_ + WarpPerBlock_M, // MWave_ + WarpPerBlock_N, // NWave_ + WarpTileM, // MPerXdl_ + WarpTileN, // NPerXdl_ + WarpTileK, // KPerXdl_ + TransposeC, // isCTransposed_ + memory_operation, // MemoryOperation_ + NumWaveGroups>; // kNumWaveGroups_ + + using GemmEpilogue = ck_tile::CShuffleEpilogue; + + // Kernel type + using GemmKernel = ck_tile::StreamKKernel; + + // Make kernel arguments + auto kargs = GemmKernel::MakeKernelArgs(args); + const auto workspace_size = GemmKernel::GetWorkSpaceSize(kargs); + ck_tile::DeviceMem workspace_data(workspace_size); + workspace_data.SetZero(); + kargs.workspace_ptr = workspace_data.GetDeviceBuffer(); + + if (!GemmKernel::IsSupportedArgument(kargs)) {{ + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!"); + }} + + // Get grid and block sizes + const dim3 grids = GemmKernel::GridSize(kargs.tile_partitioner); + const dim3 blocks = GemmKernel::BlockSize(); + + if(stream.log_level_ > 0) {{ + std::cout << "Launching kernel with args: " << GemmKernel::GetName() << "\\n" + << "shape: " << TileShape::GetName() << "\\n" + << "problem: " << UniversalGemmProblem::GetName() << "\\n" + << "pipeline: " << GemmPipeline::GetName() << "\\n" + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + }} + + auto reset_data_buffers = [&]() {{ + if(reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic) + {{ + // Clear the output C tensor results after each repetition of the kernel + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream.stream_id_)); + }} + else if(reduction_strategy == ck_tile::StreamKReductionStrategy::Reduction) + {{ + // Reset sk flags to zero before each repetition of the kernel + workspace_data.SetZero(); + }} + }}; + + + // Launch kernel + float ave_time = ck_tile::launch_kernel_time_mask( + stream, + reset_data_buffers, + ck_tile::make_kernel(GemmKernel{{}}, grids, blocks, 0, kargs)); + return ave_time; + + // ck_tile::index_t num_wgs_per_tile = kargs.tile_partitioner.estimate_num_wgs_per_tile(); + // return std::make_tuple(ave_time, num_wgs_per_tile); + }}; + + + if constexpr(ck_tile::StreamKReductionStrategy::Atomic == reduction_strategy) + {{ + return Run(ck_tile::integral_constant{{}}); + }} + else // We are using ck_tile::StreamKReductionStrategy::Reduction + {{ + return Run(ck_tile::integral_constant{{}}); + }} + }} +}}; +""" + + return kernel_name, instance_code + + def generate_individual(self, num_workers=None): + """Generate individual kernel files for separate compilation with parallel processing""" + if num_workers is None: + num_workers = min( + multiprocessing.cpu_count(), 8 + ) # Limit to avoid memory issues + + tile_configs = self._get_tile_configs() + trait_combos = self._generate_trait_combinations() + + # Prepare work items for parallel processing + work_items = [] + for tile_config in tile_configs: + for trait_combo in trait_combos: + work_items.append( + ( + tile_config, + trait_combo, + self.working_path, + self.datatype, + self.layout, + ) + ) + + print( + f"Generating {len(work_items)} individual kernel files using {num_workers} workers..." + ) + print(f" Tile configs: {len(tile_configs)}") + print(f" Trait combinations: {len(trait_combos)}") + print(f" Total kernels: {len(work_items)}") + + # Show first few work items for debugging + if work_items: + print(" First work item example:") + tile_config, trait_combo = work_items[0][:2] + print(f" Tile config: {tile_config}") + print(f" Trait combo: {trait_combo[:3]}") # Show first 3 traits + + # Process work items in parallel + kernel_list = [] + completed = 0 + + with concurrent.futures.ProcessPoolExecutor( + max_workers=num_workers + ) as executor: + # Submit all work items + print(f" Submitting {len(work_items)} tasks to executor...") + future_to_item = { + executor.submit(_generate_single_kernel_individual, item): item + for item in work_items + } + print(" All tasks submitted, waiting for completion...") + + # Collect results with progress reporting + for future in concurrent.futures.as_completed(future_to_item): + completed += 1 + if completed % 100 == 0 or completed == len(work_items): + print( + f" Progress: {completed}/{len(work_items)} kernels generated" + ) + + try: + result = future.result() + if result: + kernel_list.append(result) + except Exception as exc: + item = future_to_item[future] + print(f"Kernel generation failed for {item}: {exc}") + + # Sort kernel list for consistent ordering + kernel_list.sort(key=lambda x: x[0]) # Sort by kernel name + + # Generate CMake include file for individual targets + self._generate_cmake_individual_targets(kernel_list) + + print( + f"Generated {len(kernel_list)} individual kernel files in {self.working_path}" + ) + + def _generate_cmake_individual_targets(self, kernel_list): + """Generate CMake include file that creates individual targets""" + cmake_code = f"""# Generated CMake file for individual GEMM targets +# Datatype: {self.datatype}, Layout: {self.layout} + +""" + + for kernel_name, trait_combo, tile_config in kernel_list: + pipeline, epilogue, scheduler = trait_combo[:3] + + # Format tile config for CMake function + tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" + tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_" + tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}" + + trait_str = f"{pipeline}_{epilogue}_{scheduler}_" + "_".join( + str(x) for x in trait_combo[3:] + ) + + cmake_code += f'create_individual_gemm_target("{self.datatype}" "{self.layout}" "{trait_str}" "{tile_str}")\n' + + # Write CMake include file + with open(self.working_path / "gemm_individual_targets.cmake", "w") as f: + f.write(cmake_code) + + def write_kernel_list(self): + """Write kernel list to file for CMake to read (with comprehensive validation)""" + # Get configurations using comprehensive validation + tile_configs = self._get_tile_configs(fast_mode=False) + trait_combos = self._generate_trait_combinations() + + kernel_list = [] + for tile_config in tile_configs: + for trait_combo in trait_combos: + ( + pipeline, + epilogue, + scheduler, + pad_m, + pad_n, + pad_k, + persistent, + reduction_strategy, + ) = trait_combo + + # Create kernel name with proper boolean capitalization + kernel_name = f"gemm_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}_{reduction_strategy}" + + # Create tile configuration string + tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" + tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_" + tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}" + + kernel_name += f"_{tile_str}" + + kernel_list.append( + { + "name": kernel_name, + "tile_config": tile_config, + "trait_combo": trait_combo, + } + ) + + # Write kernel count + with open(self.working_path / "gemm_kernel_count.txt", "w") as f: + f.write(str(len(kernel_list))) + + # Write kernel list + with open(self.working_path / "gemm_kernel_list.txt", "w") as f: + for kernel in kernel_list: + # Format: kernel_name|tile_config|trait_combo + tile_config = kernel["tile_config"] + trait_combo = kernel["trait_combo"] + + tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_" + tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_" + tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}" + + trait_str = ( + f"{trait_combo[0]}_{trait_combo[1]}_{trait_combo[2]}_" + + "_".join(str(x) for x in trait_combo[3:]) + ) + + f.write(f"{kernel['name']}|{tile_str}|{trait_str}\n") + + print(f"Listed {len(kernel_list)} kernel configurations") + + def run(self, num_workers=None): + """Run the builder to generate individual kernel files""" + # Generate individual kernel files + self.generate_individual(num_workers) + + +def _generate_single_kernel_individual(work_item): + """Worker function to generate a single individual kernel file""" + tile_config, trait_combo, working_path, datatype, layout = work_item + + # Create a temporary builder instance for this worker + builder = GemmKernelBuilder(working_path, datatype, layout) + + try: + kernel_name, instance_code = builder._generate_kernel_instance( + tile_config, trait_combo + ) + + # Create simplified filename without the "gemm_" prefix + # Remove "gemm_" from the beginning of kernel_name for the filename + simplified_name = kernel_name + if simplified_name.startswith("gemm_"): + simplified_name = simplified_name[5:] # Remove "gemm_" prefix + + # Write individual header file + header_file = working_path / f"gemm_streamk_single_{simplified_name}.hpp" + with open(header_file, "w") as f: + f.write(instance_code) + + return (kernel_name, trait_combo, tile_config) + except Exception as e: + print(f"Error generating individual kernel: {e}") + return None + + +def main(): + parser = argparse.ArgumentParser( + description="GEMM kernel instance builder with parallel support" + ) + parser.add_argument("--working_path", required=True, help="Working directory path") + parser.add_argument( + "--datatype", + required=True, + choices=["fp16", "fp8", "bf16", "fp32", "fp64"], + help="Data type", + ) + parser.add_argument( + "--layout", + required=True, + choices=["rcr", "rrr", "ccr", "crr"], + help="Matrix layout", + ) + parser.add_argument("--config_json", help="Configuration JSON file") + parser.add_argument( + "--num_workers", type=int, help="Number of parallel workers (default: auto)" + ) + parser.add_argument( + "--gen_individual", action="store_true", help="Generate individual kernel files" + ) + parser.add_argument( + "--gen_single", action="store_true", help="Generate a single kernel file" + ) + parser.add_argument("--kernel_name", help="Kernel name for single generation") + parser.add_argument( + "--tile_config", help="Tile configuration string for single generation" + ) + parser.add_argument( + "--trait_combo", help="Trait combination string for single generation" + ) + parser.add_argument( + "--list_kernels", + action="store_true", + help="List kernel configurations without generating files", + ) + + args = parser.parse_args() + + # Create builder + builder = GemmKernelBuilder( + args.working_path, args.datatype, args.layout, args.config_json + ) + + if args.list_kernels: + # Fast listing mode - just write kernel list without generating files + builder.write_kernel_list() + elif args.gen_single: + # Generate a single kernel file + if not args.kernel_name or not args.tile_config or not args.trait_combo: + parser.error( + "--gen_single requires --kernel_name, --tile_config, and --trait_combo" + ) + + # Parse tile config + tile_parts = args.tile_config.split("_") + tile_dims = tile_parts[0].split("x") + warp_dims = tile_parts[1].split("x") + warp_tile_dims = tile_parts[2].split("x") + + tile_config = { + "tile_m": int(tile_dims[0]), + "tile_n": int(tile_dims[1]), + "tile_k": int(tile_dims[2]), + "warp_m": int(warp_dims[0]), + "warp_n": int(warp_dims[1]), + "warp_k": int(warp_dims[2]), + "warp_tile_m": int(warp_tile_dims[0]), + "warp_tile_n": int(warp_tile_dims[1]), + "warp_tile_k": int(warp_tile_dims[2]), + } + + # Parse trait combo + trait_parts = args.trait_combo.split("_") + trait_combo = ( + trait_parts[0], # pipeline + trait_parts[1], # epilogue + trait_parts[2], # scheduler + trait_parts[3], # reduction_strategy + trait_parts[4] == "false", # pad_m + trait_parts[5] == "false", # pad_n + trait_parts[6] == "false", # pad_k + trait_parts[7], # persistent + ) + + # Generate the kernel + kernel_name, instance_code = builder._generate_kernel_instance( + tile_config, trait_combo + ) + + # Write the file + simplified_name = kernel_name + if simplified_name.startswith("gemm_"): + simplified_name = simplified_name[5:] + + header_file = ( + builder.working_path / f"gemm_streamk_single_{simplified_name}.hpp" + ) + with open(header_file, "w") as f: + f.write(instance_code) + + print(f"Generated {header_file}") + + elif args.gen_individual: + # Generate all individual kernel files + builder.run(args.num_workers) + else: + parser.error( + "Must specify one of: --list_kernels, --gen_individual, or --gen_single" + ) + + +if __name__ == "__main__": + main() diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_profiler.hpp b/tile_engine/ops/gemm_streamk/gemm_streamk_profiler.hpp new file mode 100644 index 0000000000..256e0b9ca4 --- /dev/null +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_profiler.hpp @@ -0,0 +1,296 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck_tile/host/device_prop.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "gemm_streamk_benchmark.hpp" + +class GemmProfiler +{ + public: + static GemmProfiler& instance(Setting setting) + { + static GemmProfiler instance{setting}; + return instance; + } + + // Overload for single kernel benchmarking + void benchmark(GemmProblem& gemm_problem, + std::function kernel_func) + { + // Create a vector with a single callable that returns both name and time + std::vector(ck_tile::StreamKHostArgs&, + const ck_tile::stream_config&)>> + callables; + + callables.push_back( + [kernel_func](ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& stream) { + float time = kernel_func(args, stream); + return std::make_tuple(std::string(KERNEL_NAME), time); + }); + + benchmark(gemm_problem, callables); + } + + void benchmark(GemmProblem& gemm_problem, + std::vector( + ck_tile::StreamKHostArgs&, const ck_tile::stream_config&)>>& callables) + { + const ALayout layout_a = ALayout{}; + const BLayout layout_b = BLayout{}; + const CLayout layout_c = CLayout{}; + + gemm_problem.stride_a_ = ck_tile::get_default_stride( + gemm_problem.m_, gemm_problem.k_, gemm_problem.stride_a_, is_row_major(layout_a)); + gemm_problem.stride_b_ = ck_tile::get_default_stride( + gemm_problem.k_, gemm_problem.n_, gemm_problem.stride_b_, is_row_major(layout_b)); + gemm_problem.stride_c_ = ck_tile::get_default_stride( + gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c)); + + ck_tile::HostTensor a_m_k(ck_tile::host_tensor_descriptor( + gemm_problem.m_, gemm_problem.k_, gemm_problem.stride_a_, is_row_major(layout_a))); + ck_tile::HostTensor b_k_n(ck_tile::host_tensor_descriptor( + gemm_problem.k_, gemm_problem.n_, gemm_problem.stride_b_, is_row_major(layout_b))); + ck_tile::HostTensor c_m_n_dev_result(ck_tile::host_tensor_descriptor( + gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c))); + + if(setting_.init_method_ == 0) + { + ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n); + } + else if(setting_.init_method_ == 1) + { + ck_tile::FillMonotonicSeq{}(a_m_k); + ck_tile::FillMonotonicSeq{}(b_k_n); + } + else if(setting_.init_method_ == 2) + { + ck_tile::FillConstant{static_cast(1)}(a_m_k); + ck_tile::FillConstant{static_cast(1)}(b_k_n); + } + else + { + a_m_k.SetZero(); + b_k_n.SetZero(); + } + + if(gemm_problem.structured_sparsity_) + { + ck_tile::AdjustToStructuredSparsity{}(a_m_k); + } + + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + + if constexpr(std::is_same_v) + { + // Permute vector pk_i4x4 data for device implementation + ck_tile::HostTensor b_k_n_dev = b_k_n; + // permute_tensor_b(b_k_n_dev); + permute_vectors_i4x4_b(b_k_n_dev); + b_k_n_dev_buf.ToDevice(b_k_n_dev.data()); + } + else + { + b_k_n_dev_buf.ToDevice(b_k_n.data()); + } + + a_m_k_dev_buf.ToDevice(a_m_k.data()); + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + + ck_tile::StreamKHostArgs gemm_args{a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + c_m_n_dev_buf.GetDeviceBuffer(), + gemm_problem.m_, + gemm_problem.n_, + gemm_problem.k_, + gemm_problem.stride_a_, + gemm_problem.stride_b_, + gemm_problem.stride_c_}; + + ck_tile::HostTensor c_m_n_host_result(ck_tile::host_tensor_descriptor( + gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c))); + + if(setting_.verify_) + { + gemm_host_reference(setting_.verify_, + a_m_k, + b_k_n, + c_m_n_host_result, + a_m_k_dev_buf, + b_k_n_dev_buf, + gemm_problem.m_, + gemm_problem.n_, + gemm_problem.k_, + gemm_problem.stride_a_, + gemm_problem.stride_b_, + gemm_problem.stride_c_); + } + + for(auto& callable : callables) + { + auto kernel_run_result = callable(gemm_args, + ck_tile::stream_config{nullptr, + true, + setting_.log_, + setting_.n_warmup_, + setting_.n_repeat_, + setting_.is_gpu_timer_, + setting_.flush_cache_, + setting_.rotating_count_}); + process_result(gemm_problem, + c_m_n_dev_buf, + c_m_n_host_result, + c_m_n_dev_result, + kernel_run_result); + } + } + + void process_result(const GemmProblem& gemm_problem, + ck_tile::DeviceMem& c_m_n_dev_buf, + ck_tile::HostTensor& c_m_n_host_result, + ck_tile::HostTensor& c_m_n_dev_result, + const std::tuple& kernel_run_result) + { + auto [name, avg_time] = kernel_run_result; + auto dp_persistent = + SelectedKernel::UsePersistentKernel ? "PersistentKernel" : "NonPersistentKernel"; + auto reduction_strategy = + SelectedKernel::reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic + ? "Atomic" + : "Reduction"; + + KernelInstance kernel_instance{ + name, dp_persistent, reduction_strategy, gemm_problem, {-1.0f, -1.0f, -1.0f}}; + + // compute performance metric + std::size_t flop = std::size_t(2) * gemm_problem.m_ * gemm_problem.n_ * gemm_problem.k_; + std::size_t num_byte = sizeof(ADataType) * gemm_problem.m_ * gemm_problem.k_ + + sizeof(BDataType) * gemm_problem.n_ * gemm_problem.k_ + + sizeof(CDataType) * gemm_problem.m_ * gemm_problem.n_; + + // update + kernel_instance.perf_result_.latency_ = avg_time; + kernel_instance.perf_result_.tflops_ = static_cast(flop) / 1.E9 / avg_time; + kernel_instance.perf_result_.bandwidth_ = num_byte / 1.E6 / avg_time; + + if(setting_.log_ > 0 && !setting_.json_output_) + { + std::cout << kernel_instance << std::endl; + } + + // verify result + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); + bool verified_correct = + !setting_.verify_ || + compare( + name, gemm_problem.k_, gemm_problem.split_k_, c_m_n_dev_result, c_m_n_host_result); + + if(verified_correct) + { + kernel_instances_.emplace_back(kernel_instance); + } + else + { + std::cout << "Verification failed, skip kernel: " << name << std::endl; + } + + // clear tensor + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + } + + KernelInstance select_best_instance(Metric metric) + { + if(kernel_instances_.empty()) + throw std::runtime_error("Empty instances"); + + auto kernel_instance = *std::max_element(kernel_instances_.begin(), + kernel_instances_.end(), + [metric](const auto& a, const auto& b) { + return PerformanceResult::compare( + b.perf_result_, a.perf_result_, metric); + }); + + if(setting_.json_output_) + { + // Output clean JSON only + std::cout << kernel_instance << std::endl; + } + else + { + std::cout << "**********************************" << std::endl; + std::cout << "According to given metrics: " << get_metric_name(metric) << "\n" + << "Current kernel performance is: " << kernel_instance << std::endl; + std::cout << "**********************************" << std::endl; + } + + if(!setting_.csv_filename_.empty()) + { + std::ofstream file(setting_.csv_filename_ + ".csv", std::ios::app); + + if(!file.is_open()) + { + std::cerr << "Warning: Failed to open CSV file for writing." << std::endl; + } + else + { + if(file.tellp() == 0) + { + file << "rocm_version,device_name," + << "split_k,m,n,k,stride_a,stride_b,stride_c," + << "dtype_a,dtype_b,dtype_acc,dtype_c," << "layout_a,layout_b,layout_c," + << "structured_sparsity," << "dp_persistent," << "reduction_strategy," + << "name," << "latency(ms),tflops(TFlops),bandwidth(GB/s),metric\n"; + } + + const auto& problem = kernel_instance.problem_; + const auto& name = kernel_instance.name_; + const auto& dp_persistent = kernel_instance.dp_persistent_; + const auto& reduction_strategy = kernel_instance.reduction_strategy_; + const auto& perf = kernel_instance.perf_result_; + + file << get_rocm_version() << "," << ck_tile::get_device_name() << "," + << problem.split_k_ << "," << problem.m_ << "," << problem.n_ << "," + << problem.k_ << "," << problem.stride_a_ << "," << problem.stride_b_ << "," + << problem.stride_c_ << "," << problem.dtype_a_ << "," << problem.dtype_b_ + << "," << problem.dtype_acc_ << "," << problem.dtype_c_ << "," + << problem.layout_a_ << "," << problem.layout_b_ << "," << problem.layout_c_ + << "," << problem.structured_sparsity_ << "," << dp_persistent << "," + << reduction_strategy << "," << name << "," << std::fixed + << std::setprecision(4) << perf.latency_ << "," << std::fixed + << std::setprecision(4) << perf.tflops_ << "," << std::fixed + << std::setprecision(4) << perf.bandwidth_ << "," << get_metric_name(metric) + << "\n"; + + if(!file) + { + std::cerr << "Warning: Error occurred while writing to CSV file." << std::endl; + } + } + } + + return kernel_instance; + } + + GemmProfiler(const GemmProfiler&) = delete; + GemmProfiler& operator=(const GemmProfiler&) = delete; + + private: + ~GemmProfiler() { kernel_instances_.clear(); } + GemmProfiler(Setting setting) : setting_(setting) {} + + Setting setting_; + + std::vector kernel_instances_; +}; diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_validation_utils.py b/tile_engine/ops/gemm_streamk/gemm_streamk_validation_utils.py new file mode 100644 index 0000000000..2288d7752f --- /dev/null +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_validation_utils.py @@ -0,0 +1,350 @@ +#!/usr/bin/env python +# SPDX-License-Identifier: MIT +# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +""" +Validation utilities for GEMM kernel generation. +Extracted from tile_engine_develop for consistency. +""" + +import subprocess +import re +from functools import lru_cache +import logging +from typing import Tuple, List + +# Element size mapping for different data types +ELEMENT_SIZE_MAP = { + "fp16": 2, + "bf16": 2, + "int8": 1, + "fp8": 1, + "bf8": 1, + "int4": 0.5, + "int32": 4, + "fp32": 4, + "fp64": 8, +} + +# Supported warp tile combinations for different GPU architectures and data types +WARP_TILE_SUPPORTED_COMBINATIONS = { + "gfx90a": { + "fp16_fp16_fp16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "bf16_bf16_bf16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]], + "bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32]], + }, + "gfx942": { + "fp16_fp16_fp16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "bf16_bf16_bf16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], + "bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]], + "int8_int8_int32": [[16, 16, 32], [32, 32, 16]], + }, + "gfx950": { + "fp16_fp16_fp16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "bf16_bf16_bf16": [ + [32, 32, 8], + [16, 16, 16], + [32, 32, 16], + [16, 16, 32], + [4, 64, 16], + [64, 4, 16], + ], + "fp8_fp8_fp16": [ + [32, 32, 16], + [32, 32, 32], + [16, 16, 32], + [16, 16, 64], + [16, 16, 128], + [32, 32, 64], + ], + "bf8_bf8_fp16": [ + [32, 32, 16], + [32, 32, 32], + [16, 16, 64], + [16, 16, 32], + [16, 16, 128], + [32, 32, 64], + ], + }, +} + +# Unsupported trait combinations +TRAIT_UNSUPPORTED_COMBINATIONS = { + ("compv3", "cshuffle", "interwave", "reduction"), + ("compv3", "default", "interwave", "reduction"), + ("compv3", "cshuffle", "interwave", "atomic"), + ("compv3", "default", "interwave", "atomic"), +} + + +def element_size(data_type: str) -> float: + """Calculate the size (in bytes) of a single element for given data type.""" + data_type = data_type.lower() + if data_type not in ELEMENT_SIZE_MAP: + raise ValueError(f"Unsupported data type: {data_type}") + return ELEMENT_SIZE_MAP[data_type] + + +GPU_NAME_PATTERN = re.compile(r"Name:\s*(gfx\d+\w*)") + + +@lru_cache(maxsize=1) +def get_gpu_name_by_id(gpu_id: int = 0) -> str: + """Retrieve GPU name (e.g. gfx90a) by device ID""" + try: + output = subprocess.check_output( + ["rocminfo"], text=True, stderr=subprocess.PIPE, timeout=5 + ) + if matches := GPU_NAME_PATTERN.finditer(output): + gpu_list = [m.group(1) for m in matches] + return gpu_list[gpu_id] if gpu_id < len(gpu_list) else "" + + return "" + + except subprocess.CalledProcessError as e: + logging.debug(f"GPU query failed (exit {e.returncode}): {e.stderr.strip()}") + except FileNotFoundError: + logging.debug("ROCm tools not installed (requires rocminfo)") + except subprocess.TimeoutExpired: + logging.debug("GPU query timeout (5s)") + except Exception as e: + logging.debug(f"GPU detection error: {str(e)}") + + return "" + + +def is_trait_combination_valid( + pipeline: str, epilogue: str, scheduler: str, reduction_strategy: str +) -> bool: + """Check if a trait combination is valid.""" + return ( + pipeline, + epilogue, + scheduler, + reduction_strategy, + ) not in TRAIT_UNSUPPORTED_COMBINATIONS + + +def validate_warp_configuration(warp_m: int, warp_n: int, warp_k: int) -> bool: + """Validate warp configuration.""" + return (warp_m, warp_n, warp_k) in [(1, 4, 1), (2, 2, 1), (4, 1, 1)] + + +def validate_dimension_alignment( + tile_m: int, + tile_n: int, + tile_k: int, + warp_m: int, + warp_n: int, + warp_k: int, + warp_tile_m: int, + warp_tile_n: int, + warp_tile_k: int, +) -> Tuple[bool, List[str]]: + """Check if tile dimensions are properly aligned with warp dimensions.""" + alignment_issues = [] + + if tile_m % (warp_m * warp_tile_m) != 0: + alignment_issues.append( + f"tile_m({tile_m}) % [{warp_m}x{warp_tile_m}] = {tile_m % (warp_m * warp_tile_m)}" + ) + if tile_n % (warp_n * warp_tile_n) != 0: + alignment_issues.append( + f"tile_n({tile_n}) % [{warp_n}x{warp_tile_n}] = {tile_n % (warp_n * warp_tile_n)}" + ) + if tile_k % (warp_k * warp_tile_k) != 0: + alignment_issues.append( + f"tile_k({tile_k}) % [{warp_k}x{warp_tile_k}] = {tile_k % (warp_k * warp_tile_k)}" + ) + + return len(alignment_issues) == 0, alignment_issues + + +def validate_lds_capacity( + tile_m: int, + tile_n: int, + tile_k: int, + a_datatype: str, + b_datatype: str, + pipeline: str, +) -> Tuple[bool, str]: + """Validate LDS capacity requirements.""" + matrix_a_size = (tile_m * tile_k) * element_size(a_datatype) + matrix_b_size = (tile_n * tile_k) * element_size(b_datatype) + total_tile_in_lds = matrix_a_size + matrix_b_size + + max_tile_size = 2**15 if pipeline == "compv4" else 2**16 + + if total_tile_in_lds > max_tile_size: + error_msg = ( + f"LDS capacity exceeded: Total required {total_tile_in_lds:,}B ({total_tile_in_lds / 1024:.1f}KB) > " + f"maximum allowed {max_tile_size:,}B ({max_tile_size / 1024}KB). Breakdown:\n" + f"- Matrix A ({a_datatype}): {tile_m}x{tile_k} = {matrix_a_size:,}B\n" + f"- Matrix B ({b_datatype}): {tile_n}x{tile_k} = {matrix_b_size:,}B" + ) + return False, error_msg + + return True, "" + + +def validate_warp_tile_combination( + warp_tile_m: int, + warp_tile_n: int, + warp_tile_k: int, + a_datatype: str, + b_datatype: str, + c_datatype: str, + gpu_name: str = None, +) -> Tuple[bool, str]: + """Validate warp tile combination against GPU-specific supported combinations.""" + if gpu_name is None: + gpu_name = get_gpu_name_by_id(0) + + # Construct the key for looking up supported combinations + warp_tile_key = f"{a_datatype}_{b_datatype}_{c_datatype}" + current_combination = [warp_tile_m, warp_tile_n, warp_tile_k] + + # Check if we have GPU-specific combinations + gpu_warp_tile_combinations = WARP_TILE_SUPPORTED_COMBINATIONS.get(gpu_name, {}) + if not gpu_warp_tile_combinations: + # If GPU not recognized, try to be permissive but log warning + logging.warning(f"No warp tile combinations found for GPU: {gpu_name}") + return True, "" + + # Check if we have combinations for this data type combination + allowed_combinations = gpu_warp_tile_combinations.get(warp_tile_key, []) + if not allowed_combinations: + # For data type combinations not in the list, be permissive + logging.debug( + f"No warp tile combinations found for data types: {warp_tile_key}" + ) + return True, "" + + # Check if current combination is in the allowed list + if current_combination not in allowed_combinations: + error_msg = ( + f"Invalid warp tile combination: {current_combination} not in allowed list. " + f"Valid combinations for '{warp_tile_key}' on {gpu_name}: {allowed_combinations}" + ) + return False, error_msg + + return True, "" + + +def is_tile_config_valid( + tile_m: int, + tile_n: int, + tile_k: int, + warp_m: int, + warp_n: int, + warp_k: int, + warp_tile_m: int, + warp_tile_n: int, + warp_tile_k: int, + a_datatype: str, + b_datatype: str, + c_datatype: str, + pipeline: str, + trait_name: str = None, +) -> bool: + """ + Comprehensive tile configuration validation. + Returns True if configuration is valid, False otherwise. + """ + # Basic sanity checks + if tile_m <= 0 or tile_n <= 0 or tile_k <= 0: + return False + if warp_m <= 0 or warp_n <= 0 or warp_k <= 0: + return False + if warp_tile_m <= 0 or warp_tile_n <= 0 or warp_tile_k <= 0: + return False + + # Check that warp tiles fit within block tiles + if warp_m * warp_tile_m > tile_m: + return False + if warp_n * warp_tile_n > tile_n: + return False + if warp_k * warp_tile_k > tile_k: + return False + + # Validate warp configuration + if not validate_warp_configuration(warp_m, warp_n, warp_k): + logging.debug( + f"Invalid warp configuration: warp_m({warp_m}), warp_n({warp_n}), warp_k({warp_k})" + ) + return False + + # Validate dimension alignment + is_aligned, alignment_issues = validate_dimension_alignment( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_k, + warp_tile_m, + warp_tile_n, + warp_tile_k, + ) + if not is_aligned: + logging.debug( + f"Dimension alignment failed: {', '.join(alignment_issues)}. " + f"Tile dimensions {tile_m}x{tile_n}x{tile_k} must be divisible by " + f"[warp]: {warp_m}x{warp_n}x{warp_k} x [warp_tile]: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}" + ) + return False + + # Validate LDS capacity + lds_valid, lds_error = validate_lds_capacity( + tile_m, tile_n, tile_k, a_datatype, b_datatype, pipeline + ) + if not lds_valid: + logging.debug(f"LDS validation failed: {lds_error}") + return False + + # Validate warp tile combination + warp_tile_valid, warp_tile_error = validate_warp_tile_combination( + warp_tile_m, warp_tile_n, warp_tile_k, a_datatype, b_datatype, c_datatype + ) + if not warp_tile_valid: + logging.debug(f"Warp tile validation failed: {warp_tile_error}") + return False + + return True diff --git a/tutorial/CMakeLists.txt b/tutorial/CMakeLists.txt index a2f35ca53f..9be59919ad 100644 --- a/tutorial/CMakeLists.txt +++ b/tutorial/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + include_directories(BEFORE ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/library/include diff --git a/tutorial/ck_tile/00_copy_kernel/CMakeLists.txt b/tutorial/ck_tile/00_copy_kernel/CMakeLists.txt index 91dd036eff..2374eae5c9 100644 --- a/tutorial/ck_tile/00_copy_kernel/CMakeLists.txt +++ b/tutorial/ck_tile/00_copy_kernel/CMakeLists.txt @@ -1,4 +1,7 @@ -add_executable(tile_tutorial_copy_kernel EXCLUDE_FROM_ALL copy_basic.cpp) +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +add_executable(tile_tutorial_copy_kernel copy_basic.cpp) # Impact: This flag ensures that the compiler doesn't make # assumptions about memory aliasing that could interfere with Composable Kernel's explicit memory access patterns. diff --git a/tutorial/ck_tile/01_naive_gemm/CMakeLists.txt b/tutorial/ck_tile/01_naive_gemm/CMakeLists.txt index e16977921a..1954980532 100644 --- a/tutorial/ck_tile/01_naive_gemm/CMakeLists.txt +++ b/tutorial/ck_tile/01_naive_gemm/CMakeLists.txt @@ -1,4 +1,7 @@ -add_executable(tile_tutorial_naive_gemm EXCLUDE_FROM_ALL practice_gemm.cpp) +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +add_executable(tile_tutorial_naive_gemm practice_gemm.cpp) target_compile_options(tile_tutorial_naive_gemm PRIVATE -mllvm -enable-noalias-to-md-conversion=0 diff --git a/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp b/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp index 9bb1961cce..dd72f08d99 100644 --- a/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp +++ b/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp @@ -28,9 +28,9 @@ struct PracticeGemmHostPipeline { // Size of the entire problem - const auto M = a_dram.get_tensor_descriptor().get_length(number<0>{}); // M x K - const auto N = c_dram.get_tensor_descriptor().get_length(number<1>{}); // M x N - const auto K = a_dram.get_tensor_descriptor().get_length(number<1>{}); // M x K + const auto M = a_dram.get_tensor_descriptor().get_length(number<0>{}); // M x K + const auto N = c_dram_ref.get_tensor_descriptor().get_length(number<1>{}); // M x N + const auto K = a_dram.get_tensor_descriptor().get_length(number<1>{}); // M x K // Size of the block tile const auto MPerBlock = BlockTile::at(number<0>{}); @@ -83,7 +83,7 @@ struct PracticeGemmHostPipeline __shared__ char p_smem_char[block_gemm_pipeline.GetStaticLDSSize()]; const auto c_block_tile = block_gemm_pipeline(a_block_window, b_block_window, num_loops_k, p_smem_char); - auto c_window = make_tile_window(c_dram, + auto c_window = make_tile_window(c_dram_ref, make_tuple(number{}, number{}), {tile_origin_m, tile_origin_n}); store_tile(c_window, c_block_tile); diff --git a/tutorial/ck_tile/CMakeLists.txt b/tutorial/ck_tile/CMakeLists.txt index 9895f5a71d..f9073acffc 100644 --- a/tutorial/ck_tile/CMakeLists.txt +++ b/tutorial/ck_tile/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + include_directories(AFTER ${CMAKE_CURRENT_LIST_DIR} )