mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 08:15:04 +00:00
Merge remote-tracking branch 'origin/develop' into ck_tile/fa_train
This commit is contained in:
9
.github/CODEOWNERS
vendored
9
.github/CODEOWNERS
vendored
@@ -1,7 +1,8 @@
|
||||
* @zjing14 @junliume @illsilin @carlushuang @aosewski @yigex
|
||||
# Documentation files
|
||||
docs/* @ROCm/rocm-documentation
|
||||
*.md @ROCm/rocm-documentation
|
||||
*.rst @ROCm/rocm-documentation
|
||||
docs/* @ROCm/rocm-documentation @zjing14 @junliume @illsilin @carlushuang @aosewski @yigex
|
||||
*.md @ROCm/rocm-documentation @zjing14 @junliume @illsilin @carlushuang @aosewski @yigex
|
||||
*.rst @ROCm/rocm-documentation @zjing14 @junliume @illsilin @carlushuang @aosewski @yigex
|
||||
.readthedocs.yaml @ROCm/rocm-documentation @zjing14 @junliume @illsilin @carlushuang @aosewski @yigex
|
||||
# Header directory for Doxygen documentation
|
||||
library/include/* @ROCm/rocm-documentation
|
||||
library/include/* @ROCm/rocm-documentation @zjing14 @junliume @illsilin @carlushuang @aosewski @yigex
|
||||
|
||||
@@ -15,4 +15,4 @@ python:
|
||||
build:
|
||||
os: ubuntu-22.04
|
||||
tools:
|
||||
python: "3.8"
|
||||
python: "3.10"
|
||||
|
||||
@@ -26,7 +26,7 @@ set(version 1.1.0)
|
||||
project(composable_kernel VERSION ${version} LANGUAGES CXX)
|
||||
include(CTest)
|
||||
|
||||
find_package(Python3 3.8 COMPONENTS Interpreter REQUIRED)
|
||||
find_package(Python3 3.6 COMPONENTS Interpreter REQUIRED)
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
|
||||
|
||||
@@ -202,7 +202,7 @@ endif()
|
||||
|
||||
|
||||
option(USE_BITINT_EXTENSION_INT4 "Whether to enable clang's BitInt extension to provide int4 data type." OFF)
|
||||
option(USE_OPT_NAVI3X "Whether to enable LDS cumode and Wavefront32 mode for NAVI3X silicons." OFF)
|
||||
option(USE_OPT_GFX11 "Whether to enable LDS cumode and Wavefront32 mode for GFX11 silicons." OFF)
|
||||
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_compile_definitions(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
|
||||
@@ -210,10 +210,10 @@ if(USE_BITINT_EXTENSION_INT4)
|
||||
message("CK compiled with USE_BITINT_EXTENSION_INT4 set to ${USE_BITINT_EXTENSION_INT4}")
|
||||
endif()
|
||||
|
||||
if(USE_OPT_NAVI3X)
|
||||
if(USE_OPT_GFX11)
|
||||
add_compile_options(-mcumode)
|
||||
add_compile_options(-mno-wavefrontsize64)
|
||||
message("CK compiled with USE_OPT_NAVI3X set to ${USE_OPT_NAVI3X}")
|
||||
message("CK compiled with USE_OPT_GFX11 set to ${USE_OPT_GFX11}")
|
||||
endif()
|
||||
|
||||
## Threads
|
||||
|
||||
128
Jenkinsfile
vendored
128
Jenkinsfile
vendored
@@ -515,38 +515,33 @@ def Build_CK(Map conf=[:]){
|
||||
withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') {
|
||||
timeout(time: 24, unit: 'HOURS')
|
||||
{
|
||||
//check whether running on Navi or MI300 node
|
||||
def navi_node = 0
|
||||
def mi300_node = 0
|
||||
//check whether to run performance tests on this node
|
||||
def do_perf_tests = 0
|
||||
sh 'rocminfo | tee rocminfo.log'
|
||||
if ( runShell('grep -n "gfx1030" rocminfo.log') || runShell('grep -n "gfx1101" rocminfo.log') ){
|
||||
navi_node = 1
|
||||
echo "This is a Navi node"
|
||||
}
|
||||
if ( runShell('grep -n "gfx942" rocminfo.log') ){
|
||||
mi300_node = 1
|
||||
echo "This is MI300 node"
|
||||
if ( runShell('grep -n "gfx1030" rocminfo.log') || runShell('grep -n "gfx1101" rocminfo.log') || runShell('grep -n "gfx942" rocminfo.log') ){
|
||||
do_perf_tests = 1
|
||||
echo "Stash profiler and run performance tests"
|
||||
}
|
||||
cmake_build(conf)
|
||||
dir("build"){
|
||||
//run tests and examples
|
||||
sh 'make -j check'
|
||||
if (params.RUN_PERFORMANCE_TESTS && navi_node == 0 && mi300_node == 0 ){
|
||||
if (params.RUN_PERFORMANCE_TESTS && do_perf_tests == 0 ){
|
||||
//we only need the ckProfiler to run the performance tests, so we pack and stash it
|
||||
//do not stash profiler on Navi or MI300 nodes
|
||||
sh 'tar -zcvf ckProfiler.tar.gz bin/ckProfiler'
|
||||
stash name: "ckProfiler.tar.gz"
|
||||
//do not stash profiler on nodes where we don't need to run performance tests
|
||||
sh 'tar -zcvf ckProfiler.tar.gz bin/ckProfiler'
|
||||
stash name: "ckProfiler.tar.gz"
|
||||
}
|
||||
if (params.RUN_FULL_QA && mi300_node == 0 ){
|
||||
// build deb packages for all MI100/200/300 targets and prepare to export
|
||||
sh 'make -j package'
|
||||
archiveArtifacts artifacts: 'composablekernel-ckprofiler_*.deb'
|
||||
archiveArtifacts artifacts: 'composablekernel-tests_*.deb'
|
||||
sh 'mv composablekernel-ckprofiler_*.deb ckprofiler_0.2.0_amd64.deb'
|
||||
stash name: "ckprofiler_0.2.0_amd64.deb"
|
||||
if (params.RUN_FULL_QA && do_perf_tests == 0 ){
|
||||
// build deb packages for all gfx9 targets and prepare to export
|
||||
sh 'make -j package'
|
||||
archiveArtifacts artifacts: 'composablekernel-ckprofiler_*.deb'
|
||||
archiveArtifacts artifacts: 'composablekernel-tests_*.deb'
|
||||
sh 'mv composablekernel-ckprofiler_*.deb ckprofiler_0.2.0_amd64.deb'
|
||||
stash name: "ckprofiler_0.2.0_amd64.deb"
|
||||
}
|
||||
}
|
||||
if (params.hipTensor_test && navi_node == 0 ){
|
||||
if (params.hipTensor_test && do_perf_tests == 0 ){
|
||||
//build and test hipTensor
|
||||
sh """#!/bin/bash
|
||||
rm -rf "${params.hipTensor_branch}".zip
|
||||
@@ -660,7 +655,8 @@ def process_results(Map conf=[:]){
|
||||
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=6.1;COMPILER_VERSION=
|
||||
0 21 * * * % ROCMVERSION=6.1;COMPILER_VERSION=;COMPILER_COMMIT=
|
||||
0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-staging;COMPILER_COMMIT=;USE_SCCACHE=false
|
||||
0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;COMPILER_COMMIT=;USE_SCCACHE=false''' : ""
|
||||
0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;COMPILER_COMMIT=;USE_SCCACHE=false
|
||||
0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_CODEGEN_TESTS=false;RUN_PERFORMANCE_TESTS=false;USE_SCCACHE=false''' : ""
|
||||
|
||||
pipeline {
|
||||
agent none
|
||||
@@ -727,6 +723,10 @@ pipeline {
|
||||
name: "RUN_CODEGEN_TESTS",
|
||||
defaultValue: true,
|
||||
description: "Run the codegen tests (default: ON)")
|
||||
booleanParam(
|
||||
name: "BUILD_INSTANCES_ONLY",
|
||||
defaultValue: false,
|
||||
description: "Test building instances for various architectures simultaneously (default: OFF)")
|
||||
}
|
||||
environment{
|
||||
dbuser = "${dbuser}"
|
||||
@@ -809,22 +809,22 @@ pipeline {
|
||||
{
|
||||
parallel
|
||||
{
|
||||
stage("Run Codegen Tests on MI100/MI200")
|
||||
stage("Run Codegen Tests on gfx90a")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { params.RUN_CODEGEN_TESTS.toBoolean() }
|
||||
}
|
||||
options { retry(2) }
|
||||
agent{ label rocmnode("gfx908 || gfx90a")}
|
||||
agent{ label rocmnode("gfx90a")}
|
||||
environment{
|
||||
setup_args = "NO_CK_BUILD"
|
||||
execute_args = """ cd ../codegen && rm -rf build && mkdir build && cd build && \
|
||||
cmake -D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-D CMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D GPU_TARGETS="gfx908;gfx90a" \
|
||||
-DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " .. && make -j check"""
|
||||
-D GPU_TARGETS="gfx90a" \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j check"""
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
@@ -837,30 +837,30 @@ pipeline {
|
||||
{
|
||||
parallel
|
||||
{
|
||||
stage("Build CK and run Tests on MI100/MI200/MI300")
|
||||
stage("Build CK for all gfx9 targets")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { params.RUN_FULL_QA.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx908 || gfx90a") }
|
||||
agent{ label rocmnode("gfx90a") }
|
||||
environment{
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install \
|
||||
-DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" \
|
||||
-DCMAKE_EXE_LINKER_FLAGS=" -L ${env.WORKSPACE}/script -T hip_fatbin_insert " \
|
||||
-DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " """
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " """
|
||||
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
|
||||
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
|
||||
-DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" \
|
||||
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " .. && make -j """
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
|
||||
}
|
||||
steps{
|
||||
Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
stage("Build CK and run Tests on MI300")
|
||||
stage("Build CK and run Tests on gfx942")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
@@ -868,45 +868,65 @@ pipeline {
|
||||
}
|
||||
agent{ label rocmnode("gfx942") }
|
||||
environment{
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx942" -DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " """
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx942" -DCMAKE_CXX_FLAGS=" -O3 " """
|
||||
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
|
||||
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
|
||||
-DGPU_TARGETS="gfx942" \
|
||||
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " .. && make -j """
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
|
||||
}
|
||||
steps{
|
||||
Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
stage("Build CK and run Tests on MI100/MI200")
|
||||
stage("Build CK and run Tests on gfx90a")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { !params.RUN_FULL_QA.toBoolean() }
|
||||
expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx908 || gfx90a") }
|
||||
agent{ label rocmnode("gfx90a") }
|
||||
environment{
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a" -DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " """
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a" -DCMAKE_CXX_FLAGS=" -O3 " """
|
||||
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
|
||||
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
|
||||
-DGPU_TARGETS="gfx908;gfx90a" \
|
||||
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " .. && make -j """
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
|
||||
}
|
||||
steps{
|
||||
Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
stage("Build CK and run Tests on Navi21")
|
||||
stage("Build CK instances for different targets")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { !params.RUN_FULL_QA.toBoolean() }
|
||||
expression { params.BUILD_INSTANCES_ONLY.toBoolean() && !params.RUN_FULL_QA.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("navi21") }
|
||||
agent{ label rocmnode("gfx90a") }
|
||||
environment{
|
||||
execute_args = """ cmake -D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-D CMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D GPU_TARGETS="gfx90a;gfx1030;gfx1101" \
|
||||
-D INSTANCES_ONLY=ON \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j32 """
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
stage("Build CK and run Tests on gfx1030")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx1030") }
|
||||
environment{
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1030" -DDL_KERNELS=ON -DCMAKE_CXX_FLAGS=" -O3 " """
|
||||
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
|
||||
@@ -920,13 +940,13 @@ pipeline {
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
stage("Build CK and run Tests on Navi32")
|
||||
stage("Build CK and run Tests on gfx1101")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { !params.RUN_FULL_QA.toBoolean() }
|
||||
expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("navi32") }
|
||||
agent{ label rocmnode("gfx1101") }
|
||||
environment{
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1101" -DDL_KERNELS=ON -DCMAKE_CXX_FLAGS=" -O3 " """
|
||||
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
|
||||
@@ -947,27 +967,11 @@ pipeline {
|
||||
{
|
||||
parallel
|
||||
{
|
||||
stage("Run ckProfiler: gfx90*")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { !params.RUN_FULL_QA.toBoolean() && params.RUN_PERFORMANCE_TESTS.toBoolean() }
|
||||
}
|
||||
options { retry(2) }
|
||||
agent{ label rocmnode("gfx908 || gfx90a")}
|
||||
environment{
|
||||
setup_args = """ -DGPU_TARGETS="gfx908;gfx90a" -DBUILD_DEV=On """
|
||||
}
|
||||
steps{
|
||||
runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release')
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
stage("Run ckProfiler: gfx90a")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { params.RUN_FULL_QA.toBoolean() && params.RUN_PERFORMANCE_TESTS.toBoolean() }
|
||||
expression { params.RUN_PERFORMANCE_TESTS.toBoolean() }
|
||||
}
|
||||
options { retry(2) }
|
||||
agent{ label rocmnode("gfx90a")}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iomanip>
|
||||
@@ -160,6 +160,10 @@ bool run_grouped_conv_bwd_weight(
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
|
||||
SimpleDeviceMem workspace_dev(workspace_sz);
|
||||
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
|
||||
|
||||
@@ -181,4 +181,3 @@ int main(int argc, char* argv[])
|
||||
{1, 1, 1} /*filter_dilations*/);
|
||||
return 0;
|
||||
}
|
||||
// MI100 Perf: 0.255178 ms, 1698.9 GB/s,
|
||||
|
||||
@@ -10,4 +10,7 @@ if(GPU_TARGETS MATCHES "gfx9" AND ((DTYPES MATCHES "int8" AND DTYPES MATCHES "bf
|
||||
|
||||
add_executable(client_gemm_bf16_i8_bf16 gemm_xdl_bf16_i8.cpp)
|
||||
target_link_libraries(client_gemm_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_gemm_multiply_bf16_i8_bf16 gemm_xdl_multiply_bf16_i8.cpp)
|
||||
target_link_libraries(client_gemm_multiply_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations)
|
||||
endif()
|
||||
@@ -38,19 +38,19 @@ using EDataType = BF16;
|
||||
|
||||
using A0Layout = Row;
|
||||
using AsLayout = ck::Tuple<A0Layout>;
|
||||
using B0Layout = Col;
|
||||
using B0Layout = Row;
|
||||
using B1Layout = B0Layout;
|
||||
using BsLayout = ck::Tuple<B0Layout, B1Layout>;
|
||||
using D0Layout = Row;
|
||||
using DsLayout = ck::Tuple<D0Layout>;
|
||||
using ELayout = Row;
|
||||
|
||||
using Scales = ck::tensor_operation::element_wise::Scales;
|
||||
using Multiply = ck::tensor_operation::element_wise::Multiply;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = Scales;
|
||||
using BElementOp = Multiply;
|
||||
using CDEElementOp = AddFastGelu;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
@@ -36,7 +36,7 @@ using D0DataType = BF16;
|
||||
using DsDataType = ck::Tuple<D0DataType>;
|
||||
using EDataType = BF16;
|
||||
|
||||
using A0Layout = Col;
|
||||
using A0Layout = Row;
|
||||
using AsLayout = ck::Tuple<A0Layout>;
|
||||
using B0Layout = Row;
|
||||
using B1Layout = B0Layout;
|
||||
@@ -45,12 +45,12 @@ using D0Layout = Row;
|
||||
using DsLayout = ck::Tuple<D0Layout>;
|
||||
using ELayout = Row;
|
||||
|
||||
using Scales = ck::tensor_operation::element_wise::Scales;
|
||||
using Multiply = ck::tensor_operation::element_wise::Multiply;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Add = ck::tensor_operation::element_wise::Add;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = Scales;
|
||||
using BElementOp = Multiply;
|
||||
using CDEElementOp = Add;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
@@ -37,19 +37,19 @@ using EDataType = BF16;
|
||||
|
||||
using A0Layout = Row;
|
||||
using AsLayout = ck::Tuple<A0Layout>;
|
||||
using B0Layout = Col;
|
||||
using B0Layout = Row;
|
||||
using B1Layout = B0Layout;
|
||||
using BsLayout = ck::Tuple<B0Layout, B1Layout>;
|
||||
using D0Layout = Row;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = Row;
|
||||
|
||||
using Scales = ck::tensor_operation::element_wise::Scales;
|
||||
using Multiply = ck::tensor_operation::element_wise::Multiply;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Add = ck::tensor_operation::element_wise::Add;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = Scales;
|
||||
using BElementOp = Multiply;
|
||||
using CDEElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
@@ -74,12 +74,12 @@ struct SimpleDeviceMem
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
// GEMM shape
|
||||
ck::index_t M = 64;
|
||||
ck::index_t N = 1024;
|
||||
ck::index_t K = 512;
|
||||
ck::index_t M = 4096;
|
||||
ck::index_t N = 768;
|
||||
ck::index_t K = 6144;
|
||||
|
||||
ck::index_t StrideA = K;
|
||||
ck::index_t StrideB = N;
|
||||
ck::index_t StrideB = K;
|
||||
ck::index_t StrideE = N;
|
||||
|
||||
if(argc == 1)
|
||||
@@ -37,19 +37,19 @@ using EDataType = BF16;
|
||||
|
||||
using A0Layout = Row;
|
||||
using AsLayout = ck::Tuple<A0Layout>;
|
||||
using B0Layout = Col;
|
||||
using B0Layout = Row;
|
||||
using B1Layout = B0Layout;
|
||||
using BsLayout = ck::Tuple<B0Layout, B1Layout>;
|
||||
using D0Layout = Row;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = Row;
|
||||
|
||||
using Scales = ck::tensor_operation::element_wise::Scales;
|
||||
using Multiply = ck::tensor_operation::element_wise::Multiply;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = Scales;
|
||||
using BElementOp = Multiply;
|
||||
using CDEElementOp = FastGelu;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
220
client_example/30_gemm_bf16Aint8B/gemm_xdl_multiply_bf16_i8.cpp
Normal file
220
client_example/30_gemm_bf16Aint8B/gemm_xdl_multiply_bf16_i8.cpp
Normal file
@@ -0,0 +1,220 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <iomanip>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/gemm_multi_abd.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using I8 = int8_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using A0DataType = BF16;
|
||||
using AsDataType = ck::Tuple<A0DataType>;
|
||||
using B0DataType = I8;
|
||||
using B1DataType = BF16;
|
||||
using BsDataType = ck::Tuple<B0DataType>;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = BF16;
|
||||
using DsDataType = ck::Tuple<B1DataType>;
|
||||
using EDataType = BF16;
|
||||
|
||||
using A0Layout = Row;
|
||||
using AsLayout = ck::Tuple<A0Layout>;
|
||||
using B0Layout = Row;
|
||||
using B1Layout = B0Layout;
|
||||
using BsLayout = ck::Tuple<B0Layout>;
|
||||
using D0Layout = Row;
|
||||
using DsLayout = ck::Tuple<B1Layout>;
|
||||
using ELayout = Row;
|
||||
|
||||
using Multiply = ck::tensor_operation::element_wise::Multiply;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = Multiply;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
struct SimpleDeviceMem
|
||||
{
|
||||
SimpleDeviceMem() = delete;
|
||||
|
||||
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
|
||||
{
|
||||
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
|
||||
}
|
||||
|
||||
void* GetDeviceBuffer() { return p_mem_; }
|
||||
|
||||
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
|
||||
|
||||
void* p_mem_;
|
||||
};
|
||||
|
||||
// clang-format on
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
// GEMM shape
|
||||
ck::index_t M = 4096;
|
||||
ck::index_t N = 768;
|
||||
ck::index_t K = 6144;
|
||||
|
||||
ck::index_t StrideA = K;
|
||||
ck::index_t StrideB = K;
|
||||
ck::index_t StrideE = N;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 7)
|
||||
{
|
||||
M = std::stoi(argv[1]);
|
||||
N = std::stoi(argv[2]);
|
||||
K = std::stoi(argv[3]);
|
||||
|
||||
StrideA = std::stoi(argv[4]);
|
||||
StrideB = std::stoi(argv[5]);
|
||||
StrideE = std::stoi(argv[6]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1 to 7: M, N, K, StrideA, StrideB, StrideE\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
auto f_matrix_space_size =
|
||||
[](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
|
||||
using Layout = decltype(layout);
|
||||
|
||||
if constexpr(std::is_same<Layout, Row>::value)
|
||||
{
|
||||
return (nRow - 1) * stride + nCol;
|
||||
}
|
||||
else
|
||||
{
|
||||
return (nCol - 1) * stride + nRow;
|
||||
}
|
||||
};
|
||||
|
||||
SimpleDeviceMem a0_device_buf(sizeof(A0DataType) *
|
||||
f_matrix_space_size(M, K, StrideA, A0Layout{}));
|
||||
SimpleDeviceMem b0_device_buf(sizeof(B0DataType) *
|
||||
f_matrix_space_size(K, N, StrideB, B0Layout{}));
|
||||
SimpleDeviceMem b1_device_buf(sizeof(B1DataType) * f_matrix_space_size(K, N, 0, B1Layout{}));
|
||||
SimpleDeviceMem e_device_buf(sizeof(EDataType) * f_matrix_space_size(M, N, StrideE, ELayout{}));
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
constexpr ck::index_t NumATensor = 1;
|
||||
constexpr ck::index_t NumBTensor = 1;
|
||||
constexpr ck::index_t NumDTensor = 1;
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleABD<AsLayout,
|
||||
BsLayout,
|
||||
DsLayout,
|
||||
Row,
|
||||
AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
BF16,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
std::string best_op_name;
|
||||
bool found = false;
|
||||
int best_op_id = -1;
|
||||
float best_ave_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
|
||||
// profile device operation instances
|
||||
std::cout << "Run all instances and do timing" << std::endl;
|
||||
|
||||
for(int i = 0; i < op_ptrs.size(); ++i)
|
||||
{
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(
|
||||
std::array<const void*, NumATensor>{a0_device_buf.GetDeviceBuffer()},
|
||||
std::array<const void*, NumBTensor>{b0_device_buf.GetDeviceBuffer()},
|
||||
std::array<const void*, NumDTensor>{b1_device_buf.GetDeviceBuffer()},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
std::array<ck::index_t, NumATensor>{StrideA},
|
||||
std::array<ck::index_t, NumBTensor>{StrideB},
|
||||
std::array<ck::index_t, NumDTensor>{0},
|
||||
StrideE,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
|
||||
std::size_t num_btype =
|
||||
sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
found = true;
|
||||
best_op_id = i;
|
||||
best_op_name = op_name;
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << op_name << " does not support this problem" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
|
||||
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
16
client_example/31_grouped_gemm_bf16Aint8B/CMakeLists.txt
Normal file
16
client_example/31_grouped_gemm_bf16Aint8B/CMakeLists.txt
Normal file
@@ -0,0 +1,16 @@
|
||||
if(GPU_TARGETS MATCHES "gfx9" AND ((DTYPES MATCHES "int8" AND DTYPES MATCHES "bf16") OR NOT DEFINED DTYPES))
|
||||
add_executable(client_grouped_gemm_bias_fastgelu_bf16_i8_bf16 grouped_gemm_bias_fastgelu_xdl_bf16_i8.cpp)
|
||||
target_link_libraries(client_grouped_gemm_bias_fastgelu_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_grouped_gemm_fastgelu_bf16_i8_bf16 grouped_gemm_fastgelu_xdl_bf16_i8.cpp)
|
||||
target_link_libraries(client_grouped_gemm_fastgelu_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_grouped_gemm_multiply_bf16_i8_bf16 grouped_gemm_multiply_xdl_bf16_i8.cpp)
|
||||
target_link_libraries(client_grouped_gemm_multiply_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_grouped_gemm_multiply_bias_fastgelu_bf16_i8_bf16 grouped_gemm_multiply_bias_fastgelu_xdl_bf16_i8.cpp)
|
||||
target_link_libraries(client_grouped_gemm_multiply_bias_fastgelu_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_grouped_gemm_bf16_i8_bf16 grouped_gemm_xdl_bf16_i8.cpp)
|
||||
target_link_libraries(client_grouped_gemm_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations)
|
||||
endif()
|
||||
@@ -38,19 +38,19 @@ using EDataType = BF16;
|
||||
|
||||
using A0Layout = Row;
|
||||
using AsLayout = ck::Tuple<A0Layout>;
|
||||
using B0Layout = Col;
|
||||
using B0Layout = Row;
|
||||
using B1Layout = B0Layout;
|
||||
using BsLayout = ck::Tuple<B0Layout, B1Layout>;
|
||||
using D0Layout = Row;
|
||||
using DsLayout = ck::Tuple<D0Layout>;
|
||||
using ELayout = Row;
|
||||
|
||||
using Scales = ck::tensor_operation::element_wise::Scales;
|
||||
using Multiply = ck::tensor_operation::element_wise::Multiply;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = Scales;
|
||||
using BElementOp = Multiply;
|
||||
using CDEElementOp = AddFastGelu;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
@@ -15,6 +15,8 @@
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp"
|
||||
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
@@ -36,7 +38,7 @@ using D0DataType = BF16;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using EDataType = BF16;
|
||||
|
||||
using A0Layout = Col;
|
||||
using A0Layout = Row;
|
||||
using AsLayout = ck::Tuple<A0Layout>;
|
||||
using B0Layout = Row;
|
||||
using B1Layout = B0Layout;
|
||||
@@ -45,12 +47,12 @@ using D0Layout = Row;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = Row;
|
||||
|
||||
using Scales = ck::tensor_operation::element_wise::Scales;
|
||||
using Multiply = ck::tensor_operation::element_wise::Multiply;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = Scales;
|
||||
using BElementOp = Multiply;
|
||||
using CDEElementOp = FastGelu;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
@@ -0,0 +1,286 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <iomanip>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multply.hpp"
|
||||
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using I8 = int8_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using A0DataType = BF16;
|
||||
using B0DataType = I8;
|
||||
using B1DataType = BF16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using D0DataType = BF16;
|
||||
using DsDataType = ck::Tuple<B1DataType, D0DataType>;
|
||||
using EDataType = BF16;
|
||||
|
||||
using A0Layout = Row;
|
||||
using B0Layout = Row;
|
||||
using B1Layout = B0Layout;
|
||||
using D0Layout = Row;
|
||||
using DsLayout = ck::Tuple<B0Layout, D0Layout>;
|
||||
using ELayout = Row;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = MultiplyAddFastGelu;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
struct SimpleDeviceMem
|
||||
{
|
||||
SimpleDeviceMem() = delete;
|
||||
|
||||
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
|
||||
{
|
||||
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
|
||||
}
|
||||
|
||||
void* GetDeviceBuffer() { return p_mem_; }
|
||||
|
||||
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
|
||||
|
||||
void* p_mem_;
|
||||
};
|
||||
|
||||
struct ProblemSize final
|
||||
{
|
||||
std::vector<ck::index_t> Ms;
|
||||
std::vector<ck::index_t> Ns;
|
||||
std::vector<ck::index_t> Ks;
|
||||
|
||||
std::vector<ck::index_t> stride_As;
|
||||
std::vector<ck::index_t> stride_Bs;
|
||||
std::vector<ck::index_t> stride_Cs;
|
||||
|
||||
ck::index_t group_count;
|
||||
};
|
||||
|
||||
struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
int k_batch = 1;
|
||||
};
|
||||
|
||||
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
auto group_count = problem_size.group_count;
|
||||
|
||||
// GEMM shape
|
||||
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
|
||||
|
||||
gemm_descs.reserve(group_count);
|
||||
|
||||
int sum_of_m = 0;
|
||||
|
||||
using DeviceMemPtr = std::unique_ptr<SimpleDeviceMem>;
|
||||
|
||||
std::vector<DeviceMemPtr> a0_tensors_device, b0_tensors_device, b1_tensors_device,
|
||||
d0_tensors_device, c_tensors_device;
|
||||
|
||||
a0_tensors_device.reserve(group_count);
|
||||
b0_tensors_device.reserve(group_count);
|
||||
b1_tensors_device.reserve(group_count);
|
||||
d0_tensors_device.reserve(group_count);
|
||||
c_tensors_device.reserve(group_count);
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
sum_of_m += problem_size.Ms[i];
|
||||
}
|
||||
|
||||
constexpr ck::index_t NumDTensor = 2;
|
||||
|
||||
using GroupedGemmKernelArgument =
|
||||
ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments<NumDTensor>;
|
||||
|
||||
std::vector<GroupedGemmKernelArgument> grouped_gemm_kernel_args_;
|
||||
grouped_gemm_kernel_args_.reserve(group_count);
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
a0_tensors_device.emplace_back(std::make_unique<SimpleDeviceMem>(
|
||||
sizeof(A0DataType) * problem_size.Ms[i] * problem_size.Ks[i]));
|
||||
|
||||
b0_tensors_device.emplace_back(std::make_unique<SimpleDeviceMem>(
|
||||
sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i]));
|
||||
|
||||
b1_tensors_device.emplace_back(
|
||||
std::make_unique<SimpleDeviceMem>(sizeof(B1DataType) * problem_size.Ns[i]));
|
||||
|
||||
c_tensors_device.emplace_back(std::make_unique<SimpleDeviceMem>(
|
||||
sizeof(EDataType) * problem_size.Ms[i] * problem_size.Ns[i]));
|
||||
|
||||
d0_tensors_device.emplace_back(
|
||||
std::make_unique<SimpleDeviceMem>(sizeof(D0DataType) * problem_size.Ns[i]));
|
||||
|
||||
gemm_descs.push_back({problem_size.Ms[i],
|
||||
problem_size.Ns[i],
|
||||
problem_size.Ks[i],
|
||||
problem_size.stride_As[i],
|
||||
problem_size.stride_Bs[i],
|
||||
problem_size.stride_Cs[i],
|
||||
{0, 0}});
|
||||
|
||||
grouped_gemm_kernel_args_.push_back(
|
||||
{a0_tensors_device[i]->GetDeviceBuffer(),
|
||||
b0_tensors_device[i]->GetDeviceBuffer(),
|
||||
{b1_tensors_device[i]->GetDeviceBuffer(), d0_tensors_device[i]->GetDeviceBuffer()},
|
||||
c_tensors_device[i]->GetDeviceBuffer(),
|
||||
problem_size.Ms[i],
|
||||
problem_size.Ns[i],
|
||||
problem_size.Ks[i],
|
||||
problem_size.stride_As[i],
|
||||
problem_size.stride_Bs[i],
|
||||
{0, 0},
|
||||
problem_size.stride_Cs[i]});
|
||||
}
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmTileLoop<A0Layout,
|
||||
B0Layout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
A0DataType,
|
||||
B0DataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
std::string best_op_name;
|
||||
bool found = false;
|
||||
int best_op_id = -1;
|
||||
float best_ave_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
|
||||
// profile device operation instances
|
||||
std::cout << "Run all instances and do timing" << std::endl;
|
||||
|
||||
for(int i = 0; i < op_ptrs.size(); ++i)
|
||||
{
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
|
||||
std::vector<const void*> p_As = {};
|
||||
std::vector<const void*> p_Bs = {};
|
||||
std::vector<std::array<const void*, NumDTensor>> p_Ds = {};
|
||||
std::vector<void*> p_Cs = {};
|
||||
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(
|
||||
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op);
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
SimpleDeviceMem gemm_kernel_args_dev(
|
||||
op_ptr->GetDeviceKernelArgSize(argument_ptr.get()));
|
||||
hip_check_error(hipMemcpy(gemm_kernel_args_dev.GetDeviceBuffer(),
|
||||
grouped_gemm_kernel_args_.data(),
|
||||
op_ptr->GetDeviceKernelArgSize(argument_ptr.get()),
|
||||
hipMemcpyHostToDevice));
|
||||
|
||||
op_ptr->SetDeviceKernelArgs(argument_ptr.get(), gemm_kernel_args_dev.GetDeviceBuffer());
|
||||
|
||||
float ave_time =
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true, 0, 20, 50});
|
||||
|
||||
std::size_t flop = std::size_t(2) * sum_of_m * problem_size.Ns[0] * problem_size.Ks[0];
|
||||
|
||||
std::size_t num_btype = sizeof(A0DataType) * sum_of_m * problem_size.Ks[0] +
|
||||
sizeof(B0DataType) * problem_size.Ks[0] * problem_size.Ns[0] +
|
||||
sizeof(EDataType) * sum_of_m * problem_size.Ns[0];
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
found = true;
|
||||
best_op_id = i;
|
||||
best_op_name = op_name;
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << op_name << " does not support this problem" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
|
||||
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ProblemSize problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
problem_size.group_count = 16;
|
||||
|
||||
for(int i = 0; i < problem_size.group_count; i++)
|
||||
{
|
||||
problem_size.Ms.push_back(1 + rand() % 1024);
|
||||
problem_size.Ns.push_back(6144);
|
||||
problem_size.Ks.push_back(4096);
|
||||
|
||||
problem_size.stride_As.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Bs.push_back(problem_size.Ns[i]);
|
||||
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
|
||||
|
||||
std::cout << " M = " << problem_size.Ms[i] << " N = " << problem_size.Ns[i] << " K "
|
||||
<< problem_size.Ks[i] << std::endl;
|
||||
}
|
||||
|
||||
return !run_grouped_gemm(problem_size, config);
|
||||
}
|
||||
@@ -0,0 +1,281 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <iomanip>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multply.hpp"
|
||||
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using I8 = int8_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using A0DataType = BF16;
|
||||
using B0DataType = I8;
|
||||
using B1DataType = BF16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = BF16;
|
||||
using D0DataType = BF16;
|
||||
using DsDataType = ck::Tuple<B1DataType>;
|
||||
using EDataType = BF16;
|
||||
|
||||
using A0Layout = Row;
|
||||
using B0Layout = Row;
|
||||
using B1Layout = B0Layout;
|
||||
using D0Layout = Row;
|
||||
using DsLayout = ck::Tuple<B1Layout>;
|
||||
using ELayout = Row;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Multiply = ck::tensor_operation::element_wise::Multiply;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = Multiply;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
struct SimpleDeviceMem
|
||||
{
|
||||
SimpleDeviceMem() = delete;
|
||||
|
||||
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
|
||||
{
|
||||
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
|
||||
}
|
||||
|
||||
void* GetDeviceBuffer() { return p_mem_; }
|
||||
|
||||
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
|
||||
|
||||
void* p_mem_;
|
||||
};
|
||||
|
||||
struct ProblemSize final
|
||||
{
|
||||
std::vector<ck::index_t> Ms;
|
||||
std::vector<ck::index_t> Ns;
|
||||
std::vector<ck::index_t> Ks;
|
||||
|
||||
std::vector<ck::index_t> stride_As;
|
||||
std::vector<ck::index_t> stride_Bs;
|
||||
std::vector<ck::index_t> stride_Cs;
|
||||
|
||||
ck::index_t group_count;
|
||||
};
|
||||
|
||||
struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
int k_batch = 1;
|
||||
};
|
||||
|
||||
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
auto group_count = problem_size.group_count;
|
||||
|
||||
// GEMM shape
|
||||
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
|
||||
|
||||
gemm_descs.reserve(group_count);
|
||||
|
||||
int sum_of_m = 0;
|
||||
|
||||
using DeviceMemPtr = std::unique_ptr<SimpleDeviceMem>;
|
||||
|
||||
std::vector<DeviceMemPtr> a0_tensors_device, b0_tensors_device, b1_tensors_device,
|
||||
c_tensors_device;
|
||||
|
||||
a0_tensors_device.reserve(group_count);
|
||||
b0_tensors_device.reserve(group_count);
|
||||
b1_tensors_device.reserve(group_count);
|
||||
c_tensors_device.reserve(group_count);
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
sum_of_m += problem_size.Ms[i];
|
||||
}
|
||||
|
||||
constexpr ck::index_t NumDTensor = 1;
|
||||
|
||||
using GroupedGemmKernelArgument =
|
||||
ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments<NumDTensor>;
|
||||
|
||||
std::vector<GroupedGemmKernelArgument> grouped_gemm_kernel_args_;
|
||||
grouped_gemm_kernel_args_.reserve(group_count);
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
a0_tensors_device.emplace_back(std::make_unique<SimpleDeviceMem>(
|
||||
sizeof(A0DataType) * problem_size.Ms[i] * problem_size.Ks[i]));
|
||||
|
||||
b0_tensors_device.emplace_back(std::make_unique<SimpleDeviceMem>(
|
||||
sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i]));
|
||||
|
||||
b1_tensors_device.emplace_back(
|
||||
std::make_unique<SimpleDeviceMem>(sizeof(B1DataType) * problem_size.Ns[i]));
|
||||
|
||||
c_tensors_device.emplace_back(std::make_unique<SimpleDeviceMem>(
|
||||
sizeof(EDataType) * problem_size.Ms[i] * problem_size.Ns[i]));
|
||||
|
||||
gemm_descs.push_back({problem_size.Ms[i],
|
||||
problem_size.Ns[i],
|
||||
problem_size.Ks[i],
|
||||
problem_size.stride_As[i],
|
||||
problem_size.stride_Bs[i],
|
||||
problem_size.stride_Cs[i],
|
||||
{0}});
|
||||
|
||||
grouped_gemm_kernel_args_.push_back({a0_tensors_device[i]->GetDeviceBuffer(),
|
||||
b0_tensors_device[i]->GetDeviceBuffer(),
|
||||
{b1_tensors_device[i]->GetDeviceBuffer()},
|
||||
c_tensors_device[i]->GetDeviceBuffer(),
|
||||
problem_size.Ms[i],
|
||||
problem_size.Ns[i],
|
||||
problem_size.Ks[i],
|
||||
problem_size.stride_As[i],
|
||||
problem_size.stride_Bs[i],
|
||||
{0},
|
||||
problem_size.stride_Cs[i]});
|
||||
}
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmTileLoop<A0Layout,
|
||||
B0Layout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
A0DataType,
|
||||
B0DataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
std::string best_op_name;
|
||||
bool found = false;
|
||||
int best_op_id = -1;
|
||||
float best_ave_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
|
||||
// profile device operation instances
|
||||
std::cout << "Run all instances and do timing" << std::endl;
|
||||
|
||||
for(int i = 0; i < op_ptrs.size(); ++i)
|
||||
{
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
|
||||
std::vector<const void*> p_As = {};
|
||||
std::vector<const void*> p_Bs = {};
|
||||
std::vector<std::array<const void*, NumDTensor>> p_Ds = {};
|
||||
std::vector<void*> p_Cs = {};
|
||||
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(
|
||||
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op);
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
SimpleDeviceMem gemm_kernel_args_dev(
|
||||
op_ptr->GetDeviceKernelArgSize(argument_ptr.get()));
|
||||
hip_check_error(hipMemcpy(gemm_kernel_args_dev.GetDeviceBuffer(),
|
||||
grouped_gemm_kernel_args_.data(),
|
||||
op_ptr->GetDeviceKernelArgSize(argument_ptr.get()),
|
||||
hipMemcpyHostToDevice));
|
||||
|
||||
op_ptr->SetDeviceKernelArgs(argument_ptr.get(), gemm_kernel_args_dev.GetDeviceBuffer());
|
||||
|
||||
float ave_time =
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true, 0, 20, 50});
|
||||
|
||||
std::size_t flop = std::size_t(2) * sum_of_m * problem_size.Ns[0] * problem_size.Ks[0];
|
||||
|
||||
std::size_t num_btype = sizeof(A0DataType) * sum_of_m * problem_size.Ks[0] +
|
||||
sizeof(B0DataType) * problem_size.Ks[0] * problem_size.Ns[0] +
|
||||
sizeof(EDataType) * sum_of_m * problem_size.Ns[0];
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
found = true;
|
||||
best_op_id = i;
|
||||
best_op_name = op_name;
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << op_name << " does not support this problem" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
|
||||
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ProblemSize problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
problem_size.group_count = 16;
|
||||
|
||||
for(int i = 0; i < problem_size.group_count; i++)
|
||||
{
|
||||
problem_size.Ms.push_back(1 + rand() % 1024);
|
||||
problem_size.Ns.push_back(4096);
|
||||
problem_size.Ks.push_back(4096);
|
||||
|
||||
problem_size.stride_As.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Bs.push_back(problem_size.Ns[i]);
|
||||
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
|
||||
|
||||
std::cout << " M = " << problem_size.Ms[i] << " N = " << problem_size.Ns[i] << " K "
|
||||
<< problem_size.Ks[i] << std::endl;
|
||||
}
|
||||
|
||||
return !run_grouped_gemm(problem_size, config);
|
||||
}
|
||||
@@ -0,0 +1,287 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <iomanip>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp"
|
||||
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using I8 = int8_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using A0DataType = BF16;
|
||||
using AsDataType = ck::Tuple<A0DataType>;
|
||||
using B0DataType = I8;
|
||||
using B1DataType = BF16;
|
||||
using BsDataType = ck::Tuple<B0DataType, B1DataType>;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = BF16;
|
||||
using D0DataType = BF16;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using EDataType = BF16;
|
||||
|
||||
using A0Layout = Row;
|
||||
using AsLayout = ck::Tuple<A0Layout>;
|
||||
using B0Layout = Row;
|
||||
using B1Layout = B0Layout;
|
||||
using BsLayout = ck::Tuple<B0Layout, B1Layout>;
|
||||
using D0Layout = Row;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = Row;
|
||||
|
||||
using Multiply = ck::tensor_operation::element_wise::Multiply;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = Multiply;
|
||||
using CDEElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
struct SimpleDeviceMem
|
||||
{
|
||||
SimpleDeviceMem() = delete;
|
||||
|
||||
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
|
||||
{
|
||||
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
|
||||
}
|
||||
|
||||
void* GetDeviceBuffer() { return p_mem_; }
|
||||
|
||||
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
|
||||
|
||||
void* p_mem_;
|
||||
};
|
||||
|
||||
struct ProblemSize final
|
||||
{
|
||||
std::vector<ck::index_t> Ms;
|
||||
std::vector<ck::index_t> Ns;
|
||||
std::vector<ck::index_t> Ks;
|
||||
|
||||
std::vector<ck::index_t> stride_As;
|
||||
std::vector<ck::index_t> stride_Bs;
|
||||
std::vector<ck::index_t> stride_Cs;
|
||||
|
||||
ck::index_t group_count;
|
||||
};
|
||||
|
||||
struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
int k_batch = 1;
|
||||
};
|
||||
|
||||
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
auto group_count = problem_size.group_count;
|
||||
|
||||
// GEMM shape
|
||||
std::vector<ck::tensor_operation::device::GemmMultiABDDesc> gemm_descs;
|
||||
|
||||
gemm_descs.reserve(group_count);
|
||||
|
||||
int sum_of_m = 0;
|
||||
|
||||
using DeviceMemPtr = std::unique_ptr<SimpleDeviceMem>;
|
||||
|
||||
std::vector<DeviceMemPtr> a0_tensors_device, b0_tensors_device, b1_tensors_device,
|
||||
c_tensors_device;
|
||||
|
||||
a0_tensors_device.reserve(group_count);
|
||||
b0_tensors_device.reserve(group_count);
|
||||
b1_tensors_device.reserve(group_count);
|
||||
c_tensors_device.reserve(group_count);
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
sum_of_m += problem_size.Ms[i];
|
||||
}
|
||||
|
||||
constexpr ck::index_t NumATensor = 1;
|
||||
constexpr ck::index_t NumBTensor = 2;
|
||||
constexpr ck::index_t NumDTensor = 0;
|
||||
|
||||
using GroupedGemmKernelArgument = ck::tensor_operation::device::
|
||||
GroupedGemmMultiABDKernelArgument<NumATensor, NumBTensor, NumDTensor>;
|
||||
|
||||
std::vector<GroupedGemmKernelArgument> grouped_gemm_kernel_args_;
|
||||
grouped_gemm_kernel_args_.reserve(group_count);
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
a0_tensors_device.emplace_back(
|
||||
std::make_unique<SimpleDeviceMem>(sizeof(A0DataType) * sum_of_m * problem_size.Ks[i]));
|
||||
|
||||
b0_tensors_device.emplace_back(std::make_unique<SimpleDeviceMem>(
|
||||
sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i]));
|
||||
|
||||
b1_tensors_device.emplace_back(
|
||||
std::make_unique<SimpleDeviceMem>(sizeof(B1DataType) * problem_size.Ns[i]));
|
||||
|
||||
c_tensors_device.emplace_back(
|
||||
std::make_unique<SimpleDeviceMem>(sizeof(EDataType) * sum_of_m * problem_size.Ns[i]));
|
||||
|
||||
gemm_descs.push_back(
|
||||
{sum_of_m, problem_size.Ns[i], problem_size.Ks[i], {1}, {1, 1}, {}, 1});
|
||||
|
||||
grouped_gemm_kernel_args_.push_back(
|
||||
{std::array<const void*, NumATensor>{a0_tensors_device[i]->GetDeviceBuffer()},
|
||||
std::array<const void*, NumBTensor>{b0_tensors_device[i]->GetDeviceBuffer(),
|
||||
b1_tensors_device[i]->GetDeviceBuffer()},
|
||||
std::array<const void*, NumDTensor>{},
|
||||
c_tensors_device[i]->GetDeviceBuffer(),
|
||||
problem_size.Ms[i],
|
||||
problem_size.Ns[i],
|
||||
problem_size.Ks[i],
|
||||
std::array<ck::index_t, NumATensor>{problem_size.stride_As[i]},
|
||||
std::array<ck::index_t, NumBTensor>{problem_size.stride_Bs[i], 0},
|
||||
std::array<ck::index_t, NumDTensor>{},
|
||||
problem_size.stride_Cs[i]});
|
||||
}
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmMultiABDFixedNK<AsLayout,
|
||||
BsLayout,
|
||||
DsLayout,
|
||||
Row,
|
||||
AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
BF16,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
std::string best_op_name;
|
||||
bool found = false;
|
||||
int best_op_id = -1;
|
||||
float best_ave_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
|
||||
// profile device operation instances
|
||||
std::cout << "Run all instances and do timing" << std::endl;
|
||||
|
||||
for(int i = 0; i < op_ptrs.size(); ++i)
|
||||
{
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
|
||||
std::vector<std::array<const void*, NumATensor>> p_As = {};
|
||||
std::vector<std::array<const void*, NumBTensor>> p_Bs = {};
|
||||
std::vector<std::array<const void*, NumDTensor>> p_Ds = {};
|
||||
std::vector<void*> p_Cs = {};
|
||||
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(p_As, p_Bs, p_Ds, p_Cs, gemm_descs);
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
|
||||
SimpleDeviceMem gemm_kernel_args_dev(
|
||||
op_ptr->GetDeviceKernelArgSize(argument_ptr.get()));
|
||||
hip_check_error(hipMemcpy(gemm_kernel_args_dev.GetDeviceBuffer(),
|
||||
grouped_gemm_kernel_args_.data(),
|
||||
op_ptr->GetDeviceKernelArgSize(argument_ptr.get()),
|
||||
hipMemcpyHostToDevice));
|
||||
|
||||
op_ptr->SetDeviceKernelArgs(argument_ptr.get(), gemm_kernel_args_dev.GetDeviceBuffer());
|
||||
|
||||
op_ptr->SetElementwiseOps(
|
||||
argument_ptr.get(), a_element_op, b_element_op, cde_element_op);
|
||||
|
||||
float ave_time =
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true, 0, 20, 50});
|
||||
|
||||
std::size_t flop = std::size_t(2) * sum_of_m * problem_size.Ns[0] * problem_size.Ks[0];
|
||||
|
||||
std::size_t num_btype = sizeof(A0DataType) * sum_of_m * problem_size.Ks[0] +
|
||||
sizeof(B0DataType) * problem_size.Ks[0] * problem_size.Ns[0] +
|
||||
sizeof(EDataType) * sum_of_m * problem_size.Ns[0];
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
found = true;
|
||||
best_op_id = i;
|
||||
best_op_name = op_name;
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << op_name << " does not support this problem" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
|
||||
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ProblemSize problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
problem_size.group_count = 16;
|
||||
|
||||
for(int i = 0; i < problem_size.group_count; i++)
|
||||
{
|
||||
problem_size.Ms.push_back(1 + rand() % 1024);
|
||||
problem_size.Ns.push_back(4096);
|
||||
problem_size.Ks.push_back(4096);
|
||||
|
||||
problem_size.stride_As.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Bs.push_back(problem_size.Ns[i]);
|
||||
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
|
||||
|
||||
std::cout << " M = " << problem_size.Ms[i] << " N = " << problem_size.Ns[i] << " K "
|
||||
<< problem_size.Ks[i] << std::endl;
|
||||
}
|
||||
|
||||
return !run_grouped_gemm(problem_size, config);
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
if(GPU_TARGETS MATCHES "gfx9" AND ((DTYPES MATCHES "int8" AND DTYPES MATCHES "bf16") OR NOT DEFINED DTYPES))
|
||||
add_executable(client_grouped_gemm_bias_fastgelu_bf16_i8_bf16 grouped_gemm_bias_fastgelu_xdl_bf16_i8.cpp)
|
||||
target_link_libraries(client_grouped_gemm_bias_fastgelu_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations)
|
||||
|
||||
add_executable(client_grouped_gemm_fastgelu_bf16_i8_bf16 grouped_gemm_fastgelu_xdl_bf16_i8.cpp)
|
||||
target_link_libraries(client_grouped_gemm_fastgelu_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations)
|
||||
endif()
|
||||
@@ -1,2 +1,2 @@
|
||||
rocm-docs-core==0.38.1
|
||||
rocm-docs-core==1.1.1
|
||||
sphinxcontrib-bibtex==2.6.2
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#
|
||||
# This file is autogenerated by pip-compile with Python 3.8
|
||||
# This file is autogenerated by pip-compile with Python 3.10
|
||||
# by the following command:
|
||||
#
|
||||
# pip-compile requirements.in
|
||||
@@ -48,12 +48,6 @@ idna==3.4
|
||||
# via requests
|
||||
imagesize==1.4.1
|
||||
# via sphinx
|
||||
importlib-metadata==6.8.0
|
||||
# via
|
||||
# sphinx
|
||||
# sphinxcontrib-bibtex
|
||||
importlib-resources==6.1.0
|
||||
# via rocm-docs-core
|
||||
jinja2==3.1.2
|
||||
# via
|
||||
# myst-parser
|
||||
@@ -99,8 +93,6 @@ pyjwt[crypto]==2.6.0
|
||||
# via pygithub
|
||||
pynacl==1.5.0
|
||||
# via pygithub
|
||||
pytz==2023.3.post1
|
||||
# via babel
|
||||
pyyaml==6.0
|
||||
# via
|
||||
# myst-parser
|
||||
@@ -111,7 +103,7 @@ requests==2.31.0
|
||||
# via
|
||||
# pygithub
|
||||
# sphinx
|
||||
rocm-docs-core==0.38.1
|
||||
rocm-docs-core==1.1.1
|
||||
# via -r requirements.in
|
||||
six==1.16.0
|
||||
# via
|
||||
@@ -165,7 +157,3 @@ urllib3==1.26.18
|
||||
# via requests
|
||||
wrapt==1.15.0
|
||||
# via deprecated
|
||||
zipp==3.17.0
|
||||
# via
|
||||
# importlib-metadata
|
||||
# importlib-resources
|
||||
|
||||
@@ -28,6 +28,8 @@ add_example_executable(example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_v3)
|
||||
add_example_executable(example_gemm_xdl_fp16_fp8_v3 gemm_xdl_fp16_fp8_v3.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8_v3)
|
||||
add_example_executable(example_gemm_xdl_bf16_v3 gemm_xdl_bf16_v3.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_v3)
|
||||
|
||||
add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)
|
||||
|
||||
@@ -7,17 +7,3 @@
|
||||
#arg3: run kernel # of times (>1)
|
||||
./bin/example_gemm_xdl 0 1 5
|
||||
```
|
||||
|
||||
Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16)
|
||||
```
|
||||
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
|
||||
b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096}
|
||||
c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
|
||||
arg.a_grid_desc_k0_m_k1_{512, 3840, 8}
|
||||
arg.b_grid_desc_k0_n_k1_{512, 4096, 8}
|
||||
arg.c_grid_desc_m_n_{ 3840, 4096}
|
||||
launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up
|
||||
Start running 5 times...
|
||||
Perf: 1.19685 ms, 107.657 TFlops, 78.8501 GB/s
|
||||
```
|
||||
|
||||
48
example/01_gemm/gemm_xdl_bf16_v3.cpp
Normal file
48
example/01_gemm/gemm_xdl_bf16_v3.cpp
Normal file
@@ -0,0 +1,48 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
|
||||
|
||||
using ADataType = ck::bhalf_t;
|
||||
using BDataType = ck::bhalf_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = ck::bhalf_t;
|
||||
using CDataType = ck::bhalf_t;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmV2Instance =
|
||||
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3<
|
||||
ALayout, BLayout, CLayout,
|
||||
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
PassThrough, PassThrough, PassThrough, GemmDefault,
|
||||
256,
|
||||
128, 128,
|
||||
64, 8, 8,
|
||||
16, 16,
|
||||
4, 4,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 0,
|
||||
1, 2, S<1, 32, 1, 8>, 8,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
#include "run_gemm_example_v2.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); }
|
||||
@@ -3,6 +3,88 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
template <typename DataType>
|
||||
inline __host__ __device__ constexpr double get_rtol()
|
||||
{
|
||||
if constexpr(std::is_same_v<DataType, float>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, double>)
|
||||
{
|
||||
return 1e-6;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::half_t>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
|
||||
{
|
||||
return 5e-2;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int32_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int8_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
|
||||
{
|
||||
return 1e-1; // 240 and 224 are acceptable
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
|
||||
{
|
||||
return 1.5e-1; // 57344 and 49152 are acceptable
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
inline __host__ __device__ constexpr double get_atol()
|
||||
{
|
||||
if constexpr(std::is_same_v<DataType, float>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, double>)
|
||||
{
|
||||
return 1e-6;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::half_t>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
|
||||
{
|
||||
return 5e-2;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int32_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int8_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
|
||||
{
|
||||
return 16.1; // 240 and 224 are acceptable
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
|
||||
{
|
||||
return 8192.1; // 57344 and 49152 are acceptable
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ProblemType>
|
||||
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
@@ -180,7 +262,11 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
#else
|
||||
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
pass &= ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
|
||||
pass &= ck::utils::check_err(c_m_n_device_result,
|
||||
c_m_n_host_result,
|
||||
"Error: Incorrect results!",
|
||||
get_rtol<CDataType>(),
|
||||
get_atol<CDataType>());
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -9,20 +9,3 @@
|
||||
#arg11 to 12: alpha, beta
|
||||
./bin/example_gemm_bilinear_xdl_fp16 1 1 1 3840 4096 4096 4096 4096 4096 4096 0.5 0.5
|
||||
```
|
||||
Result (MI100 @ 1502Mhz, 184.6TFlops peak FP16)
|
||||
```
|
||||
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
|
||||
b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096}
|
||||
c0_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
|
||||
c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
|
||||
arg.a_grid_desc_k0_m_k1_{512, 3840, 8}
|
||||
arg.b_grid_desc_k0_n_k1_{512, 4096, 8}
|
||||
arg.c0_grid_desc_m_n_{ 3840, 4096}
|
||||
arg.c_grid_desc_m_n_{ 3840, 4096}
|
||||
launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up
|
||||
Start running 1 times...
|
||||
Perf: 0.936965 ms, 137.517 TFlops, 102.959 GB/s
|
||||
error: 0
|
||||
max_diff: 0, 558.5, 558.5
|
||||
```
|
||||
|
||||
@@ -8,16 +8,3 @@
|
||||
#arg4 to 11: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD0, StrideD1, StrideE"
|
||||
./bin/example_gemm_add_add_fastgelu_xdl_fp16 1 1 1
|
||||
```
|
||||
|
||||
Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16)
|
||||
```
|
||||
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
|
||||
b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096}
|
||||
d0_m_n: dim 2, lengths {3840, 4096}, strides {0, 1}
|
||||
d1_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
|
||||
e_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
|
||||
launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up 1 time
|
||||
Start running 10 times...
|
||||
Perf: 1.26914 ms, 101.525 TFlops, 100.804 GB/s, DeviceGemmMultipleD_Xdl_CShuffle<256, 256, 128, 32, 8, 8>
|
||||
```
|
||||
|
||||
@@ -3,8 +3,7 @@ add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp8 convnd_fwd_xdl_fp8.cpp)
|
||||
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
|
||||
add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_bf8 convnd_fwd_xdl_bf8.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp16_comp_fp8 convnd_fwd_xdl_fp16_comp_fp8.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp8_bf8 convnd_fwd_xdl_fp8_bf8.cpp)
|
||||
|
||||
@@ -16,17 +16,3 @@
|
||||
# <right padding>, (ie RightPy, RightPx for 2D)
|
||||
./bin/example_convnd_fwd_xdl 0 1 100
|
||||
```
|
||||
|
||||
Result (MI100 @ 1087Mhz, 33.4TFlops peak FP32)
|
||||
```
|
||||
input: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192}
|
||||
weights: dim 4, lengths {256, 192, 3, 3}, strides {1728, 1, 576, 192}
|
||||
output: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256}
|
||||
arg.a_grid_desc_k0_m_k1_{432, 165888, 4}
|
||||
arg.b_grid_desc_k0_n_k1_{432, 256, 4}
|
||||
arg.c_grid_desc_m_n_{ 165888, 256}
|
||||
launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up
|
||||
Start running 100 times...
|
||||
Perf: 4.43736 ms, 33.0753 TFlops, 150.357 GB/s
|
||||
```
|
||||
|
||||
@@ -26,6 +26,9 @@ add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int8)
|
||||
add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp16_fp8 grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp)
|
||||
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp16_fp8)
|
||||
|
||||
add_example_executable(example_grouped_gemm_multiple_d_xdl_fp16 grouped_gemm_multiple_d_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_multiple_d_xdl_fp16)
|
||||
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp)
|
||||
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4)
|
||||
|
||||
@@ -7,19 +7,3 @@
|
||||
#arg3: run kernel # of times (>1)
|
||||
./bin/example_grouped_gemm_xdl_fp16 0 1 5
|
||||
```
|
||||
|
||||
Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16)
|
||||
```
|
||||
gemm[0] a_m_k: dim 2, lengths {256, 64}, strides {64, 1} b_k_n: dim 2, lengths {64, 128}, strides {1, 64} c_m_n: dim 2, lengths {256, 128}, strides {128, 1}
|
||||
gemm[1] a_m_k: dim 2, lengths {512, 128}, strides {128, 1} b_k_n: dim 2, lengths {128, 256}, strides {1, 128} c_m_n: dim 2, lengths {512, 256}, strides {256, 1}
|
||||
gemm[2] a_m_k: dim 2, lengths {768, 192}, strides {192, 1} b_k_n: dim 2, lengths {192, 384}, strides {1, 192} c_m_n: dim 2, lengths {768, 384}, strides {384, 1}
|
||||
gemm[3] a_m_k: dim 2, lengths {1024, 256}, strides {256, 1} b_k_n: dim 2, lengths {256, 512}, strides {1, 256} c_m_n: dim 2, lengths {1024, 512}, strides {512, 1}
|
||||
group: 0 arg.a_grid_desc_k0_m_k1_{8, 256, 8}, arg.b_grid_desc_k0_n_k1_{8, 128, 8}, arg.c_grid_desc_m_n_{ 256, 128}
|
||||
group: 1 arg.a_grid_desc_k0_m_k1_{16, 512, 8}, arg.b_grid_desc_k0_n_k1_{16, 256, 8}, arg.c_grid_desc_m_n_{ 512, 256}
|
||||
group: 2 arg.a_grid_desc_k0_m_k1_{24, 768, 8}, arg.b_grid_desc_k0_n_k1_{24, 384, 8}, arg.c_grid_desc_m_n_{ 768, 384}
|
||||
group: 3 arg.a_grid_desc_k0_m_k1_{32, 1024, 8}, arg.b_grid_desc_k0_n_k1_{32, 512, 8}, arg.c_grid_desc_m_n_{ 1024, 512}
|
||||
launch_and_time_kernel: grid_dim {30, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up
|
||||
Start running 5 times...
|
||||
Perf: 0.037887 ms, 11.0706 TFlops, 90.8132 GB/s, DeviceGroupedGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2>
|
||||
```
|
||||
|
||||
404
example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp
Normal file
404
example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp
Normal file
@@ -0,0 +1,404 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include <ck/utility/data_type.hpp>
|
||||
#include <ck/utility/tuple.hpp>
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm_multiple_d.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddAdd = ck::tensor_operation::element_wise::AddAdd;
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using DDataType = F16;
|
||||
using DsDataType = ck::Tuple<DDataType, DDataType>;
|
||||
using EDataType = F16;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using DLayout = Row;
|
||||
using DsLayout = ck::Tuple<DLayout, DLayout>;
|
||||
using ELayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = AddAdd;
|
||||
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
static constexpr int NumDs = 2;
|
||||
|
||||
using DeviceGemmInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
// clang-format off
|
||||
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
|
||||
// clang-format on
|
||||
|
||||
struct ProblemSize final
|
||||
{
|
||||
std::vector<ck::index_t> Ms;
|
||||
std::vector<ck::index_t> Ns;
|
||||
std::vector<ck::index_t> Ks;
|
||||
|
||||
std::vector<ck::index_t> stride_As;
|
||||
std::vector<ck::index_t> stride_Bs;
|
||||
std::vector<std::vector<ck::index_t>> stride_Ds;
|
||||
std::vector<ck::index_t> stride_Cs;
|
||||
|
||||
ck::index_t group_count;
|
||||
};
|
||||
|
||||
struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
};
|
||||
|
||||
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
auto group_count = problem_size.group_count;
|
||||
|
||||
using KernelArguments = ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments<NumDs>;
|
||||
using GemmDesc = ck::tensor_operation::device::GemmDesc;
|
||||
|
||||
// GEMM shape
|
||||
std::vector<GemmDesc> gemm_descs;
|
||||
std::vector<KernelArguments> ggemm_kargs;
|
||||
std::vector<void*> p_Cs;
|
||||
std::vector<const void*> p_As;
|
||||
std::vector<const void*> p_Bs;
|
||||
std::vector<std::array<const void*, NumDs>> p_Ds = {};
|
||||
|
||||
gemm_descs.reserve(group_count);
|
||||
ggemm_kargs.reserve(group_count);
|
||||
p_As.reserve(group_count);
|
||||
p_Bs.reserve(group_count);
|
||||
p_Ds.reserve(group_count);
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<Tensor<ADataType>> a_tensors;
|
||||
std::vector<Tensor<BDataType>> b_tensors;
|
||||
std::vector<std::array<Tensor<DDataType>, NumDs>> d_tensors;
|
||||
std::vector<Tensor<EDataType>> c_host_tensors;
|
||||
std::vector<Tensor<EDataType>> c_device_result_tensors;
|
||||
|
||||
a_tensors.reserve(group_count);
|
||||
b_tensors.reserve(group_count);
|
||||
d_tensors.reserve(group_count);
|
||||
c_host_tensors.reserve(group_count);
|
||||
c_device_result_tensors.reserve(group_count);
|
||||
|
||||
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
|
||||
|
||||
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, c_tensors_device;
|
||||
std::vector<std::vector<DeviceMemPtr>> d_tensors_device;
|
||||
|
||||
a_tensors_device.reserve(group_count);
|
||||
b_tensors_device.reserve(group_count);
|
||||
d_tensors_device.reserve(group_count);
|
||||
c_tensors_device.reserve(group_count);
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{})));
|
||||
b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{})));
|
||||
|
||||
auto d0_tensor = Tensor<DDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{}));
|
||||
auto d1_tensor = Tensor<DDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{}));
|
||||
|
||||
std::array<Tensor<DDataType>, NumDs> d_tens = {d0_tensor, d1_tensor};
|
||||
d_tensors.push_back(d_tens);
|
||||
c_host_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
|
||||
c_device_result_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
|
||||
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
|
||||
<< " b_k_n: " << b_tensors[i].mDesc
|
||||
<< " c_m_n: " << c_device_result_tensors[i].mDesc << std::endl;
|
||||
|
||||
flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i];
|
||||
num_btype += sizeof(ADataType) * a_tensors[i].GetElementSize() +
|
||||
sizeof(BDataType) * b_tensors[i].GetElementSize() +
|
||||
sizeof(DDataType) * d_tensors[i][0].GetElementSize() * NumDs +
|
||||
sizeof(EDataType) * c_device_result_tensors[i].GetElementSize();
|
||||
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
|
||||
}
|
||||
break;
|
||||
case 2:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
}
|
||||
break;
|
||||
default:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
a_tensors_device.emplace_back(
|
||||
std::make_unique<DeviceMem>(a_tensors[i].GetElementSpaceSize() * sizeof(ADataType)));
|
||||
b_tensors_device.emplace_back(
|
||||
std::make_unique<DeviceMem>(b_tensors[i].GetElementSpaceSize() * sizeof(BDataType)));
|
||||
c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
|
||||
c_device_result_tensors[i].GetElementSpaceSize() * sizeof(EDataType)));
|
||||
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors_device[i].emplace_back(std::make_unique<DeviceMem>(
|
||||
d_tensors[i][j].GetElementSpaceSize() * sizeof(DDataType)));
|
||||
}
|
||||
|
||||
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
|
||||
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors_device[i][j]->ToDevice(d_tensors[i][j].mData.data());
|
||||
}
|
||||
c_tensors_device[i]->SetZero();
|
||||
|
||||
p_As.push_back(a_tensors_device[i]->GetDeviceBuffer());
|
||||
p_Bs.push_back(b_tensors_device[i]->GetDeviceBuffer());
|
||||
p_Ds.push_back(
|
||||
{d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()});
|
||||
p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer());
|
||||
|
||||
// The device op does not have to know M problem size at lunch time.
|
||||
gemm_descs.push_back({0,
|
||||
problem_size.Ns[i],
|
||||
problem_size.Ks[i],
|
||||
problem_size.stride_As[i],
|
||||
problem_size.stride_Bs[i],
|
||||
problem_size.stride_Cs[i],
|
||||
{problem_size.stride_Cs[i], problem_size.stride_Cs[i]}});
|
||||
ggemm_kargs.push_back(
|
||||
{a_tensors_device[i]->GetDeviceBuffer(),
|
||||
b_tensors_device[i]->GetDeviceBuffer(),
|
||||
{d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()},
|
||||
c_tensors_device[i]->GetDeviceBuffer(),
|
||||
problem_size.Ms[i],
|
||||
problem_size.Ns[i],
|
||||
problem_size.Ks[i],
|
||||
problem_size.stride_As[i],
|
||||
problem_size.stride_Bs[i],
|
||||
{problem_size.stride_Cs[i], problem_size.stride_Cs[i]},
|
||||
problem_size.stride_Cs[i]});
|
||||
}
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
|
||||
// do GEMM
|
||||
auto argument = gemm.MakeArgument(
|
||||
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op);
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument));
|
||||
hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(),
|
||||
ggemm_kargs.data(),
|
||||
gemm.GetDeviceKernelArgSize(&argument),
|
||||
hipMemcpyHostToDevice));
|
||||
gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer());
|
||||
|
||||
invoker.Run(argument, StreamConfig{nullptr, false, 1});
|
||||
|
||||
bool pass = true;
|
||||
if(config.do_verification)
|
||||
{
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceGemmMultipleD<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
{
|
||||
auto karg = ggemm_kargs[i];
|
||||
auto dev_res_tensor =
|
||||
Tensor<float>(f_host_tensor_descriptor(karg.M, karg.N, karg.StrideE, ELayout{}));
|
||||
c_tensors_device[i]->FromDevice(c_device_result_tensors[i].mData.data());
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(a_tensors[i],
|
||||
b_tensors[i],
|
||||
d_tensors[i],
|
||||
c_host_tensors[i],
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
pass &= ck::utils::check_err(c_device_result_tensors[i], c_host_tensors[i]);
|
||||
}
|
||||
|
||||
std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl;
|
||||
}
|
||||
|
||||
if(config.time_kernel)
|
||||
{
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << gemm.GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
std::vector<int> argToIntArray(char* input)
|
||||
{
|
||||
std::vector<int> out;
|
||||
std::istringstream in(input);
|
||||
std::string item;
|
||||
|
||||
while(std::getline(in, item, ','))
|
||||
{
|
||||
out.push_back(std::stoi(item));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ProblemSize problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
if(argc < 10)
|
||||
{
|
||||
std::vector<ck::index_t> Ms{64, 127, 255, 129, 260, 190, 77};
|
||||
problem_size.group_count = Ms.size();
|
||||
|
||||
for(int i = 0; i < problem_size.group_count; i++)
|
||||
{
|
||||
problem_size.Ms.push_back(Ms[i]);
|
||||
problem_size.Ns.push_back(252);
|
||||
problem_size.Ks.push_back(4608);
|
||||
|
||||
problem_size.stride_As.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
|
||||
|
||||
problem_size.stride_Ds.push_back({});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
problem_size.stride_Ds[i].push_back(problem_size.Ns[i]);
|
||||
}
|
||||
}
|
||||
|
||||
std::cout
|
||||
<< "Usage:\n"
|
||||
<< "arg1: verification (0=no, 1=yes)\n"
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
|
||||
<< "arg3: time kernel (0=n0, 1=yes)\n"
|
||||
<< "arg4 to 9: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
|
||||
"64,64 64,64 128,128)\n"
|
||||
<< "... setting default values." << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
|
||||
problem_size.Ms = argToIntArray(argv[4]);
|
||||
problem_size.Ns = argToIntArray(argv[5]);
|
||||
problem_size.Ks = argToIntArray(argv[6]);
|
||||
|
||||
problem_size.stride_As = argToIntArray(argv[7]);
|
||||
problem_size.stride_Bs = argToIntArray(argv[8]);
|
||||
problem_size.stride_Cs = argToIntArray(argv[9]);
|
||||
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
problem_size.stride_Ds.push_back(problem_size.stride_Cs);
|
||||
}
|
||||
|
||||
problem_size.group_count = problem_size.Ms.size();
|
||||
}
|
||||
|
||||
return !run_grouped_gemm(problem_size, config);
|
||||
}
|
||||
@@ -7,14 +7,3 @@
|
||||
#arg3: time kernel (0=no, 1=yes)
|
||||
./bin/example_contraction_bilinear_xdl_fp32 1 1 1
|
||||
```
|
||||
|
||||
Result (MI100 @ dynammic freq, 46TFlops peak FP32)
|
||||
```
|
||||
a_ms_ks: dim 4, lengths {30, 128, 32, 64}, strides {524288, 4096, 128, 1}
|
||||
b_ks_ns: dim 4, lengths {32, 64, 32, 64}, strides {128, 1, 524288, 4096}
|
||||
c_ms_ns: dim 4, lengths {30, 128, 32, 64}, strides {524288, 4096, 128, 1}
|
||||
launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up 1 time
|
||||
Start running 10 times...
|
||||
Perf: 0.843286 ms, 38.1985 TFlops, 94.5014 GB/s, DeviceContractionMultipleD_Xdl_CShuffle<256, 256, 128, 16, 4, 4>
|
||||
```
|
||||
|
||||
@@ -16,15 +16,3 @@ Following arguments (depending on number of spatial dims):
|
||||
./bin/example_grouped_conv_fwd_bias_relu_add_xdl_fp16 1 1 1
|
||||
```
|
||||
|
||||
Result (MI100)
|
||||
```
|
||||
in: dim 5, lengths {1, 128, 192, 71, 71}, strides {192, 967872, 1, 13632, 192}
|
||||
wei: dim 5, lengths {1, 256, 192, 3, 3}, strides {442368, 1728, 1, 576, 192}
|
||||
bias: dim 5, lengths {1, 128, 256, 36, 36}, strides {256, 0, 1, 0, 0}
|
||||
residual: dim 5, lengths {1, 128, 256, 36, 36}, strides {256, 0, 1, 0, 0}
|
||||
out: dim 5, lengths {1, 128, 256, 36, 36}, strides {256, 331776, 1, 9216, 256}
|
||||
launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up 1 time
|
||||
Start running 10 times...
|
||||
Perf: 1.55981 ms, 94.0927 TFlops, 213.868 GB/s, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 256, 16, Default>
|
||||
```
|
||||
|
||||
@@ -8,19 +8,3 @@
|
||||
#arg4 to 11: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD0, StrideD1, StrideE"
|
||||
./bin/example_gemm_add_multiply_dl_fp16 1 1 1
|
||||
```
|
||||
|
||||
Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16)
|
||||
```
|
||||
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
|
||||
b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1}
|
||||
d0_m_n: dim 2, lengths {3840, 4096}, strides {0, 1}
|
||||
d1_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
|
||||
e_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
|
||||
arg.a_grid_desc_k0_m0_m1_k1_{2048, 3840, 2}
|
||||
arg.b_grid_desc_k0_n0_n1_k1_{2048, 4096, 2}
|
||||
arg.e_grid_desc_m_n_{ 3840, 4096}
|
||||
launch_and_time_kernel: grid_dim {960, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up 1 time
|
||||
Start running 10 times...
|
||||
Perf: 3.99904 ms, 32.22 TFlops, 31.9913 GB/s, DeviceGemmMultipleD_Dl<256, 128, 128, 16, 2, 4, 4, 1>
|
||||
```
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
add_custom_target(example_grouped_gemm_xdl_multi_abd)
|
||||
|
||||
add_example_executable(example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16 grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16)
|
||||
add_example_dependencies(example_grouped_gemm_xdl_multi_abd example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16)
|
||||
|
||||
add_example_executable(example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8 grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp)
|
||||
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8)
|
||||
add_example_dependencies(example_grouped_gemm_xdl_multi_abd example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8)
|
||||
|
||||
@@ -52,12 +52,12 @@ using BsLayout = ck::Tuple<B0Layout, B1Layout>;
|
||||
using DsLayout = ck::Tuple<Row>;
|
||||
using ELayout = Row;
|
||||
|
||||
using Scales = ck::tensor_operation::element_wise::Scales;
|
||||
using Multiply = ck::tensor_operation::element_wise::Multiply;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = Scales;
|
||||
using BElementOp = Multiply;
|
||||
using CDEElementOp = AddFastGelu;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
@@ -1,2 +1,4 @@
|
||||
add_example_executable(example_gemm_multi_ABD_xdl_fp16 gemm_multi_ABD_xdl_fp16.cpp)
|
||||
add_example_executable(example_gemm_multi_ABD_xdl_bf16_i8 gemm_multi_ABD_xdl_bf16_i8.cpp)
|
||||
add_example_executable(example_gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8 gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8.cpp)
|
||||
add_example_executable(example_gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8 gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp)
|
||||
add_example_executable(example_gemm_multi_ABD_xdl_fastgelu_bf16_i8 gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp)
|
||||
|
||||
@@ -18,9 +18,12 @@
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
|
||||
#include "ck/utility/blkgemmpipe_scheduler.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using I8 = int8_t;
|
||||
using F32 = float;
|
||||
@@ -41,22 +44,22 @@ using EDataType = BF16;
|
||||
|
||||
using A0Layout = Row;
|
||||
using AsLayout = ck::Tuple<A0Layout>;
|
||||
using B0Layout = Col;
|
||||
using B0Layout = Row;
|
||||
using B1Layout = B0Layout;
|
||||
using BsLayout = ck::Tuple<B0Layout, B1Layout>;
|
||||
using D0Layout = Row;
|
||||
using DsLayout = ck::Tuple<D0Layout>;
|
||||
using ELayout = Row;
|
||||
|
||||
using Scales = ck::tensor_operation::element_wise::Scales;
|
||||
using Multiply = ck::tensor_operation::element_wise::Multiply;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = Scales;
|
||||
using BElementOp = Multiply;
|
||||
using CDEElementOp = AddFastGelu;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl_CShuffle
|
||||
// clang-format off
|
||||
@@ -64,9 +67,9 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl
|
||||
///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 16, 128, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>;
|
||||
|
||||
< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v4>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
@@ -74,13 +77,13 @@ int main(int argc, char* argv[])
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 64;
|
||||
ck::index_t N = 1024;
|
||||
ck::index_t K = 512;
|
||||
ck::index_t M = 4096;
|
||||
ck::index_t N = 768;
|
||||
ck::index_t K = 6144;
|
||||
|
||||
ck::index_t StrideA = K;
|
||||
ck::index_t StrideB = K;
|
||||
ck::index_t StrideD = N;
|
||||
ck::index_t StrideB = N;
|
||||
ck::index_t StrideD = 0;
|
||||
ck::index_t StrideE = N;
|
||||
|
||||
if(argc == 1)
|
||||
@@ -0,0 +1,273 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
|
||||
#include "ck/utility/blkgemmpipe_scheduler.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using I8 = int8_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using A0DataType = BF16;
|
||||
using AsDataType = ck::Tuple<A0DataType>;
|
||||
using B0DataType = I8;
|
||||
using B1DataType = BF16;
|
||||
using BsDataType = ck::Tuple<B0DataType, B1DataType>;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using D0DataType = BF16;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using EDataType = BF16;
|
||||
|
||||
using A0Layout = Row;
|
||||
using AsLayout = ck::Tuple<A0Layout>;
|
||||
using B0Layout = Row;
|
||||
using B1Layout = B0Layout;
|
||||
using BsLayout = ck::Tuple<B0Layout, B1Layout>;
|
||||
using D0Layout = Row;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = Row;
|
||||
|
||||
using Multiply = ck::tensor_operation::element_wise::Multiply;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = Multiply;
|
||||
using CDEElementOp = FastGelu;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl_CShuffle
|
||||
// clang-format off
|
||||
///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v4>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 4096;
|
||||
ck::index_t N = 768;
|
||||
ck::index_t K = 6144;
|
||||
|
||||
ck::index_t StrideA = K;
|
||||
ck::index_t StrideB = N;
|
||||
ck::index_t StrideD = 0;
|
||||
ck::index_t StrideE = N;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 11)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
|
||||
StrideA = std::stoi(argv[7]);
|
||||
StrideB = std::stoi(argv[8]);
|
||||
StrideD = std::stoi(argv[9]);
|
||||
StrideE = std::stoi(argv[10]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
|
||||
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
|
||||
Tensor<B1DataType> b1_k_n(f_host_tensor_descriptor(K, N, 0, B1Layout{}));
|
||||
Tensor<D0DataType> d_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{}));
|
||||
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
|
||||
std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl;
|
||||
std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl;
|
||||
std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl;
|
||||
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
|
||||
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-5, 5});
|
||||
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_2<B1DataType>{0, 5});
|
||||
d_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 5});
|
||||
d_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
|
||||
}
|
||||
|
||||
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d_device_buf(sizeof(D0DataType) * d_m_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a0_device_buf.ToDevice(a0_m_k.mData.data());
|
||||
b0_device_buf.ToDevice(b0_k_n.mData.data());
|
||||
b1_device_buf.ToDevice(b1_k_n.mData.data());
|
||||
d_device_buf.ToDevice(d_m_n.mData.data());
|
||||
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
constexpr ck::index_t NumATensor = 1;
|
||||
constexpr ck::index_t NumBTensor = 2;
|
||||
constexpr ck::index_t NumDTensor = 0;
|
||||
|
||||
// do GEMM
|
||||
auto device_op = DeviceOpInstance{};
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument =
|
||||
device_op.MakeArgument(std::array<const void*, NumATensor>{a0_device_buf.GetDeviceBuffer()},
|
||||
std::array<const void*, NumBTensor>{b0_device_buf.GetDeviceBuffer(),
|
||||
b1_device_buf.GetDeviceBuffer()},
|
||||
std::array<const void*, NumDTensor>{},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
std::array<ck::index_t, NumATensor>{StrideA},
|
||||
std::array<ck::index_t, NumBTensor>{StrideB, 0},
|
||||
std::array<ck::index_t, NumDTensor>{},
|
||||
StrideE,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
Tensor<CShuffleDataType> c_m_n({M, N});
|
||||
|
||||
Tensor<A0DataType> a_m_k({M, K});
|
||||
|
||||
Tensor<B1DataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
b_element_op(b_k_n(k, n), b0_k_n(k, n), b1_k_n(k, n));
|
||||
}
|
||||
}
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataType,
|
||||
B1DataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a0_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n));
|
||||
}
|
||||
}
|
||||
|
||||
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -0,0 +1,274 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
|
||||
#include "ck/utility/blkgemmpipe_scheduler.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using I8 = int8_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using A0DataType = BF16;
|
||||
using AsDataType = ck::Tuple<A0DataType>;
|
||||
using B0DataType = I8;
|
||||
using B1DataType = BF16;
|
||||
using BsDataType = ck::Tuple<B0DataType>;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using D0DataType = BF16;
|
||||
using DsDataType = ck::Tuple<B1DataType, D0DataType>;
|
||||
using EDataType = BF16;
|
||||
|
||||
using A0Layout = Row;
|
||||
using AsLayout = ck::Tuple<A0Layout>;
|
||||
using B0Layout = Row;
|
||||
using B1Layout = B0Layout;
|
||||
using BsLayout = ck::Tuple<B0Layout>;
|
||||
using D0Layout = Row;
|
||||
using DsLayout = ck::Tuple<B1Layout, D0Layout>;
|
||||
using ELayout = Row;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = MultiplyAddFastGelu;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl_CShuffle
|
||||
// clang-format off
|
||||
///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v4>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 4096;
|
||||
ck::index_t N = 768;
|
||||
ck::index_t K = 6144;
|
||||
|
||||
ck::index_t StrideA = K;
|
||||
ck::index_t StrideB = N;
|
||||
ck::index_t StrideD = 0;
|
||||
ck::index_t StrideE = N;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 11)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
|
||||
StrideA = std::stoi(argv[7]);
|
||||
StrideB = std::stoi(argv[8]);
|
||||
StrideD = std::stoi(argv[9]);
|
||||
StrideE = std::stoi(argv[10]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
|
||||
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
|
||||
Tensor<B1DataType> b1_k_n(f_host_tensor_descriptor(K, N, 0, B1Layout{}));
|
||||
Tensor<D0DataType> d_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{}));
|
||||
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
|
||||
std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl;
|
||||
std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl;
|
||||
std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl;
|
||||
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
|
||||
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-5, 5});
|
||||
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_2<B1DataType>{0, 5});
|
||||
d_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 5});
|
||||
d_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
|
||||
}
|
||||
|
||||
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d_device_buf(sizeof(D0DataType) * d_m_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a0_device_buf.ToDevice(a0_m_k.mData.data());
|
||||
b0_device_buf.ToDevice(b0_k_n.mData.data());
|
||||
b1_device_buf.ToDevice(b1_k_n.mData.data());
|
||||
d_device_buf.ToDevice(d_m_n.mData.data());
|
||||
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
constexpr ck::index_t NumATensor = 1;
|
||||
constexpr ck::index_t NumBTensor = 1;
|
||||
constexpr ck::index_t NumDTensor = 2;
|
||||
|
||||
// do GEMM
|
||||
auto device_op = DeviceOpInstance{};
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument =
|
||||
device_op.MakeArgument(std::array<const void*, NumATensor>{a0_device_buf.GetDeviceBuffer()},
|
||||
std::array<const void*, NumBTensor>{b0_device_buf.GetDeviceBuffer()},
|
||||
std::array<const void*, NumDTensor>{b1_device_buf.GetDeviceBuffer(),
|
||||
d_device_buf.GetDeviceBuffer()},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
std::array<ck::index_t, NumATensor>{StrideA},
|
||||
std::array<ck::index_t, NumBTensor>{StrideB},
|
||||
std::array<ck::index_t, NumDTensor>{0, StrideD},
|
||||
StrideE,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
Tensor<CShuffleDataType> c_m_n({M, N});
|
||||
|
||||
Tensor<A0DataType> a_m_k({M, K});
|
||||
|
||||
Tensor<B1DataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
|
||||
|
||||
#if 0
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
b_element_op(b_k_n(k, n), b0_k_n(k, n), b1_k_n(k, n));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataType,
|
||||
B0DataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a0_m_k, b0_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), b1_k_n(0, n), d_m_n(m, n));
|
||||
}
|
||||
}
|
||||
|
||||
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -44,9 +44,9 @@ args:
|
||||
-range_v per-tensor quantization range of v. used if squant=1. (default:16)
|
||||
-range_p per-tensor quantization range of p [e^(s-m)]. used if squant=1. (default:1)
|
||||
-range_o per-tensor quantization range of o (p*v). used if squant=1. (default:16)
|
||||
-squant if using static quantization fusion or not. 0: original flow(not prefered) (default:0)
|
||||
1: apply scale_p and scale_o with respect to P and O. calculate scale_s, scale_p,
|
||||
scale_o according to range_q, range_k, range_v, range_p, range_o
|
||||
-squant if using static quantization fusion or not. auto: fp8 will default use squant, other will not (default:auto)
|
||||
0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to P and O.
|
||||
calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, range_p, range_o
|
||||
-iperm permute input (default:1)
|
||||
if true, will be b*h*s*d, else b*s*h*d
|
||||
-operm permute output (default:1)
|
||||
@@ -64,8 +64,11 @@ args:
|
||||
-vlayout r for row-major(seqlen*hdim), c for col-major(hdim*seqlen) (default:r)
|
||||
-lse 0 not store lse, 1 store lse (default:0)
|
||||
-kname if set to 1 will print kernel name (default:0)
|
||||
-init init method. 0:random int, 1:random float, 2:trig float, 3:quantization (default:1)
|
||||
-init init method. ui, uniform random int, ni, normalized random int (default:uf)
|
||||
uf, uniform random float, nf, normalized random float, tf, trig float, uf:q, quantization
|
||||
-seed random seed used for initializing input tensors. 0 for non-deterministic seed (default:11939)
|
||||
-warmup number of iterations before benchmark the kernel (default:5)
|
||||
-repeat number of iterations to benchmark the kernel (default:20)
|
||||
```
|
||||
Example: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case.
|
||||
|
||||
|
||||
@@ -60,12 +60,14 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("range_v", "16", "per-tensor quantization range of v. used if squant=1.")
|
||||
.insert("range_p", "1", "per-tensor quantization range of p [e^(s-m)]. used if squant=1.")
|
||||
.insert("range_o", "16", "per-tensor quantization range of o (p*v). used if squant=1.")
|
||||
.insert(
|
||||
"squant",
|
||||
"0",
|
||||
"if using static quantization fusion or not. 0: original flow(not prefered)\n"
|
||||
"1: apply scale_p and scale_o with respect to P and O. calculate scale_s, scale_p,\n"
|
||||
"scale_o according to range_q, range_k, range_v, range_p, range_o")
|
||||
.insert("squant",
|
||||
"auto",
|
||||
"if using static quantization fusion or not. auto: fp8 will default use squant, "
|
||||
"other will not\n"
|
||||
"0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to "
|
||||
"P and O.\n"
|
||||
"calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, "
|
||||
"range_p, range_o")
|
||||
.insert("iperm",
|
||||
"1",
|
||||
"permute input\n"
|
||||
@@ -92,8 +94,11 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("vlayout", "r", "r for row-major(seqlen*hdim), c for col-major(hdim*seqlen)")
|
||||
.insert("lse", "0", "0 not store lse, 1 store lse")
|
||||
.insert("kname", "0", "if set to 1 will print kernel name")
|
||||
.insert(
|
||||
"init", "1", "init method. 0:random int, 1:random float, 2:trig float, 3:quantization")
|
||||
.insert("init",
|
||||
"uf",
|
||||
"init method. ui, uniform random int, ni, normalized random int\n"
|
||||
"uf, uniform random float, nf, normalized random float, tf, trig float, uf:q, "
|
||||
"quantization")
|
||||
.insert("seed",
|
||||
"11939",
|
||||
"random seed used for initializing input tensors. 0 for "
|
||||
@@ -110,7 +115,7 @@ auto create_args(int argc, char* argv[])
|
||||
|
||||
// different threshold for different dtype
|
||||
template <typename DataType>
|
||||
auto get_elimit(int /*init_method*/)
|
||||
auto get_elimit(std::string /*init_method*/)
|
||||
{
|
||||
double rtol = 1e-3;
|
||||
double atol = 1e-3;
|
||||
@@ -118,17 +123,32 @@ auto get_elimit(int /*init_method*/)
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::bf16_t>(int /*init_method*/)
|
||||
auto get_elimit<ck_tile::bf16_t>(std::string init_method)
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
if(init_method == "ui" || init_method == "ni")
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
else if(init_method == "nf")
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
else
|
||||
{
|
||||
double rtol = 3e-3;
|
||||
double atol = 3e-3;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::fp8_t>(int init_method)
|
||||
auto get_elimit<ck_tile::fp8_t>(std::string init_method)
|
||||
{
|
||||
if(init_method == 0)
|
||||
if(init_method == "ui" || init_method == "ni")
|
||||
{
|
||||
unsigned max_rounding_point_distance = 0;
|
||||
double atol = 2e-3;
|
||||
@@ -176,15 +196,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(scale_s == .0f)
|
||||
scale_s = 1.0 / ck_tile::sqrt(static_cast<float>(hdim_q)); // TODO: q ? v ?
|
||||
|
||||
bool squant = arg_parser.get_bool("squant");
|
||||
if constexpr(!std::is_same_v<DataType, ck_tile::fp8_t>)
|
||||
{
|
||||
if(squant)
|
||||
std::string squant_str = arg_parser.get_str("squant");
|
||||
bool squant = [&]() {
|
||||
if(squant_str == "auto")
|
||||
{
|
||||
std::cerr << "static quantization only support fp8 for now" << std::endl;
|
||||
return false;
|
||||
if(data_type == "fp8")
|
||||
return true;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
return atoi(squant_str.c_str()) != 0 ? true : false;
|
||||
}();
|
||||
|
||||
float range_q = arg_parser.get_float("range_q");
|
||||
float range_k = arg_parser.get_float("range_k");
|
||||
@@ -226,7 +249,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
s_randval = true;
|
||||
}
|
||||
|
||||
int init_method = arg_parser.get_int("init");
|
||||
std::string init_method = arg_parser.get_str("init");
|
||||
std::optional<uint32_t> seed = arg_parser.get_uint32("seed");
|
||||
if(*seed == 0)
|
||||
{
|
||||
@@ -339,28 +362,43 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
|
||||
|
||||
if(init_method == 0)
|
||||
if(init_method == "ui" || init_method == "0")
|
||||
{
|
||||
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-2.f, 2.f, seed}(q_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-2.f, 2.f, seed}(k_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-2.f, 2.f, seed}(v_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-2.f, 2.f, seed}(bias_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f, seed}(q_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-3.f, 3.f, seed}(k_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f, seed}(v_host);
|
||||
ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-3.f, 3.f, seed}(bias_host);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
else if(init_method == "ni")
|
||||
{
|
||||
ck_tile::FillNormalDistributionIntegerValue<QDataType>{-3.f, 3.f, seed}(q_host);
|
||||
ck_tile::FillNormalDistributionIntegerValue<KDataType>{-3.f, 3.f, seed}(k_host);
|
||||
ck_tile::FillNormalDistributionIntegerValue<VDataType>{-3.f, 3.f, seed}(v_host);
|
||||
ck_tile::FillNormalDistributionIntegerValue<BiasDataType>{-3.f, 3.f, seed}(bias_host);
|
||||
}
|
||||
else if(init_method == "uf" || init_method == "1")
|
||||
{
|
||||
ck_tile::FillUniformDistribution<QDataType>{0.f, 1.f, seed}(q_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(k_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(v_host);
|
||||
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
|
||||
}
|
||||
else if(init_method == 2)
|
||||
else if(init_method == "nf")
|
||||
{
|
||||
ck_tile::FillNormalDistribution<QDataType>{0.f, 3.f, seed}(q_host);
|
||||
ck_tile::FillNormalDistribution<KDataType>{0.f, 3.f, seed}(k_host);
|
||||
ck_tile::FillNormalDistribution<VDataType>{0.f, 3.f, seed}(v_host);
|
||||
ck_tile::FillNormalDistribution<BiasDataType>{0.f, 3.f, seed}(bias_host);
|
||||
}
|
||||
else if(init_method == "tf" || init_method == "2")
|
||||
{
|
||||
ck_tile::FillTrigValue<QDataType>{}(q_host);
|
||||
ck_tile::FillTrigValue<KDataType>{}(k_host);
|
||||
ck_tile::FillTrigValue<VDataType>{}(v_host);
|
||||
ck_tile::FillTrigValue<BiasDataType>{}(bias_host);
|
||||
}
|
||||
else if(init_method == 3) // suitable for fp8 quantization
|
||||
else if(init_method == "ufq" || init_method == "uf:q" ||
|
||||
init_method == "3") // suitable for fp8 quantization
|
||||
{
|
||||
ck_tile::FillUniformDistribution<QDataType>{-dtype_max, dtype_max, seed}(q_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{-dtype_max, dtype_max, seed}(k_host);
|
||||
|
||||
@@ -4,12 +4,19 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck/config.h"
|
||||
#include "ck/utility/env.hpp"
|
||||
|
||||
#ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS
|
||||
#include "hip/hip_runtime.h"
|
||||
#include "hip/hip_fp16.h"
|
||||
#endif
|
||||
|
||||
// environment variable to enable logging:
|
||||
// export CK_LOGGING=ON or CK_LOGGING=1 or CK_LOGGING=ENABLED
|
||||
CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
|
||||
|
||||
// to do: add various levels of logging with CK_LOG_LEVEL
|
||||
|
||||
#define CK_TIME_KERNEL 1
|
||||
|
||||
// constant address space for kernel parameter
|
||||
@@ -225,17 +232,17 @@
|
||||
// workaround: compiler issue on gfx908
|
||||
#define CK_WORKAROUND_SWDEV_388832 1
|
||||
|
||||
// flag to enable (1) or disable (0) the debugging output in some kernels
|
||||
#define DEBUG_LOG 0
|
||||
|
||||
// denorm test fix, required to work around dissue
|
||||
#ifndef CK_WORKAROUND_DENORM_FIX
|
||||
#define CK_WORKAROUND_DENORM_FIX 0
|
||||
#else
|
||||
// enable only on MI200
|
||||
// enable only for gfx90a
|
||||
#define CK_WORKAROUND_DENORM_FIX = CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
|
||||
#endif // CK_WORKAROUND_DENORM_FIX
|
||||
|
||||
// set flag to 1 to build deprecated instances
|
||||
#define CK_BUILD_DEPRECATED 1
|
||||
|
||||
namespace ck {
|
||||
|
||||
enum struct InMemoryDataOperationEnum
|
||||
|
||||
@@ -65,20 +65,20 @@ inline bool is_lds_direct_load_supported()
|
||||
ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942";
|
||||
}
|
||||
|
||||
inline bool is_navi1_supported()
|
||||
inline bool is_gfx101_supported()
|
||||
{
|
||||
return ck::get_device_name() == "gfx1010" || ck::get_device_name() == "gfx1011" ||
|
||||
ck::get_device_name() == "gfx1012";
|
||||
}
|
||||
|
||||
inline bool is_navi2_supported()
|
||||
inline bool is_gfx103_supported()
|
||||
{
|
||||
return ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx1031" ||
|
||||
ck::get_device_name() == "gfx1032" || ck::get_device_name() == "gfx1034" ||
|
||||
ck::get_device_name() == "gfx1035" || ck::get_device_name() == "gfx1036";
|
||||
}
|
||||
|
||||
inline bool is_navi3_supported()
|
||||
inline bool is_gfx11_supported()
|
||||
{
|
||||
return ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
|
||||
ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103";
|
||||
|
||||
233
include/ck/host_utility/flush_cache.hpp
Normal file
233
include/ck/host_utility/flush_cache.hpp
Normal file
@@ -0,0 +1,233 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/stream_config.hpp"
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
#include "ck/utility/flush_icache.hpp"
|
||||
namespace ck {
|
||||
namespace utility {
|
||||
|
||||
template <typename Argument>
|
||||
struct RotatingMemWrapper
|
||||
{
|
||||
using ADataType = decltype(Argument::p_a_grid);
|
||||
using BDataType = decltype(Argument::p_b_grid);
|
||||
|
||||
RotatingMemWrapper() = delete;
|
||||
RotatingMemWrapper(Argument& arg_,
|
||||
std::size_t rotating_count_,
|
||||
std::size_t size_a_,
|
||||
std::size_t size_b_)
|
||||
: arg(arg_), rotating_count(rotating_count_), size_a(size_a_), size_b(size_b_)
|
||||
{
|
||||
p_a_grids.push_back(arg.p_a_grid);
|
||||
p_b_grids.push_back(arg.p_b_grid);
|
||||
for(size_t i = 1; i < rotating_count; i++)
|
||||
{
|
||||
{
|
||||
void* pADeviceBuf;
|
||||
hip_check_error(hipMalloc(static_cast<void**>(&pADeviceBuf), size_a_));
|
||||
hip_check_error(hipMemcpy(static_cast<void*>(pADeviceBuf),
|
||||
const_cast<void*>(p_a_grids[0]),
|
||||
size_a_,
|
||||
hipMemcpyDeviceToDevice));
|
||||
p_a_grids.push_back(pADeviceBuf);
|
||||
}
|
||||
|
||||
{
|
||||
void* pBDeviceBuf;
|
||||
hip_check_error(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_b_));
|
||||
hip_check_error(hipMemcpy(static_cast<void*>(pBDeviceBuf),
|
||||
const_cast<void*>(p_b_grids[0]),
|
||||
size_b_,
|
||||
hipMemcpyDeviceToDevice));
|
||||
p_b_grids.push_back(pBDeviceBuf);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Next()
|
||||
{
|
||||
if(rotating_count > 1)
|
||||
{
|
||||
std::size_t idx = iter++ % rotating_count;
|
||||
arg.p_a_grid = reinterpret_cast<ADataType>(p_a_grids[idx]);
|
||||
arg.p_b_grid = reinterpret_cast<BDataType>(p_b_grids[idx]);
|
||||
}
|
||||
}
|
||||
void Print()
|
||||
{
|
||||
std::cout << "RotatingMemWrapper: { size_a: " << size_a << ", size_b: " << size_b
|
||||
<< ", rotating_count: " << rotating_count << "}" << std::endl;
|
||||
}
|
||||
~RotatingMemWrapper()
|
||||
{
|
||||
if(rotating_count > 1)
|
||||
{
|
||||
// restore ptr
|
||||
arg.p_a_grid = reinterpret_cast<ADataType>(p_a_grids[0]);
|
||||
arg.p_b_grid = reinterpret_cast<BDataType>(p_b_grids[0]);
|
||||
|
||||
// free device mem
|
||||
for(size_t i = 1; i < rotating_count; i++)
|
||||
{
|
||||
hip_check_error(hipFree(const_cast<void*>(p_a_grids[i])));
|
||||
hip_check_error(hipFree(const_cast<void*>(p_b_grids[i])));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Argument& arg;
|
||||
std::size_t iter = 0;
|
||||
std::size_t rotating_count = 1;
|
||||
std::size_t size_a = 0;
|
||||
std::size_t size_b = 0;
|
||||
std::vector<const void*> p_a_grids;
|
||||
std::vector<const void*> p_b_grids;
|
||||
};
|
||||
|
||||
inline void flush_icache()
|
||||
{
|
||||
hipDeviceProp_t deviceProps;
|
||||
hip_check_error(hipGetDeviceProperties(&deviceProps, 0));
|
||||
int32_t gpu_block3 = deviceProps.multiProcessorCount * 60;
|
||||
|
||||
ck::flush_icache<<<dim3(gpu_block3), dim3(64), 0, nullptr>>>();
|
||||
hip_check_error(hipGetLastError());
|
||||
}
|
||||
// if TimePrePress == false, return time does not include preprocess's time
|
||||
template <bool TimePreprocess, typename Args, typename F, typename PreProcessFunc>
|
||||
float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
|
||||
PreProcessFunc preprocess,
|
||||
F kernel,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
std::size_t lds_byte,
|
||||
Args& args)
|
||||
{
|
||||
#if CK_TIME_KERNEL
|
||||
#define MEDIAN 1
|
||||
if(stream_config.time_kernel_)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n",
|
||||
__func__,
|
||||
grid_dim.x,
|
||||
grid_dim.y,
|
||||
grid_dim.z,
|
||||
block_dim.x,
|
||||
block_dim.y,
|
||||
block_dim.z);
|
||||
|
||||
printf("Warm up %d times\n", stream_config.cold_niters_);
|
||||
}
|
||||
// warm up
|
||||
for(int i = 0; i < stream_config.cold_niters_; ++i)
|
||||
{
|
||||
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args);
|
||||
hip_check_error(hipGetLastError());
|
||||
}
|
||||
|
||||
const int nrepeat = stream_config.nrepeat_;
|
||||
if(nrepeat == 0)
|
||||
{
|
||||
return 0.0;
|
||||
}
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
printf("Start running %d times...\n", nrepeat);
|
||||
}
|
||||
|
||||
#if MEDIAN
|
||||
std::set<float> times;
|
||||
#else
|
||||
float total_time = 0;
|
||||
#endif
|
||||
for(int i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
if constexpr(!TimePreprocess)
|
||||
{
|
||||
preprocess();
|
||||
}
|
||||
|
||||
hipEvent_t start, stop;
|
||||
|
||||
hip_check_error(hipEventCreate(&start));
|
||||
hip_check_error(hipEventCreate(&stop));
|
||||
|
||||
hip_check_error(hipDeviceSynchronize());
|
||||
hip_check_error(hipEventRecord(start, stream_config.stream_id_));
|
||||
// calculate preprocess time
|
||||
if constexpr(TimePreprocess)
|
||||
{
|
||||
preprocess();
|
||||
}
|
||||
// run real kernel
|
||||
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args);
|
||||
hip_check_error(hipGetLastError());
|
||||
// end real kernel
|
||||
|
||||
hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
|
||||
hip_check_error(hipEventSynchronize(stop));
|
||||
float cur_time = 0;
|
||||
hip_check_error(hipEventElapsedTime(&cur_time, start, stop));
|
||||
#if MEDIAN
|
||||
times.insert(cur_time);
|
||||
#else
|
||||
total_time += cur_time;
|
||||
#endif
|
||||
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "i: " << i << " cur_time: " << cur_time << std::endl;
|
||||
|
||||
printf("args.p_a_grid: %p, args.p_b_grid:%p\n",
|
||||
static_cast<const void*>(args.p_a_grid),
|
||||
static_cast<const void*>(args.p_b_grid));
|
||||
}
|
||||
}
|
||||
|
||||
#if MEDIAN
|
||||
auto mid = times.begin();
|
||||
std::advance(mid, (nrepeat - 1) / 2);
|
||||
if(nrepeat % 2 == 1)
|
||||
{
|
||||
return *mid;
|
||||
}
|
||||
else
|
||||
{
|
||||
auto mid_next = mid;
|
||||
std::advance(mid_next, 1);
|
||||
return (*mid + *mid_next) / 2;
|
||||
}
|
||||
#else
|
||||
return total_time / nrepeat;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
preprocess();
|
||||
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args);
|
||||
hip_check_error(hipGetLastError());
|
||||
|
||||
return 0;
|
||||
}
|
||||
#else
|
||||
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args);
|
||||
hip_check_error(hipGetLastError());
|
||||
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace utility
|
||||
} // namespace ck
|
||||
@@ -20,18 +20,19 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
|
||||
#if CK_TIME_KERNEL
|
||||
if(stream_config.time_kernel_)
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
|
||||
__func__,
|
||||
grid_dim.x,
|
||||
grid_dim.y,
|
||||
grid_dim.z,
|
||||
block_dim.x,
|
||||
block_dim.y,
|
||||
block_dim.z);
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n",
|
||||
__func__,
|
||||
grid_dim.x,
|
||||
grid_dim.y,
|
||||
grid_dim.z,
|
||||
block_dim.x,
|
||||
block_dim.y,
|
||||
block_dim.z);
|
||||
|
||||
printf("Warm up %d times\n", stream_config.cold_niters_);
|
||||
#endif
|
||||
printf("Warm up %d times\n", stream_config.cold_niters_);
|
||||
}
|
||||
// warm up
|
||||
for(int i = 0; i < stream_config.cold_niters_; ++i)
|
||||
{
|
||||
@@ -40,9 +41,10 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
|
||||
}
|
||||
|
||||
const int nrepeat = stream_config.nrepeat_;
|
||||
#if DEBUG_LOG
|
||||
printf("Start running %d times...\n", nrepeat);
|
||||
#endif
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
printf("Start running %d times...\n", nrepeat);
|
||||
}
|
||||
hipEvent_t start, stop;
|
||||
|
||||
hip_check_error(hipEventCreate(&start));
|
||||
@@ -93,18 +95,19 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
|
||||
#if CK_TIME_KERNEL
|
||||
if(stream_config.time_kernel_)
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
|
||||
__func__,
|
||||
grid_dim.x,
|
||||
grid_dim.y,
|
||||
grid_dim.z,
|
||||
block_dim.x,
|
||||
block_dim.y,
|
||||
block_dim.z);
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n",
|
||||
__func__,
|
||||
grid_dim.x,
|
||||
grid_dim.y,
|
||||
grid_dim.z,
|
||||
block_dim.x,
|
||||
block_dim.y,
|
||||
block_dim.z);
|
||||
|
||||
printf("Warm up %d times\n", stream_config.cold_niters_);
|
||||
#endif
|
||||
printf("Warm up %d times\n", stream_config.cold_niters_);
|
||||
}
|
||||
// warm up
|
||||
preprocess();
|
||||
for(int i = 0; i < stream_config.cold_niters_; ++i)
|
||||
@@ -114,9 +117,10 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
|
||||
}
|
||||
|
||||
const int nrepeat = stream_config.nrepeat_;
|
||||
#if DEBUG_LOG
|
||||
printf("Start running %d times...\n", nrepeat);
|
||||
#endif
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
printf("Start running %d times...\n", nrepeat);
|
||||
}
|
||||
hipEvent_t start, stop;
|
||||
|
||||
hip_check_error(hipEventCreate(&start));
|
||||
|
||||
@@ -13,4 +13,7 @@ struct StreamConfig
|
||||
int log_level_ = 0;
|
||||
int cold_niters_ = 5;
|
||||
int nrepeat_ = 50;
|
||||
|
||||
bool flush_cache = false;
|
||||
int rotating_count = 1;
|
||||
};
|
||||
|
||||
@@ -140,8 +140,10 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
|
||||
static constexpr index_t WgpPerCU =
|
||||
(4 * warpSize / BlockSize) >= 1 ? 4 * warpSize / BlockSize : 1;
|
||||
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
|
||||
32768 / (4 * warpSize / BlockSize),
|
||||
32768 / WgpPerCU,
|
||||
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
|
||||
static constexpr index_t PrefetchStages =
|
||||
FullMemBandPrefetchStages >= 2
|
||||
@@ -631,8 +633,10 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
|
||||
static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack);
|
||||
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
|
||||
|
||||
static constexpr index_t WgpPerCU =
|
||||
(4 * warpSize / BlockSize) >= 1 ? 4 * warpSize / BlockSize : 1;
|
||||
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
|
||||
32768 / (4 * warpSize / BlockSize),
|
||||
32768 / WgpPerCU,
|
||||
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
|
||||
static constexpr index_t PrefetchStages =
|
||||
FullMemBandPrefetchStages >= 2
|
||||
|
||||
@@ -184,19 +184,22 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
constexpr auto ds_read_b_issue_cycle =
|
||||
HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_a_mfma_rate =
|
||||
(mfma_cycle - 8 + ds_read_a_issue_cycle - 1) / ds_read_a_issue_cycle;
|
||||
(mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
|
||||
constexpr auto ds_read_b_mfma_rate =
|
||||
(mfma_cycle - 8 + ds_read_b_issue_cycle - 1) / ds_read_b_issue_cycle;
|
||||
(mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
|
||||
|
||||
constexpr auto num_dsread_a_mfma =
|
||||
(num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
|
||||
constexpr auto num_dsread_b_mfma =
|
||||
(num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
|
||||
|
||||
// stage 1
|
||||
// Separate this part?
|
||||
constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
|
||||
sizeof(ComputeDataType) / sizeof(BDataType)
|
||||
? sizeof(ComputeDataType) / sizeof(ADataType)
|
||||
: sizeof(ComputeDataType) / sizeof(BDataType);
|
||||
constexpr auto num_mfma_stage1 =
|
||||
num_mfma_inst - num_mfma_per_ds_read * (num_ds_read_inst_a / ds_read_a_mfma_rate +
|
||||
num_ds_read_inst_b / ds_read_b_mfma_rate);
|
||||
// constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
|
||||
// sizeof(ComputeDataType) / sizeof(BDataType)
|
||||
// ? sizeof(ComputeDataType) / sizeof(ADataType)
|
||||
// : sizeof(ComputeDataType) / sizeof(BDataType);
|
||||
constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
|
||||
constexpr auto num_mfma_per_issue =
|
||||
num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
|
||||
constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
|
||||
@@ -226,16 +229,36 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
});
|
||||
|
||||
// stage 2
|
||||
static_for<0, num_ds_read_inst_a / ds_read_a_mfma_rate, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_ds_read, 0); // MFMA
|
||||
static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) {
|
||||
if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
|
||||
ds_read_a_mfma_rate)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
|
||||
}
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100,
|
||||
num_ds_read_inst_a - (num_dsread_a_mfma - 1) *
|
||||
ds_read_a_mfma_rate,
|
||||
0); // DS read
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
|
||||
static_for<0, num_ds_read_inst_b / ds_read_b_mfma_rate, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_ds_read, 0); // MFMA
|
||||
static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) {
|
||||
if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
|
||||
ds_read_b_mfma_rate)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
|
||||
}
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100,
|
||||
num_ds_read_inst_b - (num_dsread_b_mfma - 1) *
|
||||
ds_read_b_mfma_rate,
|
||||
0); // DS read
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -194,9 +194,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
|
||||
constexpr auto ds_read_b_issue_cycle =
|
||||
HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_a_mfma_rate =
|
||||
(mfma_cycle - 8 + ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
|
||||
(mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
|
||||
constexpr auto ds_read_b_mfma_rate =
|
||||
(mfma_cycle - 8 + ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
|
||||
(mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
|
||||
|
||||
constexpr auto num_dsread_stage1_a = num_ds_read_inst_a / KRepeat * (KRepeat - 1);
|
||||
constexpr auto num_dsread_stage1_b = num_ds_read_inst_b / KRepeat * (KRepeat - 1);
|
||||
|
||||
@@ -41,7 +41,8 @@ template <typename ThreadGroup,
|
||||
index_t SrcScalarPerVector,
|
||||
index_t DstScalarPerVector,
|
||||
typename ThreadTransferSrcResetCoordinateAfterRunFlags,
|
||||
typename ThreadTransferDstResetCoordinateAfterRunFlags>
|
||||
typename ThreadTransferDstResetCoordinateAfterRunFlags,
|
||||
index_t NumThreadScratch = 1>
|
||||
struct ThreadGroupTensorSliceTransfer_v7r2
|
||||
{
|
||||
static constexpr index_t nDim =
|
||||
@@ -100,7 +101,7 @@ struct ThreadGroupTensorSliceTransfer_v7r2
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(get_thread_local_1d_id()));
|
||||
make_multi_index(ThreadGroup::GetThreadId()));
|
||||
|
||||
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
|
||||
|
||||
@@ -117,29 +118,33 @@ struct ThreadGroupTensorSliceTransfer_v7r2
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffers>
|
||||
__device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs)
|
||||
template <typename SrcBuffers, index_t ThreadScratchId = 0>
|
||||
__device__ void RunRead(const SrcDescs& src_descs,
|
||||
const SrcBuffers& src_bufs,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunRead(src_descs, src_bufs);
|
||||
threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
using is_tuple = decltype(std::declval<T&>().IsTuple());
|
||||
|
||||
template <typename DstBuffers>
|
||||
__device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs)
|
||||
template <typename DstBuffers, index_t ThreadScratchId = 0>
|
||||
__device__ void RunWrite(const DstDescs& dst_descs,
|
||||
DstBuffers dst_bufs,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
if constexpr(is_detected<is_tuple, decltype(dst_bufs)>::value)
|
||||
threadwise_transfer_.RunWrite(dst_descs, dst_bufs);
|
||||
threadwise_transfer_.RunWrite(dst_descs, dst_bufs, thread_scratch_id);
|
||||
else
|
||||
threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs));
|
||||
threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs), thread_scratch_id);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -206,7 +211,8 @@ struct ThreadGroupTensorSliceTransfer_v7r2
|
||||
SrcScalarPerVector,
|
||||
DstScalarPerVector,
|
||||
ThreadTransferSrcResetCoordinateAfterRunFlags,
|
||||
ThreadTransferDstResetCoordinateAfterRunFlags>;
|
||||
ThreadTransferDstResetCoordinateAfterRunFlags,
|
||||
NumThreadScratch>;
|
||||
|
||||
ThreadwiseTransfer threadwise_transfer_;
|
||||
};
|
||||
|
||||
@@ -0,0 +1,128 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
|
||||
#include "device_grouped_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
///
|
||||
/// @brief Structure representing single GEMM problem arguments.
|
||||
///
|
||||
/// The pointer to the vector of those structures is passed to the GroupedGEMM entry
|
||||
/// point kernel.
|
||||
///
|
||||
/// @tparam NumDTensor The number of D input tensors.
|
||||
///
|
||||
template <index_t NumDTensor = 0>
|
||||
struct GroupedGemmTileLoopKernelArguments
|
||||
{
|
||||
__host__ __device__
|
||||
GroupedGemmTileLoopKernelArguments(const void* p_a_grid_,
|
||||
const void* p_b_grid_,
|
||||
std::array<const void*, NumDTensor> p_ds_grid_,
|
||||
void* p_e_grid_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t StrideA_,
|
||||
index_t StrideB_,
|
||||
std::array<index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideE_)
|
||||
: p_a_grid{p_a_grid_},
|
||||
p_b_grid{p_b_grid_},
|
||||
p_ds_grid{p_ds_grid_},
|
||||
p_e_grid{p_e_grid_},
|
||||
M{M_},
|
||||
N{N_},
|
||||
K{K_},
|
||||
StrideA{StrideA_},
|
||||
StrideB{StrideB_},
|
||||
StrideDs{StrideDs_},
|
||||
StrideE{StrideE_}
|
||||
{
|
||||
}
|
||||
|
||||
const void* p_a_grid;
|
||||
const void* p_b_grid;
|
||||
std::array<const void*, NumDTensor> p_ds_grid;
|
||||
void* p_e_grid;
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t StrideA;
|
||||
index_t StrideB;
|
||||
std::array<index_t, NumDTensor> StrideDs;
|
||||
index_t StrideE;
|
||||
|
||||
void Print() const
|
||||
{
|
||||
std::stringstream str;
|
||||
for(auto sd : StrideDs)
|
||||
str << sd << ",";
|
||||
|
||||
std::cout << "arg {"
|
||||
<< "M:" << M << ", "
|
||||
<< "N:" << N << ", "
|
||||
<< "K:" << K << ", "
|
||||
<< "SA:" << StrideA << ", "
|
||||
<< "SB:" << StrideB << ", "
|
||||
<< "SE:" << StrideE << ", "
|
||||
<< "SDs: {" << str.str() << "}"
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
struct DeviceGroupedGemmTileLoop : public DeviceGroupedGemm<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
{
|
||||
//----------------------------------------------------------------------------------------------
|
||||
/// @brief Sets the device kernel arguments pointer.
|
||||
///
|
||||
/// @param p_arg The pointer to the Argument we're going to update.
|
||||
/// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel
|
||||
/// arguments.
|
||||
///
|
||||
virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const = 0;
|
||||
|
||||
//----------------------------------------------------------------------------------------------
|
||||
/// @brief Gets the device kernel argument size.
|
||||
///
|
||||
/// @param[in] p_arg The pointer to the Device op Argument.
|
||||
///
|
||||
/// @return The device kernel argument size.
|
||||
///
|
||||
virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -829,7 +829,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::is_navi3_supported())
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
|
||||
@@ -648,7 +648,7 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
|
||||
ck::is_navi2_supported() || ck::is_navi3_supported())
|
||||
ck::is_gfx103_supported() || ck::is_gfx11_supported())
|
||||
{
|
||||
bool pass = true;
|
||||
pass = pass && arg.K_ % K1 == 0;
|
||||
|
||||
@@ -587,30 +587,31 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
|
||||
BatchStrideD1s,
|
||||
BatchStrideE1}
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
std::cout << "a0_grid_desc_m_k_{" << a0_grid_desc_m_k_.GetLength(I0) << ", "
|
||||
<< a0_grid_desc_m_k_.GetLength(I1) << "}" << std::endl;
|
||||
std::cout << "b0_grid_desc_n_k_{" << b0_grid_desc_n_k_.GetLength(I0) << ", "
|
||||
<< b0_grid_desc_n_k_.GetLength(I1) << "}" << std::endl;
|
||||
std::cout << "d0s_grid_desc_m_n_[I0]{" << d0s_grid_desc_m_n_[I0].GetLength(I0) << ", "
|
||||
<< d0s_grid_desc_m_n_[I0].GetLength(I1) << "}" << std::endl;
|
||||
std::cout << "b1_grid_desc_n_k_{" << b1_grid_desc_n_k_.GetLength(I0) << ", "
|
||||
<< b1_grid_desc_n_k_.GetLength(I1) << "}" << std::endl;
|
||||
std::cout << "d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_{"
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I0) << ", "
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I1) << ", "
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I2) << ", "
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I3) << ", "
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I4) << ", "
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I5) << ", "
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I6) << ", "
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I7) << ", "
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I8) << ", "
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I9) << "}"
|
||||
<< std::endl;
|
||||
std::cout << "e1_grid_desc_m_n_{" << e1_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< e1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
#endif
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "a0_grid_desc_m_k_{" << a0_grid_desc_m_k_.GetLength(I0) << ", "
|
||||
<< a0_grid_desc_m_k_.GetLength(I1) << "}" << std::endl;
|
||||
std::cout << "b0_grid_desc_n_k_{" << b0_grid_desc_n_k_.GetLength(I0) << ", "
|
||||
<< b0_grid_desc_n_k_.GetLength(I1) << "}" << std::endl;
|
||||
std::cout << "d0s_grid_desc_m_n_[I0]{" << d0s_grid_desc_m_n_[I0].GetLength(I0)
|
||||
<< ", " << d0s_grid_desc_m_n_[I0].GetLength(I1) << "}" << std::endl;
|
||||
std::cout << "b1_grid_desc_n_k_{" << b1_grid_desc_n_k_.GetLength(I0) << ", "
|
||||
<< b1_grid_desc_n_k_.GetLength(I1) << "}" << std::endl;
|
||||
std::cout << "d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_{"
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I0) << ", "
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I1) << ", "
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I2) << ", "
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I3) << ", "
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I4) << ", "
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I5) << ", "
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I6) << ", "
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I7) << ", "
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I8) << ", "
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I9) << "}"
|
||||
<< std::endl;
|
||||
std::cout << "e1_grid_desc_m_n_{" << e1_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< e1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
|
||||
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
|
||||
using D0Layout = remove_cvref_t<tuple_element_t<i.value, D0sLayout>>;
|
||||
|
||||
@@ -658,27 +658,28 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "arg.Batch_ = " << arg.Batch_ << std::endl;
|
||||
{
|
||||
std::cout << "arg.Batch_ = " << arg.Batch_ << std::endl;
|
||||
|
||||
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
|
||||
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
|
||||
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
|
||||
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
|
||||
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
|
||||
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
|
||||
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
|
||||
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
|
||||
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
|
||||
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
|
||||
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
|
||||
std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
|
||||
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
|
||||
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
|
||||
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0)
|
||||
<< ", " << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.reduce_grid_desc_m_{ " << arg.reduce_grid_desc_m_.GetLength(I0)
|
||||
<< "}" << std::endl;
|
||||
std::cout << "arg.reduce_grid_desc_m_{ "
|
||||
<< arg.reduce_grid_desc_m_.GetLength(I0) << "}" << std::endl;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
|
||||
@@ -858,7 +858,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
|
||||
|
||||
static bool IsSupportedArgument(const RawArg& arg)
|
||||
{
|
||||
if(ck::is_navi3_supported())
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
|
||||
{
|
||||
@@ -1435,7 +1435,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
|
||||
#if 0
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::is_navi3_supported())
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
|
||||
{
|
||||
|
||||
@@ -719,9 +719,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
arg.Print();
|
||||
#endif
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
arg.Print();
|
||||
}
|
||||
|
||||
if(!ck::is_xdl_supported())
|
||||
{
|
||||
|
||||
@@ -516,26 +516,27 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
float ave_time = 0;
|
||||
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_container_{"
|
||||
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", "
|
||||
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I1) << ", "
|
||||
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2) << "}"
|
||||
<< std::endl;
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_container_{"
|
||||
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", "
|
||||
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I1) << ", "
|
||||
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2) << "}"
|
||||
<< std::endl;
|
||||
|
||||
std::cout << "arg.b_grid_desc_k0_n_k1_container_{"
|
||||
<< arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I0) << ", "
|
||||
<< arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I1) << ", "
|
||||
<< arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I2) << "}"
|
||||
<< std::endl;
|
||||
std::cout << "arg.b_grid_desc_k0_n_k1_container_{"
|
||||
<< arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I0) << ", "
|
||||
<< arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I1) << ", "
|
||||
<< arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I2) << "}"
|
||||
<< std::endl;
|
||||
|
||||
std::cout << "arg.c_grid_desc_m_n_container_{ "
|
||||
<< arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", "
|
||||
<< arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}"
|
||||
<< std::endl;
|
||||
std::cout << "arg.c_grid_desc_m_n_container_{ "
|
||||
<< arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", "
|
||||
<< arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
|
||||
arg.b_grid_desc_k0_n_k1_container_[i],
|
||||
|
||||
@@ -644,7 +644,7 @@ struct
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << DeviceOp{}.GetTypeString() << std::endl;
|
||||
std::cout << "N " << arg.Conv_N_ << ", "
|
||||
@@ -664,9 +664,7 @@ struct
|
||||
<< arg.input_left_pads_[1] << ", " << std::endl;
|
||||
std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", "
|
||||
<< arg.input_right_pads_[1] << ", " << std::endl;
|
||||
}
|
||||
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
|
||||
@@ -684,7 +682,6 @@ struct
|
||||
std::cout << "arg.c1_grid_desc_m_n_{ " << arg.c1_grid_desc_m_n_.GetLength(I0)
|
||||
<< ", " << arg.c1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
|
||||
@@ -614,7 +614,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << DeviceOp{}.GetTypeString() << std::endl;
|
||||
std::cout << "N " << arg.Conv_N_ << ", "
|
||||
@@ -634,9 +634,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
|
||||
<< arg.input_left_pads_[1] << ", " << std::endl;
|
||||
std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", "
|
||||
<< arg.input_right_pads_[1] << ", " << std::endl;
|
||||
}
|
||||
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
|
||||
@@ -651,7 +649,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
|
||||
std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0)
|
||||
<< ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
|
||||
@@ -579,7 +579,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << DeviceOp{}.GetTypeString() << std::endl;
|
||||
std::cout << "N " << arg.Conv_N_ << ", "
|
||||
@@ -599,9 +599,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
|
||||
<< arg.input_left_pads_[1] << ", " << std::endl;
|
||||
std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", "
|
||||
<< arg.input_right_pads_[1] << ", " << std::endl;
|
||||
}
|
||||
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
|
||||
@@ -635,7 +633,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
|
||||
.GetLength(I5)
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
|
||||
@@ -431,7 +431,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
@@ -444,7 +444,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(
|
||||
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_))
|
||||
{
|
||||
|
||||
@@ -401,7 +401,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "num_batches_of_GEMM = " << arg.num_subbatches_ << std::endl;
|
||||
std::cout << "a_grid_desc_k0_m_k1{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
@@ -415,7 +415,6 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
|
||||
std::cout << "c_grid_desc_m_n{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
|
||||
@@ -1272,7 +1272,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
|
||||
float ave_time = 0;
|
||||
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_container_{"
|
||||
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", "
|
||||
@@ -1305,7 +1305,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
|
||||
<< arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i].GetLength(I5)
|
||||
<< " ) " << std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
|
||||
arg.b_grid_desc_k0_n_k1_container_[i],
|
||||
@@ -1393,8 +1392,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// check device
|
||||
if(!(ck::get_device_name() == "gfx906" || ck::is_navi2_supported() ||
|
||||
ck::is_navi3_supported()))
|
||||
if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
|
||||
ck::is_gfx11_supported()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -1220,7 +1220,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
|
||||
float ave_time = 0;
|
||||
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1{"
|
||||
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", "
|
||||
@@ -1239,7 +1239,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
|
||||
<< arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
|
||||
arg.b_grid_desc_k0_n_k1_container_[i],
|
||||
|
||||
@@ -509,7 +509,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::is_navi3_supported())
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> ||
|
||||
is_same_v<AccDataType, int32_t>))
|
||||
|
||||
@@ -334,7 +334,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m0_m1_k1_{"
|
||||
<< arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
|
||||
@@ -349,7 +349,6 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
|
||||
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(
|
||||
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_))
|
||||
@@ -536,8 +535,8 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
|
||||
}
|
||||
}
|
||||
|
||||
if(ck::get_device_name() == "gfx906" || ck::is_navi2_supported() ||
|
||||
ck::is_navi3_supported())
|
||||
if(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
|
||||
ck::is_gfx11_supported())
|
||||
{
|
||||
return GridwiseGemm::CheckValidity(
|
||||
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);
|
||||
|
||||
@@ -168,7 +168,7 @@ struct DeviceGemmDpp : public DeviceGemm<ALayout,
|
||||
|
||||
static bool IsSupportedArgument(const Argument& karg)
|
||||
{
|
||||
if(ck::is_navi2_supported() || ck::is_navi3_supported())
|
||||
if(ck::is_gfx103_supported() || ck::is_gfx11_supported())
|
||||
{
|
||||
return GridwiseGemm::CheckValidity(karg);
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -552,7 +552,7 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout,
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
|
||||
ck::is_navi2_supported() || ck::is_navi3_supported())
|
||||
ck::is_gfx103_supported() || ck::is_gfx11_supported())
|
||||
{
|
||||
return GridwiseGemm::CheckValidity(
|
||||
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_);
|
||||
|
||||
@@ -515,7 +515,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::is_navi3_supported())
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
|
||||
@@ -510,7 +510,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
|
||||
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
|
||||
@@ -528,7 +528,6 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
|
||||
std::cout << "arg.reduce_grid_desc_m_{ " << arg.reduce_grid_desc_m_.GetLength(I0)
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
|
||||
@@ -443,7 +443,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::is_navi3_supported())
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> ||
|
||||
is_same_v<AccDataType, int32_t>))
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -151,14 +152,56 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
if(arg.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
|
||||
0,
|
||||
arg.M * arg.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
Argument arg_ = arg;
|
||||
ck::utility::RotatingMemWrapper<Argument> rotating_mem(
|
||||
arg_,
|
||||
stream_config.rotating_count,
|
||||
arg_.M * arg_.K * sizeof(ADataType),
|
||||
arg_.K * arg_.N * sizeof(BDataType));
|
||||
rotating_mem.Print();
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck::utility::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
|
||||
{
|
||||
if(arg_.KBatch > 1)
|
||||
hipGetErrorString(
|
||||
hipMemsetAsync(arg_.p_c_grid,
|
||||
0,
|
||||
arg_.M * arg_.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
}
|
||||
};
|
||||
|
||||
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
|
||||
stream_config,
|
||||
run_flush_cache,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg_);
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
|
||||
0,
|
||||
arg.M * arg.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
}
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
|
||||
}
|
||||
};
|
||||
|
||||
constexpr index_t minimum_occupancy =
|
||||
@@ -172,12 +215,15 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -194,113 +240,118 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
|
||||
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::One>;
|
||||
Run(kernel);
|
||||
}
|
||||
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Full)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Full>;
|
||||
Run(kernel);
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Two>;
|
||||
TailNumber::One>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Three)
|
||||
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Full)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Three>;
|
||||
TailNumber::Full>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Four)
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Four>;
|
||||
Run(kernel);
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Two)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Two>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Five)
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Five>;
|
||||
Run(kernel);
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Three)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Three>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Six>;
|
||||
Run(kernel);
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Four)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Four>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Seven)
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Seven>;
|
||||
Run(kernel);
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Five)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Five>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Six)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Six>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Seven)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Seven>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -422,25 +473,28 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -471,25 +525,28 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -522,14 +579,18 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
// Tail number always 1
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -514,7 +514,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
|
||||
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
|
||||
@@ -529,7 +529,6 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
|
||||
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
|
||||
@@ -299,7 +299,7 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
@@ -312,7 +312,6 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
|
||||
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
|
||||
@@ -629,7 +629,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// check device
|
||||
if(ck::is_navi3_supported())
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
|
||||
@@ -197,6 +197,12 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
K0PerBlock,
|
||||
ConvBackwardWeightSpecialization>{};
|
||||
|
||||
static constexpr index_t MaxScalarPerVectorFP32 = 4;
|
||||
static constexpr index_t WorkspaceInOutScalarPerVector =
|
||||
is_same_v<AccDataType, float>
|
||||
? math::min(CBlockTransferScalarPerVector_NWaveNPerXdl, MaxScalarPerVectorFP32)
|
||||
: CBlockTransferScalarPerVector_NWaveNPerXdl;
|
||||
|
||||
// Bytes per 32 lds bank: 32 * 4 bytes
|
||||
static constexpr auto BankLength = 128;
|
||||
static constexpr auto ElePerBank = BankLength / sizeof(ADataType);
|
||||
@@ -297,7 +303,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
@@ -337,7 +343,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
BBlockLdsN1Padding,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CBlockTransferScalarPerVector_NWaveNPerXdl,
|
||||
WorkspaceInOutScalarPerVector,
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
true,
|
||||
true,
|
||||
@@ -349,7 +355,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
static constexpr auto MakeElementwiseInputSequence()
|
||||
{
|
||||
return generate_sequence_v2(
|
||||
[&](auto) constexpr { return Number<CBlockTransferScalarPerVector_NWaveNPerXdl>{}; },
|
||||
[&](auto) constexpr { return Number<WorkspaceInOutScalarPerVector>{}; },
|
||||
Number<NumDTensor + 1>{});
|
||||
}
|
||||
|
||||
@@ -499,7 +505,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N<NDimSpatial>({}, {}));
|
||||
using CDGridDesc_M_N = decltype(concat_tuple(Tuple<CGridDesc_M_N>{}, DsGridDesc_M_N{}));
|
||||
using DsGridPointerTuple = decltype(GetDsGridPointerTuple());
|
||||
using CDDataTypes = decltype(concat_tuple(Tuple<const EDataType*>{}, DsGridPointerTuple{}));
|
||||
using CDDataTypes = decltype(concat_tuple(Tuple<const AccDataType*>{}, DsGridPointerTuple{}));
|
||||
using EGridDesc_M_N = CGridDesc_M_N;
|
||||
static constexpr index_t ClusterLengthMPerBlock =
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
|
||||
@@ -659,7 +665,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
|
||||
std::size_t GetWorkspaceSizeBytes() const
|
||||
{
|
||||
return sizeof(EDataType) * ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_;
|
||||
return sizeof(AccDataType) * ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_;
|
||||
}
|
||||
|
||||
const ADataType* p_a_grid_;
|
||||
@@ -738,7 +744,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
|
||||
auto launch_gemm_kernel = [&](auto has_main_k_block_loop) {
|
||||
EDataType* p_c_grid = type_convert<EDataType*>(arg.p_workspace_);
|
||||
AccDataType* p_c_grid = type_convert<AccDataType*>(arg.p_workspace_);
|
||||
const index_t grid_size =
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.ce_grid_desc_m_n_) * arg.Conv_G_;
|
||||
|
||||
@@ -753,7 +759,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
GridwiseGemm,
|
||||
ADataType,
|
||||
BDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
OutElementwiseOperation,
|
||||
InElementwiseOperation,
|
||||
element_wise::PassThrough,
|
||||
@@ -786,7 +792,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
};
|
||||
|
||||
auto launch_elementwise_kernel = [&]() {
|
||||
const EDataType* p_c_grid = type_convert<const EDataType*>(arg.p_workspace_);
|
||||
const AccDataType* p_c_grid = type_convert<const AccDataType*>(arg.p_workspace_);
|
||||
const index_t grid_size =
|
||||
arg.elementwise_block_2_ctile_map_.CalculateGridSize(arg.ce_grid_desc_m_n_) *
|
||||
arg.Conv_G_;
|
||||
@@ -907,7 +913,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
}
|
||||
|
||||
// vector store C matrix into global memory
|
||||
if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0))
|
||||
if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0 &&
|
||||
arg.Conv_C_ % WorkspaceInOutScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,898 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
|
||||
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename AGridDesc_B_K0_M_K1,
|
||||
typename BGridDesc_B_K0_N_K1,
|
||||
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename Block2CTileMap,
|
||||
typename ComputePtrOffsetOfBatch,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_batched_gemm_xdlops_bwd_weight(
|
||||
const FloatA* __restrict__ p_a_grid,
|
||||
const FloatB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const index_t batch_count,
|
||||
const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc,
|
||||
const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx94__))
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
|
||||
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
|
||||
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
|
||||
|
||||
__shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
|
||||
p_b_grid + b_batch_offset,
|
||||
p_c_grid + c_batch_offset,
|
||||
p_shared,
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
block_2_ctile_map);
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
ignore = p_c_grid;
|
||||
ignore = a_b_k0_m_k1_grid_desc;
|
||||
ignore = b_b_k0_n_k1_grid_desc;
|
||||
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = c_element_op;
|
||||
ignore = batch_count;
|
||||
ignore = block_2_ctile_map;
|
||||
ignore = compute_ptr_offset_of_batch;
|
||||
|
||||
compute_ptr_offset_of_batch.GetAPtrOffset(0);
|
||||
compute_ptr_offset_of_batch.GetBPtrOffset(0);
|
||||
compute_ptr_offset_of_batch.GetCPtrOffset(0);
|
||||
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
|
||||
}
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
ConvolutionBackwardWeightSpecialization ConvBackwardWeightSpecialization,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerXdl,
|
||||
ck::index_t NPerXdl,
|
||||
ck::index_t MXdlPerWave,
|
||||
ck::index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool ABlockLdsAddExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BBlockLdsAddExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
|
||||
typename ComputeTypeA = InDataType,
|
||||
typename ComputeTypeB = ComputeTypeA>
|
||||
struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
: public DeviceGroupedConvBwdWeight<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle;
|
||||
|
||||
using ADataType = OutDataType;
|
||||
using BDataType = InDataType;
|
||||
using EDataType = WeiDataType;
|
||||
|
||||
using AElementwiseOperation = OutElementwiseOperation;
|
||||
using BElementwiseOperation = InElementwiseOperation;
|
||||
using CDEElementwiseOperation = WeiElementwiseOperation;
|
||||
|
||||
// TODO make A/B datatype different
|
||||
using ABDataType = InDataType;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
|
||||
static constexpr auto conv_to_gemm_transformer =
|
||||
TransformConvBwdWeightToGemm<NDimSpatial,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K1Number,
|
||||
K0PerBlock,
|
||||
ConvBackwardWeightSpecialization>{};
|
||||
|
||||
// Bytes per 32 lds bank: 32 * 4 bytes
|
||||
static constexpr auto BankLength = 128;
|
||||
static constexpr auto ElePerBank = BankLength / sizeof(ADataType);
|
||||
|
||||
// M1 & M0
|
||||
static constexpr auto ABlockLdsM1PerBlock = ElePerBank / K1;
|
||||
static constexpr auto ABlockLdsM0PerBlock = MPerBlock / ABlockLdsM1PerBlock;
|
||||
static constexpr auto ABlockLdsM1Padding = 4;
|
||||
|
||||
// N1 & N0
|
||||
static constexpr auto BBlockLdsN1PerBlock = ElePerBank / K1;
|
||||
static constexpr auto BBlockLdsN0PerBlock = NPerBlock / BBlockLdsN1PerBlock;
|
||||
static constexpr auto BBlockLdsN1Padding = 4;
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
|
||||
static auto GetABCGridDesc()
|
||||
{
|
||||
const ck::index_t dim = 1;
|
||||
const ck::index_t batch = 1;
|
||||
const std::array<ck::index_t, NDimSpatial> lengths{1};
|
||||
const std::array<ck::index_t, NDimSpatial + 3> strides{1, 1, 1, 1};
|
||||
const std::array<ck::index_t, NDimSpatial> params{1};
|
||||
return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>(
|
||||
dim,
|
||||
dim,
|
||||
dim,
|
||||
lengths,
|
||||
lengths,
|
||||
lengths,
|
||||
strides,
|
||||
strides,
|
||||
strides,
|
||||
params,
|
||||
params,
|
||||
params,
|
||||
params,
|
||||
batch);
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
|
||||
static auto GetABCGridDesc()
|
||||
{
|
||||
const ck::index_t dim = 1;
|
||||
const ck::index_t batch = 1;
|
||||
const std::array<ck::index_t, NDimSpatial> lengths{1, 1};
|
||||
const std::array<ck::index_t, NDimSpatial + 3> strides{1, 1, 1, 1, 1};
|
||||
const std::array<ck::index_t, NDimSpatial> params{1, 1};
|
||||
return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(
|
||||
dim,
|
||||
dim,
|
||||
dim,
|
||||
lengths,
|
||||
lengths,
|
||||
lengths,
|
||||
strides,
|
||||
strides,
|
||||
strides,
|
||||
params,
|
||||
params,
|
||||
params,
|
||||
params,
|
||||
batch);
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
|
||||
static auto GetABCGridDesc()
|
||||
{
|
||||
const ck::index_t dim = 1;
|
||||
const ck::index_t batch = 1;
|
||||
const std::array<ck::index_t, NDimSpatial> lengths{1, 1, 1};
|
||||
const std::array<ck::index_t, NDimSpatial + 3> strides{1, 1, 1, 1, 1, 1};
|
||||
const std::array<ck::index_t, NDimSpatial> params{1, 1, 1};
|
||||
return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(
|
||||
dim,
|
||||
dim,
|
||||
dim,
|
||||
lengths,
|
||||
lengths,
|
||||
lengths,
|
||||
strides,
|
||||
strides,
|
||||
strides,
|
||||
params,
|
||||
params,
|
||||
params,
|
||||
params,
|
||||
batch);
|
||||
}
|
||||
|
||||
using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>());
|
||||
|
||||
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
|
||||
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
|
||||
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
|
||||
|
||||
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight<
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
element_wise::PassThrough,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
K1,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
ABlockLdsAddExtraM,
|
||||
ABlockLdsM1PerBlock,
|
||||
ABlockLdsM0PerBlock,
|
||||
ABlockLdsM1Padding,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockLdsAddExtraN,
|
||||
BBlockLdsN1PerBlock,
|
||||
BBlockLdsN0PerBlock,
|
||||
BBlockLdsN1Padding,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CBlockTransferScalarPerVector_NWaveNPerXdl,
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
true,
|
||||
true,
|
||||
1,
|
||||
PipelineVersion::v1,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
|
||||
static constexpr index_t ClusterLengthMPerBlock =
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
|
||||
static constexpr index_t ClusterLengthNPerBlock =
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
|
||||
using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
|
||||
|
||||
using GridwiseElementwise =
|
||||
GridwiseElementwise<Tuple<CGridDesc_M_N>,
|
||||
Tuple<CGridDesc_M_N>,
|
||||
Tuple<const AccDataType*>,
|
||||
Tuple<EDataType*>,
|
||||
Block2TileMapElementwise,
|
||||
CDEElementwiseOperation,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
MPerBlock / ClusterLengthMPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock,
|
||||
Sequence<0, 1>,
|
||||
Sequence<CBlockTransferScalarPerVector_NWaveNPerXdl>,
|
||||
Sequence<CBlockTransferScalarPerVector_NWaveNPerXdl>,
|
||||
I1,
|
||||
I1>;
|
||||
|
||||
// Argument
|
||||
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}));
|
||||
|
||||
using Block2CTileMap =
|
||||
decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1));
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const InDataType* p_in_grid,
|
||||
WeiDataType* p_wei_grid,
|
||||
const OutDataType* p_out_grid,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_lengths, // input
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths, // weight
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
const ck::index_t M01,
|
||||
const ck::index_t N01,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
ck::index_t split_k)
|
||||
: p_a_grid_{p_out_grid},
|
||||
p_b_grid_{p_in_grid},
|
||||
p_e_grid_{p_wei_grid},
|
||||
a_grid_desc_kbatch_k0_m_k1_{},
|
||||
b_grid_desc_kbatch_k0_n_k1_{},
|
||||
ce_grid_desc_m_n_{},
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
block_2_ctile_map_{},
|
||||
compute_ptr_offset_of_batch_{},
|
||||
M01_{M01},
|
||||
N01_{N01},
|
||||
a_element_op_{out_element_op},
|
||||
b_element_op_{in_element_op},
|
||||
cde_element_op_{wei_element_op},
|
||||
Conv_G_{b_g_n_c_wis_lengths[0]},
|
||||
Conv_N_{b_g_n_c_wis_lengths[1]},
|
||||
Conv_K_{e_g_k_c_xs_lengths[1]},
|
||||
Conv_C_{b_g_n_c_wis_lengths[2]},
|
||||
input_spatial_lengths_{},
|
||||
filter_spatial_lengths_{},
|
||||
output_spatial_lengths_{},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads},
|
||||
k_batch_{split_k}
|
||||
{
|
||||
constexpr index_t spatial_offset = 3;
|
||||
std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset,
|
||||
end(b_g_n_c_wis_lengths),
|
||||
begin(input_spatial_lengths_));
|
||||
std::copy(begin(e_g_k_c_xs_lengths) + spatial_offset,
|
||||
end(e_g_k_c_xs_lengths),
|
||||
begin(filter_spatial_lengths_));
|
||||
std::copy(begin(a_g_n_k_wos_lengths) + spatial_offset,
|
||||
end(a_g_n_k_wos_lengths),
|
||||
begin(output_spatial_lengths_));
|
||||
|
||||
const auto descs =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
Conv_N_,
|
||||
Conv_K_,
|
||||
Conv_C_,
|
||||
input_spatial_lengths_,
|
||||
filter_spatial_lengths_,
|
||||
output_spatial_lengths_,
|
||||
b_g_n_c_wis_strides,
|
||||
e_g_k_c_xs_strides,
|
||||
a_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
k_batch_);
|
||||
|
||||
a_grid_desc_kbatch_k0_m_k1_ = descs[I0];
|
||||
b_grid_desc_kbatch_k0_n_k1_ = descs[I1];
|
||||
ce_grid_desc_m_n_ = descs[I2];
|
||||
|
||||
block_2_ctile_map_ =
|
||||
GridwiseGemm::MakeCBlockClusterAdaptor(ce_grid_desc_m_n_, M01, N01, k_batch_);
|
||||
elementwise_block_2_ctile_map_ = Block2TileMapElementwise{
|
||||
ce_grid_desc_m_n_.GetLength(I0), ce_grid_desc_m_n_.GetLength(I1)};
|
||||
|
||||
// A/B/C Batch Stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideC_ =
|
||||
Conv_K_ * Conv_C_ *
|
||||
std::accumulate(begin(filter_spatial_lengths_),
|
||||
end(filter_spatial_lengths_),
|
||||
index_t{1},
|
||||
std::multiplies<>{});
|
||||
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_,
|
||||
b_grid_desc_kbatch_k0_n_k1_,
|
||||
ce_grid_desc_m_n_,
|
||||
block_2_ctile_map_))
|
||||
{
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ce_grid_desc_m_n_);
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceSizeBytes() const
|
||||
{
|
||||
return sizeof(AccDataType) * ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_;
|
||||
}
|
||||
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
EDataType* p_e_grid_;
|
||||
|
||||
AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_;
|
||||
BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_;
|
||||
CGridDesc_M_N ce_grid_desc_m_n_;
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
|
||||
Block2CTileMap block_2_ctile_map_;
|
||||
Block2TileMapElementwise elementwise_block_2_ctile_map_;
|
||||
|
||||
// for computing batch offset
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_batch_;
|
||||
|
||||
index_t M01_;
|
||||
index_t N01_;
|
||||
|
||||
OutElementwiseOperation a_element_op_;
|
||||
InElementwiseOperation b_element_op_;
|
||||
WeiElementwiseOperation cde_element_op_;
|
||||
|
||||
// for checking IsSupportedArgument()
|
||||
const index_t Conv_G_;
|
||||
const index_t Conv_N_;
|
||||
const index_t Conv_K_;
|
||||
const index_t Conv_C_;
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths_;
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths_;
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths_;
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
|
||||
const index_t k_batch_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
void ShowInfo(const Argument& arg)
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{"
|
||||
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", "
|
||||
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", "
|
||||
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", "
|
||||
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{"
|
||||
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", "
|
||||
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", "
|
||||
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", "
|
||||
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.ce_grid_desc_m_n_{" << arg.ce_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< arg.ce_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.ce_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting");
|
||||
}
|
||||
|
||||
const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
|
||||
auto launch_gemm_kernel = [&](auto has_main_k_block_loop) {
|
||||
AccDataType* p_c_grid = type_convert<AccDataType*>(arg.p_workspace_);
|
||||
const index_t grid_size =
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.ce_grid_desc_m_n_) * arg.Conv_G_;
|
||||
|
||||
constexpr bool has_main_loop = has_main_k_block_loop.value;
|
||||
|
||||
auto preprocess = [&]() {
|
||||
hip_check_error(hipMemsetAsync(
|
||||
p_c_grid, 0, arg.GetWorkspaceSizeBytes(), stream_config.stream_id_));
|
||||
};
|
||||
|
||||
const auto kernel = kernel_batched_gemm_xdlops_bwd_weight<
|
||||
GridwiseGemm,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
OutElementwiseOperation,
|
||||
InElementwiseOperation,
|
||||
element_wise::PassThrough,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
remove_reference_t<DeviceOp::Block2CTileMap>,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
|
||||
has_main_loop>;
|
||||
|
||||
return launch_and_time_kernel_with_preprocess(
|
||||
stream_config,
|
||||
preprocess,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
p_c_grid,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
element_wise::PassThrough{},
|
||||
arg.Conv_G_,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_,
|
||||
arg.compute_ptr_offset_of_batch_);
|
||||
};
|
||||
|
||||
auto launch_elementwise_kernel = [&]() {
|
||||
const AccDataType* p_c_grid = type_convert<const AccDataType*>(arg.p_workspace_);
|
||||
const index_t grid_size =
|
||||
arg.elementwise_block_2_ctile_map_.CalculateGridSize(arg.ce_grid_desc_m_n_) *
|
||||
arg.Conv_G_;
|
||||
|
||||
std::array<index_t, I1> in_out_batch_strides = {
|
||||
arg.compute_ptr_offset_of_batch_.BatchStrideC_};
|
||||
|
||||
const auto kernel = kernel_batched_elementwise<GridwiseElementwise,
|
||||
ck::Tuple<CGridDesc_M_N>,
|
||||
ck::Tuple<CGridDesc_M_N>,
|
||||
ck::Tuple<const AccDataType*>,
|
||||
ck::Tuple<EDataType*>,
|
||||
Block2TileMapElementwise,
|
||||
CDEElementwiseOperation,
|
||||
I1,
|
||||
I1>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
make_tuple(arg.ce_grid_desc_m_n_),
|
||||
make_tuple(arg.ce_grid_desc_m_n_),
|
||||
make_tuple(p_c_grid),
|
||||
make_tuple(arg.p_e_grid_),
|
||||
arg.elementwise_block_2_ctile_map_,
|
||||
arg.cde_element_op_,
|
||||
arg.Conv_G_,
|
||||
in_out_batch_strides,
|
||||
in_out_batch_strides);
|
||||
};
|
||||
|
||||
float avg_time = 0;
|
||||
if(has_main_k0_block_loop)
|
||||
{
|
||||
avg_time = launch_gemm_kernel(integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
avg_time = launch_gemm_kernel(integral_constant<bool, false>{});
|
||||
}
|
||||
|
||||
avg_time += launch_elementwise_kernel();
|
||||
return avg_time;
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// Check this here, it allows to use other instances from factory even
|
||||
// if workspace is not allocated
|
||||
if(!arg.p_workspace_)
|
||||
{
|
||||
std::cerr << "Warning: Workspace for "
|
||||
"DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument is not "
|
||||
"allocated, use SetWorkSpacePointer."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
if(!ck::is_xdl_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
if constexpr(!is_GNWK_GKXC_GNWC<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
if constexpr(!(is_NHWGK_GKYXC_NHWGC<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNHWK_GKYXC_GNHWC<InLayout, WeiLayout, OutLayout>()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
if constexpr(!(is_NDHWGK_GKZYXC_NDHWGC<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_GNDHWK_GKZYXC_GNDHWC<InLayout, WeiLayout, OutLayout>()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(ConvBackwardWeightSpecialization ==
|
||||
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// check if it's 1x1, stride=1 pad = 0 conv
|
||||
for(int i = 0; i < NDimSpatial; i++)
|
||||
{
|
||||
if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 &&
|
||||
arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// vector load A/B matrix from global memory
|
||||
if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 &&
|
||||
arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 &&
|
||||
arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// vector store C matrix into global memory
|
||||
if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Gridwise GEMM size
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.ce_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto
|
||||
MakeArgument(const InDataType* p_in_grid,
|
||||
WeiDataType* p_wei_grid,
|
||||
const OutDataType* p_out_grid,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_lengths, // input
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths, // weight
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
const ck::index_t split_k)
|
||||
{
|
||||
return Argument{p_in_grid,
|
||||
p_wei_grid,
|
||||
p_out_grid,
|
||||
b_g_n_c_wis_lengths, // input
|
||||
b_g_n_c_wis_strides,
|
||||
e_g_k_c_xs_lengths, // weight
|
||||
e_g_k_c_xs_strides,
|
||||
a_g_n_k_wos_lengths, // output
|
||||
a_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
1,
|
||||
1,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op,
|
||||
split_k};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_in_grid,
|
||||
void* p_wei_grid,
|
||||
const void* p_out_grid,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_lengths, // input
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths, // weight
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
const ck::index_t split_k) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
|
||||
static_cast<WeiDataType*>(p_wei_grid),
|
||||
static_cast<const OutDataType*>(p_out_grid),
|
||||
b_g_n_c_wis_lengths, // input
|
||||
b_g_n_c_wis_strides,
|
||||
e_g_k_c_xs_lengths, // weight
|
||||
e_g_k_c_xs_strides,
|
||||
a_g_n_k_wos_lengths, // output
|
||||
a_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
1,
|
||||
1,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op,
|
||||
split_k);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< K0PerBlock << ", "
|
||||
<< getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", "
|
||||
<< K1 << ", "
|
||||
<< MXdlPerWave << ", "
|
||||
<< NXdlPerWave << ", "
|
||||
<< ABlockTransferSrcScalarPerVector << ", "
|
||||
<< ABlockTransferDstScalarPerVector_K1 << ", "
|
||||
<< BBlockTransferSrcScalarPerVector << ", "
|
||||
<< BBlockTransferDstScalarPerVector_K1 << ", "
|
||||
<< CShuffleMXdlPerWavePerShuffle << ", "
|
||||
<< CShuffleNXdlPerWavePerShuffle << ", "
|
||||
<< CBlockTransferScalarPerVector_NWaveNPerXdl
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
|
||||
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
|
||||
{
|
||||
auto arg = dynamic_cast<const Argument*>(p_arg);
|
||||
if(arg)
|
||||
{
|
||||
return arg->GetWorkspaceSizeBytes();
|
||||
}
|
||||
else
|
||||
throw std::runtime_error(
|
||||
"The argument pointer is not an object of "
|
||||
"DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument structure!");
|
||||
}
|
||||
|
||||
void SetWorkSpacePointer(BaseArgument* p_arg,
|
||||
void* p_workspace,
|
||||
const StreamConfig& = StreamConfig{}) const override
|
||||
{
|
||||
auto p_arg_ = dynamic_cast<Argument*>(p_arg);
|
||||
if(p_arg_)
|
||||
{
|
||||
p_arg_->p_workspace_ = p_workspace;
|
||||
}
|
||||
else
|
||||
throw std::runtime_error(
|
||||
"The argument pointer is not an object of "
|
||||
"DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument structure!");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -692,7 +692,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// check device
|
||||
if(ck::is_navi3_supported())
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
|
||||
@@ -666,7 +666,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
|
||||
// check device
|
||||
if(!(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
|
||||
ck::is_navi2_supported() || ck::is_navi3_supported()))
|
||||
ck::is_gfx103_supported() || ck::is_gfx11_supported()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -601,8 +601,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
|
||||
// check device
|
||||
if(!(ck::get_device_name() == "gfx906" || ck::is_navi2_supported() ||
|
||||
ck::is_navi3_supported()))
|
||||
if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
|
||||
ck::is_gfx11_supported()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -581,7 +581,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
|
||||
// check device
|
||||
if(ck::is_navi3_supported())
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
|
||||
@@ -553,24 +553,29 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
|
||||
|
||||
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{"
|
||||
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
|
||||
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I2) << "}"
|
||||
<< std::endl;
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{"
|
||||
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
<< ", "
|
||||
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I1)
|
||||
<< ", "
|
||||
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I2)
|
||||
<< "}" << std::endl;
|
||||
|
||||
std::cout << ", arg.b_grid_desc_k0_n_k1_{"
|
||||
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
|
||||
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
|
||||
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I2) << "}"
|
||||
<< std::endl;
|
||||
std::cout << ", arg.b_grid_desc_k0_n_k1_{"
|
||||
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I0)
|
||||
<< ", "
|
||||
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I1)
|
||||
<< ", "
|
||||
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I2)
|
||||
<< "}" << std::endl;
|
||||
|
||||
std::cout << ", arg.e_grid_desc_m_n_{ "
|
||||
<< arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I1) << "}"
|
||||
<< std::endl;
|
||||
#endif
|
||||
std::cout << ", arg.e_grid_desc_m_n_{ "
|
||||
<< arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I1) << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_,
|
||||
arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_,
|
||||
@@ -668,7 +673,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
|
||||
}
|
||||
|
||||
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
|
||||
ck::is_navi2_supported() || ck::is_navi3_supported())
|
||||
ck::is_gfx103_supported() || ck::is_gfx11_supported())
|
||||
{
|
||||
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
|
||||
{
|
||||
|
||||
@@ -467,18 +467,19 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
gemm_kernel_args_[i].block_start_ = block_start;
|
||||
gemm_kernel_args_[i].block_end_ = block_end;
|
||||
|
||||
#if DEBUG_LOG
|
||||
index_t tiles = (block_end - block_start) / K_BATCH;
|
||||
std::cout << "block_start: " << block_start << "\n"
|
||||
<< "block_end: " << block_end << "\n"
|
||||
<< "tiles: " << tiles << std::endl
|
||||
<< std::endl;
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
index_t tiles = (block_end - block_start) / K_BATCH;
|
||||
std::cout << "block_start: " << block_start << "\n"
|
||||
<< "block_end: " << block_end << "\n"
|
||||
<< "tiles: " << tiles << std::endl
|
||||
<< std::endl;
|
||||
|
||||
std::cout << "KPadded: " << karg.KPadded << std::endl
|
||||
<< "K0Padded: " << karg.K0Padded << std::endl
|
||||
<< "KBatch: " << karg.k_batch << std::endl
|
||||
<< "grid_size_: " << karg.KPadded << std::endl;
|
||||
#endif
|
||||
std::cout << "KPadded: " << karg.KPadded << std::endl
|
||||
<< "K0Padded: " << karg.K0Padded << std::endl
|
||||
<< "KBatch: " << karg.k_batch << std::endl
|
||||
<< "grid_size_: " << karg.KPadded << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -493,12 +494,13 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
arg.karg_.p_c_grid = p_workspace + offset;
|
||||
index_t tiles = (arg.block_end_ - arg.block_start_) / arg.karg_.k_batch;
|
||||
offset += tiles * MPerBlock * NPerBlock;
|
||||
#if DEBUG_LOG
|
||||
std::cout << "block_start: " << arg.block_start_ << "\n"
|
||||
<< "block_end: " << arg.block_end_ << "\n"
|
||||
<< "tiles: " << tiles << "\n"
|
||||
<< "offset: " << offset << std::endl;
|
||||
#endif
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "block_start: " << arg.block_start_ << "\n"
|
||||
<< "block_end: " << arg.block_end_ << "\n"
|
||||
<< "tiles: " << tiles << "\n"
|
||||
<< "offset: " << offset << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -816,11 +818,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
if((ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) +
|
||||
arg.skipped_group_count_) != arg.group_count_)
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
std::cout << "The group count is not equal to sum of skipped groups "
|
||||
"and kernel args size!"
|
||||
<< std::endl;
|
||||
#endif // DEBUG_LOG
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "The group count is not equal to sum of skipped groups "
|
||||
"and kernel args size!"
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -832,11 +835,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
bool group_arg_valid = GridwiseGemm::CheckValidity(gemm_arg);
|
||||
if(not group_arg_valid)
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
std::cout << "[" << __func__ << "] group id: " << i
|
||||
<< " has invalid GridwiseGemm settings!" << std::endl;
|
||||
gemm_arg.Print();
|
||||
#endif // DEBUG_LOG
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "[" << __func__ << "] group id: " << i
|
||||
<< " has invalid GridwiseGemm settings!" << std::endl;
|
||||
gemm_arg.Print();
|
||||
}
|
||||
}
|
||||
supported = supported && group_arg_valid;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,789 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
#include "ck/host_utility/stream_utility.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/utility/loop_scheduler.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
///
|
||||
/// @brief Entry point kernel for device-wide Grouped GEMM operation.
|
||||
///
|
||||
/// @param[in] gemm_descs_const The pointer to the array of GEMM descriptor structures.
|
||||
/// @param[in] group_count The number of together processed GEMMs.
|
||||
///
|
||||
/// @tparam GridwiseGemm The specific GridwiseGEMM algorithm implementation.
|
||||
/// @tparam GemmDesc The structure holding all necessary descriptors and
|
||||
/// other data needed for grouped gemm calculation and work
|
||||
/// distribution.
|
||||
/// @tparam LocalBlock2ETileMap The structure providing mapping between workgroup ids,
|
||||
/// the data tiles to process and the output tiles.
|
||||
///
|
||||
template <typename GridwiseGemm,
|
||||
typename GemmDesc,
|
||||
GemmSpecialization GemmSpec,
|
||||
typename DsDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename OffsettedBlockToCTileMap,
|
||||
typename LocalBlock2ETileMap,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_grouped_gemm_multiple_d_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
|
||||
const index_t group_count,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx94__))
|
||||
|
||||
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
|
||||
__shared__ uint8_t p_shared[shared_size];
|
||||
|
||||
const auto gemm_desc_ptr =
|
||||
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
|
||||
|
||||
constexpr auto NumDTensor = DsDataType::Size();
|
||||
index_t tile_id = get_block_1d_id();
|
||||
index_t tile_offset = 0;
|
||||
index_t group_id = -1;
|
||||
index_t group_offset = 0;
|
||||
index_t grid_size_grp = 0;
|
||||
|
||||
index_t gemm_tile_id_start = 0;
|
||||
index_t gemm_tile_id_end = 0;
|
||||
|
||||
using AGridDescMK =
|
||||
remove_cvref_t<decltype(GridwiseGemm::template MakeAGridDescriptor_M_K<ALayout, GemmSpec>(
|
||||
1, 1, 1))>;
|
||||
using BGridDescNK =
|
||||
remove_cvref_t<decltype(GridwiseGemm::template MakeBGridDescriptor_N_K<BLayout, GemmSpec>(
|
||||
1, 1, 1))>;
|
||||
using EGridDescMN =
|
||||
remove_cvref_t<decltype(GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(
|
||||
1, 1, 1))>;
|
||||
using DsGridDescMN =
|
||||
remove_cvref_t<decltype(GridwiseGemm::template MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>(
|
||||
{}, {}, {}))>;
|
||||
|
||||
index_t M = 0, N = 0, K = 0;
|
||||
index_t StrideA, StrideB, StrideE;
|
||||
std::array<index_t, NumDTensor> StrideDs;
|
||||
|
||||
AGridDescMK a_grid_desc_mk;
|
||||
BGridDescNK b_grid_desc_nk;
|
||||
EGridDescMN e_grid_desc_mn;
|
||||
DsGridDescMN ds_grid_desc_mn;
|
||||
auto b2c_tile_map = OffsettedBlockToCTileMap(LocalBlock2ETileMap(1, 1), 1, 1);
|
||||
|
||||
do
|
||||
{
|
||||
// Find corresponding GEMM group for our tile
|
||||
while(!(tile_id >= gemm_tile_id_start && tile_id < gemm_tile_id_end) &&
|
||||
group_id < group_count)
|
||||
{
|
||||
group_offset += grid_size_grp;
|
||||
group_id++;
|
||||
|
||||
if(group_id >= group_count)
|
||||
return;
|
||||
|
||||
M = gemm_desc_ptr[group_id].M;
|
||||
N = gemm_desc_ptr[group_id].N;
|
||||
K = gemm_desc_ptr[group_id].K;
|
||||
|
||||
if(M * N * K == 0)
|
||||
{
|
||||
grid_size_grp = 0;
|
||||
continue;
|
||||
}
|
||||
|
||||
b2c_tile_map =
|
||||
OffsettedBlockToCTileMap(LocalBlock2ETileMap(M, N), group_offset, tile_offset);
|
||||
grid_size_grp = b2c_tile_map.CalculateGridSize(M, N);
|
||||
|
||||
gemm_tile_id_start = group_offset;
|
||||
gemm_tile_id_end = group_offset + grid_size_grp;
|
||||
}
|
||||
|
||||
StrideA = gemm_desc_ptr[group_id].StrideA;
|
||||
StrideB = gemm_desc_ptr[group_id].StrideB;
|
||||
StrideDs = gemm_desc_ptr[group_id].StrideDs;
|
||||
StrideE = gemm_desc_ptr[group_id].StrideE;
|
||||
|
||||
a_grid_desc_mk =
|
||||
GridwiseGemm::template MakeAGridDescriptor_M_K<ALayout, GemmSpec>(M, K, StrideA);
|
||||
b_grid_desc_nk =
|
||||
GridwiseGemm::template MakeBGridDescriptor_N_K<BLayout, GemmSpec>(K, N, StrideB);
|
||||
e_grid_desc_mn =
|
||||
GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto j) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
|
||||
ds_grid_desc_mn(j) = GridwiseGemm::template MakeEGridDescriptor_M_N<DLayout, GemmSpec>(
|
||||
M, N, StrideDs[j]);
|
||||
});
|
||||
|
||||
using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer());
|
||||
DsGridPointer p_ds_grid;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
p_ds_grid(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]);
|
||||
});
|
||||
|
||||
bool has_main_kblock_loop =
|
||||
GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_mk.GetLength(Number<1>{}));
|
||||
// Update tile offset if we have moved within group
|
||||
b2c_tile_map.UpdateTileOffset(tile_offset);
|
||||
|
||||
if(has_main_kblock_loop)
|
||||
{
|
||||
GridwiseGemm::template Run<true>(gemm_desc_ptr[group_id].p_a_grid,
|
||||
gemm_desc_ptr[group_id].p_b_grid,
|
||||
p_ds_grid,
|
||||
gemm_desc_ptr[group_id].p_e_grid,
|
||||
static_cast<void*>(p_shared),
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_mk,
|
||||
b_grid_desc_nk,
|
||||
ds_grid_desc_mn,
|
||||
e_grid_desc_mn,
|
||||
b2c_tile_map);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseGemm::template Run<false>(gemm_desc_ptr[group_id].p_a_grid,
|
||||
gemm_desc_ptr[group_id].p_b_grid,
|
||||
p_ds_grid,
|
||||
gemm_desc_ptr[group_id].p_e_grid,
|
||||
static_cast<void*>(p_shared),
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_mk,
|
||||
b_grid_desc_nk,
|
||||
ds_grid_desc_mn,
|
||||
e_grid_desc_mn,
|
||||
b2c_tile_map);
|
||||
}
|
||||
|
||||
tile_id += get_grid_size();
|
||||
tile_offset += get_grid_size();
|
||||
|
||||
} while(group_id < group_count);
|
||||
#else
|
||||
ignore = gemm_descs_const;
|
||||
ignore = group_count;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = cde_element_op;
|
||||
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
|
||||
}
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::index_t NumGemmKPrefetchStage,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t AK1,
|
||||
ck::index_t BK1,
|
||||
ck::index_t MPerXDL,
|
||||
ck::index_t NPerXDL,
|
||||
ck::index_t MXdlPerWave,
|
||||
ck::index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
index_t ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
index_t BBlockLdsExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
typename ComputeDataType = EDataType>
|
||||
struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
: public DeviceGroupedGemmTileLoop<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedGemmMultipleDXdlCShuffleTileLoop;
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
NumGemmKPrefetchStage,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched,
|
||||
PipelineVer>;
|
||||
|
||||
template <typename UnderlyingBlockToCTileMap>
|
||||
struct OffsettedBlockToCTileMap
|
||||
{
|
||||
using underlying_type = UnderlyingBlockToCTileMap;
|
||||
|
||||
__host__ __device__ OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map,
|
||||
index_t group_offset,
|
||||
index_t tile_offset)
|
||||
: block_to_ctile_map_{block_to_ctile_map},
|
||||
group_offset_{group_offset},
|
||||
tile_offset_{tile_offset}
|
||||
{
|
||||
}
|
||||
|
||||
template <typename TopIdx>
|
||||
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
|
||||
{
|
||||
return block_to_ctile_map_.CalculateBottomIndex(
|
||||
make_multi_index(idx_top[Number<0>{}] + tile_offset_ - group_offset_));
|
||||
}
|
||||
|
||||
template <typename CTileIdx, typename CTileDim>
|
||||
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
|
||||
const CTileDim& c_tile_dim) const
|
||||
{
|
||||
return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
{
|
||||
return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
|
||||
{
|
||||
return block_to_ctile_map_.CalculateGridSize(M, N);
|
||||
}
|
||||
|
||||
__device__ void UpdateTileOffset(index_t offset) { tile_offset_ = offset; }
|
||||
UnderlyingBlockToCTileMap block_to_ctile_map_;
|
||||
index_t group_offset_;
|
||||
index_t tile_offset_;
|
||||
};
|
||||
|
||||
using KernelArguments = GroupedGemmTileLoopKernelArguments<NumDTensor>;
|
||||
using Block2ETileMap = BlockToCTileMap_N00_M0_N01Adapt<MPerBlock, NPerBlock>;
|
||||
using OffsetedLocalBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMap>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(std::vector<const void*>& /* p_As */,
|
||||
std::vector<const void*>& /* p_Bs */,
|
||||
std::vector<std::array<const void*, NumDTensor>>& /* p_Ds */,
|
||||
std::vector<void*>& /* p_Es */,
|
||||
const std::vector<GemmDesc>& gemm_descs,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
int occupancy_num_blocks,
|
||||
int gpu_cu_count)
|
||||
: group_count_{static_cast<index_t>(gemm_descs.size())},
|
||||
occupancy_num_blocks_{occupancy_num_blocks},
|
||||
gpu_cu_count_{gpu_cu_count},
|
||||
gemm_descs_{gemm_descs},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op},
|
||||
tile_count_{0}
|
||||
{
|
||||
for(const auto& desc : gemm_descs)
|
||||
{
|
||||
const auto M = desc.M_;
|
||||
const auto N = desc.N_;
|
||||
const auto b2c_tile_map = Block2ETileMap(M, N);
|
||||
tile_count_ += b2c_tile_map.CalculateGridSize(M, N);
|
||||
}
|
||||
}
|
||||
|
||||
index_t group_count_;
|
||||
const void* p_dev_gemm_args_;
|
||||
int occupancy_num_blocks_;
|
||||
int gpu_cu_count_;
|
||||
|
||||
const std::vector<GemmDesc>& gemm_descs_;
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
index_t tile_count_;
|
||||
};
|
||||
|
||||
struct KernelConfig
|
||||
{
|
||||
// The oversubscription factor for the number of blocks that can simultaneously reside on
|
||||
// GPU.
|
||||
static constexpr int BLOCK_SUBSCRIPTION_FACTOR = 1;
|
||||
static constexpr int BLOCK_WAVES = BlockSize / get_warp_size();
|
||||
static constexpr int CU_SIMDS = 4;
|
||||
// Assume we want to have at most 2 waves per SIMD
|
||||
static constexpr int CU_BLOCKS = math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES);
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
///
|
||||
/// @brief Launch Grouped Gemm kernel.
|
||||
///
|
||||
/// @note This function overload is using user provided device buffer for kernel
|
||||
/// arguments.
|
||||
///
|
||||
/// @param[in] arg The structure containing kernel arguments (in host
|
||||
/// memory).
|
||||
/// @param[in] dev_gemm_args The pointer to device memory with kernel arguments.
|
||||
/// @param[in] stream_config The device stream configuration.
|
||||
///
|
||||
/// @return The average kernel execution time (if time measurement is enabled.)
|
||||
///
|
||||
float Run(const Argument& arg,
|
||||
const void* dev_gemm_args,
|
||||
const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(dev_gemm_args == nullptr)
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "The gemm arguments device buffer is not allocated!"
|
||||
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
|
||||
float ave_time = 0;
|
||||
ave_time = DispatchKernel(arg, dev_gemm_args, stream_config);
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
///
|
||||
/// @brief Launch Grouped Gemm kernel.
|
||||
///
|
||||
/// @note This function overload is using device buffers (for kernel arguments and
|
||||
/// for kernel auxiliary workspace) provided with an argument. The user should
|
||||
/// call @see GetDeviceKernelArgSize, and @see SetDeviceKernelArgs, on arg
|
||||
/// parameter to properly allocate those buffers.
|
||||
///
|
||||
/// @param[in] arg The structure containing kernel arguments (in host memory).
|
||||
/// @param[in] stream_config The device stream configuration.
|
||||
///
|
||||
/// @return The average kernel execution time (if time measurement is enabled.)
|
||||
///
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(arg.p_dev_gemm_args_ == nullptr)
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "The gemm arguments device buffer is not allocated!"
|
||||
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
|
||||
return Run(arg, arg.p_dev_gemm_args_, stream_config);
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
|
||||
private:
|
||||
float DispatchKernel(const Argument& arg,
|
||||
const void* dev_gemm_args,
|
||||
const StreamConfig& stream_config) const
|
||||
{
|
||||
const auto kernel = kernel_grouped_gemm_multiple_d_xdl<GridwiseGemm,
|
||||
KernelArguments,
|
||||
GemmSpec,
|
||||
DsDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
OffsetedLocalBlock2ETileMap,
|
||||
Block2ETileMap,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>;
|
||||
return LaunchKernel(kernel, arg, dev_gemm_args, stream_config);
|
||||
}
|
||||
|
||||
template <typename KernelFunction>
|
||||
int CalculateMaxOccupancyGridSize(const KernelFunction& kernel,
|
||||
const StreamConfig& stream_config) const
|
||||
{
|
||||
// Calculate max number of workgroups that can simultaneously reside on the CU.
|
||||
int occ_num_blocks = 0;
|
||||
size_t dyn_shared_mem_per_blk = 0;
|
||||
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&occ_num_blocks, kernel, BlockSize, dyn_shared_mem_per_blk));
|
||||
|
||||
int cu_count = getAvailableComputeUnitCount(stream_config);
|
||||
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
std::cout << "MaxActiveBlocksPerCU: " << occ_num_blocks
|
||||
<< ", available CUs count: " << cu_count << ", occup. grid size: "
|
||||
<< ck::math::min(occ_num_blocks, KernelConfig::CU_BLOCKS) * cu_count
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
return cu_count * ck::math::min(occ_num_blocks, KernelConfig::CU_BLOCKS);
|
||||
}
|
||||
|
||||
template <typename KernelFunction>
|
||||
float LaunchKernel(const KernelFunction& kernel,
|
||||
const Argument& arg,
|
||||
const void* dev_gemm_args,
|
||||
const StreamConfig& stream_config) const
|
||||
{
|
||||
int grid_size = CalculateMaxOccupancyGridSize(kernel, stream_config);
|
||||
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
std::cout << "grid_size: " << grid_size << " tile_count: " << arg.tile_count_
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
cast_pointer_to_constant_address_space(dev_gemm_args),
|
||||
arg.group_count_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_xdl_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
using DsGridDescMN = remove_cvref_t<
|
||||
decltype(GridwiseGemm::template MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>(
|
||||
{}, {}, {}))>;
|
||||
|
||||
bool supported = true;
|
||||
|
||||
for(const auto& gdesc : arg.gemm_descs_)
|
||||
{
|
||||
const auto M = gdesc.M_;
|
||||
const auto N = gdesc.N_;
|
||||
const auto K = gdesc.K_;
|
||||
|
||||
const auto StrideA = gdesc.stride_A_;
|
||||
const auto StrideB = gdesc.stride_B_;
|
||||
const auto StrideE = gdesc.stride_C_;
|
||||
const auto& StrideDs = gdesc.stride_Ds_;
|
||||
|
||||
// If M dimension is unknown at launch time then validate just NK.
|
||||
// If N or K dim is zero (or unknown) then the vector loads responsibility lies on
|
||||
// the user.
|
||||
if(N * K == 0)
|
||||
continue;
|
||||
|
||||
const auto a_grid_desc_mk =
|
||||
GridwiseGemm::template MakeAGridDescriptor_M_K<ALayout, GemmSpec>(M, K, StrideA);
|
||||
const auto b_grid_desc_nk =
|
||||
GridwiseGemm::template MakeBGridDescriptor_N_K<BLayout, GemmSpec>(K, N, StrideB);
|
||||
const auto e_grid_desc_mn =
|
||||
GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
|
||||
|
||||
DsGridDescMN ds_grid_desc_mn;
|
||||
static_for<0, NumDTensor, 1>{}([&](auto j) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
|
||||
ds_grid_desc_mn(j) =
|
||||
GridwiseGemm::template MakeEGridDescriptor_M_N<DLayout, GemmSpec>(
|
||||
M, N, StrideDs[j]);
|
||||
});
|
||||
|
||||
const auto b2c_tile_map = Block2ETileMap(M, N);
|
||||
|
||||
if(!(GridwiseGemm::template CheckValidity(a_grid_desc_mk,
|
||||
b_grid_desc_nk,
|
||||
ds_grid_desc_mn,
|
||||
e_grid_desc_mn,
|
||||
b2c_tile_map) &&
|
||||
GridwiseGemm::template CheckTensorTransfersValidity<ALayout, BLayout, ELayout>(
|
||||
M, N, K)))
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "The provided GEMM problem size (M,N,K) [" << M << "," << N << ","
|
||||
<< K << "] are not supported by current template parameters!"
|
||||
<< " In " << __FILE__ << ":" << __LINE__
|
||||
<< ", in function: " << __func__;
|
||||
}
|
||||
supported = false;
|
||||
}
|
||||
}
|
||||
|
||||
return supported;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(std::vector<const void*>& p_As,
|
||||
std::vector<const void*>& p_Bs,
|
||||
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
|
||||
std::vector<void*>& p_Es,
|
||||
std::vector<GemmDesc>& gemm_descs,
|
||||
AElementwiseOperation a_elementwise_op,
|
||||
BElementwiseOperation b_elementwise_op,
|
||||
CDEElementwiseOperation cde_elementwise_op)
|
||||
{
|
||||
const auto kernel = kernel_grouped_gemm_multiple_d_xdl<GridwiseGemm,
|
||||
KernelArguments,
|
||||
GemmSpec,
|
||||
DsDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
OffsetedLocalBlock2ETileMap,
|
||||
Block2ETileMap,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>;
|
||||
int occupancy, num_cu;
|
||||
hip_check_error(
|
||||
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
|
||||
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
hip_check_error(hipGetDevice(&dev));
|
||||
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
|
||||
num_cu = dev_prop.multiProcessorCount;
|
||||
|
||||
return Argument{p_As,
|
||||
p_Bs,
|
||||
p_Ds,
|
||||
p_Es,
|
||||
gemm_descs,
|
||||
a_elementwise_op,
|
||||
b_elementwise_op,
|
||||
cde_elementwise_op,
|
||||
occupancy,
|
||||
num_cu};
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(std::vector<const void*>& p_As,
|
||||
std::vector<const void*>& p_Bs,
|
||||
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
|
||||
std::vector<void*>& p_Es,
|
||||
std::vector<GemmDesc>& gemm_descs,
|
||||
AElementwiseOperation a_elementwise_op,
|
||||
BElementwiseOperation b_elementwise_op,
|
||||
CDEElementwiseOperation cde_elementwise_op) override
|
||||
{
|
||||
const auto kernel = kernel_grouped_gemm_multiple_d_xdl<GridwiseGemm,
|
||||
KernelArguments,
|
||||
GemmSpec,
|
||||
DsDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
OffsetedLocalBlock2ETileMap,
|
||||
Block2ETileMap,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>;
|
||||
int occupancy, num_cu;
|
||||
hip_check_error(
|
||||
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
|
||||
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
hip_check_error(hipGetDevice(&dev));
|
||||
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
|
||||
num_cu = dev_prop.multiProcessorCount;
|
||||
|
||||
return std::make_unique<Argument>(p_As,
|
||||
p_Bs,
|
||||
p_Ds,
|
||||
p_Es,
|
||||
gemm_descs,
|
||||
a_elementwise_op,
|
||||
b_elementwise_op,
|
||||
cde_elementwise_op,
|
||||
occupancy,
|
||||
num_cu);
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::ostringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGroupedGemmMultipleDXdlCShuffleTileLoop"
|
||||
<< "<"
|
||||
<< std::string(ALayout::name)[0] << ","
|
||||
<< std::string(BLayout::name)[0] << ","
|
||||
<< std::string(ELayout::name)[0] << ","
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< AK1 << ", "
|
||||
<< BK1 << ", "
|
||||
<< MPerXDL << ", "
|
||||
<< NPerXDL << ", "
|
||||
<< MXdlPerWave << ", "
|
||||
<< NXdlPerWave << ", "
|
||||
<< ABlockTransferSrcScalarPerVector << ", "
|
||||
<< BBlockTransferSrcScalarPerVector << ", "
|
||||
<< CShuffleMXdlPerWavePerShuffle << ", "
|
||||
<< CShuffleNXdlPerWavePerShuffle << ", "
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< PipelineVer << ", "
|
||||
<< LoopSched
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
|
||||
void SetDeviceKernelArgs(Argument& arg, void* p_dev_kernel_args) const
|
||||
{
|
||||
arg.p_dev_gemm_args_ = p_dev_kernel_args;
|
||||
}
|
||||
|
||||
void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override
|
||||
{
|
||||
return SetDeviceKernelArgs(*dynamic_cast<Argument*>(p_arg), p_dev_kernel_args);
|
||||
}
|
||||
|
||||
size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
|
||||
{
|
||||
return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(KernelArguments);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -514,28 +514,29 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
|
||||
|
||||
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
std::cout << "group: " << i << " arg.a_grid_desc_ak0_m_ak1_{"
|
||||
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0)
|
||||
<< ", "
|
||||
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I1)
|
||||
<< ", "
|
||||
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2)
|
||||
<< "}";
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "group: " << i << " arg.a_grid_desc_ak0_m_ak1_{"
|
||||
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0)
|
||||
<< ", "
|
||||
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I1)
|
||||
<< ", "
|
||||
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2)
|
||||
<< "}";
|
||||
|
||||
std::cout << ", arg.b_grid_desc_bk0_n_bk1_{"
|
||||
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I0)
|
||||
<< ", "
|
||||
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I1)
|
||||
<< ", "
|
||||
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I2)
|
||||
<< "}";
|
||||
std::cout << ", arg.b_grid_desc_bk0_n_bk1_{"
|
||||
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I0)
|
||||
<< ", "
|
||||
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I1)
|
||||
<< ", "
|
||||
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I2)
|
||||
<< "}";
|
||||
|
||||
std::cout << ", arg.e_grid_desc_m_n_{ "
|
||||
<< arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I1) << "}"
|
||||
<< std::endl;
|
||||
#endif
|
||||
std::cout << ", arg.e_grid_desc_m_n_{ "
|
||||
<< arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I1) << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_m_k_,
|
||||
arg.gemm_desc_kernel_arg_[i].b_grid_desc_n_k_,
|
||||
|
||||
@@ -529,11 +529,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
if((ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) +
|
||||
arg.skipped_group_count_) != arg.group_count_)
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
std::cout << "The group count is not equal to sum of skipped groups "
|
||||
"and kernel args size!"
|
||||
<< std::endl;
|
||||
#endif // DEBUG_LOG
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "The group count is not equal to sum of skipped groups "
|
||||
"and kernel args size!"
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -544,11 +545,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
bool group_arg_valid = GridwiseGemm::CheckValidity(a);
|
||||
if(not group_arg_valid)
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
std::cout << "[" << __func__ << "] group id: " << i
|
||||
<< " has invalid GridwiseGemm settings!" << std::endl;
|
||||
a.Print();
|
||||
#endif // DEBUG_LOG
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "[" << __func__ << "] group id: " << i
|
||||
<< " has invalid GridwiseGemm settings!" << std::endl;
|
||||
a.Print();
|
||||
}
|
||||
}
|
||||
supported = supported && group_arg_valid;
|
||||
}
|
||||
|
||||
@@ -596,7 +596,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
|
||||
|
||||
static bool IsSupportedArgument(const RawArg& arg)
|
||||
{
|
||||
if(ck::is_navi3_supported())
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
|
||||
{
|
||||
@@ -958,7 +958,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
|
||||
#if 0
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::is_navi3_supported())
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
|
||||
{
|
||||
|
||||
@@ -594,7 +594,7 @@ struct DeviceMultiQueryAttentionForward_Wmma
|
||||
|
||||
static bool IsSupportedArgument(const RawArg& arg)
|
||||
{
|
||||
if(ck::is_navi3_supported())
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
|
||||
{
|
||||
@@ -950,7 +950,7 @@ struct DeviceMultiQueryAttentionForward_Wmma
|
||||
#if 0
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::is_navi3_supported())
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
|
||||
{
|
||||
|
||||
@@ -92,15 +92,6 @@ struct Add
|
||||
};
|
||||
};
|
||||
|
||||
struct Scales
|
||||
{
|
||||
template <typename Y, typename X0, typename X1>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const
|
||||
{
|
||||
y = ck::type_convert<Y>(ck::type_convert<float>(x0) * ck::type_convert<float>(x1));
|
||||
}
|
||||
};
|
||||
|
||||
struct Max
|
||||
{
|
||||
template <typename Y, typename X0, typename X1>
|
||||
@@ -188,6 +179,16 @@ struct Multiply
|
||||
y = ck::type_convert<bhalf_t>(y_tmp);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bhalf_t>(bhalf_t& y, const int8_t& x0, const bhalf_t& x1) const
|
||||
{
|
||||
const float x1_tmp = ck::type_convert<float>(x0);
|
||||
const float x2_tmp = ck::type_convert<float>(x1);
|
||||
const float y_tmp = x1_tmp * x2_tmp;
|
||||
y = ck::type_convert<bhalf_t>(y_tmp);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
|
||||
@@ -521,6 +522,71 @@ struct AddFastGelu
|
||||
}
|
||||
};
|
||||
|
||||
// E = MultiplyFastGelu(C + D)
|
||||
struct MultiplyFastGelu
|
||||
{
|
||||
template <typename E, typename C, typename D>
|
||||
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<float, float, float>(float& e, const float& c, const float& d) const
|
||||
{
|
||||
const float x = c * d;
|
||||
|
||||
FastGelu{}.template operator()<float, float>(e, x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t, half_t, half_t>(half_t& e, const half_t& c, const half_t& d) const
|
||||
{
|
||||
const half_t x = c * d;
|
||||
|
||||
ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t, float, half_t>(half_t& e, const float& c, const half_t& d) const
|
||||
{
|
||||
const float x0_f = c * d;
|
||||
|
||||
float x1_f = 0;
|
||||
|
||||
ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
|
||||
x0_f);
|
||||
|
||||
e = type_convert<half_t>(x1_f);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& e, const bhalf_t& c, const bhalf_t& d) const
|
||||
{
|
||||
const float x0_f = type_convert<float>(c) * type_convert<float>(d);
|
||||
|
||||
float x1_f = 0;
|
||||
|
||||
FastGelu{}.template operator()<float, float>(x1_f, x0_f);
|
||||
|
||||
e = type_convert<bhalf_t>(x1_f);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bhalf_t, float, bhalf_t>(bhalf_t& e, const float& c, const bhalf_t& d) const
|
||||
{
|
||||
const float x0_f = c * type_convert<float>(d);
|
||||
|
||||
float x1_f = 0;
|
||||
|
||||
FastGelu{}.template operator()<float, float>(x1_f, x0_f);
|
||||
|
||||
e = type_convert<bhalf_t>(x1_f);
|
||||
}
|
||||
};
|
||||
|
||||
// E = Silu(C + D)
|
||||
struct AddSilu
|
||||
{
|
||||
|
||||
@@ -221,6 +221,15 @@ struct MultiplyAdd
|
||||
e = y;
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, float, bhalf_t, bhalf_t>(bhalf_t& e,
|
||||
const float& c,
|
||||
const bhalf_t& d0,
|
||||
const bhalf_t& d1) const
|
||||
{
|
||||
const bhalf_t y = type_convert<bhalf_t>(c) * d0 + d1;
|
||||
e = y;
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ void operator()<float, float, half_t, half_t>(float& e,
|
||||
const float& c,
|
||||
const half_t& d0,
|
||||
@@ -240,6 +249,26 @@ struct MultiplyAdd
|
||||
}
|
||||
};
|
||||
|
||||
struct MultiplyAddFastGelu
|
||||
{
|
||||
template <typename E, typename C, typename D0, typename D1>
|
||||
__host__ __device__ constexpr void
|
||||
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<ck::bhalf_t, float, ck::bhalf_t, ck::bhalf_t>(
|
||||
ck::bhalf_t& e, const float& c, const ck::bhalf_t& d0, const ck::bhalf_t& d1) const
|
||||
{
|
||||
const float x0_f = c * ck::type_convert<float>(d0) + ck::type_convert<float>(d1);
|
||||
|
||||
float x1_f = 0;
|
||||
|
||||
FastGelu{}.template operator()<float, float>(x1_f, x0_f);
|
||||
|
||||
e = ck::type_convert<ck::bhalf_t>(x1_f);
|
||||
}
|
||||
};
|
||||
|
||||
// E = FastGelu(C + D0 + D1)
|
||||
struct AddAddFastGelu
|
||||
{
|
||||
@@ -499,6 +528,26 @@ struct UnaryTypeConvert<ck::bhalf_t, float>
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvInvscale
|
||||
{
|
||||
/// @brief Op to multiply convolution results by inverted scale factors
|
||||
/// @param e Output after scaling
|
||||
/// @param c Convolution result
|
||||
/// @param d0 Input scale factor
|
||||
/// @param d1 Weights scale factor
|
||||
/// @param d2 Output scale factor
|
||||
template <typename E, typename C, typename D0, typename D1, typename D2>
|
||||
__host__ __device__ void
|
||||
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<f8_t, float, float, float, float>(
|
||||
f8_t& e, const float& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
e = type_convert<f8_t>(c / d0 / d1 / d2);
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace element_wise
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
@@ -504,6 +504,16 @@ struct FastGelu
|
||||
y = type_convert<half_t>(y_f);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
float y_f;
|
||||
|
||||
this->operator()<float, float>(y_f, x);
|
||||
|
||||
y = type_convert<bhalf_t>(y_f);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
|
||||
{
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -151,7 +151,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
|
||||
{
|
||||
}
|
||||
|
||||
__host__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
|
||||
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
|
||||
{
|
||||
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
|
||||
@@ -260,7 +260,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt : BlockToCTileMap_M00_N0_M01Adapt<MPerBlo
|
||||
};
|
||||
|
||||
// Grouped Rows of column-vectors WGP mapping
|
||||
// Optimized for MI300-like multipe-die chip
|
||||
// Optimized for gfx94x-like multipe-die chip
|
||||
|
||||
template <index_t GroupNum, index_t MPerBlock, index_t NPerBlock>
|
||||
struct BlockToCTileMap_Grouped_M00_N0_M01Adapt
|
||||
@@ -275,7 +275,7 @@ struct BlockToCTileMap_Grouped_M00_N0_M01Adapt
|
||||
{
|
||||
}
|
||||
|
||||
__host__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
|
||||
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
|
||||
{
|
||||
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
|
||||
@@ -428,7 +428,7 @@ struct BlockToCTileMap_N00_M0_N01Adapt<MPerBlock, NPerBlock, void>
|
||||
{
|
||||
}
|
||||
|
||||
__host__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
|
||||
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
|
||||
{
|
||||
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
|
||||
@@ -900,6 +900,11 @@ struct OffsettedBlockToCTileMap
|
||||
return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
|
||||
{
|
||||
return block_to_ctile_map_.CalculateGridSize(M, N);
|
||||
}
|
||||
|
||||
UnderlyingBlockToCTileMap block_to_ctile_map_;
|
||||
index_t block_start_;
|
||||
};
|
||||
|
||||
@@ -594,11 +594,6 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); },
|
||||
Number<NumATensor>{});
|
||||
|
||||
#if 0
|
||||
static_assert(ABlockTransferSrcScalarPerVector == ABlockTransferDstScalarPerVector_AK1,
|
||||
"Src and Dst ScalarPerVector must be the same");
|
||||
#endif
|
||||
|
||||
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
|
||||
ThisThreadBlock,
|
||||
AsDataType,
|
||||
@@ -616,7 +611,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
uniform_sequence_gen_t<NumATensor, false>,
|
||||
uniform_sequence_gen_t<NumATensor, AThreadTransferSrcResetCoordinateAfterRun>,
|
||||
Sequence<true>>{as_grid_desc_ak0_m_ak1,
|
||||
idx_as_block_begin,
|
||||
tie(a_block_desc_ak0_m_ak1),
|
||||
@@ -627,11 +622,6 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
generate_tuple([&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); },
|
||||
Number<NumBTensor>{});
|
||||
|
||||
#if 0
|
||||
static_assert(BBlockTransferSrcScalarPerVector == BBlockTransferDstScalarPerVector_BK1,
|
||||
"Src and Dst ScalarPerVector must be the same");
|
||||
#endif
|
||||
|
||||
auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
|
||||
ThisThreadBlock,
|
||||
BsDataType,
|
||||
@@ -649,7 +639,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
uniform_sequence_gen_t<NumBTensor, false>,
|
||||
uniform_sequence_gen_t<NumBTensor, BThreadTransferSrcResetCoordinateAfterRun>,
|
||||
Sequence<true>>{bs_grid_desc_bk0_n_bk1,
|
||||
idx_bs_block_begin,
|
||||
tie(b_block_desc_bk0_n_bk1),
|
||||
|
||||
@@ -257,7 +257,70 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
e_grid_desc_m_n);
|
||||
}
|
||||
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
template <typename ALayout, typename BLayout, typename ELayout>
|
||||
__host__ __device__ static bool
|
||||
CheckTensorTransfersValidity(index_t MRaw, index_t NRaw, index_t KRaw)
|
||||
{
|
||||
// Check if the vector dim is K1 or M|N
|
||||
const auto A_vector_dim_size = ABlockTransferSrcVectorDim == 2 ? KRaw : MRaw;
|
||||
const auto B_vector_dim_size = BBlockTransferSrcVectorDim == 2 ? KRaw : NRaw;
|
||||
const auto E_vector_dim_size = NRaw;
|
||||
|
||||
// check vector load for A tensor
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
if(!(A_vector_dim_size == KRaw &&
|
||||
A_vector_dim_size % ABlockTransferSrcScalarPerVector == 0))
|
||||
return false;
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
if(!(A_vector_dim_size == MRaw &&
|
||||
A_vector_dim_size % ABlockTransferSrcScalarPerVector == 0))
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
if(!(B_vector_dim_size == NRaw &&
|
||||
B_vector_dim_size % BBlockTransferSrcScalarPerVector == 0))
|
||||
return false;
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
if(!(B_vector_dim_size == KRaw &&
|
||||
B_vector_dim_size % BBlockTransferSrcScalarPerVector == 0))
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ELayout>)
|
||||
{
|
||||
if(!(E_vector_dim_size == NRaw &&
|
||||
E_vector_dim_size % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
return false;
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ELayout>)
|
||||
{
|
||||
if(!(E_vector_dim_size == NRaw &&
|
||||
CDEShuffleBlockTransferScalarPerVector_NPerBlock == 1))
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename AGridDesc_M_K,
|
||||
typename BGridDesc_N_K,
|
||||
typename DsGridDesc_M_N,
|
||||
@@ -267,7 +330,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
const BGridDesc_N_K& b_grid_desc_n_k,
|
||||
const DsGridDesc_M_N& ds_grid_desc_m_n,
|
||||
const EGridDesc_M_N& e_grid_desc_m_n,
|
||||
const Block2ETileMap&)
|
||||
[[maybe_unused]] const Block2ETileMap&)
|
||||
{
|
||||
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
|
||||
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
|
||||
@@ -285,7 +348,6 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
bool valid = true;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
@@ -306,7 +368,6 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
|
||||
// check gridwise gemm pipeline
|
||||
const auto num_k_loop = AK / KPerBlock;
|
||||
|
||||
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
|
||||
{
|
||||
return false;
|
||||
@@ -938,6 +999,63 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_etile_map);
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
typename AGridDesc_MK,
|
||||
typename BGridDesc_NK,
|
||||
typename DsGridDesc_MN,
|
||||
typename EGridDesc_MN,
|
||||
typename Block2ETileMap>
|
||||
__device__ static void Run(const void* __restrict__ p_a_grid_,
|
||||
const void* __restrict__ p_b_grid_,
|
||||
DsGridPointer p_ds_grid,
|
||||
void* __restrict__ p_e_grid_,
|
||||
void* __restrict__ p_shared,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op,
|
||||
const AGridDesc_MK& a_grid_desc_m_k,
|
||||
const BGridDesc_NK& b_grid_desc_n_k,
|
||||
const DsGridDesc_MN& ds_grid_desc_m_n,
|
||||
const EGridDesc_MN& e_grid_desc_m_n,
|
||||
const Block2ETileMap& block_2_etile_map)
|
||||
{
|
||||
const auto p_a_grid = reinterpret_cast<const ADataType*>(p_a_grid_);
|
||||
const auto p_b_grid = reinterpret_cast<const BDataType*>(p_b_grid_);
|
||||
const auto p_e_grid = reinterpret_cast<EDataType*>(p_e_grid_);
|
||||
|
||||
// tensor descriptors for block/thread-wise copy
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
|
||||
const auto b_grid_desc_bk0_n_bk1 = MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
|
||||
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
DsGridDesc_MN{}))>;
|
||||
|
||||
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto j) {
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock(j) =
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]);
|
||||
});
|
||||
|
||||
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
|
||||
|
||||
Run<HasMainKBlockLoop>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_etile_map);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
|
||||
@@ -57,3 +58,16 @@ constexpr auto GridwiseGemmPipeline_Selector()
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p)
|
||||
{
|
||||
switch(p)
|
||||
{
|
||||
case ck::PipelineVersion::v1: os << "PipelineVersion::v1"; break;
|
||||
case ck::PipelineVersion::v2: os << "PipelineVersion::v2"; break;
|
||||
case ck::PipelineVersion::v4: os << "PipelineVersion::v4"; break;
|
||||
case ck::PipelineVersion::weight_only: os << "PipelineVersion::weight_only"; break;
|
||||
default: os << "";
|
||||
}
|
||||
return os;
|
||||
}
|
||||
|
||||
@@ -935,12 +935,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(!(karg.M % MPerBlock == 0))
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
<< std::endl;
|
||||
|
||||
#endif // DEBUG_LOG
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -952,12 +952,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(!(karg.N % NPerBlock == 0))
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
<< std::endl;
|
||||
|
||||
#endif // DEBUG_LOG
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -971,12 +971,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
auto K_t = karg.KBatch * KPerBlock;
|
||||
if(!(karg.K % K_t == 0))
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
|
||||
<< karg.K << " " << __FILE__ << ":" << __LINE__
|
||||
<< ", in function: " << __func__ << std::endl;
|
||||
|
||||
#endif // DEBUG_LOG
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
|
||||
<< karg.K << " " << __FILE__ << ":" << __LINE__
|
||||
<< ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -995,13 +995,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(karg.K % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
std::cout << "Arg K (" << karg.K
|
||||
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
|
||||
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
|
||||
#endif // DEBUG_LOG
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg K (" << karg.K
|
||||
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
|
||||
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -1009,13 +1009,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(karg.M % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
std::cout << "Arg M (" << karg.M
|
||||
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
|
||||
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
|
||||
#endif // DEBUG_LOG
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg M (" << karg.M
|
||||
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
|
||||
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -1024,13 +1024,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(karg.N % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
std::cout << "Arg N (" << karg.N
|
||||
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
|
||||
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
|
||||
#endif // DEBUG_LOG
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg N (" << karg.N
|
||||
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
|
||||
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -1038,13 +1038,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(karg.K % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
std::cout << "Arg K (" << karg.K
|
||||
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
|
||||
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
|
||||
#endif // DEBUG_LOG
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg K (" << karg.K
|
||||
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
|
||||
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -1053,14 +1053,15 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
std::cout << "Arg N (" << karg.N
|
||||
<< ") value is not a multiple of "
|
||||
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
|
||||
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
|
||||
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
|
||||
#endif // DEBUG_LOG
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg N (" << karg.N
|
||||
<< ") value is not a multiple of "
|
||||
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
|
||||
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -1068,14 +1069,28 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
#if DEBUG_LOG
|
||||
std::cout << "Arg M (" << karg.M
|
||||
<< ") value is not a multiple of "
|
||||
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
|
||||
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
|
||||
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg M (" << karg.M
|
||||
<< ") value is not a multiple of "
|
||||
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
|
||||
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
#endif // DEBUG_LOG
|
||||
if constexpr(is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__
|
||||
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
if(karg.KBatch > 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user