mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Merge branch 'develop' into improve_pipeline_v3
This commit is contained in:
@@ -72,8 +72,9 @@ message(STATUS "Build with HIP ${HIP_VERSION}")
|
||||
|
||||
|
||||
rocm_create_package(
|
||||
NAME CK-${CK_BACKEND}
|
||||
NAME composablekernel
|
||||
DESCRIPTION "High Performance Composable Kernel for AMD GPUs"
|
||||
MAINTAINER "MIOpen Kernels Dev Team <dl.MIOpen@amd.com>"
|
||||
LDCONFIG
|
||||
)
|
||||
|
||||
@@ -226,15 +227,12 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib)
|
||||
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib)
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/bin)
|
||||
|
||||
configure_file("${PROJECT_SOURCE_DIR}/include/ck/hip_version.hpp.in" "${PROJECT_BINARY_DIR}/include/ck/hip_version.hpp")
|
||||
|
||||
include_directories(BEFORE
|
||||
${PROJECT_SOURCE_DIR}/include
|
||||
${PROJECT_BINARY_DIR}/include
|
||||
${PROJECT_SOURCE_DIR}/library/include
|
||||
)
|
||||
|
||||
include(googletest)
|
||||
|
||||
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
|
||||
if(BUILD_DEV)
|
||||
@@ -243,7 +241,31 @@ if(BUILD_DEV)
|
||||
endif()
|
||||
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
|
||||
|
||||
add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR})
|
||||
|
||||
add_subdirectory(library)
|
||||
add_subdirectory(example)
|
||||
add_subdirectory(test)
|
||||
add_subdirectory(profiler)
|
||||
|
||||
#Create an interface target for the include only files and call it "composablekernels"
|
||||
include(CMakePackageConfigHelpers)
|
||||
|
||||
set(version 1.0.0)
|
||||
write_basic_package_version_file(
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfigVersion.cmake"
|
||||
VERSION "${version}"
|
||||
COMPATIBILITY AnyNewerVersion
|
||||
)
|
||||
|
||||
configure_package_config_file(${CMAKE_CURRENT_SOURCE_DIR}/Config.cmake.in
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfig.cmake"
|
||||
INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel
|
||||
NO_CHECK_REQUIRED_COMPONENTS_MACRO
|
||||
)
|
||||
|
||||
install(FILES
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfig.cmake"
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfigVersion.cmake"
|
||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel
|
||||
)
|
||||
|
||||
11
Config.cmake.in
Normal file
11
Config.cmake.in
Normal file
@@ -0,0 +1,11 @@
|
||||
@PACKAGE_INIT@
|
||||
|
||||
set(_composable_kernel_supported_components device_operations host_tensor)
|
||||
|
||||
foreach(_comp ${composable_kernel_FIND_COMPONENTS})
|
||||
if(NOT _comp IN_LIST _composable_kernel_supported_components)
|
||||
set(composable_kernel_FOUND False)
|
||||
set(composable_kernel_NOT_FOUND_MESSAGE "Unsupported component: ${_comp}")
|
||||
endif()
|
||||
include("${CMAKE_CURRENT_LIST_DIR}/composable_kernel${_comp}Targets.cmake")
|
||||
endforeach()
|
||||
24
Dockerfile
24
Dockerfile
@@ -11,13 +11,7 @@ ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/
|
||||
RUN apt-get update
|
||||
RUN apt-get install -y wget gnupg
|
||||
RUN wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add -
|
||||
RUN if ! [ -z $OSDB_BKC_VERSION ]; then \
|
||||
echo "Using BKC VERISION: $OSDB_BKC_VERSION";\
|
||||
sh -c "echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-osdb-deb/ compute-rocm-dkms-no-npi-hipclang ${OSDB_BKC_VERSION} > /etc/apt/sources.list.d/rocm.list" ;\
|
||||
cat /etc/apt/sources.list.d/rocm.list;\
|
||||
else \
|
||||
sh -c "echo deb [arch=amd64] $DEB_ROCM_REPO ubuntu main > /etc/apt/sources.list.d/rocm.list" ;\
|
||||
fi
|
||||
RUN sh -c "echo deb [arch=amd64] $DEB_ROCM_REPO ubuntu main > /etc/apt/sources.list.d/rocm.list"
|
||||
RUN wget --no-check-certificate -qO - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | apt-key add -
|
||||
RUN sh -c "echo deb https://apt.kitware.com/ubuntu/ bionic main | tee -a /etc/apt/sources.list"
|
||||
|
||||
@@ -25,18 +19,15 @@ RUN sh -c "echo deb https://apt.kitware.com/ubuntu/ bionic main | tee -a /etc/ap
|
||||
# Install dependencies
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \
|
||||
apt-utils \
|
||||
sshpass \
|
||||
build-essential \
|
||||
cmake-data=3.15.1-0kitware1 \
|
||||
cmake=3.15.1-0kitware1 \
|
||||
curl \
|
||||
doxygen \
|
||||
g++ \
|
||||
gdb \
|
||||
git \
|
||||
hip-rocclr \
|
||||
jq \
|
||||
lcov \
|
||||
libelf-dev \
|
||||
libncurses5-dev \
|
||||
libnuma-dev \
|
||||
@@ -44,7 +35,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
|
||||
llvm-amdgpu \
|
||||
pkg-config \
|
||||
python \
|
||||
python3 \
|
||||
python3.8 \
|
||||
python-dev \
|
||||
python3-dev \
|
||||
python-pip \
|
||||
@@ -62,8 +53,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# RUN pip3 install --default-timeout=100000 -r requirements.txt
|
||||
|
||||
# Setup ubsan environment to printstacktrace
|
||||
RUN ln -s /usr/bin/llvm-symbolizer-3.8 /usr/local/bin/llvm-symbolizer
|
||||
ENV UBSAN_OPTIONS=print_stacktrace=1
|
||||
@@ -83,6 +72,13 @@ ARG PREFIX=/opt/rocm
|
||||
RUN cget install pfultz2/rocm-recipes
|
||||
# Install rbuild
|
||||
RUN pip3 install https://github.com/RadeonOpenCompute/rbuild/archive/6d78a0553babdaea8d2da5de15cbda7e869594b8.tar.gz
|
||||
# Install packages for processing the performance results
|
||||
RUN pip3 install --upgrade pip
|
||||
RUN pip3 install sqlalchemy
|
||||
RUN pip3 install pymysql
|
||||
RUN pip3 install pandas
|
||||
RUN pip3 install setuptools-rust
|
||||
RUN pip3 install sshtunnel
|
||||
# Setup ubsan environment to printstacktrace
|
||||
ENV UBSAN_OPTIONS=print_stacktrace=1
|
||||
|
||||
@@ -92,5 +88,3 @@ ADD rbuild.ini /rbuild.ini
|
||||
ADD dev-requirements.txt dev-requirements.txt
|
||||
RUN rbuild prepare -s develop -d $PREFIX
|
||||
RUN groupadd -f render
|
||||
# RUN cget install -f min-requirements.txt
|
||||
# RUN CXXFLAGS='-isystem $PREFIX/include' cget install -f ./mlir-requirements.txt
|
||||
|
||||
308
Jenkinsfile
vendored
308
Jenkinsfile
vendored
@@ -7,7 +7,6 @@ def show_node_info() {
|
||||
echo "NODE_NAME = \$NODE_NAME"
|
||||
lsb_release -sd
|
||||
uname -r
|
||||
cat /sys/module/amdgpu/version
|
||||
ls /opt/ -la
|
||||
"""
|
||||
}
|
||||
@@ -100,35 +99,45 @@ def buildHipClangJob(Map conf=[:]){
|
||||
|
||||
def variant = env.STAGE_NAME
|
||||
|
||||
|
||||
def retimage
|
||||
gitStatusWrapper(credentialsId: '7126e5fe-eb51-4576-b52b-9aaf1de8f0fd', gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') {
|
||||
try {
|
||||
retimage = docker.build("${image}", dockerArgs + '.')
|
||||
withDockerContainer(image: image, args: dockerOpts) {
|
||||
timeout(time: 5, unit: 'MINUTES')
|
||||
{
|
||||
sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo'
|
||||
|
||||
gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') {
|
||||
if (params.USE_DOCKERFILE){
|
||||
try {
|
||||
retimage = docker.build("${image}", dockerArgs + '.')
|
||||
withDockerContainer(image: image, args: dockerOpts) {
|
||||
timeout(time: 5, unit: 'MINUTES')
|
||||
{
|
||||
sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo'
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){
|
||||
echo "The job was cancelled or aborted"
|
||||
throw e
|
||||
}
|
||||
catch(Exception ex) {
|
||||
retimage = docker.build("${image}", dockerArgs + "--no-cache .")
|
||||
withDockerContainer(image: image, args: dockerOpts) {
|
||||
timeout(time: 5, unit: 'MINUTES')
|
||||
{
|
||||
sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){
|
||||
echo "The job was cancelled or aborted"
|
||||
throw e
|
||||
}
|
||||
catch(Exception ex) {
|
||||
retimage = docker.build("${image}", dockerArgs + "--no-cache .")
|
||||
withDockerContainer(image: image, args: dockerOpts) {
|
||||
timeout(time: 5, unit: 'MINUTES')
|
||||
{
|
||||
sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo'
|
||||
}
|
||||
else{
|
||||
timeout(time: 3, unit: 'HOURS'){
|
||||
retimage = docker.image('compute-artifactory.amd.com:5000/rocm-plus-docker/framework/compute-rocm-dkms-no-npi-hipclang:9110_ubuntu18.04_py3.6_pytorch_rocm5.0_internal_testing_7ff5b54').pull()
|
||||
image="b56f8ac0d6ea"
|
||||
sh "docker images"
|
||||
}
|
||||
}
|
||||
|
||||
withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') {
|
||||
timeout(time: 5, unit: 'HOURS')
|
||||
{
|
||||
sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo'
|
||||
cmake_build(conf)
|
||||
}
|
||||
}
|
||||
@@ -140,6 +149,10 @@ def reboot(){
|
||||
build job: 'reboot-slaves', propagate: false , parameters: [string(name: 'server', value: "${env.NODE_NAME}"),]
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def buildHipClangJobAndReboot(Map conf=[:]){
|
||||
try{
|
||||
buildHipClangJob(conf)
|
||||
@@ -156,14 +169,157 @@ def buildHipClangJobAndReboot(Map conf=[:]){
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def runCKProfiler(Map conf=[:]){
|
||||
show_node_info()
|
||||
|
||||
env.HSA_ENABLE_SDMA=0
|
||||
checkout scm
|
||||
|
||||
def image = "composable_kernels"
|
||||
def prefixpath = conf.get("prefixpath", "/opt/rocm")
|
||||
def gpu_arch = conf.get("gpu_arch", "gfx908")
|
||||
|
||||
// Jenkins is complaining about the render group
|
||||
// def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
|
||||
def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
|
||||
if (conf.get("enforce_xnack_on", false)) {
|
||||
dockerOpts = dockerOpts + " --env HSA_XNACK=1"
|
||||
}
|
||||
def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg GPU_ARCH='${gpu_arch}' "
|
||||
|
||||
def variant = env.STAGE_NAME
|
||||
|
||||
def retimage
|
||||
|
||||
gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') {
|
||||
if (params.USE_DOCKERFILE){
|
||||
try {
|
||||
retimage = docker.build("${image}", dockerArgs + '.')
|
||||
withDockerContainer(image: image, args: dockerOpts) {
|
||||
timeout(time: 5, unit: 'MINUTES')
|
||||
{
|
||||
sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo'
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){
|
||||
echo "The job was cancelled or aborted"
|
||||
throw e
|
||||
}
|
||||
catch(Exception ex) {
|
||||
retimage = docker.build("${image}", dockerArgs + "--no-cache .")
|
||||
withDockerContainer(image: image, args: dockerOpts) {
|
||||
timeout(time: 5, unit: 'MINUTES')
|
||||
{
|
||||
sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else{
|
||||
timeout(time: 3, unit: 'HOURS'){
|
||||
retimage = docker.image('compute-artifactory.amd.com:5000/rocm-plus-docker/framework/compute-rocm-dkms-no-npi-hipclang:9110_ubuntu18.04_py3.6_pytorch_rocm5.0_internal_testing_7ff5b54').pull()
|
||||
image="b56f8ac0d6ea"
|
||||
sh "docker images"
|
||||
}
|
||||
}
|
||||
|
||||
withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') {
|
||||
timeout(time: 5, unit: 'HOURS')
|
||||
{
|
||||
cmake_build(conf)
|
||||
dir("script"){
|
||||
//run gemm performance tests
|
||||
def gemm_log = "perf_gemm_${gpu_arch}.log"
|
||||
sh "rm -f ${gemm_log}"
|
||||
sh "echo Branch name: ${env.BRANCH_NAME} > ${gemm_log}"
|
||||
sh "echo Node name: ${NODE_NAME} >> ${gemm_log}"
|
||||
sh "echo GPU_arch name: ${gpu_arch} >> ${gemm_log}"
|
||||
sh "rocminfo | grep 'Compute Unit:' >> ${gemm_log} "
|
||||
sh "hipcc --version | grep -e 'HIP version' >> ${gemm_log}"
|
||||
sh "/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> ${gemm_log}"
|
||||
sh "./profile_gemm.sh gemm 0 0 0 1 0 5 | tee -a ${gemm_log}"
|
||||
sh "./profile_gemm.sh gemm 1 0 0 1 0 5 | tee -a ${gemm_log}"
|
||||
sh "./profile_gemm.sh gemm 2 0 0 1 0 5 | tee -a ${gemm_log}"
|
||||
sh "./profile_gemm.sh gemm 3 0 0 1 0 5 | tee -a ${gemm_log}"
|
||||
sh "./profile_gemm.sh gemm 0 1 0 1 0 5 | tee -a ${gemm_log}"
|
||||
sh "./profile_gemm.sh gemm 1 1 0 1 0 5 | tee -a ${gemm_log}"
|
||||
sh "./profile_gemm.sh gemm 2 1 0 1 0 5 | tee -a ${gemm_log}"
|
||||
sh "./profile_gemm.sh gemm 3 1 0 1 0 5 | tee -a ${gemm_log}"
|
||||
sh "./profile_gemm.sh gemm 0 2 0 1 0 5 | tee -a ${gemm_log}"
|
||||
sh "./profile_gemm.sh gemm 1 2 0 1 0 5 | tee -a ${gemm_log}"
|
||||
sh "./profile_gemm.sh gemm 2 2 0 1 0 5 | tee -a ${gemm_log}"
|
||||
sh "./profile_gemm.sh gemm 3 2 0 1 0 5 | tee -a ${gemm_log}"
|
||||
sh "./profile_gemm.sh gemm 0 3 0 1 0 5 | tee -a ${gemm_log}"
|
||||
sh "./profile_gemm.sh gemm 1 3 0 1 0 5 | tee -a ${gemm_log}"
|
||||
sh "./profile_gemm.sh gemm 2 3 0 1 0 5 | tee -a ${gemm_log}"
|
||||
sh "./profile_gemm.sh gemm 3 3 0 1 0 5 | tee -a ${gemm_log}"
|
||||
//results will be parsed, stored, and analyzed within the python script
|
||||
//the script will return 0 if the performance criteria are met
|
||||
//or return 1 if the criteria are not met
|
||||
archiveArtifacts "${gemm_log}"
|
||||
sh "python3 parse_perf_data.py ${gemm_log} "
|
||||
//run resnet50 test
|
||||
def resnet_log = "perf_resnet50_${gpu_arch}.log"
|
||||
sh "rm -f ${resnet_log}"
|
||||
sh "echo Branch name: ${env.BRANCH_NAME} > ${resnet_log}"
|
||||
sh "echo Node name: ${NODE_NAME} >> ${resnet_log}"
|
||||
sh "echo GPU_arch name: ${gpu_arch} >> ${resnet_log}"
|
||||
sh "rocminfo | grep 'Compute Unit:' >> ${resnet_log} "
|
||||
sh "hipcc --version | grep -e 'HIP version' >> ${resnet_log}"
|
||||
sh "/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> ${resnet_log}"
|
||||
//first run tests with N=256
|
||||
sh "./profile_conv.sh conv_fwd_bias_relu 1 1 1 1 0 2 0 1 256 | tee -a ${resnet_log}"
|
||||
//then run with N=4
|
||||
sh "./profile_conv.sh conv_fwd_bias_relu 1 1 1 1 0 2 0 1 4 | tee -a ${resnet_log}"
|
||||
archiveArtifacts "${resnet_log}"
|
||||
//the script will put the results from N=256 and N=4 runs into separate tables
|
||||
sh "python3 parse_perf_data.py ${resnet_log} "
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return retimage
|
||||
}
|
||||
|
||||
|
||||
def runPerfTest(Map conf=[:]){
|
||||
try{
|
||||
runCKProfiler(conf)
|
||||
}
|
||||
catch(e){
|
||||
echo "throwing error exception in performance tests"
|
||||
echo 'Exception occurred: ' + e.toString()
|
||||
throw e
|
||||
}
|
||||
finally{
|
||||
if (!conf.get("no_reboot", false)) {
|
||||
reboot()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pipeline {
|
||||
agent none
|
||||
options {
|
||||
parallelsAlwaysFailFast()
|
||||
}
|
||||
// environment{
|
||||
// variable = value
|
||||
// }
|
||||
parameters {
|
||||
booleanParam(
|
||||
name: "USE_DOCKERFILE",
|
||||
defaultValue: true,
|
||||
description: "")
|
||||
}
|
||||
environment{
|
||||
dbuser = "${dbuser}"
|
||||
dbpassword = "${dbpassword}"
|
||||
dbsship = "${dbsship}"
|
||||
dbsshport = "${dbsshport}"
|
||||
dbsshuser = "${dbsshuser}"
|
||||
dbsshpassword = "${dbsshpassword}"
|
||||
status_wrapper_creds = "${status_wrapper_creds}"
|
||||
}
|
||||
stages{
|
||||
stage("Static checks") {
|
||||
parallel{
|
||||
@@ -178,29 +334,6 @@ pipeline {
|
||||
// buildHipClangJobAndReboot(build_cmd: build_cmd, no_reboot:true, prefixpath: '/opt/rocm', build_type: 'debug')
|
||||
// }
|
||||
// }
|
||||
stage('Build Profiler: Release, gfx908')
|
||||
{
|
||||
agent { label rocmnode("nogpu")}
|
||||
environment{
|
||||
setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " -DBUILD_DEV=On """
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release')
|
||||
}
|
||||
}
|
||||
stage('Build Profiler: Debug, gfx908')
|
||||
{
|
||||
agent { label rocmnode("nogpu")}
|
||||
environment{
|
||||
setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " -DBUILD_DEV=On """
|
||||
}
|
||||
steps{
|
||||
// until we stabilize debug build due to compiler crashes
|
||||
catchError(buildResult: 'SUCCESS', stageResult: 'FAILURE') {
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Debug')
|
||||
}
|
||||
}
|
||||
}
|
||||
stage('Clang Format') {
|
||||
agent{ label rocmnode("nogpu") }
|
||||
environment{
|
||||
@@ -220,7 +353,7 @@ pipeline {
|
||||
}
|
||||
}
|
||||
}
|
||||
stage("Tests")
|
||||
stage("Tests")
|
||||
{
|
||||
parallel
|
||||
{
|
||||
@@ -228,12 +361,11 @@ pipeline {
|
||||
{
|
||||
agent{ label rocmnode("gfx908")}
|
||||
environment{
|
||||
setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " -DBUILD_DEV=On """
|
||||
setup_args = """ -D CMAKE_CXX_FLAGS=" --offload-arch=gfx908 -O3 " -DBUILD_DEV=On """
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "check", no_reboot:true, build_type: 'Release')
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "check", no_reboot:true, build_type: 'Release', gpu_arch: "gfx908")
|
||||
}
|
||||
|
||||
}
|
||||
stage("Run Tests: gfx90a")
|
||||
{
|
||||
@@ -242,26 +374,68 @@ pipeline {
|
||||
setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx90a -O3 " -DBUILD_DEV=On """
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "check", no_reboot:true, build_type: 'Release')
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "check", no_reboot:true, build_type: 'Release', gpu_arch: "gfx90a")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
// enable after the cmake file supports packaging
|
||||
// stage("Packages") {
|
||||
// when {
|
||||
// expression { params.BUILD_PACKAGES && params.TARGET_NOGPU && params.DATATYPE_NA }
|
||||
// }
|
||||
// parallel {
|
||||
// stage("Package /opt/rocm") {
|
||||
// agent{ label rocmnode("nogpu") }
|
||||
// steps{
|
||||
// buildHipClangJobAndReboot( package_build: "true", prefixpath: '/opt/rocm', gpu_arch: "gfx906;gfx908;gfx90a")
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
stage("Client App")
|
||||
{
|
||||
parallel
|
||||
{
|
||||
stage("Run Client App")
|
||||
{
|
||||
agent{ label rocmnode("gfx908")}
|
||||
environment{
|
||||
setup_args = """ -D -DBUILD_DEV=Off -DCMAKE_INSTALL_PREFIX=../install CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " """
|
||||
execute_args = """ cd ../test/client_app && rm -rf build && mkdir build && cd build && cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" .. && make """
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
stage("Performance Tests")
|
||||
{
|
||||
parallel
|
||||
{
|
||||
stage("Run ckProfiler: gfx908")
|
||||
{
|
||||
agent{ label rocmnode("gfx908")}
|
||||
environment{
|
||||
setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " -DBUILD_DEV=On """
|
||||
}
|
||||
steps{
|
||||
runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release', gpu_arch: "gfx908")
|
||||
}
|
||||
}
|
||||
stage("Run ckProfiler: gfx90a")
|
||||
{
|
||||
agent{ label rocmnode("gfx90a")}
|
||||
environment{
|
||||
setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx90a -O3 " -DBUILD_DEV=On """
|
||||
}
|
||||
steps{
|
||||
runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release', gpu_arch: "gfx90a")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
/* enable after the cmake file supports packaging
|
||||
stage("Packages") {
|
||||
when {
|
||||
expression { params.BUILD_PACKAGES && params.TARGET_NOGPU && params.DATATYPE_NA }
|
||||
}
|
||||
parallel {
|
||||
stage("Package /opt/rocm") {
|
||||
agent{ label rocmnode("nogpu") }
|
||||
steps{
|
||||
buildHipClangJobAndReboot( package_build: "true", prefixpath: '/opt/rocm', gpu_arch: "gfx906;gfx908;gfx90a")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
*/
|
||||
}
|
||||
}
|
||||
|
||||
28
LICENSE
Normal file
28
LICENSE
Normal file
@@ -0,0 +1,28 @@
|
||||
Copyright (c) 2018- , Advanced Micro Devices, Inc. (Chao Liu, Jing Zhang)
|
||||
Copyright (c) 2019- , Advanced Micro Devices, Inc. (Letao Qin, Qianfeng Zhang, Liang Huang, Shaojie Wang)
|
||||
Copyright (c) 2022- , Advanced Micro Devices, Inc. (Anthony Chang, Chunyu Lai, Illia Silin, Adam Osewski, Poyen Chen, Jehandad Khan)
|
||||
Copyright (c) 2019-2021, Advanced Micro Devices, Inc. (Hanwen Chang)
|
||||
Copyright (c) 2019-2020, Advanced Micro Devices, Inc. (Tejash Shah)
|
||||
Copyright (c) 2020 , Advanced Micro Devices, Inc. (Xiaoyan Zhou)
|
||||
Copyright (c) 2021-2022, Advanced Micro Devices, Inc. (Jianfeng Yan)
|
||||
|
||||
SPDX-License-Identifier: MIT
|
||||
Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
14
README.md
14
README.md
@@ -6,7 +6,7 @@ docker run \
|
||||
--group-add sudo \
|
||||
-w /root/workspace \
|
||||
-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \
|
||||
rocm/tensorflow:rocm4.3.1-tf2.6-dev \
|
||||
rocm/tensorflow:rocm5.1-tf2.6-dev \
|
||||
/bin/bash
|
||||
```
|
||||
|
||||
@@ -20,7 +20,7 @@ mkdir build && cd build
|
||||
cmake \
|
||||
-D BUILD_DEV=OFF \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D CMAKE_CXX_FLAGS=" --offload-arch=gfx908 --offload-arch=gfx90a -O3 \
|
||||
-D CMAKE_CXX_FLAGS=" --offload-arch=gfx908 --offload-arch=gfx90a -O3" \
|
||||
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
||||
-D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
..
|
||||
@@ -43,3 +43,13 @@ Instructions for running each individual examples are under ```example/```
|
||||
make -j ckProfiler
|
||||
```
|
||||
Instructions for running ckProfiler are under ```profiler/```
|
||||
|
||||
|
||||
## Caveat
|
||||
### Kernel Timing and Verification
|
||||
CK's own kernel timer will warn up kernel once, and then run it multiple times
|
||||
to get average kernel time. For some kernels that use atomic add, this will cause
|
||||
output buffer to be accumulated multiple times, causing verfication failure.
|
||||
To work around it, do not use CK's own timer and do verification at the same time.
|
||||
CK's own timer and verification in each example and ckProfiler can be enabled or
|
||||
disabled from command line.
|
||||
|
||||
@@ -66,7 +66,7 @@ else()
|
||||
-Wunreachable-code
|
||||
-Wunused
|
||||
|
||||
-Wno-sign-compare
|
||||
-Wsign-compare
|
||||
-Wno-extra-semi-stmt
|
||||
)
|
||||
if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "Clang")
|
||||
|
||||
@@ -18,6 +18,8 @@ list(APPEND GTEST_CMAKE_CXX_FLAGS
|
||||
-Wno-switch-enum
|
||||
-Wno-zero-as-null-pointer-constant
|
||||
-Wno-unused-member-function
|
||||
-Wno-comma
|
||||
-Wno-old-style-cast
|
||||
)
|
||||
message(STATUS "Suppressing googltest warnings with flags: ${GTEST_CMAKE_CXX_FLAGS}")
|
||||
|
||||
@@ -33,4 +35,5 @@ FetchContent_MakeAvailable(googletest)
|
||||
|
||||
target_compile_options(gtest PRIVATE ${GTEST_CMAKE_CXX_FLAGS})
|
||||
target_compile_options(gtest_main PRIVATE ${GTEST_CMAKE_CXX_FLAGS})
|
||||
|
||||
target_compile_options(gmock PRIVATE ${GTEST_CMAKE_CXX_FLAGS})
|
||||
target_compile_options(gmock_main PRIVATE ${GTEST_CMAKE_CXX_FLAGS})
|
||||
|
||||
@@ -1,3 +1,8 @@
|
||||
add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp)
|
||||
add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp)
|
||||
add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp)
|
||||
add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp)
|
||||
add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp)
|
||||
add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp)
|
||||
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
|
||||
add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp)
|
||||
|
||||
209
example/01_gemm/gemm_dl_fp16.cpp
Normal file
209
example/01_gemm/gemm_dl_fp16.cpp
Normal file
@@ -0,0 +1,209 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_gemm_dl.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "reference_gemm.hpp"
|
||||
#include "gemm_specialization.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = ck::half_t;
|
||||
using BDataType = ck::half_t;
|
||||
using CDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
|
||||
using ALayout = Col;
|
||||
using BLayout = Row;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::
|
||||
// ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
|
||||
// ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
|
||||
// ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
|
||||
// ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 3840;
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
|
||||
ck::index_t StrideA = 4096;
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideC = 4096;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// do nothing
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 10)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
|
||||
StrideA = std::stoi(argv[7]);
|
||||
StrideB = std::stoi(argv[8]);
|
||||
StrideC = std::stoi(argv[9]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({stride, 1}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({1, stride}));
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
break;
|
||||
case 2:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
}
|
||||
|
||||
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
|
||||
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
|
||||
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< gemm.GetTypeString() << std::endl;
|
||||
|
||||
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
pass = ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
|
||||
}
|
||||
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
208
example/01_gemm/gemm_dl_fp32.cpp
Normal file
208
example/01_gemm/gemm_dl_fp32.cpp
Normal file
@@ -0,0 +1,208 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_gemm_dl.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "reference_gemm.hpp"
|
||||
#include "gemm_specialization.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = float;
|
||||
using BDataType = float;
|
||||
using CDataType = float;
|
||||
using AccDataType = float;
|
||||
|
||||
using ALayout = Col;
|
||||
using BLayout = Row;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::
|
||||
// ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
|
||||
// ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
|
||||
// ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
|
||||
// ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmDl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 3840;
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
|
||||
ck::index_t StrideA = 4096;
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideC = 4096;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// do nothing
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 10)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
|
||||
StrideA = std::stoi(argv[7]);
|
||||
StrideB = std::stoi(argv[8]);
|
||||
StrideC = std::stoi(argv[9]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({stride, 1}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({1, stride}));
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
break;
|
||||
case 2:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
}
|
||||
|
||||
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
|
||||
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
|
||||
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< gemm.GetTypeString() << std::endl;
|
||||
|
||||
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
pass = ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
|
||||
}
|
||||
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
206
example/01_gemm/gemm_dl_int8.cpp
Normal file
206
example/01_gemm/gemm_dl_int8.cpp
Normal file
@@ -0,0 +1,206 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_gemm_dl.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "reference_gemm.hpp"
|
||||
#include "gemm_specialization.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = int8_t;
|
||||
using BDataType = int8_t;
|
||||
using CDataType = int8_t;
|
||||
using AccDataType = int32_t;
|
||||
|
||||
using ALayout = Col;
|
||||
using BLayout = Row;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::
|
||||
// #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
|
||||
// #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
|
||||
// #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
|
||||
// #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 3840;
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
|
||||
ck::index_t StrideA = 4096;
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideC = 4096;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// do nothing
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 10)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
|
||||
StrideA = std::stoi(argv[7]);
|
||||
StrideB = std::stoi(argv[8]);
|
||||
StrideC = std::stoi(argv[9]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({stride, 1}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({1, stride}));
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
break;
|
||||
case 2:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
}
|
||||
|
||||
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
|
||||
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
|
||||
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< gemm.GetTypeString() << std::endl;
|
||||
|
||||
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
pass = ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
|
||||
}
|
||||
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
@@ -84,13 +84,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<float, float, float, PassThrough, PassThrough, PassThrough>;
|
||||
ReferenceGemm<float, float, float, float, PassThrough, PassThrough, PassThrough>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = 0;
|
||||
int init_method = 0;
|
||||
int nrepeat = 5;
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 3840;
|
||||
@@ -105,13 +105,13 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
nrepeat = std::stoi(argv[3]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 10)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
nrepeat = std::stoi(argv[3]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
@@ -125,7 +125,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: run kernel # of times (>1)\n");
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
|
||||
exit(0);
|
||||
}
|
||||
@@ -193,12 +193,12 @@ int main(int argc, char* argv[])
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, nrepeat);
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
@@ -232,7 +232,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
ck::utils::check_err(c_m_n_device_f32_result.mData, c_m_n_host_result.mData);
|
||||
return ck::utils::check_err(c_m_n_device_f32_result.mData, c_m_n_host_result.mData) ? 0 : 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "device.hpp"
|
||||
@@ -29,29 +28,30 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = ck::half_t;
|
||||
using BDataType = ck::half_t;
|
||||
using CDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using CDataType = F16;
|
||||
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// clang-format off
|
||||
#if 1
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
|
||||
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
|
||||
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
|
||||
#elif 0
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_ProducerConsumer_CShuffle
|
||||
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
@@ -70,13 +70,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>;
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = 0;
|
||||
int init_method = 0;
|
||||
int nrepeat = 5;
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 3840;
|
||||
@@ -87,17 +87,21 @@ int main(int argc, char* argv[])
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideC = 4096;
|
||||
|
||||
if(argc == 4)
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
nrepeat = std::stoi(argv[3]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 10)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
nrepeat = std::stoi(argv[3]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
@@ -111,7 +115,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: run kernel # of times (>1)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
|
||||
exit(0);
|
||||
}
|
||||
@@ -184,12 +188,12 @@ int main(int argc, char* argv[])
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, nrepeat);
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
@@ -214,7 +218,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
|
||||
return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
238
example/01_gemm/gemm_xdl_fp64.cpp
Normal file
238
example/01_gemm/gemm_xdl_fp64.cpp
Normal file
@@ -0,0 +1,238 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_gemm_xdl.hpp"
|
||||
#include "device_gemm_xdl_cshuffle.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "reference_gemm.hpp"
|
||||
#include "gemm_specialization.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F64 = double;
|
||||
|
||||
using ADataType = double;
|
||||
using BDataType = double;
|
||||
using CDataType = double;
|
||||
using AccDataType = double;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
|
||||
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
#if 0
|
||||
< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 32, 4, 1, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 7, 1>;
|
||||
#else
|
||||
< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>;
|
||||
#endif
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
template <typename DataType>
|
||||
std::ostream& show_2d_matrix(std::ostream& os, Tensor<DataType>& matrix)
|
||||
{
|
||||
os << "[" << std::endl;
|
||||
for(int x = 0; x < matrix.mDesc.GetLengths()[0]; x++)
|
||||
{
|
||||
os << "[";
|
||||
for(int y = 0; y < matrix.mDesc.GetLengths()[1]; y++)
|
||||
{
|
||||
os << std::setw(4) << static_cast<float>(matrix(x, y));
|
||||
}
|
||||
os << "]" << std::endl;
|
||||
}
|
||||
os << "]";
|
||||
return os;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = 0;
|
||||
int init_method = 0;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 3840;
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
|
||||
ck::index_t StrideA = 4096;
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideC = 4096;
|
||||
|
||||
if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 10)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
|
||||
StrideA = std::stoi(argv[7]);
|
||||
StrideB = std::stoi(argv[8]);
|
||||
StrideC = std::stoi(argv[9]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: run kernel # of times (>1)\n");
|
||||
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({stride, 1}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({1, stride}));
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
std::cout << "data type: " << typeid(ADataType{}).name() << std::endl;
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
break;
|
||||
case 2:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
|
||||
}
|
||||
|
||||
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
|
||||
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
|
||||
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< gemm.GetTypeString() << std::endl;
|
||||
|
||||
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
#if 0
|
||||
{
|
||||
show_2d_matrix(std::cout << "a : ", a_m_k) << std::endl;
|
||||
show_2d_matrix(std::cout << "b: ", b_k_n) << std::endl;
|
||||
show_2d_matrix(std::cout << "c_device: ", c_m_n_device_result) << std::endl;
|
||||
show_2d_matrix(std::cout << "c_host :", c_m_n_host_result) << std::endl;
|
||||
}
|
||||
#endif
|
||||
return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -78,14 +78,19 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
|
||||
16>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>;
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = 0;
|
||||
int init_method = 0;
|
||||
int nrepeat = 5;
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 3840;
|
||||
@@ -100,13 +105,13 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
nrepeat = std::stoi(argv[3]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 10)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
nrepeat = std::stoi(argv[3]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
@@ -120,7 +125,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: run kernel # of times (>1)\n");
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
|
||||
exit(0);
|
||||
}
|
||||
@@ -189,12 +194,12 @@ int main(int argc, char* argv[])
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, nrepeat);
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
@@ -219,7 +224,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
|
||||
return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
@@ -86,9 +86,9 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemmBias2D<AD
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = 0;
|
||||
int init_method = 0;
|
||||
int nrepeat = 5;
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 3840;
|
||||
@@ -106,13 +106,13 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
nrepeat = std::stoi(argv[3]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 6)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
nrepeat = std::stoi(argv[3]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
alpha = std::stof(argv[4]);
|
||||
beta = std::stof(argv[5]);
|
||||
@@ -121,7 +121,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
nrepeat = std::stoi(argv[3]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
@@ -138,7 +138,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: run kernel # of times (>1)\n");
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC, alpha, beta\n");
|
||||
exit(0);
|
||||
}
|
||||
@@ -216,7 +216,7 @@ int main(int argc, char* argv[])
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, nrepeat);
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
@@ -246,6 +246,8 @@ int main(int argc, char* argv[])
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
|
||||
return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -3,89 +3,109 @@
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "host_gemm.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "device_gemm_xdl_c_shuffle_bias_activation.hpp"
|
||||
#include "reference_gemm_bias_activation.hpp"
|
||||
#include "reference_gemm.hpp"
|
||||
#include "gemm_specialization.hpp"
|
||||
#include "device_gemm_multiple_d_xdl_cshuffle.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using ADataType = ck::half_t;
|
||||
using BDataType = ck::half_t;
|
||||
using CDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::AddRelu;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle_Bias_Activation<
|
||||
ADataType, // ADataType
|
||||
BDataType, // BDataType
|
||||
CDataType, // CDataType
|
||||
AccDataType, // AccDataType
|
||||
ALayout, // ALayout
|
||||
BLayout, // BLayout
|
||||
CLayout, // CLayout
|
||||
AElementOp, // AElementwiseOperation
|
||||
BElementOp, // BElementwiseOperation
|
||||
CElementOp, // CElementwiseOperation
|
||||
256, // BlockSize
|
||||
256, // MPerBlock
|
||||
128, // NPerBlock
|
||||
4, // K0PerBlock
|
||||
8, // K1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
4, // MXdlPerWave
|
||||
2, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_K1
|
||||
true, // ABlockLdsAddExtraM
|
||||
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_K1
|
||||
true, // BBlockLdsAddExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
|
||||
8>; // CBlockTransferScalarPerVector_NWaveNPerXdl
|
||||
// clang-format on
|
||||
// C = A * B
|
||||
// E = Relu(C + D);
|
||||
struct AddRelu
|
||||
{
|
||||
__host__ __device__ void
|
||||
operator()(ck::half_t& e, const ck::half_t& c, const ck::half_t& d) const
|
||||
{
|
||||
const ck::half_t x = c + d;
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemmBiasActivation<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
e = x > 0 ? x : 0;
|
||||
}
|
||||
};
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F16;
|
||||
using DDataType = F16;
|
||||
using DsDataType = ck::Tuple<DDataType>;
|
||||
using EDataType = F16;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using ELayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = AddRelu;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
using DeviceOpInstance =
|
||||
ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle<ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp,
|
||||
GemmDefault,
|
||||
1,
|
||||
256,
|
||||
256,
|
||||
128,
|
||||
32,
|
||||
8,
|
||||
8,
|
||||
32,
|
||||
32,
|
||||
4,
|
||||
2,
|
||||
S<4, 64, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
1,
|
||||
S<4, 64, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = 0;
|
||||
int init_method = 0;
|
||||
int nrepeat = 5;
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 3840;
|
||||
@@ -94,19 +114,23 @@ int main(int argc, char* argv[])
|
||||
|
||||
ck::index_t StrideA = 4096;
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideC = 4096;
|
||||
ck::index_t StrideE = 4096;
|
||||
|
||||
if(argc == 4)
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
nrepeat = std::stoi(argv[3]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 10)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
nrepeat = std::stoi(argv[3]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
@@ -114,14 +138,14 @@ int main(int argc, char* argv[])
|
||||
|
||||
StrideA = std::stoi(argv[7]);
|
||||
StrideB = std::stoi(argv[8]);
|
||||
StrideC = std::stoi(argv[9]);
|
||||
StrideE = std::stoi(argv[9]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: run kernel # of times (>1)\n");
|
||||
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
@@ -141,17 +165,14 @@ int main(int argc, char* argv[])
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
// c0_n[n]
|
||||
Tensor<CDataType> c0_n(HostTensorDescriptor(
|
||||
std::vector<std::size_t>({static_cast<std::size_t>(N)}), std::vector<std::size_t>({1})));
|
||||
Tensor<DDataType> d_m_n(f_host_tensor_descriptor(M, N, 0, ELayout{}));
|
||||
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
|
||||
std::cout << "c0_n: " << c0_n.mDesc << std::endl;
|
||||
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
|
||||
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
@@ -159,59 +180,59 @@ int main(int argc, char* argv[])
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
c0_n.GenerateTensorValue(GeneratorTensor_2<CDataType>{-5, 5});
|
||||
d_m_n.GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
c0_n.GenerateTensorValue(GeneratorTensor_3<CDataType>{0.0, 1.0});
|
||||
d_m_n.GenerateTensorValue(GeneratorTensor_3<DDataType>{0.0, 1.0});
|
||||
}
|
||||
|
||||
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
|
||||
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
|
||||
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
|
||||
DeviceMem c0_n_device_buf(sizeof(CDataType) * c0_n.mDesc.GetElementSpace());
|
||||
DeviceMem d_m_n_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpace());
|
||||
DeviceMem e_m_n_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace());
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
|
||||
c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data());
|
||||
c0_n_device_buf.ToDevice(c0_n.mData.data());
|
||||
d_m_n_device_buf.ToDevice(d_m_n.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto device_op = DeviceOpInstance{};
|
||||
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c0_n_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
auto argument =
|
||||
device_op.MakeArgument(a_m_k_device_buf.GetDeviceBuffer(),
|
||||
b_k_n_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, 1>{d_m_n_device_buf.GetDeviceBuffer()},
|
||||
e_m_n_device_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
std::array<ck::index_t, 1>{0},
|
||||
StrideE,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
throw std::runtime_error("wrong! this device_op instance does not support this problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, nrepeat);
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
|
||||
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * M +
|
||||
sizeof(CDataType) * M * N + sizeof(CDataType) * N;
|
||||
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
|
||||
sizeof(EDataType) * M * N + sizeof(EDataType) * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
@@ -220,18 +241,38 @@ int main(int argc, char* argv[])
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
e_m_n_device_buf.FromDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
Tensor<AccDataType> c_m_n(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
PassThrough>;
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_m_k, b_k_n, c_m_n_host_result, c0_n, a_element_op, b_element_op, c_element_op);
|
||||
auto ref_argument =
|
||||
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n));
|
||||
}
|
||||
}
|
||||
|
||||
return ck::utils::check_err(e_m_n_device_result.mData, e_m_n_host_result.mData) ? 0 : 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
1
example/04_gemm_add_add_fastgelu/CMakeLists.txt
Normal file
1
example/04_gemm_add_add_fastgelu/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp)
|
||||
23
example/04_gemm_add_add_fastgelu/README.md
Normal file
23
example/04_gemm_add_add_fastgelu/README.md
Normal file
@@ -0,0 +1,23 @@
|
||||
# Instructions for ```example_gemm_add_add_fastgelu_xdl_fp16```
|
||||
|
||||
## Run ```example_gemm_add_add_fastgelu_xdl_fp16```
|
||||
```bash
|
||||
#arg1: verification (0=no, 1=yes)
|
||||
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
|
||||
#arg3: time kernel (0=no, 1=yes)
|
||||
#arg4 to 11: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD0, StrideD1, StrideE"
|
||||
./bin/example_gemm_add_add_fastgelu_xdl_fp16 1 1 1
|
||||
```
|
||||
|
||||
Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16)
|
||||
```
|
||||
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
|
||||
b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096}
|
||||
d0_m_n: dim 2, lengths {3840, 4096}, strides {0, 1}
|
||||
d1_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
|
||||
e_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
|
||||
launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up 1 time
|
||||
Start running 10 times...
|
||||
Perf: 1.26914 ms, 101.525 TFlops, 100.804 GB/s, DeviceGemmMultipleD_Xdl_CShuffle<256, 256, 128, 32, 8, 8>
|
||||
```
|
||||
@@ -0,0 +1,245 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "reference_gemm.hpp"
|
||||
#include "gemm_specialization.hpp"
|
||||
#include "device_gemm_multiple_d_xdl_cshuffle.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu;
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using D0DataType = F16;
|
||||
using D1DataType = F16;
|
||||
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
|
||||
using EDataType = F16;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using D0Layout = Row;
|
||||
using D1Layout = Row;
|
||||
using ELayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = AddAddFastGelu;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// clang-format off
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle
|
||||
//######| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 3840;
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
|
||||
ck::index_t StrideA = 4096;
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideD0 = 0;
|
||||
ck::index_t StrideD1 = 4096;
|
||||
ck::index_t StrideE = 4096;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 12)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
|
||||
StrideA = std::stoi(argv[7]);
|
||||
StrideB = std::stoi(argv[8]);
|
||||
StrideD0 = std::stoi(argv[9]);
|
||||
StrideD1 = std::stoi(argv[10]);
|
||||
StrideE = std::stoi(argv[11]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD0, StrideD1, "
|
||||
"StrideE\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({stride, 1}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({1, stride}));
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, D0Layout{}));
|
||||
Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor(M, N, StrideD1, D1Layout{}));
|
||||
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl;
|
||||
std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl;
|
||||
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-5, 5});
|
||||
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
d0_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
|
||||
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
|
||||
}
|
||||
|
||||
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
|
||||
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
|
||||
DeviceMem d0_m_n_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpace());
|
||||
DeviceMem d1_m_n_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpace());
|
||||
DeviceMem e_m_n_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace());
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
|
||||
d0_m_n_device_buf.ToDevice(d0_m_n.mData.data());
|
||||
d1_m_n_device_buf.ToDevice(d1_m_n.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto device_op = DeviceOpInstance{};
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument =
|
||||
device_op.MakeArgument(a_m_k_device_buf.GetDeviceBuffer(),
|
||||
b_k_n_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, 2>{d0_m_n_device_buf.GetDeviceBuffer(),
|
||||
d1_m_n_device_buf.GetDeviceBuffer()},
|
||||
e_m_n_device_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
std::array<ck::index_t, 2>{StrideD0, StrideD1},
|
||||
StrideE,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error("wrong! this device_op instance does not support this problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
|
||||
sizeof(D0DataType) * N + sizeof(D1DataType) * M * N +
|
||||
sizeof(EDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< device_op.GetTypeString() << std::endl;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
Tensor<AccDataType> c_m_n(HostTensorDescriptor(
|
||||
std::vector<std::size_t>{static_cast<std::size_t>(M), static_cast<std::size_t>(N)}));
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
PassThrough>;
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument =
|
||||
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n));
|
||||
}
|
||||
}
|
||||
|
||||
e_m_n_device_buf.FromDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
return ck::utils::check_err(e_m_n_device_result.mData, e_m_n_host_result.mData) ? 0 : 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -1 +0,0 @@
|
||||
add_example_executable(example_gemm_xdl_bias_relu_add gemm_xdl_bias_relu_add.cpp)
|
||||
@@ -1,28 +0,0 @@
|
||||
# Instructions for ```example_gemm_xdl_bias_relu_add```
|
||||
|
||||
## Run ```example_gemm_xdl_bias_relu_add```
|
||||
```bash
|
||||
#arg1: verification (0=no, 1=yes)
|
||||
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
|
||||
#arg3: run kernel # of times (>1)
|
||||
#arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC
|
||||
./bin/example_gemm_xdl_bias_relu_add 0 1 5 3840 4096 4096 4096 4096 4096
|
||||
```
|
||||
|
||||
Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16)
|
||||
```
|
||||
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
|
||||
b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096}
|
||||
c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
|
||||
c0_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
|
||||
c1_m_n: dim 2, lengths {3840, 4096}, strides {1, 0}
|
||||
arg.a_grid_desc_k0_m_k1_{512, 3840, 8}
|
||||
arg.b_grid_desc_k0_n_k1_{512, 4096, 8}
|
||||
arg.c_grid_desc_m_n_{ 3840, 4096}
|
||||
arg.c0_grid_desc_m_n_{ 3840, 4096}
|
||||
arg.c1_grid_desc_m_n_{ 3840, 4096}
|
||||
launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up
|
||||
Start running 5 times...
|
||||
Perf: 1.27583 ms, 100.992 TFlops, 73.9688 GB/s
|
||||
```
|
||||
@@ -1,255 +0,0 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "host_gemm.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "device_gemm_xdl_c_shuffle_bias_activation_add.hpp"
|
||||
#include "reference_gemm_bias_activation_add.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using ADataType = ck::half_t;
|
||||
using BDataType = ck::half_t;
|
||||
using CDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::AddReluAdd;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle_Bias_Activation_Add<
|
||||
ADataType, // ADataType
|
||||
BDataType, // BDataType
|
||||
CDataType, // CDataType
|
||||
AccDataType, // AccDataType
|
||||
ALayout, // ALayout
|
||||
BLayout, // BLayout
|
||||
CLayout, // CLayout
|
||||
AElementOp, // AElementwiseOperation
|
||||
BElementOp, // BElementwiseOperation
|
||||
CElementOp, // CElementwiseOperation
|
||||
256, // BlockSize
|
||||
256, // MPerBlock
|
||||
128, // NPerBlock
|
||||
4, // K0PerBlock
|
||||
8, // K1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
4, // MXdlPerWave
|
||||
2, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_K1
|
||||
true, // ABlockLdsAddExtraM
|
||||
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_K1
|
||||
true, // BBlockLdsAddExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
|
||||
8>; // CBlockTransferScalarPerVector_NWaveNPerXdl
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceGemmBiasActivationAdd<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = 0;
|
||||
int init_method = 0;
|
||||
int nrepeat = 5;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 3840;
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
|
||||
ck::index_t StrideA = 4096;
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideC = 4096;
|
||||
ck::index_t StrideC1 = 4096;
|
||||
|
||||
if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
nrepeat = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 11)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
nrepeat = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
|
||||
StrideA = std::stoi(argv[7]);
|
||||
StrideB = std::stoi(argv[8]);
|
||||
StrideC = std::stoi(argv[9]);
|
||||
StrideC1 = std::stoi(argv[10]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: run kernel # of times (>1)\n");
|
||||
printf("arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC, StrideC1\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({stride, 1}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({1, stride}));
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
// c0_n[n]
|
||||
Tensor<CDataType> c0_n(HostTensorDescriptor(
|
||||
std::vector<std::size_t>({static_cast<std::size_t>(N)}), std::vector<std::size_t>({1})));
|
||||
|
||||
// c1_m_n[m ,n]
|
||||
Tensor<CDataType> c1_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
|
||||
std::cout << "c0_n: " << c0_n.mDesc << std::endl;
|
||||
std::cout << "c1_m_n: " << c1_m_n.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
c0_n.GenerateTensorValue(GeneratorTensor_2<CDataType>{-5, 5});
|
||||
c1_m_n.GenerateTensorValue(GeneratorTensor_2<CDataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
c0_n.GenerateTensorValue(GeneratorTensor_3<CDataType>{0.0, 1.0});
|
||||
c1_m_n.GenerateTensorValue(GeneratorTensor_3<CDataType>{0.0, 1.0});
|
||||
}
|
||||
|
||||
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
|
||||
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
|
||||
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
|
||||
DeviceMem c0_n_device_buf(sizeof(CDataType) * c0_n.mDesc.GetElementSpace());
|
||||
DeviceMem c1_m_n_device_buf(sizeof(CDataType) * c1_m_n.mDesc.GetElementSpace());
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
|
||||
c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data());
|
||||
c0_n_device_buf.ToDevice(c0_n.mData.data());
|
||||
c1_m_n_device_buf.ToDevice(c1_m_n.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c0_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c1_m_n_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
StrideC1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, nrepeat);
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * M +
|
||||
sizeof(CDataType) * M * N + sizeof(CDataType) * N +
|
||||
sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(a_m_k,
|
||||
b_k_n,
|
||||
c_m_n_host_result,
|
||||
c0_n,
|
||||
c1_m_n,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
|
||||
}
|
||||
}
|
||||
@@ -1,2 +1,2 @@
|
||||
add_example_executable(example_conv2d_fwd_xdl_bias_relu conv2d_fwd_xdl_bias_relu.cpp)
|
||||
target_link_libraries(example_conv2d_fwd_xdl_bias_relu PRIVATE conv_fwd_util)
|
||||
target_link_libraries(example_conv2d_fwd_xdl_bias_relu PRIVATE conv_util)
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "conv_fwd_util.hpp"
|
||||
#include "conv_util.hpp"
|
||||
#include "device.hpp"
|
||||
#include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
@@ -93,7 +93,7 @@ void PrintUseMsg()
|
||||
{
|
||||
std::cout << "arg1: verification (0=no, 1=yes)\n"
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
|
||||
<< "arg3: run kernel # of times (>1)\n"
|
||||
<< "arg3: time kernel (0=n0, 1=yes)\n"
|
||||
<< "Following arguments:\n"
|
||||
<< " N, K, C, \n"
|
||||
<< " <filter spatial dimensions>, (ie Y, X for 2D)\n"
|
||||
@@ -120,40 +120,40 @@ ck::utils::conv::ConvParams ParseConvParams(int argc, char* argv[])
|
||||
ck::utils::conv::ConvParams params;
|
||||
int arg_idx = 4;
|
||||
|
||||
params.num_dim_spatial = num_dim_spatial;
|
||||
params.N = std::stoi(argv[arg_idx++]);
|
||||
params.K = std::stoi(argv[arg_idx++]);
|
||||
params.C = std::stoi(argv[arg_idx++]);
|
||||
params.num_dim_spatial_ = num_dim_spatial;
|
||||
params.N_ = std::stoi(argv[arg_idx++]);
|
||||
params.K_ = std::stoi(argv[arg_idx++]);
|
||||
params.C_ = std::stoi(argv[arg_idx++]);
|
||||
|
||||
params.filter_spatial_lengths.resize(num_dim_spatial);
|
||||
params.filter_spatial_lengths_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
|
||||
params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_spatial_lengths.resize(num_dim_spatial);
|
||||
params.input_spatial_lengths_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
|
||||
params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.conv_filter_strides.resize(num_dim_spatial);
|
||||
params.conv_filter_strides_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.conv_filter_strides[i] = std::stoi(argv[arg_idx++]);
|
||||
params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.conv_filter_dilations.resize(num_dim_spatial);
|
||||
params.conv_filter_dilations_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.conv_filter_dilations[i] = std::stoi(argv[arg_idx++]);
|
||||
params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_left_pads.resize(num_dim_spatial);
|
||||
params.input_left_pads_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_left_pads[i] = std::stoi(argv[arg_idx++]);
|
||||
params.input_left_pads_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_right_pads.resize(num_dim_spatial);
|
||||
params.input_right_pads_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_right_pads[i] = std::stoi(argv[arg_idx++]);
|
||||
params.input_right_pads_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
return params;
|
||||
@@ -165,9 +165,9 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck::utils::conv;
|
||||
|
||||
bool do_verification = 0;
|
||||
int init_method = 0;
|
||||
int nrepeat = 5;
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
const int num_dim_spatial = 2;
|
||||
|
||||
ck::utils::conv::ConvParams params;
|
||||
@@ -176,7 +176,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
nrepeat = std::stoi(argv[3]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
|
||||
if(argc >= 5)
|
||||
@@ -184,21 +184,21 @@ int main(int argc, char* argv[])
|
||||
params = ParseConvParams(argc, argv);
|
||||
}
|
||||
|
||||
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N),
|
||||
static_cast<std::size_t>(params.C)};
|
||||
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N_),
|
||||
static_cast<std::size_t>(params.C_)};
|
||||
input_dims.insert(std::end(input_dims),
|
||||
std::begin(params.input_spatial_lengths),
|
||||
std::end(params.input_spatial_lengths));
|
||||
std::begin(params.input_spatial_lengths_),
|
||||
std::end(params.input_spatial_lengths_));
|
||||
|
||||
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K),
|
||||
static_cast<std::size_t>(params.C)};
|
||||
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K_),
|
||||
static_cast<std::size_t>(params.C_)};
|
||||
filter_dims.insert(std::end(filter_dims),
|
||||
std::begin(params.filter_spatial_lengths),
|
||||
std::end(params.filter_spatial_lengths));
|
||||
std::begin(params.filter_spatial_lengths_),
|
||||
std::end(params.filter_spatial_lengths_));
|
||||
|
||||
const std::vector<ck::index_t>& output_spatial_lengths = params.GetOutputSpatialLengths();
|
||||
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N),
|
||||
static_cast<std::size_t>(params.K)};
|
||||
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N_),
|
||||
static_cast<std::size_t>(params.K_)};
|
||||
output_dims.insert(std::end(output_dims),
|
||||
std::begin(output_spatial_lengths),
|
||||
std::end(output_spatial_lengths));
|
||||
@@ -211,7 +211,7 @@ int main(int argc, char* argv[])
|
||||
get_output_host_tensor_descriptor(output_dims, num_dim_spatial));
|
||||
// bias: assume contiguous 1d vector
|
||||
Tensor<OutDataType> bias(
|
||||
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(params.K)})));
|
||||
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(params.K_)})));
|
||||
|
||||
std::cout << "input: " << input.mDesc << std::endl;
|
||||
std::cout << "weights: " << weights.mDesc << std::endl;
|
||||
@@ -248,16 +248,16 @@ int main(int argc, char* argv[])
|
||||
static_cast<const WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
static_cast<const OutDataType*>(bias_device_buf.GetDeviceBuffer()),
|
||||
params.N,
|
||||
params.K,
|
||||
params.C,
|
||||
params.input_spatial_lengths,
|
||||
params.filter_spatial_lengths,
|
||||
params.N_,
|
||||
params.K_,
|
||||
params.C_,
|
||||
params.input_spatial_lengths_,
|
||||
params.filter_spatial_lengths_,
|
||||
output_spatial_lengths,
|
||||
params.conv_filter_strides,
|
||||
params.conv_filter_dilations,
|
||||
params.input_left_pads,
|
||||
params.input_right_pads,
|
||||
params.conv_filter_strides_,
|
||||
params.conv_filter_dilations_,
|
||||
params.input_left_pads_,
|
||||
params.input_right_pads_,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{});
|
||||
@@ -269,18 +269,18 @@ int main(int argc, char* argv[])
|
||||
"not support this problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, nrepeat);
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = get_flops(
|
||||
params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths);
|
||||
params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths);
|
||||
std::size_t num_btype =
|
||||
get_btype<InDataType, WeiDataType, OutDataType>(params.N,
|
||||
params.C,
|
||||
params.K,
|
||||
params.input_spatial_lengths,
|
||||
params.filter_spatial_lengths,
|
||||
get_btype<InDataType, WeiDataType, OutDataType>(params.N_,
|
||||
params.C_,
|
||||
params.K_,
|
||||
params.input_spatial_lengths_,
|
||||
params.filter_spatial_lengths_,
|
||||
output_spatial_lengths) +
|
||||
sizeof(OutDataType) * (params.K);
|
||||
sizeof(OutDataType) * (params.K_);
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
@@ -296,16 +296,17 @@ int main(int argc, char* argv[])
|
||||
weights,
|
||||
host_output,
|
||||
bias,
|
||||
params.conv_filter_strides,
|
||||
params.conv_filter_dilations,
|
||||
params.input_left_pads,
|
||||
params.input_right_pads,
|
||||
params.conv_filter_strides_,
|
||||
params.conv_filter_dilations_,
|
||||
params.input_left_pads_,
|
||||
params.input_right_pads_,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{});
|
||||
ref_invoker.Run(ref_argument);
|
||||
out_device_buf.FromDevice(device_output.mData.data());
|
||||
ck::utils::check_err(
|
||||
host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f);
|
||||
return ck::utils::check_err(device_output.mData, host_output.mData) ? 0 : 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
add_example_executable(example_conv2d_fwd_xdl_bias_relu_add conv2d_fwd_xdl_bias_relu_add.cpp)
|
||||
target_link_libraries(example_conv2d_fwd_xdl_bias_relu_add PRIVATE conv_fwd_util)
|
||||
# FIXME: should fix validation failure
|
||||
add_example_executable_no_testing(example_conv2d_fwd_xdl_bias_relu_add conv2d_fwd_xdl_bias_relu_add.cpp)
|
||||
target_link_libraries(example_conv2d_fwd_xdl_bias_relu_add PRIVATE conv_util)
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "conv_fwd_util.hpp"
|
||||
#include "conv_util.hpp"
|
||||
#include "device.hpp"
|
||||
#include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
@@ -90,7 +90,7 @@ void PrintUseMsg()
|
||||
{
|
||||
std::cout << "arg1: verification (0=no, 1=yes)\n"
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
|
||||
<< "arg3: run kernel # of times (>1)\n"
|
||||
<< "arg3: time kernel (0=n0, 1=yes)\n"
|
||||
<< "Following arguments:\n"
|
||||
<< " N, K, C, \n"
|
||||
<< " <filter spatial dimensions>, (ie Y, X for 2D)\n"
|
||||
@@ -117,40 +117,40 @@ ck::utils::conv::ConvParams ParseConvParams(int argc, char* argv[])
|
||||
ck::utils::conv::ConvParams params;
|
||||
int arg_idx = 4;
|
||||
|
||||
params.num_dim_spatial = num_dim_spatial;
|
||||
params.N = std::stoi(argv[arg_idx++]);
|
||||
params.K = std::stoi(argv[arg_idx++]);
|
||||
params.C = std::stoi(argv[arg_idx++]);
|
||||
params.num_dim_spatial_ = num_dim_spatial;
|
||||
params.N_ = std::stoi(argv[arg_idx++]);
|
||||
params.K_ = std::stoi(argv[arg_idx++]);
|
||||
params.C_ = std::stoi(argv[arg_idx++]);
|
||||
|
||||
params.filter_spatial_lengths.resize(num_dim_spatial);
|
||||
params.filter_spatial_lengths_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
|
||||
params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_spatial_lengths.resize(num_dim_spatial);
|
||||
params.input_spatial_lengths_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
|
||||
params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.conv_filter_strides.resize(num_dim_spatial);
|
||||
params.conv_filter_strides_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.conv_filter_strides[i] = std::stoi(argv[arg_idx++]);
|
||||
params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.conv_filter_dilations.resize(num_dim_spatial);
|
||||
params.conv_filter_dilations_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.conv_filter_dilations[i] = std::stoi(argv[arg_idx++]);
|
||||
params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_left_pads.resize(num_dim_spatial);
|
||||
params.input_left_pads_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_left_pads[i] = std::stoi(argv[arg_idx++]);
|
||||
params.input_left_pads_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_right_pads.resize(num_dim_spatial);
|
||||
params.input_right_pads_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_right_pads[i] = std::stoi(argv[arg_idx++]);
|
||||
params.input_right_pads_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
return params;
|
||||
@@ -162,9 +162,9 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck::utils::conv;
|
||||
|
||||
bool do_verification = 0;
|
||||
int init_method = 0;
|
||||
int nrepeat = 5;
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
const int num_dim_spatial = 2;
|
||||
|
||||
ck::utils::conv::ConvParams params;
|
||||
@@ -173,7 +173,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
nrepeat = std::stoi(argv[3]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
|
||||
if(argc >= 5)
|
||||
@@ -181,21 +181,21 @@ int main(int argc, char* argv[])
|
||||
params = ParseConvParams(argc, argv);
|
||||
}
|
||||
|
||||
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N),
|
||||
static_cast<std::size_t>(params.C)};
|
||||
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N_),
|
||||
static_cast<std::size_t>(params.C_)};
|
||||
input_dims.insert(std::end(input_dims),
|
||||
std::begin(params.input_spatial_lengths),
|
||||
std::end(params.input_spatial_lengths));
|
||||
std::begin(params.input_spatial_lengths_),
|
||||
std::end(params.input_spatial_lengths_));
|
||||
|
||||
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K),
|
||||
static_cast<std::size_t>(params.C)};
|
||||
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K_),
|
||||
static_cast<std::size_t>(params.C_)};
|
||||
filter_dims.insert(std::end(filter_dims),
|
||||
std::begin(params.filter_spatial_lengths),
|
||||
std::end(params.filter_spatial_lengths));
|
||||
std::begin(params.filter_spatial_lengths_),
|
||||
std::end(params.filter_spatial_lengths_));
|
||||
|
||||
const std::vector<ck::index_t>& output_spatial_lengths = params.GetOutputSpatialLengths();
|
||||
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N),
|
||||
static_cast<std::size_t>(params.K)};
|
||||
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N_),
|
||||
static_cast<std::size_t>(params.K_)};
|
||||
output_dims.insert(std::end(output_dims),
|
||||
std::begin(output_spatial_lengths),
|
||||
std::end(output_spatial_lengths));
|
||||
@@ -209,7 +209,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
// bias: assume contiguous 1d vector
|
||||
Tensor<OutDataType> bias(
|
||||
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(params.K)})));
|
||||
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(params.K_)})));
|
||||
|
||||
// residual: assume same layout as output tensor
|
||||
Tensor<OutDataType> residual(get_output_host_tensor_descriptor(output_dims, num_dim_spatial));
|
||||
@@ -224,10 +224,10 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
input.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
|
||||
weights.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
|
||||
bias.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
|
||||
residual.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
|
||||
input.GenerateTensorValue(GeneratorTensor_2<InDataType>{-2, 2});
|
||||
weights.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-2, 2});
|
||||
bias.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-2, 2});
|
||||
residual.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-2, 2});
|
||||
break;
|
||||
default:
|
||||
input.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
|
||||
@@ -259,16 +259,16 @@ int main(int argc, char* argv[])
|
||||
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
static_cast<const OutDataType*>(bias_device_buf.GetDeviceBuffer()),
|
||||
static_cast<const OutDataType*>(resi_device_buf.GetDeviceBuffer()),
|
||||
params.N,
|
||||
params.K,
|
||||
params.C,
|
||||
params.input_spatial_lengths,
|
||||
params.filter_spatial_lengths,
|
||||
params.N_,
|
||||
params.K_,
|
||||
params.C_,
|
||||
params.input_spatial_lengths_,
|
||||
params.filter_spatial_lengths_,
|
||||
output_spatial_lengths,
|
||||
params.conv_filter_strides,
|
||||
params.conv_filter_dilations,
|
||||
params.input_left_pads,
|
||||
params.input_right_pads,
|
||||
params.conv_filter_strides_,
|
||||
params.conv_filter_dilations_,
|
||||
params.input_left_pads_,
|
||||
params.input_right_pads_,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
@@ -280,20 +280,20 @@ int main(int argc, char* argv[])
|
||||
"not support this problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, nrepeat);
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = get_flops(
|
||||
params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths);
|
||||
params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths);
|
||||
std::size_t num_btype =
|
||||
get_btype<InDataType, WeiDataType, OutDataType>(params.N,
|
||||
params.C,
|
||||
params.K,
|
||||
params.input_spatial_lengths,
|
||||
params.filter_spatial_lengths,
|
||||
get_btype<InDataType, WeiDataType, OutDataType>(params.N_,
|
||||
params.C_,
|
||||
params.K_,
|
||||
params.input_spatial_lengths_,
|
||||
params.filter_spatial_lengths_,
|
||||
output_spatial_lengths) +
|
||||
sizeof(OutDataType) * (params.K) +
|
||||
sizeof(OutDataType) * (params.K_) +
|
||||
sizeof(OutDataType) *
|
||||
(params.N * params.K * output_spatial_lengths[0] * output_spatial_lengths[1]);
|
||||
(params.N_ * params.K_ * output_spatial_lengths[0] * output_spatial_lengths[1]);
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
@@ -310,17 +310,18 @@ int main(int argc, char* argv[])
|
||||
host_output,
|
||||
bias,
|
||||
residual,
|
||||
params.conv_filter_strides,
|
||||
params.conv_filter_dilations,
|
||||
params.input_left_pads,
|
||||
params.input_right_pads,
|
||||
params.conv_filter_strides_,
|
||||
params.conv_filter_dilations_,
|
||||
params.input_left_pads_,
|
||||
params.input_right_pads_,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
out_device_buf.FromDevice(device_output.mData.data());
|
||||
ck::utils::check_err(
|
||||
host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f);
|
||||
return ck::utils::check_err(device_output.mData, host_output.mData) ? 0 : 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
add_example_executable(example_convnd_fwd_xdl convnd_fwd_xdl.cpp)
|
||||
target_link_libraries(example_convnd_fwd_xdl PRIVATE conv_fwd_util)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp)
|
||||
target_link_libraries(example_convnd_fwd_xdl_int8 PRIVATE conv_fwd_util)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp)
|
||||
target_link_libraries(example_convnd_fwd_xdl_fp16 PRIVATE conv_fwd_util)
|
||||
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
|
||||
add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp)
|
||||
target_link_libraries(example_convnd_fwd_xdl_fp64 PRIVATE conv_util)
|
||||
target_link_libraries(example_convnd_fwd_xdl_fp32 PRIVATE conv_util)
|
||||
target_link_libraries(example_convnd_fwd_xdl_int8 PRIVATE conv_util)
|
||||
target_link_libraries(example_convnd_fwd_xdl_fp16 PRIVATE conv_util)
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "conv_fwd_util.hpp"
|
||||
#include "conv_util.hpp"
|
||||
#include "device.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp"
|
||||
@@ -43,10 +43,10 @@ template <ck::index_t NumDimSpatial>
|
||||
using DeviceConvNDFwdInstance = ck::tensor_operation::device::
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
|
||||
// clang-format off
|
||||
InDataType, //
|
||||
InDataType, //
|
||||
WeiDataType, //
|
||||
OutDataType, //
|
||||
AccDataType, //
|
||||
AccDataType, //
|
||||
InElementOp, // Input Elementwise Operation
|
||||
WeiElementOp, // Weights Elementwise Operation
|
||||
OutElementOp, // Output Elementwise Operation
|
||||
@@ -110,7 +110,7 @@ void print_use_msg()
|
||||
{
|
||||
std::cout << "arg1: verification (0=no, 1=yes)\n"
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
|
||||
<< "arg3: run kernel # of times (>1)\n"
|
||||
<< "arg3: time kernel (0=n0, 1=yes)\n"
|
||||
<< "arg4: N spatial dimensions (default 2)\n"
|
||||
<< "Following arguments (depending on number of spatial dims):\n"
|
||||
<< " N, K, C, \n"
|
||||
@@ -137,40 +137,40 @@ ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, int argc, cha
|
||||
ck::utils::conv::ConvParams params;
|
||||
int arg_idx = 5;
|
||||
|
||||
params.num_dim_spatial = num_dim_spatial;
|
||||
params.N = std::stoi(argv[arg_idx++]);
|
||||
params.K = std::stoi(argv[arg_idx++]);
|
||||
params.C = std::stoi(argv[arg_idx++]);
|
||||
params.num_dim_spatial_ = num_dim_spatial;
|
||||
params.N_ = std::stoi(argv[arg_idx++]);
|
||||
params.K_ = std::stoi(argv[arg_idx++]);
|
||||
params.C_ = std::stoi(argv[arg_idx++]);
|
||||
|
||||
params.filter_spatial_lengths.resize(num_dim_spatial);
|
||||
params.filter_spatial_lengths_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
|
||||
params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_spatial_lengths.resize(num_dim_spatial);
|
||||
params.input_spatial_lengths_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
|
||||
params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.conv_filter_strides.resize(num_dim_spatial);
|
||||
params.conv_filter_strides_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.conv_filter_strides[i] = std::stoi(argv[arg_idx++]);
|
||||
params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.conv_filter_dilations.resize(num_dim_spatial);
|
||||
params.conv_filter_dilations_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.conv_filter_dilations[i] = std::stoi(argv[arg_idx++]);
|
||||
params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_left_pads.resize(num_dim_spatial);
|
||||
params.input_left_pads_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_left_pads[i] = std::stoi(argv[arg_idx++]);
|
||||
params.input_left_pads_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_right_pads.resize(num_dim_spatial);
|
||||
params.input_right_pads_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_right_pads[i] = std::stoi(argv[arg_idx++]);
|
||||
params.input_right_pads_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
return params;
|
||||
@@ -182,9 +182,9 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck::utils::conv;
|
||||
|
||||
bool do_verification = 0;
|
||||
int init_method = 0;
|
||||
int nrepeat = 5;
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
int num_dim_spatial = 2;
|
||||
|
||||
ck::utils::conv::ConvParams params;
|
||||
@@ -193,7 +193,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
nrepeat = std::stoi(argv[3]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
num_dim_spatial = std::stoi(argv[4]);
|
||||
}
|
||||
|
||||
@@ -202,21 +202,21 @@ int main(int argc, char* argv[])
|
||||
params = parse_conv_params(num_dim_spatial, argc, argv);
|
||||
}
|
||||
|
||||
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N),
|
||||
static_cast<std::size_t>(params.C)};
|
||||
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N_),
|
||||
static_cast<std::size_t>(params.C_)};
|
||||
input_dims.insert(std::end(input_dims),
|
||||
std::begin(params.input_spatial_lengths),
|
||||
std::end(params.input_spatial_lengths));
|
||||
std::begin(params.input_spatial_lengths_),
|
||||
std::end(params.input_spatial_lengths_));
|
||||
|
||||
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K),
|
||||
static_cast<std::size_t>(params.C)};
|
||||
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K_),
|
||||
static_cast<std::size_t>(params.C_)};
|
||||
filter_dims.insert(std::end(filter_dims),
|
||||
std::begin(params.filter_spatial_lengths),
|
||||
std::end(params.filter_spatial_lengths));
|
||||
std::begin(params.filter_spatial_lengths_),
|
||||
std::end(params.filter_spatial_lengths_));
|
||||
|
||||
const std::vector<ck::index_t>& output_spatial_lengths = params.GetOutputSpatialLengths();
|
||||
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N),
|
||||
static_cast<std::size_t>(params.K)};
|
||||
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N_),
|
||||
static_cast<std::size_t>(params.K_)};
|
||||
output_dims.insert(std::end(output_dims),
|
||||
std::begin(output_spatial_lengths),
|
||||
std::end(output_spatial_lengths));
|
||||
@@ -256,16 +256,16 @@ int main(int argc, char* argv[])
|
||||
conv->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
params.N,
|
||||
params.K,
|
||||
params.C,
|
||||
params.input_spatial_lengths,
|
||||
params.filter_spatial_lengths,
|
||||
params.N_,
|
||||
params.K_,
|
||||
params.C_,
|
||||
params.input_spatial_lengths_,
|
||||
params.filter_spatial_lengths_,
|
||||
output_spatial_lengths,
|
||||
params.conv_filter_strides,
|
||||
params.conv_filter_dilations,
|
||||
params.input_left_pads,
|
||||
params.input_right_pads,
|
||||
params.conv_filter_strides_,
|
||||
params.conv_filter_dilations_,
|
||||
params.input_left_pads_,
|
||||
params.input_right_pads_,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{});
|
||||
@@ -277,22 +277,22 @@ int main(int argc, char* argv[])
|
||||
"not support this Conv problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker->Run(argument.get(), nrepeat);
|
||||
float ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = get_flops(
|
||||
params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths);
|
||||
params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths);
|
||||
std::size_t num_btype = get_btype<InDataType, WeiDataType, OutDataType>(
|
||||
params.N,
|
||||
params.C,
|
||||
params.K,
|
||||
params.input_spatial_lengths,
|
||||
params.filter_spatial_lengths,
|
||||
params.N_,
|
||||
params.C_,
|
||||
params.K_,
|
||||
params.input_spatial_lengths_,
|
||||
params.filter_spatial_lengths_,
|
||||
output_spatial_lengths);
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< conv->GetTypeString() << std::endl;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
@@ -302,40 +302,38 @@ int main(int argc, char* argv[])
|
||||
auto ref_argument = ref_conv.MakeArgument(input,
|
||||
weights,
|
||||
host_output,
|
||||
params.conv_filter_strides,
|
||||
params.conv_filter_dilations,
|
||||
params.input_left_pads,
|
||||
params.input_right_pads,
|
||||
params.conv_filter_strides_,
|
||||
params.conv_filter_dilations_,
|
||||
params.input_left_pads_,
|
||||
params.input_right_pads_,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
out_device_buf.FromDevice(device_output.mData.data());
|
||||
ck::utils::check_err(
|
||||
host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f);
|
||||
return ck::utils::check_err(
|
||||
host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f) ? 0 : 1;
|
||||
};
|
||||
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 3: {
|
||||
auto ref_conv = ReferenceConvNDFwdInstance<3>();
|
||||
verify_f(ref_conv);
|
||||
break;
|
||||
return verify_f(ref_conv);
|
||||
}
|
||||
case 2: {
|
||||
auto ref_conv = ReferenceConvNDFwdInstance<2>();
|
||||
verify_f(ref_conv);
|
||||
break;
|
||||
return verify_f(ref_conv);
|
||||
}
|
||||
case 1: {
|
||||
auto ref_conv = ReferenceConvNDFwdInstance<1>();
|
||||
verify_f(ref_conv);
|
||||
break;
|
||||
return verify_f(ref_conv);
|
||||
}
|
||||
default: {
|
||||
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "conv_fwd_util.hpp"
|
||||
#include "conv_util.hpp"
|
||||
#include "device.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp"
|
||||
@@ -39,10 +39,10 @@ template <ck::index_t NumDimSpatial>
|
||||
using DeviceConvNDFwdInstance = ck::tensor_operation::device::
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
|
||||
// clang-format off
|
||||
InDataType, //
|
||||
InDataType, //
|
||||
WeiDataType, //
|
||||
OutDataType, //
|
||||
AccDataType, //
|
||||
AccDataType, //
|
||||
InElementOp, // Input Elementwise Operation
|
||||
WeiElementOp, // Weights Elementwise Operation
|
||||
OutElementOp, // Output Elementwise Operation
|
||||
@@ -107,7 +107,7 @@ void print_use_msg()
|
||||
{
|
||||
std::cout << "arg1: verification (0=no, 1=yes)\n"
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
|
||||
<< "arg3: run kernel # of times (>1)\n"
|
||||
<< "arg3: time kernel (0=n0, 1=yes)\n"
|
||||
<< "arg4: N spatial dimensions (default 2)\n"
|
||||
<< "Following arguments (depending on number of spatial dims):\n"
|
||||
<< " N, K, C, \n"
|
||||
@@ -134,40 +134,40 @@ ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, int argc, cha
|
||||
ck::utils::conv::ConvParams params;
|
||||
int arg_idx = 5;
|
||||
|
||||
params.num_dim_spatial = num_dim_spatial;
|
||||
params.N = std::stoi(argv[arg_idx++]);
|
||||
params.K = std::stoi(argv[arg_idx++]);
|
||||
params.C = std::stoi(argv[arg_idx++]);
|
||||
params.num_dim_spatial_ = num_dim_spatial;
|
||||
params.N_ = std::stoi(argv[arg_idx++]);
|
||||
params.K_ = std::stoi(argv[arg_idx++]);
|
||||
params.C_ = std::stoi(argv[arg_idx++]);
|
||||
|
||||
params.filter_spatial_lengths.resize(num_dim_spatial);
|
||||
params.filter_spatial_lengths_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
|
||||
params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_spatial_lengths.resize(num_dim_spatial);
|
||||
params.input_spatial_lengths_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
|
||||
params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.conv_filter_strides.resize(num_dim_spatial);
|
||||
params.conv_filter_strides_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.conv_filter_strides[i] = std::stoi(argv[arg_idx++]);
|
||||
params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.conv_filter_dilations.resize(num_dim_spatial);
|
||||
params.conv_filter_dilations_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.conv_filter_dilations[i] = std::stoi(argv[arg_idx++]);
|
||||
params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_left_pads.resize(num_dim_spatial);
|
||||
params.input_left_pads_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_left_pads[i] = std::stoi(argv[arg_idx++]);
|
||||
params.input_left_pads_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_right_pads.resize(num_dim_spatial);
|
||||
params.input_right_pads_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_right_pads[i] = std::stoi(argv[arg_idx++]);
|
||||
params.input_right_pads_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
return params;
|
||||
@@ -179,9 +179,9 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck::utils::conv;
|
||||
|
||||
bool do_verification = 0;
|
||||
int init_method = 0;
|
||||
int nrepeat = 5;
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
int num_dim_spatial = 2;
|
||||
|
||||
ck::utils::conv::ConvParams params;
|
||||
@@ -190,7 +190,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
nrepeat = std::stoi(argv[3]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
num_dim_spatial = std::stoi(argv[4]);
|
||||
}
|
||||
|
||||
@@ -199,21 +199,21 @@ int main(int argc, char* argv[])
|
||||
params = parse_conv_params(num_dim_spatial, argc, argv);
|
||||
}
|
||||
|
||||
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N),
|
||||
static_cast<std::size_t>(params.C)};
|
||||
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N_),
|
||||
static_cast<std::size_t>(params.C_)};
|
||||
input_dims.insert(std::end(input_dims),
|
||||
std::begin(params.input_spatial_lengths),
|
||||
std::end(params.input_spatial_lengths));
|
||||
std::begin(params.input_spatial_lengths_),
|
||||
std::end(params.input_spatial_lengths_));
|
||||
|
||||
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K),
|
||||
static_cast<std::size_t>(params.C)};
|
||||
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K_),
|
||||
static_cast<std::size_t>(params.C_)};
|
||||
filter_dims.insert(std::end(filter_dims),
|
||||
std::begin(params.filter_spatial_lengths),
|
||||
std::end(params.filter_spatial_lengths));
|
||||
std::begin(params.filter_spatial_lengths_),
|
||||
std::end(params.filter_spatial_lengths_));
|
||||
|
||||
const std::vector<ck::index_t>& output_spatial_lengths = params.GetOutputSpatialLengths();
|
||||
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N),
|
||||
static_cast<std::size_t>(params.K)};
|
||||
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N_),
|
||||
static_cast<std::size_t>(params.K_)};
|
||||
output_dims.insert(std::end(output_dims),
|
||||
std::begin(output_spatial_lengths),
|
||||
std::end(output_spatial_lengths));
|
||||
@@ -255,16 +255,16 @@ int main(int argc, char* argv[])
|
||||
conv->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
params.N,
|
||||
params.K,
|
||||
params.C,
|
||||
params.input_spatial_lengths,
|
||||
params.filter_spatial_lengths,
|
||||
params.N_,
|
||||
params.K_,
|
||||
params.C_,
|
||||
params.input_spatial_lengths_,
|
||||
params.filter_spatial_lengths_,
|
||||
output_spatial_lengths,
|
||||
params.conv_filter_strides,
|
||||
params.conv_filter_dilations,
|
||||
params.input_left_pads,
|
||||
params.input_right_pads,
|
||||
params.conv_filter_strides_,
|
||||
params.conv_filter_dilations_,
|
||||
params.input_left_pads_,
|
||||
params.input_right_pads_,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{});
|
||||
@@ -276,16 +276,16 @@ int main(int argc, char* argv[])
|
||||
"not support this Conv problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker->Run(argument.get(), nrepeat);
|
||||
float ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = get_flops(
|
||||
params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths);
|
||||
params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths);
|
||||
std::size_t num_btype =
|
||||
get_btype<InDataType, WeiDataType, OutDataType>(params.N,
|
||||
params.C,
|
||||
params.K,
|
||||
params.input_spatial_lengths,
|
||||
params.filter_spatial_lengths,
|
||||
get_btype<InDataType, WeiDataType, OutDataType>(params.N_,
|
||||
params.C_,
|
||||
params.K_,
|
||||
params.input_spatial_lengths_,
|
||||
params.filter_spatial_lengths_,
|
||||
output_spatial_lengths);
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
@@ -301,40 +301,43 @@ int main(int argc, char* argv[])
|
||||
auto ref_argument = ref_conv.MakeArgument(input,
|
||||
weights,
|
||||
host_output,
|
||||
params.conv_filter_strides,
|
||||
params.conv_filter_dilations,
|
||||
params.input_left_pads,
|
||||
params.input_right_pads,
|
||||
params.conv_filter_strides_,
|
||||
params.conv_filter_dilations_,
|
||||
params.input_left_pads_,
|
||||
params.input_right_pads_,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
out_device_buf.FromDevice(device_output.mData.data());
|
||||
ck::utils::check_err(
|
||||
host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f);
|
||||
return ck::utils::check_err(device_output.mData,
|
||||
host_output.mData,
|
||||
"Error: incorrect results!",
|
||||
1e-5f,
|
||||
1e-4f)
|
||||
? 0
|
||||
: 1;
|
||||
};
|
||||
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 3: {
|
||||
auto ref_conv = ReferenceConvNDFwdInstance<3>();
|
||||
verify_f(ref_conv);
|
||||
break;
|
||||
return verify_f(ref_conv);
|
||||
}
|
||||
case 2: {
|
||||
auto ref_conv = ReferenceConvNDFwdInstance<2>();
|
||||
verify_f(ref_conv);
|
||||
break;
|
||||
return verify_f(ref_conv);
|
||||
}
|
||||
case 1: {
|
||||
auto ref_conv = ReferenceConvNDFwdInstance<1>();
|
||||
verify_f(ref_conv);
|
||||
break;
|
||||
return verify_f(ref_conv);
|
||||
}
|
||||
default: {
|
||||
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
344
example/09_convnd_fwd/convnd_fwd_xdl_fp64.cpp
Normal file
344
example/09_convnd_fwd/convnd_fwd_xdl_fp64.cpp
Normal file
@@ -0,0 +1,344 @@
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <type_traits>
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "conv_util.hpp"
|
||||
#include "device.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "reference_conv_fwd.hpp"
|
||||
#include "tensor_layout.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
using InDataType = double;
|
||||
using WeiDataType = double;
|
||||
using OutDataType = double;
|
||||
using AccDataType = double;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto ConvFwdDefault =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
using DeviceConvFwdBasePtr =
|
||||
ck::tensor_operation::device::DeviceConvFwdPtr<InElementOp, WeiElementOp, OutElementOp>;
|
||||
|
||||
template <ck::index_t NumDimSpatial>
|
||||
using DeviceConvNDFwdInstance = ck::tensor_operation::device::
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
|
||||
// clang-format off
|
||||
InDataType, //
|
||||
WeiDataType, //
|
||||
OutDataType, //
|
||||
AccDataType, //
|
||||
InElementOp, // Input Elementwise Operation
|
||||
WeiElementOp, // Weights Elementwise Operation
|
||||
OutElementOp, // Output Elementwise Operation
|
||||
ConvFwdDefault, // ConvForwardSpecialization
|
||||
NumDimSpatial, // NumDimSpatial
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
4, // K0PerBlock
|
||||
2, // K1
|
||||
16, // MPerXDL
|
||||
16, // NPerXDL
|
||||
4, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
2, // ABlockTransferSrcScalarPerVector
|
||||
2, // ABlockTransferDstScalarPerVector_K1
|
||||
true, // ABlockLdsAddExtraM
|
||||
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
2, // BBlockTransferSrcScalarPerVector
|
||||
2, // BBlockTransferDstScalarPerVector_K1
|
||||
true, // BBlockTransferAddExtraN
|
||||
7, // CThreadTransferSrcDstVectorDim
|
||||
1>; // CThreadTransferDstScalarPerVector
|
||||
// clang-format on
|
||||
|
||||
template <ck::index_t NumDimSpatial>
|
||||
using ReferenceConvNDFwdInstance = ck::tensor_operation::host::ReferenceConvFwd<InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
NumDimSpatial>;
|
||||
|
||||
DeviceConvFwdBasePtr get_conv_instance(int num_dim_spatial)
|
||||
{
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 3: {
|
||||
return std::make_unique<DeviceConvNDFwdInstance<3>>();
|
||||
}
|
||||
case 2: {
|
||||
return std::make_unique<DeviceConvNDFwdInstance<2>>();
|
||||
}
|
||||
case 1: {
|
||||
return std::make_unique<DeviceConvNDFwdInstance<1>>();
|
||||
}
|
||||
default: {
|
||||
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void print_use_msg()
|
||||
{
|
||||
std::cout << "arg1: verification (0=no, 1=yes)\n"
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
|
||||
<< "arg3: run kernel # of times (>1)\n"
|
||||
<< "arg4: N spatial dimensions (default 2)\n"
|
||||
<< "Following arguments (depending on number of spatial dims):\n"
|
||||
<< " N, K, C, \n"
|
||||
<< " <filter spatial dimensions>, (ie Y, X for 2D)\n"
|
||||
<< " <input image spatial dimensions>, (ie Hi, Wi for 2D)\n"
|
||||
<< " <strides>, (ie Sy, Sx for 2D)\n"
|
||||
<< " <dilations>, (ie Dy, Dx for 2D)\n"
|
||||
<< " <left padding>, (ie LeftPy, LeftPx for 2D)\n"
|
||||
<< " <right padding>, (ie RightPy, RightPx for 2D)\n"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, int argc, char* argv[])
|
||||
{
|
||||
// (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right)
|
||||
int conv_args = 3 + num_dim_spatial * 6;
|
||||
int cmdline_nargs = conv_args + 5;
|
||||
if(cmdline_nargs != argc)
|
||||
{
|
||||
print_use_msg();
|
||||
exit(0);
|
||||
}
|
||||
|
||||
ck::utils::conv::ConvParams params;
|
||||
int arg_idx = 5;
|
||||
|
||||
params.num_dim_spatial_ = num_dim_spatial;
|
||||
params.N_ = std::stoi(argv[arg_idx++]);
|
||||
params.K_ = std::stoi(argv[arg_idx++]);
|
||||
params.C_ = std::stoi(argv[arg_idx++]);
|
||||
|
||||
params.filter_spatial_lengths_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_spatial_lengths_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.conv_filter_strides_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.conv_filter_dilations_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_left_pads_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_left_pads_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_right_pads_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_right_pads_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck::utils::conv;
|
||||
|
||||
bool do_verification = 0;
|
||||
int init_method = 0;
|
||||
bool time_kernel = false;
|
||||
int num_dim_spatial = 2;
|
||||
|
||||
ck::utils::conv::ConvParams params;
|
||||
|
||||
if(argc >= 5)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
num_dim_spatial = std::stoi(argv[4]);
|
||||
}
|
||||
|
||||
if(argc >= 6)
|
||||
{
|
||||
params = parse_conv_params(num_dim_spatial, argc, argv);
|
||||
}
|
||||
|
||||
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N_),
|
||||
static_cast<std::size_t>(params.C_)};
|
||||
input_dims.insert(std::end(input_dims),
|
||||
std::begin(params.input_spatial_lengths_),
|
||||
std::end(params.input_spatial_lengths_));
|
||||
|
||||
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K_),
|
||||
static_cast<std::size_t>(params.C_)};
|
||||
filter_dims.insert(std::end(filter_dims),
|
||||
std::begin(params.filter_spatial_lengths_),
|
||||
std::end(params.filter_spatial_lengths_));
|
||||
|
||||
const std::vector<ck::index_t>& output_spatial_lengths = params.GetOutputSpatialLengths();
|
||||
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N_),
|
||||
static_cast<std::size_t>(params.K_)};
|
||||
output_dims.insert(std::end(output_dims),
|
||||
std::begin(output_spatial_lengths),
|
||||
std::end(output_spatial_lengths));
|
||||
|
||||
Tensor<InDataType> input(get_input_host_tensor_descriptor(input_dims, num_dim_spatial));
|
||||
Tensor<WeiDataType> weights(get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial));
|
||||
Tensor<OutDataType> host_output(
|
||||
get_output_host_tensor_descriptor(output_dims, num_dim_spatial));
|
||||
Tensor<OutDataType> device_output(
|
||||
get_output_host_tensor_descriptor(output_dims, num_dim_spatial));
|
||||
|
||||
std::cout << "input: " << input.mDesc << std::endl;
|
||||
std::cout << "weights: " << weights.mDesc << std::endl;
|
||||
std::cout << "output: " << host_output.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
input.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
|
||||
weights.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
|
||||
break;
|
||||
case 2:
|
||||
input.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
|
||||
weights.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
|
||||
break;
|
||||
default:
|
||||
input.GenerateTensorValue(GeneratorTensor_1<InDataType>{1});
|
||||
weights.GenerateTensorValue(GeneratorTensor_1<WeiDataType>{1});
|
||||
}
|
||||
|
||||
DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace());
|
||||
DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace());
|
||||
DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpace());
|
||||
|
||||
in_device_buf.ToDevice(input.mData.data());
|
||||
wei_device_buf.ToDevice(weights.mData.data());
|
||||
|
||||
// do GEMM
|
||||
auto conv = get_conv_instance(num_dim_spatial);
|
||||
auto invoker = conv->MakeInvokerPointer();
|
||||
auto argument =
|
||||
conv->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
params.N_,
|
||||
params.K_,
|
||||
params.C_,
|
||||
params.input_spatial_lengths_,
|
||||
params.filter_spatial_lengths_,
|
||||
output_spatial_lengths,
|
||||
params.conv_filter_strides_,
|
||||
params.conv_filter_dilations_,
|
||||
params.input_left_pads_,
|
||||
params.input_right_pads_,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{});
|
||||
|
||||
if(!conv->IsSupportedArgument(argument.get()))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_conv with the specified compilation parameters does "
|
||||
"not support this Conv problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = get_flops(
|
||||
params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths);
|
||||
std::size_t num_btype =
|
||||
get_btype<InDataType, WeiDataType, OutDataType>(params.N_,
|
||||
params.C_,
|
||||
params.K_,
|
||||
params.input_spatial_lengths_,
|
||||
params.filter_spatial_lengths_,
|
||||
output_spatial_lengths);
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
auto verify_f = [&input, &weights, &host_output, ¶ms, &out_device_buf, &device_output](
|
||||
const auto& ref_conv) {
|
||||
auto ref_invoker = ref_conv.MakeInvoker();
|
||||
auto ref_argument = ref_conv.MakeArgument(input,
|
||||
weights,
|
||||
host_output,
|
||||
params.conv_filter_strides_,
|
||||
params.conv_filter_dilations_,
|
||||
params.input_left_pads_,
|
||||
params.input_right_pads_,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
out_device_buf.FromDevice(device_output.mData.data());
|
||||
ck::utils::check_err(
|
||||
host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f);
|
||||
};
|
||||
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 3: {
|
||||
auto ref_conv = ReferenceConvNDFwdInstance<3>();
|
||||
verify_f(ref_conv);
|
||||
break;
|
||||
}
|
||||
case 2: {
|
||||
auto ref_conv = ReferenceConvNDFwdInstance<2>();
|
||||
verify_f(ref_conv);
|
||||
break;
|
||||
}
|
||||
case 1: {
|
||||
auto ref_conv = ReferenceConvNDFwdInstance<1>();
|
||||
verify_f(ref_conv);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "conv_fwd_util.hpp"
|
||||
#include "conv_util.hpp"
|
||||
#include "device.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp"
|
||||
@@ -45,10 +45,10 @@ template <ck::index_t NumDimSpatial>
|
||||
using DeviceConvNDFwdInstance = ck::tensor_operation::device::
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
|
||||
// clang-format off
|
||||
InDataType, //
|
||||
InDataType, //
|
||||
WeiDataType, //
|
||||
OutDataType, //
|
||||
AccDataType, //
|
||||
AccDataType, //
|
||||
InElementOp, // Input Elementwise Operation
|
||||
WeiElementOp, // Weights Elementwise Operation
|
||||
OutElementOp, // Output Elementwise Operation
|
||||
@@ -112,7 +112,7 @@ void print_use_msg()
|
||||
{
|
||||
std::cout << "arg1: verification (0=no, 1=yes)\n"
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
|
||||
<< "arg3: run kernel # of times (>1)\n"
|
||||
<< "arg3: time kernel (0=n0, 1=yes)\n"
|
||||
<< "arg4: N spatial dimensions (default 2)\n"
|
||||
<< "Following arguments (depending on number of spatial dims):\n"
|
||||
<< " N, K, C, \n"
|
||||
@@ -139,40 +139,40 @@ ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, int argc, cha
|
||||
ck::utils::conv::ConvParams params;
|
||||
int arg_idx = 5;
|
||||
|
||||
params.num_dim_spatial = num_dim_spatial;
|
||||
params.N = std::stoi(argv[arg_idx++]);
|
||||
params.K = std::stoi(argv[arg_idx++]);
|
||||
params.C = std::stoi(argv[arg_idx++]);
|
||||
params.num_dim_spatial_ = num_dim_spatial;
|
||||
params.N_ = std::stoi(argv[arg_idx++]);
|
||||
params.K_ = std::stoi(argv[arg_idx++]);
|
||||
params.C_ = std::stoi(argv[arg_idx++]);
|
||||
|
||||
params.filter_spatial_lengths.resize(num_dim_spatial);
|
||||
params.filter_spatial_lengths_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
|
||||
params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_spatial_lengths.resize(num_dim_spatial);
|
||||
params.input_spatial_lengths_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
|
||||
params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.conv_filter_strides.resize(num_dim_spatial);
|
||||
params.conv_filter_strides_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.conv_filter_strides[i] = std::stoi(argv[arg_idx++]);
|
||||
params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.conv_filter_dilations.resize(num_dim_spatial);
|
||||
params.conv_filter_dilations_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.conv_filter_dilations[i] = std::stoi(argv[arg_idx++]);
|
||||
params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_left_pads.resize(num_dim_spatial);
|
||||
params.input_left_pads_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_left_pads[i] = std::stoi(argv[arg_idx++]);
|
||||
params.input_left_pads_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_right_pads.resize(num_dim_spatial);
|
||||
params.input_right_pads_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_right_pads[i] = std::stoi(argv[arg_idx++]);
|
||||
params.input_right_pads_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
return params;
|
||||
@@ -184,9 +184,9 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck::utils::conv;
|
||||
|
||||
bool do_verification = 0;
|
||||
int init_method = 0;
|
||||
int nrepeat = 5;
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
int num_dim_spatial = 2;
|
||||
|
||||
ck::utils::conv::ConvParams params;
|
||||
@@ -195,7 +195,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
nrepeat = std::stoi(argv[3]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
num_dim_spatial = std::stoi(argv[4]);
|
||||
}
|
||||
|
||||
@@ -204,21 +204,21 @@ int main(int argc, char* argv[])
|
||||
params = parse_conv_params(num_dim_spatial, argc, argv);
|
||||
}
|
||||
|
||||
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N),
|
||||
static_cast<std::size_t>(params.C)};
|
||||
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N_),
|
||||
static_cast<std::size_t>(params.C_)};
|
||||
input_dims.insert(std::end(input_dims),
|
||||
std::begin(params.input_spatial_lengths),
|
||||
std::end(params.input_spatial_lengths));
|
||||
std::begin(params.input_spatial_lengths_),
|
||||
std::end(params.input_spatial_lengths_));
|
||||
|
||||
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K),
|
||||
static_cast<std::size_t>(params.C)};
|
||||
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K_),
|
||||
static_cast<std::size_t>(params.C_)};
|
||||
filter_dims.insert(std::end(filter_dims),
|
||||
std::begin(params.filter_spatial_lengths),
|
||||
std::end(params.filter_spatial_lengths));
|
||||
std::begin(params.filter_spatial_lengths_),
|
||||
std::end(params.filter_spatial_lengths_));
|
||||
|
||||
const std::vector<ck::index_t>& output_spatial_lengths = params.GetOutputSpatialLengths();
|
||||
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N),
|
||||
static_cast<std::size_t>(params.K)};
|
||||
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N_),
|
||||
static_cast<std::size_t>(params.K_)};
|
||||
output_dims.insert(std::end(output_dims),
|
||||
std::begin(output_spatial_lengths),
|
||||
std::end(output_spatial_lengths));
|
||||
@@ -258,16 +258,16 @@ int main(int argc, char* argv[])
|
||||
conv->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
params.N,
|
||||
params.K,
|
||||
params.C,
|
||||
params.input_spatial_lengths,
|
||||
params.filter_spatial_lengths,
|
||||
params.N_,
|
||||
params.K_,
|
||||
params.C_,
|
||||
params.input_spatial_lengths_,
|
||||
params.filter_spatial_lengths_,
|
||||
output_spatial_lengths,
|
||||
params.conv_filter_strides,
|
||||
params.conv_filter_dilations,
|
||||
params.input_left_pads,
|
||||
params.input_right_pads,
|
||||
params.conv_filter_strides_,
|
||||
params.conv_filter_dilations_,
|
||||
params.input_left_pads_,
|
||||
params.input_right_pads_,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{});
|
||||
@@ -279,16 +279,16 @@ int main(int argc, char* argv[])
|
||||
"not support this Conv problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker->Run(argument.get(), nrepeat);
|
||||
float ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = get_flops(
|
||||
params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths);
|
||||
params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths);
|
||||
std::size_t num_btype = get_btype<InDataType, WeiDataType, OutDataType>(
|
||||
params.N,
|
||||
params.C,
|
||||
params.K,
|
||||
params.input_spatial_lengths,
|
||||
params.filter_spatial_lengths,
|
||||
params.N_,
|
||||
params.C_,
|
||||
params.K_,
|
||||
params.input_spatial_lengths_,
|
||||
params.filter_spatial_lengths_,
|
||||
output_spatial_lengths);
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
@@ -304,40 +304,38 @@ int main(int argc, char* argv[])
|
||||
auto ref_argument = ref_conv.MakeArgument(input,
|
||||
weights,
|
||||
host_output,
|
||||
params.conv_filter_strides,
|
||||
params.conv_filter_dilations,
|
||||
params.input_left_pads,
|
||||
params.input_right_pads,
|
||||
params.conv_filter_strides_,
|
||||
params.conv_filter_dilations_,
|
||||
params.input_left_pads_,
|
||||
params.input_right_pads_,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
out_device_buf.FromDevice(device_output.mData.data());
|
||||
ck::utils::check_err(
|
||||
host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f);
|
||||
return ck::utils::check_err(
|
||||
host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f) ? 0 : 1;
|
||||
};
|
||||
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 3: {
|
||||
auto ref_conv = ReferenceConvNDFwdInstance<3>();
|
||||
verify_f(ref_conv);
|
||||
break;
|
||||
return verify_f(ref_conv);
|
||||
}
|
||||
case 2: {
|
||||
auto ref_conv = ReferenceConvNDFwdInstance<2>();
|
||||
verify_f(ref_conv);
|
||||
break;
|
||||
return verify_f(ref_conv);
|
||||
}
|
||||
case 1: {
|
||||
auto ref_conv = ReferenceConvNDFwdInstance<1>();
|
||||
verify_f(ref_conv);
|
||||
break;
|
||||
return verify_f(ref_conv);
|
||||
}
|
||||
default: {
|
||||
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
add_example_executable(example_conv2d_bwd_data_xdl conv2d_bwd_data_xdl.cpp)
|
||||
target_link_libraries(example_conv2d_bwd_data_xdl PRIVATE conv_fwd_util)
|
||||
target_link_libraries(example_conv2d_bwd_data_xdl PRIVATE conv_util)
|
||||
|
||||
@@ -77,9 +77,9 @@ using ReferenceConvBwdInstance = ck::tensor_operation::host::ReferenceConvBwdDat
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = 0;
|
||||
int init_method = 0;
|
||||
int nrepeat = 5;
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
// Conv shape
|
||||
ck::index_t N = 128;
|
||||
@@ -102,13 +102,13 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
nrepeat = std::stoi(argv[3]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 19)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
nrepeat = std::stoi(argv[3]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
N = std::stoi(argv[4]);
|
||||
K = std::stoi(argv[5]);
|
||||
@@ -130,7 +130,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: run kernel # of times (>1)\n");
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
printf("arg4 to 18: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
|
||||
"RightPx\n");
|
||||
exit(0);
|
||||
@@ -214,7 +214,7 @@ int main(int argc, char* argv[])
|
||||
"not support this Conv problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, nrepeat);
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X;
|
||||
|
||||
@@ -249,6 +249,10 @@ int main(int argc, char* argv[])
|
||||
|
||||
in_device_buf.FromDevice(in_n_c_hi_wi_device_result.mData.data());
|
||||
|
||||
ck::utils::check_err(in_n_c_hi_wi_device_result.mData, in_n_c_hi_wi_host_result.mData);
|
||||
return ck::utils::check_err(in_n_c_hi_wi_device_result.mData,
|
||||
in_n_c_hi_wi_host_result.mData)
|
||||
? 0
|
||||
: 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
add_example_executable(example_conv2d_bwd_weight_xdl conv2d_bwd_weight_xdl.cpp)
|
||||
target_link_libraries(example_conv2d_bwd_weight_xdl PRIVATE conv_fwd_util)
|
||||
target_link_libraries(example_conv2d_bwd_weight_xdl PRIVATE conv_util)
|
||||
|
||||
@@ -82,9 +82,9 @@ using ReferenceConvBwdWeightInstance =
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = 0;
|
||||
int init_method = 0;
|
||||
int nrepeat = 5;
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
int do_log = 0;
|
||||
int split_k = 4;
|
||||
|
||||
@@ -109,7 +109,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
nrepeat = std::stoi(argv[3]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
do_log = std::stoi(argv[4]);
|
||||
split_k = std::stoi(argv[5]);
|
||||
}
|
||||
@@ -117,7 +117,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
nrepeat = std::stoi(argv[3]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
do_log = std::stoi(argv[4]);
|
||||
split_k = std::stoi(argv[5]);
|
||||
|
||||
@@ -141,7 +141,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: run kernel # of times (>1)\n");
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
printf("arg4: is show log (0=no, 1=yes)\n");
|
||||
printf("arg5: split-k \n");
|
||||
printf("arg6 to 19: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
|
||||
@@ -246,7 +246,7 @@ int main(int argc, char* argv[])
|
||||
return 1;
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, nrepeat);
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X;
|
||||
|
||||
@@ -291,6 +291,9 @@ int main(int argc, char* argv[])
|
||||
LogRangeAsType<float>(std::cout << "wei_host : ", wei_k_c_y_x_host_result.mData, ",")
|
||||
<< std::endl;
|
||||
}
|
||||
ck::utils::check_err(wei_k_c_y_x_device_result.mData, wei_k_c_y_x_host_result.mData);
|
||||
return ck::utils::check_err(wei_k_c_y_x_device_result.mData, wei_k_c_y_x_host_result.mData)
|
||||
? 0
|
||||
: 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
add_example_executable(example_reduce_blockwise reduce_blockwise.cpp)
|
||||
add_example_executable(example_reduce_blockwise_two_call reduce_blockwise_two_call.cpp)
|
||||
|
||||
@@ -5,23 +5,37 @@
|
||||
# -D <xxx> : input 4-d tensor lengths
|
||||
# -v <x> : verification (0=no, 1=yes)
|
||||
#arg1: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)
|
||||
#arg2: run kernel # of times (>1)
|
||||
./bin/example_reduce_blockwise -D 16,64,32,960 -v 1 1 10
|
||||
#arg2: time kernel (0=no, 1=yes)
|
||||
./bin/example_reduce_blockwise -D 16,64,32,960 -v 1 1 1
|
||||
```
|
||||
|
||||
Result
|
||||
```
|
||||
launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up
|
||||
Start running 3 times...
|
||||
Perf: 0.23536 ms, 267.32 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSrcVectorDim_0_InSrcVectorSize_1_OutDstVectorSize_1>
|
||||
error: 0
|
||||
max_diff: 0, 529, 529
|
||||
root@dc-smc-18:/data/composable_kernel/Build3# bin/example_reduce_blockwise -D 16,64,32,960 -v 1 1 10
|
||||
launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up
|
||||
./bin/example_reduce_blockwise -D 16,64,32,960 -v 1 1 1
|
||||
launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up 1 time
|
||||
Start running 10 times...
|
||||
Perf: 0.23392 ms, 268.966 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSrcVectorDim_0_InSrcVectorSize_1_OutDstVectorSize_1>
|
||||
error: 0
|
||||
max_diff: 0, 528, 528
|
||||
Perf: 0.282592 ms, 222.641 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSrcVectorDim_0_InSrcVectorSize_1_OutDstVectorSize_1>
|
||||
```
|
||||
|
||||
# Instructions for ```example_reduce_blockwise_two_call```
|
||||
|
||||
## Run ```example_reduce_blockwise_two_call```
|
||||
```bash
|
||||
#arg1: verification (0=no, 1=yes(
|
||||
#arg2: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)
|
||||
#arg3: time kernel (0=no, 1=yes)
|
||||
./bin/example_reduce_blockwise_two_call 1 2 1
|
||||
```
|
||||
|
||||
Result
|
||||
```
|
||||
./bin/example_reduce_blockwise_two_call 1 2 1
|
||||
launch_and_time_kernel: grid_dim {204800, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up 1 time
|
||||
Start running 10 times...
|
||||
launch_and_time_kernel: grid_dim {6400, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up 1 time
|
||||
Start running 10 times...
|
||||
Perf: 2.1791 ms, 771.42 GB/s, DeviceReduceBlockWise<256,M_C32_S1,K_C8_S1,InSrcVectorDim_1_InSrcVectorSize_1_OutDstVectorSize_1> => DeviceReduceBlockWise<256,M_C256_S1,K_C1_S1,InSrcVectorDim_1_InSrcVectorSize_1_OutDstVectorSize_1>
|
||||
```
|
||||
|
||||
@@ -12,8 +12,8 @@
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_base.hpp"
|
||||
#include "device_reduce_blockwise.hpp"
|
||||
#include "host_reduce_util.hpp"
|
||||
#include "device_reduce_multiblock.hpp"
|
||||
#include "host_common_util.hpp"
|
||||
#include "host_reduction.hpp"
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
@@ -30,96 +30,53 @@ constexpr int Rank = 4;
|
||||
constexpr int NumReduceDim = 3;
|
||||
|
||||
constexpr ReduceTensorOp ReduceOpId = ReduceTensorOp::NORM2;
|
||||
constexpr NanPropagation NanOpt = NanPropagation::PROPAGATE_NAN;
|
||||
constexpr bool PropagateNan = (NanOpt == NanPropagation::NOT_PROPAGATE_NAN) ? false : true;
|
||||
constexpr ReduceTensorIndices IndicesOpt = ReduceTensorIndices::NO_INDICES;
|
||||
constexpr bool PropagateNan = true;
|
||||
constexpr bool OutputIndex = false;
|
||||
|
||||
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType;
|
||||
using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
|
||||
using InElementwiseOperation =
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation;
|
||||
typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
|
||||
using AccElementwiseOperation =
|
||||
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation;
|
||||
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
|
||||
|
||||
using DeviceReduceInstance = DeviceReduceBlockWise<InDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
PropagateNan,
|
||||
false,
|
||||
256,
|
||||
4,
|
||||
64,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
1>;
|
||||
using DeviceReduceInstance = DeviceReduceMultiBlock<InDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
PropagateNan,
|
||||
OutputIndex,
|
||||
false, // HaveIndexInputIfOutputIndex
|
||||
256,
|
||||
4,
|
||||
64,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
1>;
|
||||
|
||||
static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'},
|
||||
{"scales", required_argument, nullptr, 'S'},
|
||||
{"verify", required_argument, nullptr, 'v'},
|
||||
{"help", no_argument, nullptr, '?'},
|
||||
{nullptr, 0, nullptr, 0}};
|
||||
|
||||
class SimpleAppArgs
|
||||
{
|
||||
template <typename T>
|
||||
static T getSingleValueFromString(const std::string& valueStr)
|
||||
{
|
||||
std::istringstream iss(valueStr);
|
||||
|
||||
T ret;
|
||||
|
||||
iss >> ret;
|
||||
|
||||
return (ret);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
static std::vector<T> getTypeValuesFromString(const char* cstr_values)
|
||||
{
|
||||
std::string valuesStr(cstr_values);
|
||||
|
||||
std::vector<T> values;
|
||||
std::size_t pos = 0;
|
||||
std::size_t new_pos;
|
||||
|
||||
new_pos = valuesStr.find(',', pos);
|
||||
while(new_pos != std::string::npos)
|
||||
{
|
||||
const std::string sliceStr = valuesStr.substr(pos, new_pos - pos);
|
||||
|
||||
T val = getSingleValueFromString<T>(sliceStr);
|
||||
|
||||
values.push_back(val);
|
||||
|
||||
pos = new_pos + 1;
|
||||
new_pos = valuesStr.find(',', pos);
|
||||
};
|
||||
|
||||
std::string sliceStr = valuesStr.substr(pos);
|
||||
T val = getSingleValueFromString<T>(sliceStr);
|
||||
|
||||
values.push_back(val);
|
||||
|
||||
return (values);
|
||||
};
|
||||
|
||||
private:
|
||||
int option_index = 0;
|
||||
|
||||
public:
|
||||
std::vector<size_t> inLengths;
|
||||
std::vector<float> scales;
|
||||
std::vector<size_t> inLengths = {16, 64, 32, 960};
|
||||
std::vector<float> scales = {1.0f, 0.0f};
|
||||
|
||||
bool do_verification = false;
|
||||
|
||||
int init_method = 1;
|
||||
int nrepeat = 5;
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
|
||||
public:
|
||||
void show_usage(const char* cmd)
|
||||
@@ -127,24 +84,24 @@ class SimpleAppArgs
|
||||
std::cout << "Usage of " << cmd << std::endl;
|
||||
std::cout << "--inLengths or -D, comma separated list of input tensor dimension lengths"
|
||||
<< std::endl;
|
||||
std::cout << "--scales or -S, comma separated two float values for alpha and beta"
|
||||
<< std::endl;
|
||||
std::cout << "--verify or -v, 1/0 to indicate whether to verify the reduction result by "
|
||||
"comparing with the host-based reduction"
|
||||
<< std::endl;
|
||||
std::cout << "Arg1 -- init method (0=no init, 1=single integer value, 2=scope integer "
|
||||
"value, 3=decimal value)"
|
||||
<< std::endl;
|
||||
std::cout << "Arg2 -- number of repeats to run the kernel" << std::endl;
|
||||
std::cout << "Arg2 -- time kernel (0=no, 1=yes)" << std::endl;
|
||||
};
|
||||
|
||||
int processArgs(int argc, char* argv[])
|
||||
{
|
||||
unsigned int ch;
|
||||
using ck::host_common::getTypeValuesFromString;
|
||||
|
||||
int ch;
|
||||
|
||||
while(1)
|
||||
{
|
||||
ch = getopt_long(argc, argv, "D:S:v:l:", long_options, &option_index);
|
||||
ch = getopt_long(argc, argv, "D:v:l:", long_options, &option_index);
|
||||
if(ch == -1)
|
||||
break;
|
||||
switch(ch)
|
||||
@@ -155,12 +112,6 @@ class SimpleAppArgs
|
||||
|
||||
inLengths = getTypeValuesFromString<size_t>(optarg);
|
||||
break;
|
||||
case 'S':
|
||||
if(!optarg)
|
||||
throw std::runtime_error("Invalid option format!");
|
||||
|
||||
scales = getTypeValuesFromString<float>(optarg);
|
||||
break;
|
||||
case 'v':
|
||||
if(!optarg)
|
||||
throw std::runtime_error("Invalid option format!");
|
||||
@@ -182,7 +133,7 @@ class SimpleAppArgs
|
||||
throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!");
|
||||
|
||||
init_method = std::atoi(argv[optind++]);
|
||||
nrepeat = std::atoi(argv[optind]);
|
||||
time_kernel = static_cast<bool>(std::atoi(argv[optind]));
|
||||
|
||||
if(scales.empty())
|
||||
{
|
||||
@@ -196,23 +147,21 @@ class SimpleAppArgs
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck::host_reduce;
|
||||
|
||||
const std::vector<int> reduceDims{0, 1, 2};
|
||||
const std::vector<int> invariantDims{3};
|
||||
|
||||
SimpleAppArgs args;
|
||||
|
||||
if(args.processArgs(argc, argv) < 0)
|
||||
return (-1);
|
||||
if(argc > 1)
|
||||
{
|
||||
if(args.processArgs(argc, argv) < 0)
|
||||
return (-1);
|
||||
};
|
||||
|
||||
constexpr bool op_support_indices =
|
||||
(ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX ||
|
||||
ReduceOpId == ReduceTensorOp::AMAX);
|
||||
|
||||
constexpr bool NeedIndices =
|
||||
(op_support_indices && (IndicesOpt != ReduceTensorIndices::NO_INDICES));
|
||||
|
||||
// if input is half type, no reason to use float for indiced reduction operation and must use
|
||||
// float for non-indiced reduction operation for accuracy
|
||||
constexpr bool invalid_reduce_1 =
|
||||
@@ -226,8 +175,7 @@ int main(int argc, char* argv[])
|
||||
(op_support_indices && !std::is_same<AccDataType, float>::value);
|
||||
|
||||
// indices option can only be used when it is really needed
|
||||
constexpr bool invalid_reduce_3 =
|
||||
(!op_support_indices && IndicesOpt != ReduceTensorIndices::NO_INDICES);
|
||||
constexpr bool invalid_reduce_3 = (!op_support_indices && OutputIndex);
|
||||
|
||||
constexpr bool invalid_reduce = (invalid_reduce_1 || invalid_reduce_2 || invalid_reduce_3);
|
||||
|
||||
@@ -295,51 +243,65 @@ int main(int argc, char* argv[])
|
||||
if(beta != 0.0f)
|
||||
out_dev.ToDevice(out.mData.data());
|
||||
|
||||
size_t indicesSizeInBytes = NeedIndices ? out.mDesc.GetElementSize() * sizeof(int32_t) : 0;
|
||||
size_t indicesSizeInBytes = OutputIndex ? out.mDesc.GetElementSize() * sizeof(int32_t) : 0;
|
||||
|
||||
DeviceMem out_indices_dev(indicesSizeInBytes);
|
||||
DeviceMem out_index_dev(indicesSizeInBytes);
|
||||
|
||||
InElementwiseOperation in_elementwise_op;
|
||||
AccElementwiseOperation acc_elementwise_op;
|
||||
|
||||
std::tie(in_elementwise_op, acc_elementwise_op) =
|
||||
reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(
|
||||
static_cast<int32_t>(reduce_total_length));
|
||||
|
||||
if(args.do_verification)
|
||||
{
|
||||
ReductionHost<InDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
ReduceOpId,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
PropagateNan,
|
||||
NeedIndices>
|
||||
OutputIndex>
|
||||
hostReduce(in.mDesc, out_ref.mDesc, invariantDims, reduceDims);
|
||||
|
||||
hostReduce.Run(
|
||||
alpha, in.mData.data(), beta, out_ref.mData.data(), out_indices_ref.mData.data());
|
||||
hostReduce.Run(alpha,
|
||||
in.mData.data(),
|
||||
beta,
|
||||
out_ref.mData.data(),
|
||||
out_indices_ref.mData.data(),
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op);
|
||||
};
|
||||
|
||||
const auto i_inLengths = to_int_vector(args.inLengths);
|
||||
const auto i_inStrides = to_int_vector(inStrides);
|
||||
const auto i_outLengths = to_int_vector(outLengths);
|
||||
const auto i_outStrides = to_int_vector(outStrides);
|
||||
std::vector<ck::index_t> i_inLengths;
|
||||
std::vector<ck::index_t> i_inStrides;
|
||||
std::vector<ck::index_t> i_outLengths;
|
||||
std::vector<ck::index_t> i_outStrides;
|
||||
|
||||
i_inLengths.assign(args.inLengths.begin(), args.inLengths.end());
|
||||
i_inStrides.assign(inStrides.begin(), inStrides.end());
|
||||
i_outLengths.assign(outLengths.begin(), outLengths.end());
|
||||
i_outStrides.assign(outStrides.begin(), outStrides.end());
|
||||
|
||||
auto reduce = DeviceReduceInstance{};
|
||||
|
||||
auto wsSizeInBytes = reduce.GetWorkspaceSizeInBytes(i_inLengths, reduceDims);
|
||||
|
||||
DeviceMem ws_dev(wsSizeInBytes);
|
||||
|
||||
auto argument_ptr =
|
||||
reduce.MakeArgumentPointer(i_inLengths,
|
||||
i_inStrides,
|
||||
i_outLengths,
|
||||
i_outStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
in_dev.GetDeviceBuffer(),
|
||||
out_dev.GetDeviceBuffer(),
|
||||
out_indices_dev.GetDeviceBuffer(),
|
||||
ws_dev.GetDeviceBuffer(),
|
||||
InElementwiseOperation{static_cast<int>(reduce_total_length)},
|
||||
AccElementwiseOperation{static_cast<int>(reduce_total_length)});
|
||||
auto argument_ptr = reduce.MakeArgumentPointer(i_inLengths,
|
||||
i_inStrides,
|
||||
i_outLengths,
|
||||
i_outStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
in_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
out_dev.GetDeviceBuffer(),
|
||||
out_index_dev.GetDeviceBuffer(),
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op);
|
||||
|
||||
if(!reduce.IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
@@ -352,7 +314,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
auto invoker_ptr = reduce.MakeInvokerPointer();
|
||||
|
||||
float avg_time = invoker_ptr->Run(argument_ptr.get(), args.nrepeat);
|
||||
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, args.time_kernel});
|
||||
|
||||
std::size_t num_bytes = invariant_total_length * reduce_total_length * sizeof(InDataType) +
|
||||
invariant_total_length * sizeof(OutDataType);
|
||||
@@ -362,16 +324,19 @@ int main(int argc, char* argv[])
|
||||
std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s, " << reduce_name
|
||||
<< std::endl;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(args.do_verification)
|
||||
{
|
||||
out_dev.FromDevice(out.mData.data());
|
||||
ck::utils::check_err(out.mData, out_ref.mData);
|
||||
pass = pass && ck::utils::check_err(out.mData, out_ref.mData);
|
||||
|
||||
if(NeedIndices)
|
||||
if(OutputIndex)
|
||||
{
|
||||
out_indices_dev.FromDevice(out_indices.mData.data());
|
||||
ck::utils::check_err(out_indices.mData, out_indices_ref.mData);
|
||||
;
|
||||
out_index_dev.FromDevice(out_indices.mData.data());
|
||||
pass = pass && ck::utils::check_err(out_indices.mData, out_indices_ref.mData);
|
||||
};
|
||||
};
|
||||
|
||||
return (pass ? 0 : 1);
|
||||
}
|
||||
|
||||
301
example/12_reduce/reduce_blockwise_two_call.cpp
Normal file
301
example/12_reduce/reduce_blockwise_two_call.cpp
Normal file
@@ -0,0 +1,301 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <getopt.h>
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_base.hpp"
|
||||
#include "device_reduce_multiblock.hpp"
|
||||
#include "host_common_util.hpp"
|
||||
#include "host_reduction.hpp"
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
|
||||
using namespace ck;
|
||||
using namespace ck::tensor_operation::device;
|
||||
|
||||
using InOutDataType = ck::half_t;
|
||||
using InOutDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
|
||||
constexpr ReduceTensorOp ReduceOpId = ReduceTensorOp::NORM2;
|
||||
constexpr bool PropagateNan = true;
|
||||
constexpr bool OutputIndex = false;
|
||||
|
||||
using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
|
||||
using InElementwiseOperation =
|
||||
typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
|
||||
using AccElementwiseOperation =
|
||||
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
|
||||
|
||||
using PassThroughOp = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using DeviceReduceInstance_1 = DeviceReduceMultiBlock<InOutDataType,
|
||||
AccDataType,
|
||||
InOutDataType,
|
||||
5, // Rank
|
||||
1, // NumReduceDim
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
PassThroughOp,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
PropagateNan,
|
||||
OutputIndex,
|
||||
false, // HaveIndexInputIfOutputIndex
|
||||
256,
|
||||
32,
|
||||
8,
|
||||
1,
|
||||
1,
|
||||
1, // vector dim
|
||||
1,
|
||||
1>;
|
||||
|
||||
using DeviceReduceInstance_2 = DeviceReduceMultiBlock<InOutDataType,
|
||||
AccDataType,
|
||||
InOutDataType,
|
||||
4, // Rank
|
||||
1, // NumReduceDim
|
||||
ReduceOperation,
|
||||
PassThroughOp,
|
||||
AccElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
PropagateNan,
|
||||
OutputIndex,
|
||||
false, // HaveIndexInputIfOutputIndex
|
||||
256,
|
||||
128,
|
||||
2,
|
||||
1,
|
||||
1,
|
||||
1, // vector dim
|
||||
1,
|
||||
1>;
|
||||
|
||||
static bool do_verify;
|
||||
static int init_method;
|
||||
static float alpha;
|
||||
static float beta;
|
||||
static bool time_kernel;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
// used by the device reduction
|
||||
const std::vector<int> reduceDims_1 = {4};
|
||||
const std::vector<int> invariantDims_1 = {0, 1, 2, 3};
|
||||
|
||||
const std::vector<int> reduceDims_2 = {3};
|
||||
const std::vector<int> invariantDims_2 = {0, 1, 2};
|
||||
|
||||
// used by the host reduction
|
||||
const std::vector<int> reduceDims = {3, 4};
|
||||
const std::vector<int> invariantDims = {0, 1, 2};
|
||||
|
||||
const std::vector<size_t> inLengths_1 = {64, 320, 80, 4, 128};
|
||||
|
||||
// input lengths of the second reduction, which is also the output lengths of the first
|
||||
// reduction
|
||||
const std::vector<size_t> inLengths_2 = {64, 320, 80, 4};
|
||||
|
||||
const std::vector<size_t> outLengths = {64, 320, 80};
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
do_verify = true;
|
||||
init_method = 2;
|
||||
time_kernel = true;
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verify = static_cast<bool>(argv[1]);
|
||||
init_method = atoi(argv[2]);
|
||||
time_kernel = static_cast<bool>(atoi(argv[3]));
|
||||
}
|
||||
else
|
||||
{
|
||||
std::ostringstream ostr;
|
||||
|
||||
ostr << "Wrong parameter! " << std::endl
|
||||
<< "Usage: " << argv[0] << "[verify 0/1] init_method time_kernel" << std::endl;
|
||||
|
||||
throw std::runtime_error(ostr.str());
|
||||
};
|
||||
|
||||
alpha = 1.0f;
|
||||
beta = 0.0f;
|
||||
|
||||
Tensor<InOutDataType> in_1(inLengths_1);
|
||||
|
||||
Tensor<InOutDataType> out_ref(outLengths);
|
||||
Tensor<InOutDataType> in_2(inLengths_2); // also the output tensor of the first reduction
|
||||
Tensor<InOutDataType> out(outLengths);
|
||||
|
||||
auto inStrides_1 = in_1.mDesc.GetStrides();
|
||||
auto inStrides_2 = in_2.mDesc.GetStrides();
|
||||
auto outStrides = out.mDesc.GetStrides();
|
||||
|
||||
size_t invariant_total_length = out.mDesc.GetElementSize();
|
||||
size_t reduce_total_length = in_1.mDesc.GetElementSize() / invariant_total_length;
|
||||
|
||||
std::size_t num_thread = 1;
|
||||
|
||||
if(do_verify)
|
||||
{
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
in_1.GenerateTensorValue(GeneratorTensor_1<InOutDataType>{1}, num_thread);
|
||||
if(beta != 0.0f)
|
||||
out_ref.GenerateTensorValue(GeneratorTensor_1<InOutDataType>{1}, num_thread);
|
||||
break;
|
||||
case 2:
|
||||
in_1.GenerateTensorValue(GeneratorTensor_2<InOutDataType>{-5, 5}, num_thread);
|
||||
if(beta != 0.0f)
|
||||
out_ref.GenerateTensorValue(GeneratorTensor_2<InOutDataType>{-5, 5}, num_thread);
|
||||
break;
|
||||
default:
|
||||
in_1.GenerateTensorValue(GeneratorTensor_3<InOutDataType>{-5.0, 5.0}, num_thread);
|
||||
if(beta != 0.0f)
|
||||
out_ref.GenerateTensorValue(GeneratorTensor_3<InOutDataType>{-5.0, 5.0},
|
||||
num_thread);
|
||||
}
|
||||
|
||||
if(beta != 0.0f)
|
||||
for(size_t i = 0; i < out_ref.mDesc.GetElementSpace(); i++)
|
||||
out.mData[i] = out_ref.mData[i];
|
||||
};
|
||||
|
||||
DeviceMem in_1_dev(sizeof(InOutDataType) * in_1.mDesc.GetElementSpace());
|
||||
DeviceMem in_2_dev(sizeof(InOutDataType) * in_2.mDesc.GetElementSpace());
|
||||
DeviceMem out_dev(sizeof(InOutDataType) * out.mDesc.GetElementSpace());
|
||||
|
||||
in_1_dev.ToDevice(in_1.mData.data());
|
||||
|
||||
if(beta != 0.0f)
|
||||
out_dev.ToDevice(out.mData.data());
|
||||
|
||||
InElementwiseOperation in_elementwise_op;
|
||||
AccElementwiseOperation acc_elementwise_op;
|
||||
|
||||
std::tie(in_elementwise_op, acc_elementwise_op) =
|
||||
reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(
|
||||
static_cast<int32_t>(reduce_total_length));
|
||||
|
||||
if(do_verify)
|
||||
{
|
||||
ReductionHost<InOutDataType,
|
||||
AccDataType,
|
||||
InOutDataType,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
5, // Rank
|
||||
2, // NumReduceDim
|
||||
PropagateNan,
|
||||
OutputIndex>
|
||||
hostReduce(in_1.mDesc, out_ref.mDesc, invariantDims, reduceDims);
|
||||
|
||||
hostReduce.Run(alpha,
|
||||
in_1.mData.data(),
|
||||
beta,
|
||||
out_ref.mData.data(),
|
||||
nullptr,
|
||||
in_elementwise_op,
|
||||
acc_elementwise_op);
|
||||
};
|
||||
|
||||
std::vector<ck::index_t> i_inLengths_1;
|
||||
std::vector<ck::index_t> i_inStrides_1;
|
||||
std::vector<ck::index_t> i_inLengths_2;
|
||||
std::vector<ck::index_t> i_inStrides_2;
|
||||
std::vector<ck::index_t> i_outLengths;
|
||||
std::vector<ck::index_t> i_outStrides;
|
||||
|
||||
i_inLengths_1.assign(inLengths_1.begin(), inLengths_1.end());
|
||||
i_inStrides_1.assign(inStrides_1.begin(), inStrides_1.end());
|
||||
i_inLengths_2.assign(inLengths_2.begin(), inLengths_2.end());
|
||||
i_inStrides_2.assign(inStrides_2.begin(), inStrides_2.end());
|
||||
i_outLengths.assign(outLengths.begin(), outLengths.end());
|
||||
i_outStrides.assign(outStrides.begin(), outStrides.end());
|
||||
|
||||
auto reduce_1 = DeviceReduceInstance_1{};
|
||||
|
||||
auto argument_ptr_1 = reduce_1.MakeArgumentPointer(i_inLengths_1,
|
||||
i_inStrides_1,
|
||||
i_inLengths_2,
|
||||
i_inStrides_2,
|
||||
reduceDims_1,
|
||||
1.0f,
|
||||
0.0f,
|
||||
in_1_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
in_2_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
in_elementwise_op,
|
||||
PassThroughOp{});
|
||||
|
||||
if(!reduce_1.IsSupportedArgument(argument_ptr_1.get()))
|
||||
{
|
||||
std::cout
|
||||
<< "The runtime parameters seems not supported by the DeviceReduce instance, exiting!"
|
||||
<< std::endl;
|
||||
};
|
||||
|
||||
auto invoker_ptr_1 = reduce_1.MakeInvokerPointer();
|
||||
|
||||
auto reduce_2 = DeviceReduceInstance_2{};
|
||||
|
||||
auto argument_ptr_2 = reduce_2.MakeArgumentPointer(i_inLengths_2,
|
||||
i_inStrides_2,
|
||||
i_outLengths,
|
||||
i_outStrides,
|
||||
reduceDims_2,
|
||||
alpha,
|
||||
beta,
|
||||
in_2_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
out_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
PassThroughOp{},
|
||||
acc_elementwise_op);
|
||||
|
||||
if(!reduce_2.IsSupportedArgument(argument_ptr_2.get()))
|
||||
{
|
||||
std::cout
|
||||
<< "The runtime parameters seems not supported by the DeviceReduce instance, exiting!"
|
||||
<< std::endl;
|
||||
};
|
||||
|
||||
auto invoker_ptr_2 = reduce_2.MakeInvokerPointer();
|
||||
|
||||
float avg_time_1 = invoker_ptr_1->Run(argument_ptr_1.get(), StreamConfig{nullptr, time_kernel});
|
||||
float avg_time_2 = invoker_ptr_2->Run(argument_ptr_2.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t num_bytes = invariant_total_length * reduce_total_length * sizeof(InOutDataType) +
|
||||
invariant_total_length * sizeof(InOutDataType);
|
||||
|
||||
float gb_per_sec = num_bytes / 1.E6 / (avg_time_1 + avg_time_2);
|
||||
|
||||
std::cout << "Perf: " << avg_time_1 + avg_time_2 << " ms, " << gb_per_sec << " GB/s, "
|
||||
<< reduce_1.GetTypeString() << " => " << reduce_2.GetTypeString() << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_verify)
|
||||
{
|
||||
out_dev.FromDevice(out.mData.data());
|
||||
pass = pass && ck::utils::check_err(out.mData, out_ref.mData);
|
||||
};
|
||||
|
||||
return (pass ? 0 : 1);
|
||||
}
|
||||
@@ -1 +1,3 @@
|
||||
add_example_executable(example_pool2d_fwd pool2d_fwd.cpp)
|
||||
add_example_executable(example_pool2d_fwd_fp16 pool2d_fwd_fp16.cpp)
|
||||
add_example_executable(example_pool2d_fwd_fp32 pool2d_fwd_fp32.cpp)
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
# Instructions for ```example_pool2d_fwd``` Example
|
||||
# Instructions for ```example_pool2d_fwd``` Examples
|
||||
|
||||
## Run ```example_pool2d_fwd```
|
||||
## Run ```example_pool2d_fwd_fp16```
|
||||
```bash
|
||||
#arg1: verification (0=no, 1=yes)
|
||||
#arg2: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)
|
||||
#arg3: run kernel # of times (>1)
|
||||
#arg3: time kernel (0=no, 1=yes)
|
||||
#arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, RightPx
|
||||
./bin/example_pool2d_fwd 1 1 10
|
||||
./bin/example_pool2d_fwd_fp16 1 1 1
|
||||
```
|
||||
|
||||
Result
|
||||
@@ -14,9 +14,28 @@ Result
|
||||
in_n_c_hi_wi: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192}
|
||||
out_n_c_ho_wo: dim 4, lengths {128, 192, 36, 36}, strides {248832, 1, 6912, 192}
|
||||
launch_and_time_kernel: grid_dim {124416, 1, 1}, block_dim {64, 1, 1}
|
||||
Warm up
|
||||
Warm up 1 time
|
||||
Start running 10 times...
|
||||
Perf: 0.415453 ms, 1.37996 TFlops, 749.726 GB/s
|
||||
error: 0
|
||||
max_diff: 0, 1, 1
|
||||
Perf: 0.397436 ms, 1.44252 TFlops, 783.713 GB/s
|
||||
```
|
||||
|
||||
## Run ```example_pool2d_fwd_fp32```
|
||||
```bash
|
||||
#arg1: verification (0=no, 1=yes)
|
||||
#arg2: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)
|
||||
#arg3: time kernel (0=no, 1=yes)
|
||||
#arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, RightPx
|
||||
./bin/example_pool2d_fwd_fp32 1 1 1
|
||||
```
|
||||
|
||||
|
||||
Result
|
||||
```
|
||||
./bin/example_pool2d_fwd_fp32 1 1 1
|
||||
in_n_c_hi_wi: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192}
|
||||
out_n_c_ho_wo: dim 4, lengths {128, 192, 36, 36}, strides {248832, 1, 6912, 192}
|
||||
launch_and_time_kernel: grid_dim {124416, 1, 1}, block_dim {64, 1, 1}
|
||||
Warm up 1 time
|
||||
Start running 10 times...
|
||||
Perf: 1.01823 ms, 0.563045 TFlops, 611.8 GB/s
|
||||
```
|
||||
|
||||
@@ -1,315 +0,0 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "host_reduce_util.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "tensor_layout.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
#include "device_pool2d_fwd_nhwc_nhwc.hpp"
|
||||
|
||||
using InDataType = ck::half_t;
|
||||
using OutDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
|
||||
using InLayout = ck::tensor_layout::convolution::NHWC;
|
||||
using OutLayout = ck::tensor_layout::convolution::NHWC;
|
||||
|
||||
#if 1
|
||||
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
|
||||
#else
|
||||
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
|
||||
#endif
|
||||
|
||||
static constexpr bool NeedIndices = false;
|
||||
static constexpr bool PropagateNan = false;
|
||||
|
||||
using DevicePoolFwdInstance =
|
||||
ck::tensor_operation::device::DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<
|
||||
InDataType, // InDataType
|
||||
OutDataType, // OutDataType
|
||||
AccDataType, // AccDataType
|
||||
ReduceOpId,
|
||||
NeedIndices,
|
||||
64, // BlockSize
|
||||
64, // ReduceMThreadClusterSize
|
||||
1, // ReduceKThreadClusterSize
|
||||
4, // ReduceMThreadSliceSize
|
||||
1, // ReduceKThreadSliceSize
|
||||
4>; // InSrcOutDstVectorSize
|
||||
|
||||
template <typename InDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
ck::ReduceTensorOp ReduceOpId,
|
||||
bool PropagateNan,
|
||||
bool NeedIndices>
|
||||
static void pool_host_verify(const Tensor<InDataType>& in,
|
||||
Tensor<OutDataType>& out,
|
||||
Tensor<int>& out_indices,
|
||||
const std::array<ck::index_t, 2>& window_spatial_lengths,
|
||||
const std::array<ck::index_t, 2>& window_strides,
|
||||
const std::array<ck::index_t, 2>& in_left_pads,
|
||||
const std::array<ck::index_t, 2>& /*in_right_pads*/)
|
||||
{
|
||||
using namespace ck::host_reduce;
|
||||
|
||||
const int divider = window_spatial_lengths[0] * window_spatial_lengths[1];
|
||||
|
||||
const auto PreUnaryOp = PreUnaryOpFn<AccDataType, ReduceOpId>(divider);
|
||||
const auto PosUnaryOp = PosUnaryOpFn<AccDataType, ReduceOpId>(divider);
|
||||
|
||||
if constexpr(!NeedIndices)
|
||||
{
|
||||
auto opReduce = ReduceOpFn<AccDataType, ReduceOpId>();
|
||||
|
||||
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
|
||||
auto accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>();
|
||||
|
||||
for(int y = 0; y < window_spatial_lengths[0]; ++y)
|
||||
{
|
||||
int hi = ho * window_strides[0] + y - in_left_pads[0];
|
||||
for(int x = 0; x < window_spatial_lengths[1]; ++x)
|
||||
{
|
||||
int wi = wo * window_strides[1] + x - in_left_pads[1];
|
||||
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
|
||||
wi < in.mDesc.GetLengths()[3])
|
||||
{
|
||||
AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi));
|
||||
|
||||
PreUnaryOp(currVal);
|
||||
|
||||
binop_with_nan_check<AccDataType, PropagateNan>(opReduce, accuVal, currVal);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
PosUnaryOp(accuVal);
|
||||
|
||||
out(n, c, ho, wo) = accuVal;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nchw,
|
||||
out.mDesc.GetLengths()[0],
|
||||
out.mDesc.GetLengths()[1],
|
||||
out.mDesc.GetLengths()[2],
|
||||
out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else
|
||||
{
|
||||
auto opReduce = ReduceOpFn2<AccDataType, ReduceOpId>();
|
||||
|
||||
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
|
||||
auto accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>();
|
||||
int accuIndex = 0;
|
||||
|
||||
for(int y = 0; y < window_spatial_lengths[0]; ++y)
|
||||
{
|
||||
int hi = ho * window_strides[0] + y - in_left_pads[0];
|
||||
for(int x = 0; x < window_spatial_lengths[1]; ++x)
|
||||
{
|
||||
int wi = wo * window_strides[1] + x - in_left_pads[1];
|
||||
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
|
||||
wi < in.mDesc.GetLengths()[3])
|
||||
{
|
||||
AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi));
|
||||
int currIndex = y * window_spatial_lengths[1] + x;
|
||||
|
||||
PreUnaryOp(currVal);
|
||||
|
||||
binop_with_nan_check2<AccDataType, PropagateNan>(
|
||||
opReduce, accuVal, currVal, accuIndex, currIndex);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
PosUnaryOp(accuVal);
|
||||
|
||||
out(n, c, ho, wo) = accuVal;
|
||||
out_indices(n, c, ho, wo) = accuIndex;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nchw,
|
||||
out.mDesc.GetLengths()[0],
|
||||
out.mDesc.GetLengths()[1],
|
||||
out.mDesc.GetLengths()[2],
|
||||
out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||
};
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck::host_reduce;
|
||||
|
||||
bool do_verification = 0;
|
||||
int init_method = 0;
|
||||
int nrepeat = 5;
|
||||
|
||||
// Pool shape
|
||||
ck::index_t N = 128;
|
||||
ck::index_t C = 192;
|
||||
ck::index_t Y = 3;
|
||||
ck::index_t X = 3;
|
||||
ck::index_t Hi = 71;
|
||||
ck::index_t Wi = 71;
|
||||
ck::index_t window_stride_h = 2;
|
||||
ck::index_t window_stride_w = 2;
|
||||
ck::index_t in_left_pad_h = 1;
|
||||
ck::index_t in_left_pad_w = 1;
|
||||
ck::index_t in_right_pad_h = 1;
|
||||
ck::index_t in_right_pad_w = 1;
|
||||
|
||||
if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
nrepeat = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 16)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
nrepeat = std::stoi(argv[3]);
|
||||
|
||||
N = std::stoi(argv[4]);
|
||||
C = std::stoi(argv[5]);
|
||||
Y = std::stoi(argv[6]);
|
||||
X = std::stoi(argv[7]);
|
||||
Hi = std::stoi(argv[8]);
|
||||
Wi = std::stoi(argv[9]);
|
||||
window_stride_h = std::stoi(argv[10]);
|
||||
window_stride_w = std::stoi(argv[11]);
|
||||
in_left_pad_h = std::stoi(argv[12]);
|
||||
in_left_pad_w = std::stoi(argv[13]);
|
||||
in_right_pad_h = std::stoi(argv[14]);
|
||||
in_right_pad_w = std::stoi(argv[15]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: run kernel # of times (>1)\n");
|
||||
printf("arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, "
|
||||
"RightPx\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Y) / window_stride_h + 1;
|
||||
const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - X) / window_stride_w + 1;
|
||||
|
||||
const std::array<ck::index_t, 2> window_spatial_lengths{{Y, X}};
|
||||
const std::array<ck::index_t, 2> window_strides{{window_stride_h, window_stride_w}};
|
||||
const std::array<ck::index_t, 2> input_left_pads{{in_left_pad_h, in_left_pad_w}};
|
||||
const std::array<ck::index_t, 2> input_right_pads{{in_right_pad_h, in_right_pad_w}};
|
||||
|
||||
// tensor layout
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W, auto layout) {
|
||||
if constexpr(ck::is_same<decltype(layout), ck::tensor_layout::convolution::NCHW>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({N_, C_, H, W}),
|
||||
std::vector<std::size_t>({C_ * H * W, H * W, W, 1}));
|
||||
}
|
||||
else if constexpr(ck::is_same<decltype(layout),
|
||||
ck::tensor_layout::convolution::NHWC>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({N_, C_, H, W}),
|
||||
std::vector<std::size_t>({C_ * H * W, 1, W * C_, C_}));
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<InDataType> in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{}));
|
||||
Tensor<OutDataType> out_n_c_ho_wo_host(f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{}));
|
||||
Tensor<int> out_indices_n_c_ho_wo_host(f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{}));
|
||||
Tensor<OutDataType> out_n_c_ho_wo_device(f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{}));
|
||||
Tensor<int> out_indices_n_c_ho_wo_device(f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{}));
|
||||
|
||||
std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl;
|
||||
std::cout << "out_n_c_ho_wo: " << out_n_c_ho_wo_host.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{1}); break;
|
||||
case 2: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}); break;
|
||||
default: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3<InDataType>{-5.0, 5.0});
|
||||
}
|
||||
|
||||
DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem out_device_buf(sizeof(OutDataType) * out_n_c_ho_wo_device.mDesc.GetElementSpace());
|
||||
DeviceMem out_indices_device_buf(sizeof(int) *
|
||||
out_indices_n_c_ho_wo_device.mDesc.GetElementSpace());
|
||||
|
||||
in_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
|
||||
|
||||
auto pool = DevicePoolFwdInstance{};
|
||||
auto invoker_ptr = pool.MakeInvokerPointer();
|
||||
auto argument_ptr =
|
||||
pool.MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
static_cast<int*>(out_indices_device_buf.GetDeviceBuffer()),
|
||||
N,
|
||||
C,
|
||||
std::array<ck::index_t, 2>{{Hi, Wi}},
|
||||
std::array<ck::index_t, 2>{{Y, X}},
|
||||
std::array<ck::index_t, 2>{{Ho, Wo}},
|
||||
window_strides,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
|
||||
if(!pool.IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
throw std::runtime_error("wrong! device_op with the specified compilation parameters does "
|
||||
"not support this problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat);
|
||||
|
||||
std::size_t flop = std::size_t(2) * N * C * Ho * Wo * Y * X;
|
||||
|
||||
std::size_t num_btype =
|
||||
sizeof(InDataType) * (N * C * Hi * Wi) + sizeof(OutDataType) * (N * C * Ho * Wo);
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
pool_host_verify<InDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
ReduceOpId,
|
||||
PropagateNan,
|
||||
NeedIndices>(in_n_c_hi_wi,
|
||||
out_n_c_ho_wo_host,
|
||||
out_indices_n_c_ho_wo_host,
|
||||
window_spatial_lengths,
|
||||
window_strides,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
|
||||
out_device_buf.FromDevice(out_n_c_ho_wo_device.mData.data());
|
||||
|
||||
ck::utils::check_err(out_n_c_ho_wo_device.mData, out_n_c_ho_wo_host.mData);
|
||||
|
||||
if constexpr(NeedIndices)
|
||||
{
|
||||
out_indices_device_buf.FromDevice(out_indices_n_c_ho_wo_device.mData.data());
|
||||
|
||||
// ck::utils::check_err(out_indices_n_c_ho_wo_device.mData,
|
||||
// out_indices_n_c_ho_wo_host.mData);;
|
||||
};
|
||||
}
|
||||
}
|
||||
280
example/13_pool2d_fwd/pool2d_fwd_common.hpp
Normal file
280
example/13_pool2d_fwd/pool2d_fwd_common.hpp
Normal file
@@ -0,0 +1,280 @@
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "tensor_layout.hpp"
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
#include "reduction_functions_accumulate.hpp"
|
||||
|
||||
#include "device_pool2d_fwd_nhwc_nhwc.hpp"
|
||||
|
||||
template <typename InDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename IndexDataType,
|
||||
ck::ReduceTensorOp ReduceOpId,
|
||||
bool PropagateNan,
|
||||
bool OutputIndex>
|
||||
static void pool_host_verify(const Tensor<InDataType>& in,
|
||||
Tensor<OutDataType>& out,
|
||||
Tensor<IndexDataType>& out_indices,
|
||||
const std::array<ck::index_t, 2>& window_spatial_lengths,
|
||||
const std::array<ck::index_t, 2>& window_strides,
|
||||
const std::array<ck::index_t, 2>& in_left_pads,
|
||||
const std::array<ck::index_t, 2>& /*in_right_pads*/)
|
||||
{
|
||||
const int32_t reduceLength = window_spatial_lengths[0] * window_spatial_lengths[1];
|
||||
|
||||
using ReduceOperation = typename ck::reduce_binary_operator<ReduceOpId>::opType;
|
||||
|
||||
auto elementwise_ops =
|
||||
ck::reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(reduceLength);
|
||||
|
||||
auto in_elementwise_op = std::get<0>(elementwise_ops);
|
||||
auto acc_elementwise_op = std::get<1>(elementwise_ops);
|
||||
|
||||
if constexpr(!OutputIndex)
|
||||
{
|
||||
using Accumulation =
|
||||
ck::detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
|
||||
|
||||
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
|
||||
auto accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
|
||||
|
||||
for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y)
|
||||
{
|
||||
ck::index_t hi = ho * window_strides[0] + y - in_left_pads[0];
|
||||
for(ck::index_t x = 0; x < window_spatial_lengths[1]; ++x)
|
||||
{
|
||||
ck::index_t wi = wo * window_strides[1] + x - in_left_pads[1];
|
||||
if(hi >= 0 && hi < static_cast<ck::index_t>(in.mDesc.GetLengths()[2]) &&
|
||||
wi >= 0 && wi < static_cast<ck::index_t>(in.mDesc.GetLengths()[3]))
|
||||
{
|
||||
AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi));
|
||||
|
||||
in_elementwise_op(currVal, currVal);
|
||||
|
||||
Accumulation::Calculate(accuVal, currVal);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
acc_elementwise_op(accuVal, accuVal);
|
||||
|
||||
out(n, c, ho, wo) = accuVal;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nchw,
|
||||
out.mDesc.GetLengths()[0],
|
||||
out.mDesc.GetLengths()[1],
|
||||
out.mDesc.GetLengths()[2],
|
||||
out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else
|
||||
{
|
||||
using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck<PropagateNan,
|
||||
ReduceOperation,
|
||||
AccDataType,
|
||||
IndexDataType>;
|
||||
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
|
||||
auto accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
|
||||
IndexDataType accuIndex = 0;
|
||||
|
||||
for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y)
|
||||
{
|
||||
ck::index_t hi = ho * window_strides[0] + y - in_left_pads[0];
|
||||
for(ck::index_t x = 0; x < window_spatial_lengths[1]; ++x)
|
||||
{
|
||||
ck::index_t wi = wo * window_strides[1] + x - in_left_pads[1];
|
||||
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
|
||||
wi < in.mDesc.GetLengths()[3])
|
||||
{
|
||||
AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi));
|
||||
IndexDataType currIndex = y * window_spatial_lengths[1] + x;
|
||||
|
||||
in_elementwise_op(currVal, currVal);
|
||||
|
||||
Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
acc_elementwise_op(accuVal, accuVal);
|
||||
|
||||
out(n, c, ho, wo) = accuVal;
|
||||
out_indices(n, c, ho, wo) = accuIndex;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nchw,
|
||||
out.mDesc.GetLengths()[0],
|
||||
out.mDesc.GetLengths()[1],
|
||||
out.mDesc.GetLengths()[2],
|
||||
out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||
};
|
||||
}
|
||||
|
||||
template <typename InDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename IndexDataType,
|
||||
typename InLayout,
|
||||
typename OutLayout,
|
||||
ck::ReduceTensorOp ReduceOpId,
|
||||
bool PropagateNan,
|
||||
bool OutputIndex>
|
||||
bool pool_test(bool do_verification,
|
||||
int init_method,
|
||||
bool time_kernel,
|
||||
ck::index_t N,
|
||||
ck::index_t C,
|
||||
ck::index_t Y,
|
||||
ck::index_t X,
|
||||
ck::index_t Hi,
|
||||
ck::index_t Wi,
|
||||
ck::index_t window_stride_h,
|
||||
ck::index_t window_stride_w,
|
||||
ck::index_t in_left_pad_h,
|
||||
ck::index_t in_left_pad_w,
|
||||
ck::index_t in_right_pad_h,
|
||||
ck::index_t in_right_pad_w)
|
||||
{
|
||||
using DevicePoolFwdInstance =
|
||||
ck::tensor_operation::device::DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<
|
||||
InDataType, // InDataType
|
||||
OutDataType, // OutDataType
|
||||
AccDataType, // AccDataType
|
||||
ReduceOpId,
|
||||
OutputIndex,
|
||||
64, // BlockSize
|
||||
64, // ReduceMThreadClusterSize
|
||||
1, // ReduceKThreadClusterSize
|
||||
4, // ReduceMThreadSliceSize
|
||||
1, // ReduceKThreadSliceSize
|
||||
4>; // InSrcOutDstVectorSize
|
||||
|
||||
const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Y) / window_stride_h + 1;
|
||||
const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - X) / window_stride_w + 1;
|
||||
|
||||
const std::array<ck::index_t, 2> window_spatial_lengths{{Y, X}};
|
||||
const std::array<ck::index_t, 2> window_strides{{window_stride_h, window_stride_w}};
|
||||
const std::array<ck::index_t, 2> input_left_pads{{in_left_pad_h, in_left_pad_w}};
|
||||
const std::array<ck::index_t, 2> input_right_pads{{in_right_pad_h, in_right_pad_w}};
|
||||
|
||||
// tensor layout
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W, auto layout) {
|
||||
if constexpr(ck::is_same<decltype(layout), ck::tensor_layout::convolution::NCHW>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({N_, C_, H, W}),
|
||||
std::vector<std::size_t>({C_ * H * W, H * W, W, 1}));
|
||||
}
|
||||
else if constexpr(ck::is_same<decltype(layout),
|
||||
ck::tensor_layout::convolution::NHWC>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({N_, C_, H, W}),
|
||||
std::vector<std::size_t>({C_ * H * W, 1, W * C_, C_}));
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<InDataType> in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{}));
|
||||
Tensor<OutDataType> out_n_c_ho_wo_host(f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{}));
|
||||
Tensor<IndexDataType> out_indices_n_c_ho_wo_host(
|
||||
f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{}));
|
||||
Tensor<OutDataType> out_n_c_ho_wo_device(f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{}));
|
||||
Tensor<IndexDataType> out_indices_n_c_ho_wo_device(
|
||||
f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{}));
|
||||
|
||||
std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl;
|
||||
std::cout << "out_n_c_ho_wo: " << out_n_c_ho_wo_host.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{1}); break;
|
||||
case 2: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}); break;
|
||||
default: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3<InDataType>{-5.0, 5.0});
|
||||
}
|
||||
|
||||
DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem out_device_buf(sizeof(OutDataType) * out_n_c_ho_wo_device.mDesc.GetElementSpace());
|
||||
DeviceMem out_indices_device_buf(sizeof(IndexDataType) *
|
||||
out_indices_n_c_ho_wo_device.mDesc.GetElementSpace());
|
||||
|
||||
in_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
|
||||
|
||||
auto pool = DevicePoolFwdInstance{};
|
||||
auto invoker_ptr = pool.MakeInvokerPointer();
|
||||
auto argument_ptr = pool.MakeArgumentPointer(
|
||||
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
static_cast<IndexDataType*>(out_indices_device_buf.GetDeviceBuffer()),
|
||||
N,
|
||||
C,
|
||||
std::array<ck::index_t, 2>{{Hi, Wi}},
|
||||
std::array<ck::index_t, 2>{{Y, X}},
|
||||
std::array<ck::index_t, 2>{{Ho, Wo}},
|
||||
window_strides,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
|
||||
if(!pool.IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
throw std::runtime_error("wrong! device_op with the specified compilation parameters does "
|
||||
"not support this problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * N * C * Ho * Wo * Y * X;
|
||||
|
||||
std::size_t num_btype =
|
||||
sizeof(InDataType) * (N * C * Hi * Wi) + sizeof(OutDataType) * (N * C * Ho * Wo);
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
pool_host_verify<InDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
IndexDataType,
|
||||
ReduceOpId,
|
||||
PropagateNan,
|
||||
OutputIndex>(in_n_c_hi_wi,
|
||||
out_n_c_ho_wo_host,
|
||||
out_indices_n_c_ho_wo_host,
|
||||
window_spatial_lengths,
|
||||
window_strides,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
|
||||
out_device_buf.FromDevice(out_n_c_ho_wo_device.mData.data());
|
||||
|
||||
pass = pass && ck::utils::check_err(out_n_c_ho_wo_device.mData, out_n_c_ho_wo_host.mData);
|
||||
|
||||
if constexpr(OutputIndex)
|
||||
{
|
||||
out_indices_device_buf.FromDevice(out_indices_n_c_ho_wo_device.mData.data());
|
||||
|
||||
pass = pass && ck::utils::check_err(out_indices_n_c_ho_wo_device.mData,
|
||||
out_indices_n_c_ho_wo_host.mData);
|
||||
};
|
||||
}
|
||||
|
||||
return (pass);
|
||||
};
|
||||
114
example/13_pool2d_fwd/pool2d_fwd_fp16.cpp
Normal file
114
example/13_pool2d_fwd/pool2d_fwd_fp16.cpp
Normal file
@@ -0,0 +1,114 @@
|
||||
#include <iostream>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "config.hpp"
|
||||
#include "tensor_layout.hpp"
|
||||
#include "reduction_enums.hpp"
|
||||
|
||||
#include "pool2d_fwd_common.hpp"
|
||||
|
||||
using InDataType = ck::half_t;
|
||||
using OutDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
|
||||
using IndexDataType = int32_t;
|
||||
|
||||
using InLayout = ck::tensor_layout::convolution::NHWC;
|
||||
using OutLayout = ck::tensor_layout::convolution::NHWC;
|
||||
|
||||
#if 1
|
||||
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
|
||||
#else
|
||||
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
|
||||
#endif
|
||||
|
||||
static constexpr bool OutputIndex = false;
|
||||
static constexpr bool PropagateNan = false;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification;
|
||||
int init_method;
|
||||
bool time_kernel;
|
||||
|
||||
// Pool shape
|
||||
ck::index_t N = 128;
|
||||
ck::index_t C = 192;
|
||||
ck::index_t Y = 3;
|
||||
ck::index_t X = 3;
|
||||
ck::index_t Hi = 71;
|
||||
ck::index_t Wi = 71;
|
||||
ck::index_t window_stride_h = 2;
|
||||
ck::index_t window_stride_w = 2;
|
||||
ck::index_t in_left_pad_h = 1;
|
||||
ck::index_t in_left_pad_w = 1;
|
||||
ck::index_t in_right_pad_h = 1;
|
||||
ck::index_t in_right_pad_w = 1;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
do_verification = true;
|
||||
init_method = 1;
|
||||
time_kernel = true;
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = static_cast<bool>(std::stoi(argv[3]));
|
||||
}
|
||||
else if(argc == 16)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = static_cast<bool>(std::stoi(argv[3]));
|
||||
|
||||
N = std::stoi(argv[4]);
|
||||
C = std::stoi(argv[5]);
|
||||
Y = std::stoi(argv[6]);
|
||||
X = std::stoi(argv[7]);
|
||||
Hi = std::stoi(argv[8]);
|
||||
Wi = std::stoi(argv[9]);
|
||||
window_stride_h = std::stoi(argv[10]);
|
||||
window_stride_w = std::stoi(argv[11]);
|
||||
in_left_pad_h = std::stoi(argv[12]);
|
||||
in_left_pad_w = std::stoi(argv[13]);
|
||||
in_right_pad_h = std::stoi(argv[14]);
|
||||
in_right_pad_w = std::stoi(argv[15]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, "
|
||||
"RightPx\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
bool pass = pool_test<InDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
IndexDataType,
|
||||
InLayout,
|
||||
OutLayout,
|
||||
ReduceOpId,
|
||||
PropagateNan,
|
||||
OutputIndex>(do_verification,
|
||||
init_method,
|
||||
time_kernel,
|
||||
N,
|
||||
C,
|
||||
Y,
|
||||
X,
|
||||
Hi,
|
||||
Wi,
|
||||
window_stride_h,
|
||||
window_stride_w,
|
||||
in_left_pad_h,
|
||||
in_left_pad_w,
|
||||
in_right_pad_h,
|
||||
in_right_pad_w);
|
||||
|
||||
return (pass ? 0 : 1);
|
||||
}
|
||||
114
example/13_pool2d_fwd/pool2d_fwd_fp32.cpp
Normal file
114
example/13_pool2d_fwd/pool2d_fwd_fp32.cpp
Normal file
@@ -0,0 +1,114 @@
|
||||
#include <iostream>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "config.hpp"
|
||||
#include "tensor_layout.hpp"
|
||||
#include "reduction_enums.hpp"
|
||||
|
||||
#include "pool2d_fwd_common.hpp"
|
||||
|
||||
using InDataType = float;
|
||||
using OutDataType = float;
|
||||
using AccDataType = float;
|
||||
|
||||
using IndexDataType = int32_t;
|
||||
|
||||
using InLayout = ck::tensor_layout::convolution::NHWC;
|
||||
using OutLayout = ck::tensor_layout::convolution::NHWC;
|
||||
|
||||
#if 1
|
||||
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
|
||||
#else
|
||||
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
|
||||
#endif
|
||||
|
||||
static constexpr bool OutputIndex = false;
|
||||
static constexpr bool PropagateNan = false;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification;
|
||||
int init_method;
|
||||
bool time_kernel;
|
||||
|
||||
// Pool shape
|
||||
ck::index_t N = 128;
|
||||
ck::index_t C = 192;
|
||||
ck::index_t Y = 3;
|
||||
ck::index_t X = 3;
|
||||
ck::index_t Hi = 71;
|
||||
ck::index_t Wi = 71;
|
||||
ck::index_t window_stride_h = 2;
|
||||
ck::index_t window_stride_w = 2;
|
||||
ck::index_t in_left_pad_h = 1;
|
||||
ck::index_t in_left_pad_w = 1;
|
||||
ck::index_t in_right_pad_h = 1;
|
||||
ck::index_t in_right_pad_w = 1;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
do_verification = true;
|
||||
init_method = 1;
|
||||
time_kernel = true;
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = static_cast<bool>(std::stoi(argv[3]));
|
||||
}
|
||||
else if(argc == 16)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = static_cast<bool>(std::stoi(argv[3]));
|
||||
|
||||
N = std::stoi(argv[4]);
|
||||
C = std::stoi(argv[5]);
|
||||
Y = std::stoi(argv[6]);
|
||||
X = std::stoi(argv[7]);
|
||||
Hi = std::stoi(argv[8]);
|
||||
Wi = std::stoi(argv[9]);
|
||||
window_stride_h = std::stoi(argv[10]);
|
||||
window_stride_w = std::stoi(argv[11]);
|
||||
in_left_pad_h = std::stoi(argv[12]);
|
||||
in_left_pad_w = std::stoi(argv[13]);
|
||||
in_right_pad_h = std::stoi(argv[14]);
|
||||
in_right_pad_w = std::stoi(argv[15]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, "
|
||||
"RightPx\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
bool pass = pool_test<InDataType,
|
||||
OutDataType,
|
||||
AccDataType,
|
||||
IndexDataType,
|
||||
InLayout,
|
||||
OutLayout,
|
||||
ReduceOpId,
|
||||
PropagateNan,
|
||||
OutputIndex>(do_verification,
|
||||
init_method,
|
||||
time_kernel,
|
||||
N,
|
||||
C,
|
||||
Y,
|
||||
X,
|
||||
Hi,
|
||||
Wi,
|
||||
window_stride_h,
|
||||
window_stride_w,
|
||||
in_left_pad_h,
|
||||
in_left_pad_w,
|
||||
in_right_pad_h,
|
||||
in_right_pad_w);
|
||||
|
||||
return (pass ? 0 : 1);
|
||||
}
|
||||
@@ -100,14 +100,19 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
|
||||
16>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, RequantReluRequant>;
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
float,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
RequantReluRequant>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = 0;
|
||||
int init_method = 0;
|
||||
int nrepeat = 5;
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 3840;
|
||||
@@ -125,13 +130,13 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
nrepeat = std::stoi(argv[3]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 10)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
nrepeat = std::stoi(argv[3]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
@@ -145,7 +150,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: run kernel # of times (>1)\n");
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
|
||||
exit(0);
|
||||
}
|
||||
@@ -219,7 +224,7 @@ int main(int argc, char* argv[])
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, nrepeat);
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
@@ -244,7 +249,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
|
||||
return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
@@ -56,29 +56,29 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmXdl
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>;
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = 0;
|
||||
int init_method = 0;
|
||||
int nrepeat = 5;
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
nrepeat = std::stoi(argv[3]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: run kernel # of times (>1)\n");
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
int group_count = 4;
|
||||
int group_count = rand() % 16 + 1;
|
||||
|
||||
// GEMM shape
|
||||
std::vector<ck::tensor_operation::device::GemmShape> gemm_shapes;
|
||||
@@ -131,7 +131,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
|
||||
for(int i = 0; i < gemm_shapes.size(); i++)
|
||||
for(std::size_t i = 0; i < gemm_shapes.size(); i++)
|
||||
{
|
||||
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
|
||||
gemm_shapes[i].M, gemm_shapes[i].K, gemm_shapes[i].StrideA, ALayout{})));
|
||||
@@ -168,7 +168,7 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
}
|
||||
|
||||
for(int i = 0; i < gemm_shapes.size(); i++)
|
||||
for(std::size_t i = 0; i < gemm_shapes.size(); i++)
|
||||
{
|
||||
a_tensors_device.emplace_back(
|
||||
std::make_unique<DeviceMem>(sizeof(ADataType) * a_tensors[i].mDesc.GetElementSpace()));
|
||||
@@ -189,12 +189,17 @@ int main(int argc, char* argv[])
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
|
||||
// do GEMM
|
||||
auto argument =
|
||||
gemm.MakeArgument(p_a, p_b, p_c, gemm_shapes, a_element_op, b_element_op, c_element_op);
|
||||
|
||||
DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(&argument));
|
||||
|
||||
gemm.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer());
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
@@ -202,7 +207,7 @@ int main(int argc, char* argv[])
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, nrepeat);
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
@@ -211,9 +216,10 @@ int main(int argc, char* argv[])
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< gemm.GetTypeString() << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
if(do_verification)
|
||||
{
|
||||
for(int i = 0; i < gemm_shapes.size(); i++)
|
||||
for(std::size_t i = 0; i < gemm_shapes.size(); i++)
|
||||
{
|
||||
c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data());
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
@@ -227,9 +233,9 @@ int main(int argc, char* argv[])
|
||||
c_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
ck::utils::check_err(c_device_tensors[i].mData, c_host_tensors[i].mData);
|
||||
pass &= ck::utils::check_err(c_device_tensors[i].mData, c_host_tensors[i].mData);
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
add_example_executable(example_gemm_reduce_xdl_fp16 gemm_reduce_xdl_fp16.cpp)
|
||||
add_example_executable(example_gemm_reduce_xdl_max_fp16 gemm_reduce_xdl_max_fp16.cpp)
|
||||
add_example_executable(example_gemm_reduce_xdl_mean_squaremean_fp16 gemm_reduce_xdl_mean_squaremean_fp16.cpp)
|
||||
|
||||
@@ -1,273 +0,0 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
#include "config.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_gemm_reduce_xdl_cshuffle.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reference_gemm.hpp"
|
||||
#include "gemm_specialization.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using CDataType = F16;
|
||||
using DDataType = F32;
|
||||
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using D0ReduceOp = ck::reduce::Add<float>;
|
||||
using D1ReduceOp = ck::reduce::Add<float>;
|
||||
using D1ElementOp = ck::tensor_operation::element_wise::UnarySquare<float, float, false>;
|
||||
|
||||
static constexpr auto GemmSpecialization =
|
||||
ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle
|
||||
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0| D1| D1EleOp| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
|
||||
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, AElementOp, BElementOp, CElementOp, D0ReduceOp, D1ReduceOp, D1ElementOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = 1;
|
||||
int init_method = 1;
|
||||
int nrepeat = 5;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 3840;
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
|
||||
ck::index_t StrideA = 4096;
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideC = 4096;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// do nothing
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
nrepeat = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 10)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
nrepeat = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
|
||||
StrideA = std::stoi(argv[7]);
|
||||
StrideB = std::stoi(argv[8]);
|
||||
StrideC = std::stoi(argv[9]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: run kernel # of times (>1)\n");
|
||||
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({stride, 1}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({1, stride}));
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<DDataType> d0_m_host_result(
|
||||
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
|
||||
Tensor<DDataType> d1_m_host_result(
|
||||
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
|
||||
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<DDataType> d0_m_device_result(
|
||||
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
|
||||
Tensor<DDataType> d1_m_device_result(
|
||||
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
|
||||
std::cout << "d0_m: " << d0_m_host_result.mDesc << std::endl;
|
||||
std::cout << "d1_m: " << d1_m_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
break;
|
||||
}
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
|
||||
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
|
||||
DeviceMem d0_device_buf(sizeof(DDataType) * d0_m_device_result.mDesc.GetElementSpace());
|
||||
DeviceMem d1_device_buf(sizeof(DDataType) * d1_m_device_result.mDesc.GetElementSpace());
|
||||
|
||||
a_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_device_buf.ToDevice(b_k_n.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
auto d1_element_op = D1ElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto gemm = DeviceGemmReduceInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<DDataType*>(d0_device_buf.GetDeviceBuffer()),
|
||||
static_cast<DDataType*>(d1_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
d1_element_op);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
// warm up
|
||||
invoker.Run(argument);
|
||||
|
||||
// timing
|
||||
float total_time = 0;
|
||||
|
||||
for(int i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
// init DO, D1 to 0
|
||||
d0_device_buf.SetZero();
|
||||
d1_device_buf.SetZero();
|
||||
|
||||
KernelTimer timer;
|
||||
|
||||
timer.Start();
|
||||
|
||||
invoker.Run(argument);
|
||||
|
||||
timer.End();
|
||||
|
||||
total_time += timer.GetElapsedTime();
|
||||
}
|
||||
|
||||
float ave_time = total_time / nrepeat;
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< gemm.GetTypeString() << std::endl;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
d0_device_buf.FromDevice(d0_m_device_result.mData.data());
|
||||
d1_device_buf.FromDevice(d1_m_device_result.mData.data());
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
auto d0_reduce_op = D0ReduceOp{};
|
||||
auto d1_reduce_op = D1ReduceOp{};
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
float d0_acc = d0_reduce_op.GetReductionZeroVal();
|
||||
float d1_acc = d1_reduce_op.GetReductionZeroVal();
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
float d0_val = ck::type_convert<float>(c_m_n_host_result(m, n));
|
||||
float d1_val;
|
||||
|
||||
d1_element_op(d1_val, d0_val);
|
||||
d0_reduce_op(d0_acc, d0_val);
|
||||
d1_reduce_op(d1_acc, d1_val);
|
||||
}
|
||||
|
||||
d0_m_host_result(m) = ck::type_convert<DDataType>(d0_acc);
|
||||
d1_m_host_result(m) = ck::type_convert<DDataType>(d1_acc);
|
||||
}
|
||||
|
||||
check_error(c_m_n_host_result, c_m_n_device_result);
|
||||
check_error(d0_m_host_result, d0_m_device_result);
|
||||
check_error(d1_m_host_result, d1_m_device_result);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
268
example/16_gemm_reduce/gemm_reduce_xdl_max_fp16.cpp
Normal file
268
example/16_gemm_reduce/gemm_reduce_xdl_max_fp16.cpp
Normal file
@@ -0,0 +1,268 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_gemm_reduce_xdl_cshuffle.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "reference_gemm.hpp"
|
||||
#include "gemm_specialization.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using F64 = double;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using CDataType = F16;
|
||||
using GemmAccDataType = F32;
|
||||
using ReduceAccDataType = F32;
|
||||
using DDataType = F64;
|
||||
using DPtrsGlobal = ck::Tuple<DDataType*>;
|
||||
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using DsReduceOp = ck::Tuple<ck::reduce::Max>;
|
||||
using DsElementOp = ck::Tuple<ck::tensor_operation::element_wise::PassThrough>;
|
||||
using DGlobalMemOp =
|
||||
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicMax>;
|
||||
|
||||
static constexpr auto GemmSpecialization =
|
||||
ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle
|
||||
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
|
||||
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< Row, Col, Row, F16, F16, F16, F32, F32, ReduceAccDataType, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DsReduceOp, DsElementOp, DsElementOp, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
GemmAccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
template <typename ADataType, typename BDataType, typename CDataType, typename DDataType>
|
||||
void DumpGemmLayerNormPerf(float gemm_reduce_time, int M, int N, int K)
|
||||
{
|
||||
std::size_t gemm_flop = std::size_t(2) * M * N * K;
|
||||
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
|
||||
sizeof(CDataType) * M * N + sizeof(DDataType) * M;
|
||||
|
||||
float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time;
|
||||
float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time;
|
||||
|
||||
std::cout << "gemm + reduceMax Perf: " << gemm_reduce_time << " ms, " << tflops << " TFlops, "
|
||||
<< gemm_gb_per_sec << " GB/s, " << std::endl;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 3840;
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
|
||||
ck::index_t StrideA = 4096;
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideC = 4096;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// do nothing
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 10)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
|
||||
StrideA = std::stoi(argv[7]);
|
||||
StrideB = std::stoi(argv[8]);
|
||||
StrideC = std::stoi(argv[9]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: run kernel # of times (>1)\n");
|
||||
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({stride, 1}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({1, stride}));
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<DDataType> d_m_host_result(
|
||||
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
|
||||
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<DDataType> d_m_device_result(
|
||||
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
|
||||
std::cout << "d_m: " << d_m_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
break;
|
||||
}
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
|
||||
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
|
||||
DeviceMem d_device_buf(sizeof(DDataType) * d_m_device_result.mDesc.GetElementSpace());
|
||||
|
||||
a_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_device_buf.ToDevice(b_k_n.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
auto ds_element_op = DsElementOp{};
|
||||
auto p_ds_global = ck::make_tuple(static_cast<DDataType*>(d_device_buf.GetDeviceBuffer()));
|
||||
|
||||
// do GEMM
|
||||
auto gemm = DeviceGemmReduceInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
p_ds_global,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
ds_element_op,
|
||||
ds_element_op);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
// [CAUSION]: launch_and_time_kernel will not initialize D.
|
||||
// If we evaluate kernel multiple time but without initialize D. Verification will fail
|
||||
d_device_buf.SetValue(ck::NumericLimits<DDataType>::Lowest());
|
||||
invoker.Run(argument, StreamConfig{nullptr, false});
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
d_device_buf.FromDevice(d_m_device_result.mData.data());
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
auto d_reduce_op = DsReduceOp{}[ck::Number<0>{}];
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
ReduceAccDataType d_acc = d_reduce_op.GetIdentityValue<ReduceAccDataType>();
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
ReduceAccDataType curr_val =
|
||||
ck::type_convert<ReduceAccDataType>(c_m_n_host_result(m, n));
|
||||
d_reduce_op(d_acc, curr_val);
|
||||
};
|
||||
|
||||
d_m_host_result(m) = d_acc;
|
||||
}
|
||||
|
||||
pass = ck::utils::check_err(c_m_n_device_result.mData,
|
||||
c_m_n_host_result.mData,
|
||||
"Error: Incorrect results c") &&
|
||||
ck::utils::check_err(d_m_device_result.mData,
|
||||
d_m_host_result.mData,
|
||||
"Error: Incorrect results d",
|
||||
1e-3,
|
||||
1e-3);
|
||||
}
|
||||
|
||||
if(time_kernel)
|
||||
{
|
||||
float gemm_reduceMax_ave_time = invoker.Run(argument, StreamConfig{nullptr, true});
|
||||
|
||||
DumpGemmLayerNormPerf<ADataType, BDataType, CDataType, DDataType>(
|
||||
gemm_reduceMax_ave_time, M, N, K);
|
||||
}
|
||||
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
305
example/16_gemm_reduce/gemm_reduce_xdl_mean_squaremean_fp16.cpp
Normal file
305
example/16_gemm_reduce/gemm_reduce_xdl_mean_squaremean_fp16.cpp
Normal file
@@ -0,0 +1,305 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_gemm_reduce_xdl_cshuffle.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reference_gemm.hpp"
|
||||
#include "gemm_specialization.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using CDataType = F16;
|
||||
using GemmAccDataType = F32;
|
||||
using ReduceAccDataType = F32;
|
||||
using DDataType = F32;
|
||||
using DPtrsGlobal = ck::Tuple<DDataType*, DDataType*>;
|
||||
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using D0ReduceOp = ck::reduce::Add;
|
||||
using D1ReduceOp = ck::reduce::Add;
|
||||
using DxsReduceOp = ck::Tuple<D0ReduceOp, D1ReduceOp>;
|
||||
|
||||
using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide;
|
||||
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
|
||||
using DxsOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
|
||||
|
||||
using DGlobalMemOp =
|
||||
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
|
||||
ck::InMemoryDataOperationEnum::AtomicAdd>;
|
||||
|
||||
static constexpr auto GemmSpecialization =
|
||||
ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle
|
||||
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
|
||||
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOps, DxsOutElementOps, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
GemmAccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
template <typename ADataType, typename BDataType, typename CDataType, typename DDataType>
|
||||
void DumpGemmLayerNormPerf(float gemm_reduce_time, int M, int N, int K)
|
||||
{
|
||||
std::size_t gemm_flop = std::size_t(2) * M * N * K;
|
||||
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
|
||||
sizeof(CDataType) * M * N + sizeof(DDataType) * M +
|
||||
sizeof(DDataType) * M;
|
||||
|
||||
float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time;
|
||||
float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time;
|
||||
|
||||
std::cout << "gemm + reduce_mean + reduce_mean_square Perf: " << gemm_reduce_time << " ms, "
|
||||
<< tflops << " TFlops, " << gemm_gb_per_sec << " GB/s, " << std::endl;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 3840;
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
|
||||
ck::index_t StrideA = 4096;
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideC = 4096;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// do nothing
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 10)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
|
||||
StrideA = std::stoi(argv[7]);
|
||||
StrideB = std::stoi(argv[8]);
|
||||
StrideC = std::stoi(argv[9]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({stride, 1}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({1, stride}));
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<DDataType> d0_m_host_result(
|
||||
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
|
||||
Tensor<DDataType> d1_m_host_result(
|
||||
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
|
||||
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<DDataType> d0_m_device_result(
|
||||
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
|
||||
Tensor<DDataType> d1_m_device_result(
|
||||
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
|
||||
std::cout << "d0_m: " << d0_m_host_result.mDesc << std::endl;
|
||||
std::cout << "d1_m: " << d1_m_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
break;
|
||||
}
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
|
||||
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
|
||||
DeviceMem d0_device_buf(sizeof(DDataType) * d0_m_device_result.mDesc.GetElementSpace());
|
||||
DeviceMem d1_device_buf(sizeof(DDataType) * d1_m_device_result.mDesc.GetElementSpace());
|
||||
|
||||
a_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_device_buf.ToDevice(b_k_n.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
auto dxs_global = ck::make_tuple(static_cast<DDataType*>(d0_device_buf.GetDeviceBuffer()),
|
||||
static_cast<DDataType*>(d1_device_buf.GetDeviceBuffer()));
|
||||
|
||||
auto dxs_in_element_op = DxsInElementOps{};
|
||||
auto dxs_out_element_op = DxsOutElementOps{N, N};
|
||||
|
||||
// do GEMM
|
||||
auto gemm = DeviceGemmReduceInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
dxs_global,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
dxs_in_element_op,
|
||||
dxs_out_element_op);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
// init DO, D1 to 0
|
||||
d0_device_buf.SetZero();
|
||||
d1_device_buf.SetZero();
|
||||
|
||||
// if time_kernel == true, kernel will run multiple times. This kernel use atomic-add so result
|
||||
// will not be correct. need to set time_kernel = false for correctness test
|
||||
invoker.Run(argument, StreamConfig{nullptr, false});
|
||||
bool pass = true;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
d0_device_buf.FromDevice(d0_m_device_result.mData.data());
|
||||
d1_device_buf.FromDevice(d1_m_device_result.mData.data());
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
auto d0_reduce_op = D0ReduceOp{};
|
||||
auto d1_reduce_op = D1ReduceOp{};
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
auto d0_acc = d0_reduce_op.GetIdentityValue<ReduceAccDataType>();
|
||||
auto d1_acc = d1_reduce_op.GetIdentityValue<ReduceAccDataType>();
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
auto c_val = ck::type_convert<ReduceAccDataType>(c_m_n_host_result(m, n));
|
||||
ReduceAccDataType d0_val;
|
||||
ReduceAccDataType d1_val;
|
||||
|
||||
dxs_in_element_op(ck::Number<0>{})(d0_val, c_val);
|
||||
dxs_in_element_op(ck::Number<1>{})(d1_val, c_val);
|
||||
d0_reduce_op(d0_acc, d0_val);
|
||||
d1_reduce_op(d1_acc, d1_val);
|
||||
}
|
||||
|
||||
dxs_out_element_op(ck::Number<0>{})(d0_acc, d0_acc);
|
||||
dxs_out_element_op(ck::Number<1>{})(d1_acc, d1_acc);
|
||||
d0_m_host_result(m) = ck::type_convert<DDataType>(d0_acc);
|
||||
d1_m_host_result(m) = ck::type_convert<DDataType>(d1_acc);
|
||||
}
|
||||
|
||||
pass = ck::utils::check_err(c_m_n_device_result.mData,
|
||||
c_m_n_host_result.mData,
|
||||
"Error: Incorrect results c") &&
|
||||
ck::utils::check_err(d0_m_device_result.mData,
|
||||
d0_m_host_result.mData,
|
||||
"Error: Incorrect results d0",
|
||||
1e-4,
|
||||
1e-5) &&
|
||||
ck::utils::check_err(d1_m_device_result.mData,
|
||||
d1_m_host_result.mData,
|
||||
"Error: Incorrect results d1",
|
||||
1e-3,
|
||||
1e-5);
|
||||
}
|
||||
|
||||
if(time_kernel)
|
||||
{
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, true});
|
||||
|
||||
DumpGemmLayerNormPerf<ADataType, BDataType, CDataType, DDataType>(ave_time, M, N, K);
|
||||
}
|
||||
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
@@ -1,2 +1,2 @@
|
||||
add_example_executable(example_convnd_bwd_data_xdl convnd_bwd_data_xdl.cpp)
|
||||
target_link_libraries(example_convnd_bwd_data_xdl PRIVATE conv_fwd_util)
|
||||
target_link_libraries(example_convnd_bwd_data_xdl PRIVATE conv_util)
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
#include <half.hpp>
|
||||
|
||||
#include "config.hpp"
|
||||
#include "conv_fwd_util.hpp"
|
||||
#include "conv_util.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
@@ -87,7 +87,7 @@ void print_use_msg()
|
||||
{
|
||||
std::cout << "arg1: verification (0=no, 1=yes)\n"
|
||||
<< "arg2: initialization (0=no init, 1=random value, 2= init to 1 )\n"
|
||||
<< "arg3: run kernel # of times (>1)\n"
|
||||
<< "arg3: time kernel (0=n0, 1=yes)\n"
|
||||
<< "arg4: N spatial dimensions (default 2)\n"
|
||||
<< "Following arguments (depending on number of spatial dims):\n"
|
||||
<< " N, K, C, \n"
|
||||
@@ -105,40 +105,40 @@ ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[])
|
||||
ck::utils::conv::ConvParams params;
|
||||
int arg_idx = 5;
|
||||
|
||||
params.num_dim_spatial = num_dim_spatial;
|
||||
params.N = std::stoi(argv[arg_idx++]);
|
||||
params.K = std::stoi(argv[arg_idx++]);
|
||||
params.C = std::stoi(argv[arg_idx++]);
|
||||
params.num_dim_spatial_ = num_dim_spatial;
|
||||
params.N_ = std::stoi(argv[arg_idx++]);
|
||||
params.K_ = std::stoi(argv[arg_idx++]);
|
||||
params.C_ = std::stoi(argv[arg_idx++]);
|
||||
|
||||
params.filter_spatial_lengths.resize(num_dim_spatial);
|
||||
params.filter_spatial_lengths_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
|
||||
params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_spatial_lengths.resize(num_dim_spatial);
|
||||
params.input_spatial_lengths_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
|
||||
params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.conv_filter_strides.resize(num_dim_spatial);
|
||||
params.conv_filter_strides_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.conv_filter_strides[i] = std::stoi(argv[arg_idx++]);
|
||||
params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.conv_filter_dilations.resize(num_dim_spatial);
|
||||
params.conv_filter_dilations_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.conv_filter_dilations[i] = std::stoi(argv[arg_idx++]);
|
||||
params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_left_pads.resize(num_dim_spatial);
|
||||
params.input_left_pads_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_left_pads[i] = std::stoi(argv[arg_idx++]);
|
||||
params.input_left_pads_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_right_pads.resize(num_dim_spatial);
|
||||
params.input_right_pads_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_right_pads[i] = std::stoi(argv[arg_idx++]);
|
||||
params.input_right_pads_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
return params;
|
||||
@@ -165,25 +165,25 @@ DeviceConvBwdDataBasePtr get_conv_instance(int num_dim_spatial)
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = 0;
|
||||
int init_method = 0;
|
||||
int nrepeat = 5;
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
int num_dim_spatial = 2;
|
||||
|
||||
ck::utils::conv::ConvParams params;
|
||||
params.C = 128;
|
||||
params.C_ = 128;
|
||||
|
||||
if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
nrepeat = std::stoi(argv[3]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc > 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
nrepeat = std::stoi(argv[3]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
num_dim_spatial = std::stoi(argv[4]);
|
||||
// check args number
|
||||
int conv_args = 3 + num_dim_spatial * 6;
|
||||
@@ -202,21 +202,21 @@ int main(int argc, char* argv[])
|
||||
exit(1);
|
||||
}
|
||||
|
||||
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N),
|
||||
static_cast<std::size_t>(params.C)};
|
||||
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N_),
|
||||
static_cast<std::size_t>(params.C_)};
|
||||
input_dims.insert(std::end(input_dims),
|
||||
std::begin(params.input_spatial_lengths),
|
||||
std::end(params.input_spatial_lengths));
|
||||
std::begin(params.input_spatial_lengths_),
|
||||
std::end(params.input_spatial_lengths_));
|
||||
|
||||
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K),
|
||||
static_cast<std::size_t>(params.C)};
|
||||
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K_),
|
||||
static_cast<std::size_t>(params.C_)};
|
||||
filter_dims.insert(std::end(filter_dims),
|
||||
std::begin(params.filter_spatial_lengths),
|
||||
std::end(params.filter_spatial_lengths));
|
||||
std::begin(params.filter_spatial_lengths_),
|
||||
std::end(params.filter_spatial_lengths_));
|
||||
|
||||
const std::vector<ck::index_t>& output_spatial_lengths = params.GetOutputSpatialLengths();
|
||||
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N),
|
||||
static_cast<std::size_t>(params.K)};
|
||||
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N_),
|
||||
static_cast<std::size_t>(params.K_)};
|
||||
output_dims.insert(std::end(output_dims),
|
||||
std::begin(output_spatial_lengths),
|
||||
std::end(output_spatial_lengths));
|
||||
@@ -263,16 +263,16 @@ int main(int argc, char* argv[])
|
||||
conv->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
params.N,
|
||||
params.K,
|
||||
params.C,
|
||||
params.input_spatial_lengths,
|
||||
params.filter_spatial_lengths,
|
||||
params.N_,
|
||||
params.K_,
|
||||
params.C_,
|
||||
params.input_spatial_lengths_,
|
||||
params.filter_spatial_lengths_,
|
||||
output_spatial_lengths,
|
||||
params.conv_filter_strides,
|
||||
params.conv_filter_dilations,
|
||||
params.input_left_pads,
|
||||
params.input_right_pads,
|
||||
params.conv_filter_strides_,
|
||||
params.conv_filter_dilations_,
|
||||
params.input_left_pads_,
|
||||
params.input_right_pads_,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{});
|
||||
@@ -284,16 +284,16 @@ int main(int argc, char* argv[])
|
||||
"not support this Conv problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker->Run(argument.get(), nrepeat);
|
||||
float ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = ck::utils::conv::get_flops(
|
||||
params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths);
|
||||
params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths);
|
||||
std::size_t num_btype = ck::utils::conv::get_btype<InDataType, WeiDataType, OutDataType>(
|
||||
params.N,
|
||||
params.C,
|
||||
params.K,
|
||||
params.input_spatial_lengths,
|
||||
params.filter_spatial_lengths,
|
||||
params.N_,
|
||||
params.C_,
|
||||
params.K_,
|
||||
params.input_spatial_lengths_,
|
||||
params.filter_spatial_lengths_,
|
||||
output_spatial_lengths);
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
@@ -310,10 +310,10 @@ int main(int argc, char* argv[])
|
||||
auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi_host_result,
|
||||
wei_k_c_y_x,
|
||||
out_n_k_ho_wo,
|
||||
params.conv_filter_strides,
|
||||
params.conv_filter_dilations,
|
||||
params.input_left_pads,
|
||||
params.input_right_pads,
|
||||
params.conv_filter_strides_,
|
||||
params.conv_filter_dilations_,
|
||||
params.input_left_pads_,
|
||||
params.input_right_pads_,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{});
|
||||
@@ -322,29 +322,30 @@ int main(int argc, char* argv[])
|
||||
|
||||
in_device_buf.FromDevice(in_n_c_hi_wi_device_result.mData.data());
|
||||
|
||||
check_error(in_n_c_hi_wi_host_result, in_n_c_hi_wi_device_result);
|
||||
return ck::utils::check_err(in_n_c_hi_wi_device_result.mData,
|
||||
in_n_c_hi_wi_host_result.mData)
|
||||
? 0
|
||||
: 1;
|
||||
};
|
||||
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 3: {
|
||||
auto ref_conv = ReferenceConvBwdDataInstance<3>();
|
||||
verify_f(ref_conv);
|
||||
break;
|
||||
return verify_f(ref_conv);
|
||||
}
|
||||
case 2: {
|
||||
auto ref_conv = ReferenceConvBwdDataInstance<2>();
|
||||
verify_f(ref_conv);
|
||||
break;
|
||||
return verify_f(ref_conv);
|
||||
}
|
||||
case 1: {
|
||||
auto ref_conv = ReferenceConvBwdDataInstance<1>();
|
||||
verify_f(ref_conv);
|
||||
break;
|
||||
return verify_f(ref_conv);
|
||||
}
|
||||
default: {
|
||||
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
@@ -24,10 +25,12 @@ using F32 = float;
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using CDataType = F16;
|
||||
using DDataType = F32;
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using CDataType = F16;
|
||||
using ReduceAccDataType = F32;
|
||||
using DDataType = F32;
|
||||
using DPtrsGlobal = ck::Tuple<DDataType*, DDataType*>;
|
||||
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
@@ -36,20 +39,29 @@ using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using D0ReduceOp = ck::reduce::Add<float>;
|
||||
using D1ReduceOp = ck::reduce::Add<float>;
|
||||
using D1ElementOp = ck::tensor_operation::element_wise::UnarySquare<float, float, false>;
|
||||
using D0ReduceOp = ck::reduce::Add;
|
||||
using D1ReduceOp = ck::reduce::Add;
|
||||
using DxsReduceOp = ck::Tuple<D0ReduceOp, D1ReduceOp>;
|
||||
|
||||
using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
|
||||
using DxsOutElementOps = ck::Tuple<UnaryIdenticElementOp, UnaryIdenticElementOp>;
|
||||
|
||||
using DGlobalMemOp =
|
||||
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
|
||||
ck::InMemoryDataOperationEnum::AtomicAdd>;
|
||||
|
||||
static constexpr auto GemmSpecialization =
|
||||
ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// clang-format off
|
||||
using DeviceBatchedGemmReduceInstance = ck::tensor_operation::device::DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0| D1| D1EleOp| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
|
||||
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, AElementOp, BElementOp, CElementOp, D0ReduceOp, D1ReduceOp, D1ElementOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
|
||||
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
|
||||
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOps, DxsOutElementOps, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceBatchedGemmInstance = ck::tensor_operation::host::
|
||||
@@ -57,18 +69,18 @@ using ReferenceBatchedGemmInstance = ck::tensor_operation::host::
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = 1;
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
int nrepeat = 5;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 3840;
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
ck::index_t M = 2048;
|
||||
ck::index_t N = 1920;
|
||||
ck::index_t K = 2048;
|
||||
|
||||
ck::index_t StrideA = 4096;
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideC = 4096;
|
||||
ck::index_t StrideA = 2048;
|
||||
ck::index_t StrideB = 2048;
|
||||
ck::index_t StrideC = 1920;
|
||||
|
||||
ck::index_t BatchCount = 4;
|
||||
|
||||
@@ -80,13 +92,13 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
nrepeat = std::stoi(argv[3]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 11)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
nrepeat = std::stoi(argv[3]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
@@ -96,13 +108,13 @@ int main(int argc, char* argv[])
|
||||
StrideB = std::stoi(argv[8]);
|
||||
StrideC = std::stoi(argv[9]);
|
||||
|
||||
BatchCount = std::stoi(argv[9]);
|
||||
BatchCount = std::stoi(argv[10]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: run kernel # of times (>1)\n");
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
printf("arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC, BatchCount\n");
|
||||
exit(0);
|
||||
}
|
||||
@@ -169,12 +181,11 @@ int main(int argc, char* argv[])
|
||||
a_device_buf.ToDevice(a_g_m_k.mData.data());
|
||||
b_device_buf.ToDevice(b_g_k_n.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
auto d0_reduce_op = D0ReduceOp{};
|
||||
auto d1_reduce_op = D1ReduceOp{};
|
||||
auto d1_element_op = D1ElementOp{};
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
auto dxs_global = ck::make_tuple(static_cast<DDataType*>(d0_device_buf.GetDeviceBuffer()),
|
||||
static_cast<DDataType*>(d1_device_buf.GetDeviceBuffer()));
|
||||
|
||||
// do GEMM
|
||||
auto batched_gemm = DeviceBatchedGemmReduceInstance{};
|
||||
@@ -183,8 +194,7 @@ int main(int argc, char* argv[])
|
||||
batched_gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<DDataType*>(d0_device_buf.GetDeviceBuffer()),
|
||||
static_cast<DDataType*>(d1_device_buf.GetDeviceBuffer()),
|
||||
dxs_global,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
@@ -194,7 +204,8 @@ int main(int argc, char* argv[])
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
d1_element_op,
|
||||
DxsInElementOps{},
|
||||
DxsOutElementOps{},
|
||||
BatchCount);
|
||||
|
||||
if(!batched_gemm.IsSupportedArgument(argument))
|
||||
@@ -204,30 +215,13 @@ int main(int argc, char* argv[])
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
// warm up
|
||||
invoker.Run(argument);
|
||||
// init DO, D1 to 0
|
||||
d0_device_buf.SetZero();
|
||||
d1_device_buf.SetZero();
|
||||
|
||||
// timing
|
||||
float total_time = 0;
|
||||
|
||||
for(int i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
// init DO, D1 to 0
|
||||
d0_device_buf.SetZero();
|
||||
d1_device_buf.SetZero();
|
||||
|
||||
KernelTimer timer;
|
||||
|
||||
timer.Start();
|
||||
|
||||
invoker.Run(argument);
|
||||
|
||||
timer.End();
|
||||
|
||||
total_time += timer.GetElapsedTime();
|
||||
}
|
||||
|
||||
float ave_time = total_time / nrepeat;
|
||||
// if time_kernel == true, kernel will run multiple times. This kernel use atomic-add so result
|
||||
// will not be correct. need to set time_kernel = false for correctness test
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * BatchCount * M * N * K;
|
||||
std::size_t num_btype = sizeof(ADataType) * BatchCount * M * K +
|
||||
@@ -241,6 +235,7 @@ int main(int argc, char* argv[])
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< batched_gemm.GetTypeString() << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
if(do_verification)
|
||||
{
|
||||
c_device_buf.FromDevice(c_g_m_n_device_result.mData.data());
|
||||
@@ -255,19 +250,25 @@ int main(int argc, char* argv[])
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
auto d0_reduce_op = D0ReduceOp{};
|
||||
auto d1_reduce_op = D1ReduceOp{};
|
||||
|
||||
for(int batch = 0; batch < BatchCount; ++batch)
|
||||
{
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
float d0_acc = d0_reduce_op.GetReductionZeroVal();
|
||||
float d1_acc = d1_reduce_op.GetReductionZeroVal();
|
||||
auto d0_acc = d0_reduce_op.GetIdentityValue<ReduceAccDataType>();
|
||||
auto d1_acc = d1_reduce_op.GetIdentityValue<ReduceAccDataType>();
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
float d0_val = ck::type_convert<float>(c_g_m_n_host_result(m, n));
|
||||
float d1_val;
|
||||
auto c_val =
|
||||
ck::type_convert<ReduceAccDataType>(c_g_m_n_host_result(batch, m, n));
|
||||
ReduceAccDataType d0_val;
|
||||
ReduceAccDataType d1_val;
|
||||
|
||||
d1_element_op(d1_val, d0_val);
|
||||
UnaryIdenticElementOp{}(d0_val, c_val);
|
||||
UnarySquareElementOp{}(d1_val, c_val);
|
||||
d0_reduce_op(d0_acc, d0_val);
|
||||
d1_reduce_op(d1_acc, d1_val);
|
||||
}
|
||||
@@ -277,10 +278,20 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
}
|
||||
|
||||
check_error(c_g_m_n_host_result, c_g_m_n_device_result);
|
||||
check_error(d0_g_m_host_result, d0_g_m_device_result);
|
||||
check_error(d1_g_m_host_result, d1_g_m_device_result);
|
||||
pass = ck::utils::check_err(c_g_m_n_host_result.mData,
|
||||
c_g_m_n_device_result.mData,
|
||||
"Error: Incorrect results c") &&
|
||||
ck::utils::check_err(d0_g_m_device_result.mData,
|
||||
d0_g_m_host_result.mData,
|
||||
"Error: Incorrect results! D0",
|
||||
1e-4,
|
||||
1e-5) &&
|
||||
ck::utils::check_err(d1_g_m_device_result.mData,
|
||||
d1_g_m_host_result.mData,
|
||||
"Error: Incorrect results! D1",
|
||||
1e-3,
|
||||
1e-5);
|
||||
}
|
||||
|
||||
return 0;
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
|
||||
4
example/19_binary_elementwise/CMakeLists.txt
Normal file
4
example/19_binary_elementwise/CMakeLists.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
add_example_executable(example_broadcast_add_2d_amn_bn broadcast_add_2d_amn_bn.cpp)
|
||||
add_example_executable(example_broadcast_add_3d_am_bmnk broadcast_add_3d_am_bmnk.cpp)
|
||||
add_example_executable(example_elementwise_add_1d elementwise_add_1d.cpp)
|
||||
add_example_executable(example_elementwise_add_4d elementwise_add_4d.cpp)
|
||||
164
example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp
Normal file
164
example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp
Normal file
@@ -0,0 +1,164 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2022 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#include <iostream>
|
||||
#include <cstdlib>
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
|
||||
#include "device_tensor.hpp"
|
||||
#include "binary_element_wise_operation.hpp"
|
||||
#include "device_binary_elementwise.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using ABDataType = F16;
|
||||
using CDataType = F16;
|
||||
using EltwiseComputeDataType = F32;
|
||||
|
||||
using Add = ck::tensor_operation::element_wise::Add;
|
||||
|
||||
using DeviceElementwiseAddInstance =
|
||||
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
|
||||
ABDataType,
|
||||
CDataType,
|
||||
EltwiseComputeDataType,
|
||||
Add,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
8,
|
||||
8>;
|
||||
|
||||
template <typename HostTensorA,
|
||||
typename HostTensorB,
|
||||
typename HostTensorC,
|
||||
typename ComputeDataType,
|
||||
typename Functor,
|
||||
int broadcastDim>
|
||||
void host_broadcast2D(
|
||||
HostTensorC& C, const HostTensorA& A, const HostTensorB& B, int M, int N, Functor functor)
|
||||
{
|
||||
using ctype = ck::remove_reference_t<decltype(C(0, 0))>;
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
ComputeDataType Amn = ck::type_convert<ComputeDataType>(A(m, n));
|
||||
ComputeDataType Cmn = 0;
|
||||
if constexpr(broadcastDim == 0)
|
||||
{
|
||||
ComputeDataType Bn = ck::type_convert<ComputeDataType>(B(n));
|
||||
functor(Cmn, Amn, Bn);
|
||||
}
|
||||
else
|
||||
{
|
||||
ComputeDataType Bm = ck::type_convert<ComputeDataType>(B(m));
|
||||
functor(Cmn, Amn, Bm);
|
||||
}
|
||||
C(m, n) = ck::type_convert<ctype>(Cmn);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
ck::index_t M = 1024;
|
||||
ck::index_t N = 1024;
|
||||
ck::index_t Stride = 1024;
|
||||
|
||||
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({len}),
|
||||
std::vector<std::size_t>({stride}));
|
||||
};
|
||||
|
||||
auto f_host_tensor_descriptor2d = [](std::size_t row, std::size_t col, std::size_t stride) {
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({stride, 1}));
|
||||
};
|
||||
|
||||
Tensor<ABDataType> a_m_n(f_host_tensor_descriptor2d(M, N, Stride));
|
||||
Tensor<ABDataType> b_n(f_host_tensor_descriptor1d(N, 1));
|
||||
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, Stride));
|
||||
|
||||
a_m_n.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
|
||||
b_n.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
|
||||
|
||||
DeviceMem a_m_n_device_buf(sizeof(ABDataType) * a_m_n.mDesc.GetElementSpace());
|
||||
DeviceMem b_n_device_buf(sizeof(ABDataType) * b_n.mDesc.GetElementSpace());
|
||||
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace());
|
||||
|
||||
a_m_n_device_buf.ToDevice(a_m_n.mData.data());
|
||||
b_n_device_buf.ToDevice(b_n.mData.data());
|
||||
|
||||
auto broadcastAdd = DeviceElementwiseAddInstance{};
|
||||
auto argument = broadcastAdd.MakeArgumentPointer(a_m_n_device_buf.GetDeviceBuffer(),
|
||||
b_n_device_buf.GetDeviceBuffer(),
|
||||
c_m_n_device_buf.GetDeviceBuffer(),
|
||||
{M, N},
|
||||
{Stride, 1},
|
||||
{0, 1}, // broadcast in first dimension
|
||||
{Stride, 1},
|
||||
Add{});
|
||||
|
||||
if(!broadcastAdd.IsSupportedArgument(argument.get()))
|
||||
{
|
||||
throw std::runtime_error("The runtime parameters seems not supported by the "
|
||||
"DeviceBinaryElementwise instance, exiting!");
|
||||
};
|
||||
|
||||
auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer();
|
||||
float ave_time =
|
||||
broadcastAdd_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms" << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
if(do_verification)
|
||||
{
|
||||
c_m_n_device_buf.FromDevice(c_m_n.mData.data());
|
||||
Tensor<CDataType> host_c_m_n(f_host_tensor_descriptor2d(M, N, Stride));
|
||||
|
||||
host_broadcast2D<Tensor<ABDataType>,
|
||||
Tensor<ABDataType>,
|
||||
Tensor<CDataType>,
|
||||
EltwiseComputeDataType,
|
||||
Add,
|
||||
0>(host_c_m_n, a_m_n, b_n, M, N, Add{});
|
||||
|
||||
pass &= ck::utils::check_err(
|
||||
c_m_n.mData, host_c_m_n.mData, "Error: Incorrect results c", 1e-3, 1e-3);
|
||||
}
|
||||
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
123
example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp
Normal file
123
example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp
Normal file
@@ -0,0 +1,123 @@
|
||||
#include <iostream>
|
||||
#include <cstdlib>
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
|
||||
#include "device_tensor.hpp"
|
||||
#include "binary_element_wise_operation.hpp"
|
||||
#include "device_binary_elementwise.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using ABDataType = F16;
|
||||
using CDataType = F16;
|
||||
using EltwiseComputeDataType = F32;
|
||||
|
||||
using Add = ck::tensor_operation::element_wise::Add;
|
||||
|
||||
using DeviceElementwiseAddInstance =
|
||||
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
|
||||
ABDataType,
|
||||
CDataType,
|
||||
EltwiseComputeDataType,
|
||||
Add,
|
||||
3,
|
||||
8,
|
||||
1,
|
||||
8,
|
||||
8>;
|
||||
|
||||
template <typename HostTensorA,
|
||||
typename HostTensorB,
|
||||
typename HostTensorC,
|
||||
typename ComputeDataType,
|
||||
typename Functor>
|
||||
void host_broadcast3D_am_bmnk(HostTensorC& C,
|
||||
const HostTensorA& A,
|
||||
const HostTensorB& B,
|
||||
const std::vector<std::size_t>& shape,
|
||||
Functor functor)
|
||||
{
|
||||
using ctype = ck::remove_reference_t<decltype(C(0, 0))>;
|
||||
|
||||
for(std::size_t m = 0; m < shape[0]; ++m)
|
||||
for(std::size_t n = 0; n < shape[1]; ++n)
|
||||
for(std::size_t k = 0; k < shape[2]; ++k)
|
||||
{
|
||||
ComputeDataType a_val = ck::type_convert<ComputeDataType>(A(m));
|
||||
ComputeDataType b_val = ck::type_convert<ComputeDataType>(B(m, n, k));
|
||||
ComputeDataType c_val = 0;
|
||||
functor(c_val, a_val, b_val);
|
||||
C(m, n, k) = ck::type_convert<ctype>(c_val);
|
||||
}
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
std::vector<std::size_t> mnk = {4, 16, 32};
|
||||
ck::index_t M = mnk[0];
|
||||
|
||||
Tensor<ABDataType> a_m({M});
|
||||
Tensor<ABDataType> b_m_n_k(mnk);
|
||||
Tensor<CDataType> c_m_n_k(mnk);
|
||||
|
||||
a_m.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
|
||||
b_m_n_k.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
|
||||
|
||||
DeviceMem a_m_device_buf(sizeof(ABDataType) * a_m.mDesc.GetElementSpace());
|
||||
DeviceMem b_m_n_k_device_buf(sizeof(ABDataType) * b_m_n_k.mDesc.GetElementSpace());
|
||||
DeviceMem c_m_n_k_device_buf(sizeof(CDataType) * c_m_n_k.mDesc.GetElementSpace());
|
||||
|
||||
a_m_device_buf.ToDevice(a_m.mData.data());
|
||||
b_m_n_k_device_buf.ToDevice(b_m_n_k.mData.data());
|
||||
|
||||
auto broadcastAdd = DeviceElementwiseAddInstance{};
|
||||
auto argument = broadcastAdd.MakeArgumentPointer(
|
||||
a_m_device_buf.GetDeviceBuffer(),
|
||||
b_m_n_k_device_buf.GetDeviceBuffer(),
|
||||
c_m_n_k_device_buf.GetDeviceBuffer(),
|
||||
std::vector<ck::index_t>{mnk.begin(), mnk.end()},
|
||||
{1, 0, 0}, // broadcast A on second and third dimension
|
||||
std::vector<ck::index_t>{b_m_n_k.mDesc.GetStrides().begin(),
|
||||
b_m_n_k.mDesc.GetStrides().end()},
|
||||
std::vector<ck::index_t>{c_m_n_k.mDesc.GetStrides().begin(),
|
||||
c_m_n_k.mDesc.GetStrides().end()},
|
||||
Add{});
|
||||
|
||||
if(!broadcastAdd.IsSupportedArgument(argument.get()))
|
||||
{
|
||||
throw std::runtime_error("The runtime parameters seems not supported by the "
|
||||
"DeviceBinaryElementwise instance, exiting!");
|
||||
};
|
||||
|
||||
auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer();
|
||||
float ave_time =
|
||||
broadcastAdd_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms" << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
if(do_verification)
|
||||
{
|
||||
c_m_n_k_device_buf.FromDevice(c_m_n_k.mData.data());
|
||||
Tensor<CDataType> host_c_m_n_k(mnk);
|
||||
|
||||
host_broadcast3D_am_bmnk<Tensor<ABDataType>,
|
||||
Tensor<ABDataType>,
|
||||
Tensor<CDataType>,
|
||||
EltwiseComputeDataType,
|
||||
Add>(host_c_m_n_k, a_m, b_m_n_k, mnk, Add{});
|
||||
|
||||
pass &= ck::utils::check_err(
|
||||
c_m_n_k.mData, host_c_m_n_k.mData, "Error: Incorrect results c", 1e-3, 1e-3);
|
||||
}
|
||||
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
144
example/19_binary_elementwise/elementwise_add_1d.cpp
Normal file
144
example/19_binary_elementwise/elementwise_add_1d.cpp
Normal file
@@ -0,0 +1,144 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2022 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#include <iostream>
|
||||
#include <cstdlib>
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
|
||||
#include "device_tensor.hpp"
|
||||
#include "binary_element_wise_operation.hpp"
|
||||
#include "device_binary_elementwise.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using ABDataType = F16;
|
||||
using CDataType = F16;
|
||||
using EltwiseComputeDataType = F32;
|
||||
|
||||
using Add = ck::tensor_operation::element_wise::Add;
|
||||
|
||||
using DeviceElementwiseAddInstance =
|
||||
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
|
||||
ABDataType,
|
||||
CDataType,
|
||||
EltwiseComputeDataType,
|
||||
Add,
|
||||
1,
|
||||
8,
|
||||
8,
|
||||
8,
|
||||
8>;
|
||||
|
||||
template <typename HostTensorA,
|
||||
typename HostTensorB,
|
||||
typename HostTensorC,
|
||||
typename ComputeDataType,
|
||||
typename Functor>
|
||||
void host_elementwise1D(
|
||||
HostTensorC& C, const HostTensorA& A, const HostTensorB& B, int M, Functor functor)
|
||||
{
|
||||
using ctype = ck::remove_reference_t<decltype(C(0))>;
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
ComputeDataType Am = ck::type_convert<ComputeDataType>(A(m));
|
||||
ComputeDataType Bm = ck::type_convert<ComputeDataType>(B(m));
|
||||
ComputeDataType Cm = 0;
|
||||
functor(Cm, Am, Bm);
|
||||
C(m) = ck::type_convert<ctype>(Cm);
|
||||
}
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
ck::index_t M = 1024;
|
||||
|
||||
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({len}),
|
||||
std::vector<std::size_t>({stride}));
|
||||
};
|
||||
|
||||
Tensor<ABDataType> a_m(f_host_tensor_descriptor1d(M, 1));
|
||||
Tensor<ABDataType> b_m(f_host_tensor_descriptor1d(M, 1));
|
||||
Tensor<CDataType> c_m(f_host_tensor_descriptor1d(M, 1));
|
||||
|
||||
a_m.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
|
||||
b_m.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
|
||||
|
||||
DeviceMem a_m_device_buf(sizeof(ABDataType) * a_m.mDesc.GetElementSpace());
|
||||
DeviceMem b_m_device_buf(sizeof(ABDataType) * b_m.mDesc.GetElementSpace());
|
||||
DeviceMem c_m_device_buf(sizeof(CDataType) * c_m.mDesc.GetElementSpace());
|
||||
|
||||
a_m_device_buf.ToDevice(a_m.mData.data());
|
||||
b_m_device_buf.ToDevice(b_m.mData.data());
|
||||
|
||||
auto broadcastAdd = DeviceElementwiseAddInstance{};
|
||||
auto argument = broadcastAdd.MakeArgumentPointer(a_m_device_buf.GetDeviceBuffer(),
|
||||
b_m_device_buf.GetDeviceBuffer(),
|
||||
c_m_device_buf.GetDeviceBuffer(),
|
||||
{M},
|
||||
{1},
|
||||
{1},
|
||||
{1},
|
||||
Add{});
|
||||
|
||||
if(!broadcastAdd.IsSupportedArgument(argument.get()))
|
||||
{
|
||||
throw std::runtime_error("The runtime parameters seems not supported by the "
|
||||
"DeviceBinaryElementwise instance, exiting!");
|
||||
};
|
||||
|
||||
auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer();
|
||||
float ave_time =
|
||||
broadcastAdd_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms" << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
if(do_verification)
|
||||
{
|
||||
c_m_device_buf.FromDevice(c_m.mData.data());
|
||||
Tensor<CDataType> host_c_m(f_host_tensor_descriptor1d(M, 1));
|
||||
|
||||
host_elementwise1D<Tensor<ABDataType>,
|
||||
Tensor<ABDataType>,
|
||||
Tensor<CDataType>,
|
||||
EltwiseComputeDataType,
|
||||
Add>(host_c_m, a_m, b_m, M, Add{});
|
||||
|
||||
pass &= ck::utils::check_err(
|
||||
c_m.mData, host_c_m.mData, "Error: Incorrect results c", 1e-3, 1e-3);
|
||||
}
|
||||
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
146
example/19_binary_elementwise/elementwise_add_4d.cpp
Normal file
146
example/19_binary_elementwise/elementwise_add_4d.cpp
Normal file
@@ -0,0 +1,146 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2020 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#include <iostream>
|
||||
#include <cstdlib>
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
|
||||
#include "device_tensor.hpp"
|
||||
#include "binary_element_wise_operation.hpp"
|
||||
#include "device_binary_elementwise.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using ABDataType = F16;
|
||||
using CDataType = F16;
|
||||
using EltwiseComputeDataType = F32;
|
||||
|
||||
using Add = ck::tensor_operation::element_wise::Add;
|
||||
|
||||
using DeviceElementwiseAddInstance =
|
||||
ck::tensor_operation::device::DeviceBinaryElementwise<ABDataType,
|
||||
ABDataType,
|
||||
CDataType,
|
||||
EltwiseComputeDataType,
|
||||
Add,
|
||||
4,
|
||||
8,
|
||||
8,
|
||||
8,
|
||||
8>;
|
||||
|
||||
template <typename HostTensorA,
|
||||
typename HostTensorB,
|
||||
typename HostTensorC,
|
||||
typename ComputeDataType,
|
||||
typename Functor>
|
||||
void host_elementwise4D(HostTensorC& C,
|
||||
const HostTensorA& A,
|
||||
const HostTensorB& B,
|
||||
const std::vector<std::size_t>& shape,
|
||||
Functor functor)
|
||||
{
|
||||
using ctype = ck::remove_reference_t<decltype(C(0, 0, 0, 0))>;
|
||||
|
||||
for(std::size_t n = 0; n < shape[0]; ++n)
|
||||
for(std::size_t c = 0; c < shape[1]; ++c)
|
||||
for(std::size_t h = 0; h < shape[2]; ++h)
|
||||
for(std::size_t w = 0; w < shape[3]; ++w)
|
||||
{
|
||||
ComputeDataType a_val = ck::type_convert<ComputeDataType>(A(n, c, h, w));
|
||||
ComputeDataType b_val = ck::type_convert<ComputeDataType>(B(n, c, h, w));
|
||||
ComputeDataType c_val = 0;
|
||||
functor(c_val, a_val, b_val);
|
||||
C(n, c, h, w) = ck::type_convert<ctype>(c_val);
|
||||
}
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
std::vector<std::size_t> nchw = {4, 16, 32, 32};
|
||||
|
||||
Tensor<ABDataType> a(nchw);
|
||||
Tensor<ABDataType> b(nchw);
|
||||
Tensor<CDataType> c(nchw);
|
||||
|
||||
a.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
|
||||
b.GenerateTensorValue(GeneratorTensor_3<ABDataType>{0.0, 1.0});
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ABDataType) * a.mDesc.GetElementSpace());
|
||||
DeviceMem b_device_buf(sizeof(ABDataType) * b.mDesc.GetElementSpace());
|
||||
DeviceMem c_device_buf(sizeof(CDataType) * c.mDesc.GetElementSpace());
|
||||
|
||||
a_device_buf.ToDevice(a.mData.data());
|
||||
b_device_buf.ToDevice(b.mData.data());
|
||||
|
||||
auto broadcastAdd = DeviceElementwiseAddInstance{};
|
||||
auto argument = broadcastAdd.MakeArgumentPointer(
|
||||
a_device_buf.GetDeviceBuffer(),
|
||||
b_device_buf.GetDeviceBuffer(),
|
||||
c_device_buf.GetDeviceBuffer(),
|
||||
std::vector<ck::index_t>{nchw.begin(), nchw.end()},
|
||||
std::vector<ck::index_t>{a.mDesc.GetStrides().begin(), a.mDesc.GetStrides().end()},
|
||||
std::vector<ck::index_t>{b.mDesc.GetStrides().begin(), b.mDesc.GetStrides().end()},
|
||||
std::vector<ck::index_t>{c.mDesc.GetStrides().begin(), c.mDesc.GetStrides().end()},
|
||||
Add{});
|
||||
|
||||
if(!broadcastAdd.IsSupportedArgument(argument.get()))
|
||||
{
|
||||
throw std::runtime_error("The runtime parameters seems not supported by the "
|
||||
"DeviceBinaryElementwise instance, exiting!");
|
||||
};
|
||||
|
||||
auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer();
|
||||
float ave_time =
|
||||
broadcastAdd_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms" << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
if(do_verification)
|
||||
{
|
||||
c_device_buf.FromDevice(c.mData.data());
|
||||
Tensor<CDataType> host_c(nchw);
|
||||
|
||||
host_elementwise4D<Tensor<ABDataType>,
|
||||
Tensor<ABDataType>,
|
||||
Tensor<CDataType>,
|
||||
EltwiseComputeDataType,
|
||||
Add>(host_c, a, b, nchw, Add{});
|
||||
|
||||
pass &=
|
||||
ck::utils::check_err(c.mData, host_c.mData, "Error: Incorrect results c", 1e-3, 1e-3);
|
||||
}
|
||||
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
4
example/20_convnd_bwd_weight_xdl/CMakeLists.txt
Normal file
4
example/20_convnd_bwd_weight_xdl/CMakeLists.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
add_example_executable(example_convnd_bwd_weight_xdl convnd_bwd_weight_xdl.cpp)
|
||||
add_example_executable(example_convnd_bwd_weight_xdl_bf16_splitk convnd_bwd_weight_xdl_bf16_splitk.cpp)
|
||||
target_link_libraries(example_convnd_bwd_weight_xdl PRIVATE conv_util)
|
||||
target_link_libraries(example_convnd_bwd_weight_xdl_bf16_splitk PRIVATE conv_util)
|
||||
385
example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp
Normal file
385
example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp
Normal file
@@ -0,0 +1,385 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "conv_util.hpp"
|
||||
#include "config.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "tensor_layout.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp"
|
||||
#include "reference_conv_backward_weight.hpp"
|
||||
|
||||
using InDataType = ck::half_t;
|
||||
using WeiDataType = ck::half_t;
|
||||
using OutDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto ConvBwdWeightDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default;
|
||||
|
||||
using DeviceConvBwdWeightBasePtr =
|
||||
ck::tensor_operation::device::DeviceConvBwdWeightPtr<InElementOp, WeiElementOp, OutElementOp>;
|
||||
|
||||
// clang-format off
|
||||
template <ck::index_t NumDimSpatial>
|
||||
using DeviceConvndBwdWeightInstance = ck::tensor_operation::device::
|
||||
DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
|
||||
InDataType, // InDataType
|
||||
WeiDataType, // WeiDataType
|
||||
OutDataType, // OutDataType
|
||||
AccDataType, // AccDataType
|
||||
InElementOp, // InElementwiseOperation
|
||||
WeiElementOp, // WeiElementwiseOperation
|
||||
OutElementOp, // OutElementwiseOperation
|
||||
ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization
|
||||
NumDimSpatial, // NumDimSpatial
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
4, // K0PerBlock
|
||||
8, // K1
|
||||
32, // MPerXdl
|
||||
32, // NPerXdl
|
||||
2, // MXdlPerWave
|
||||
2, // NXdlPerWave
|
||||
S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
2, // ABlockTransferDstScalarPerVector_K1
|
||||
true, // ABlockLdsAddExtraM
|
||||
S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1
|
||||
S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
2, // BBlockTransferDstScalarPerVector_K1
|
||||
true, // BBlockLdsAddExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8>; // CBlockTransferScalarPerVector_NWaveNPerXdl
|
||||
// clang-format on
|
||||
|
||||
template <ck::index_t NumDimSpatial>
|
||||
using ReferenceConvBwdWeightInstance =
|
||||
ck::tensor_operation::host::ReferenceConvBwdWeight<InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
NumDimSpatial>;
|
||||
|
||||
void print_use_msg()
|
||||
{
|
||||
std::cout << "arg1: verification (0=no, 1=yes)\n"
|
||||
<< "arg2: initialization (0=no init, 1=random value, 2= init to 1 )\n"
|
||||
<< "arg3: time kernel (0=n0, 1=yes)\n"
|
||||
<< "arg4: is show log (0=no, 1=yes)\n"
|
||||
<< "arg5: split-k \n"
|
||||
<< "arg6: N spatial dimensions (default 2)\n"
|
||||
<< "Following arguments (depending on number of spatial dims):\n"
|
||||
<< " N, K, C, \n"
|
||||
<< " <filter spatial dimensions>, (ie Y, X for 2D)\n"
|
||||
<< " <input image spatial dimensions>, (ie Hi, Wi for 2D)\n"
|
||||
<< " <strides>, (ie Sy, Sx for 2D)\n"
|
||||
<< " <dilations>, (ie Dy, Dx for 2D)\n"
|
||||
<< " <left padding>, (ie LeftPy, LeftPx for 2D)\n"
|
||||
<< " <right padding>, (ie RightPy, RightPx for 2D)\n"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[])
|
||||
{
|
||||
// (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right)
|
||||
ck::utils::conv::ConvParams params;
|
||||
int arg_idx = 7;
|
||||
|
||||
params.num_dim_spatial_ = num_dim_spatial;
|
||||
params.N_ = std::stoi(argv[arg_idx++]);
|
||||
params.K_ = std::stoi(argv[arg_idx++]);
|
||||
params.C_ = std::stoi(argv[arg_idx++]);
|
||||
|
||||
params.filter_spatial_lengths_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_spatial_lengths_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.conv_filter_strides_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.conv_filter_dilations_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_left_pads_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_left_pads_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_right_pads_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_right_pads_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
DeviceConvBwdWeightBasePtr get_conv_instance(int num_dim_spatial)
|
||||
{
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 3: {
|
||||
return std::make_unique<DeviceConvndBwdWeightInstance<3>>();
|
||||
}
|
||||
case 2: {
|
||||
return std::make_unique<DeviceConvndBwdWeightInstance<2>>();
|
||||
}
|
||||
case 1: {
|
||||
return std::make_unique<DeviceConvndBwdWeightInstance<1>>();
|
||||
}
|
||||
default: {
|
||||
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
int num_dim_spatial = 2;
|
||||
int do_log = 0;
|
||||
int split_k = 1;
|
||||
|
||||
ck::utils::conv::ConvParams params;
|
||||
params.C_ = 128;
|
||||
|
||||
if(argc == 6)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
do_log = std::stoi(argv[4]);
|
||||
split_k = std::stoi(argv[5]);
|
||||
}
|
||||
else if(argc > 6)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
do_log = std::stoi(argv[4]);
|
||||
split_k = std::stoi(argv[5]);
|
||||
num_dim_spatial = std::stoi(argv[6]);
|
||||
// check args number
|
||||
int conv_args = 3 + num_dim_spatial * 6;
|
||||
int cmdline_nargs = conv_args + 7;
|
||||
if(cmdline_nargs != argc)
|
||||
{
|
||||
print_use_msg();
|
||||
exit(1);
|
||||
}
|
||||
|
||||
params = parse_conv_params(num_dim_spatial, argv);
|
||||
}
|
||||
else if(argc != 1)
|
||||
{
|
||||
print_use_msg();
|
||||
exit(1);
|
||||
}
|
||||
|
||||
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N_),
|
||||
static_cast<std::size_t>(params.C_)};
|
||||
input_dims.insert(std::end(input_dims),
|
||||
std::begin(params.input_spatial_lengths_),
|
||||
std::end(params.input_spatial_lengths_));
|
||||
|
||||
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K_),
|
||||
static_cast<std::size_t>(params.C_)};
|
||||
filter_dims.insert(std::end(filter_dims),
|
||||
std::begin(params.filter_spatial_lengths_),
|
||||
std::end(params.filter_spatial_lengths_));
|
||||
|
||||
const std::vector<ck::index_t>& output_spatial_lengths = params.GetOutputSpatialLengths();
|
||||
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N_),
|
||||
static_cast<std::size_t>(params.K_)};
|
||||
output_dims.insert(std::end(output_dims),
|
||||
std::begin(output_spatial_lengths),
|
||||
std::end(output_spatial_lengths));
|
||||
|
||||
Tensor<InDataType> in_n_c_hi_wi(
|
||||
ck::utils::conv::get_input_host_tensor_descriptor(input_dims, num_dim_spatial));
|
||||
Tensor<WeiDataType> wei_k_c_y_x_host_result(
|
||||
ck::utils::conv::get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial));
|
||||
Tensor<WeiDataType> wei_k_c_y_x_device_result(
|
||||
ck::utils::conv::get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial));
|
||||
Tensor<OutDataType> out_n_k_ho_wo(
|
||||
ck::utils::conv::get_output_host_tensor_descriptor(output_dims, num_dim_spatial));
|
||||
|
||||
std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl;
|
||||
std::cout << "wei_k_c_y_x: " << wei_k_c_y_x_device_result.mDesc << std::endl;
|
||||
std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl;
|
||||
|
||||
std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl;
|
||||
std::cout << "wei_k_c_y_x: " << wei_k_c_y_x_host_result.mDesc << std::endl;
|
||||
std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-2, 2});
|
||||
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-2, 2});
|
||||
break;
|
||||
default:
|
||||
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
|
||||
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{1});
|
||||
}
|
||||
|
||||
DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem wei_device_buf(sizeof(WeiDataType) *
|
||||
wei_k_c_y_x_device_result.mDesc.GetElementSpace());
|
||||
DeviceMem out_device_buf(sizeof(OutDataType) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
||||
|
||||
in_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
|
||||
out_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
// reset input to zero
|
||||
wei_device_buf.SetZero();
|
||||
|
||||
// do GEMM
|
||||
auto conv = get_conv_instance(num_dim_spatial);
|
||||
auto invoker = conv->MakeInvokerPointer();
|
||||
auto argument =
|
||||
conv->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
params.N_,
|
||||
params.K_,
|
||||
params.C_,
|
||||
params.input_spatial_lengths_,
|
||||
params.filter_spatial_lengths_,
|
||||
output_spatial_lengths,
|
||||
params.conv_filter_strides_,
|
||||
params.conv_filter_dilations_,
|
||||
params.input_left_pads_,
|
||||
params.input_right_pads_,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{},
|
||||
split_k);
|
||||
|
||||
// alloc work space
|
||||
float ave_time = 0.f;
|
||||
if(!conv->IsSupportedArgument(argument.get()))
|
||||
{
|
||||
std::cout << "wrong! device_conv with the specified compilation parameters does "
|
||||
"not support this Conv problem"
|
||||
<< std::endl;
|
||||
return 1;
|
||||
}
|
||||
ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = ck::utils::conv::get_flops(
|
||||
params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths);
|
||||
std::size_t num_btype = ck::utils::conv::get_btype<InDataType, WeiDataType, OutDataType>(
|
||||
params.N_,
|
||||
params.C_,
|
||||
params.K_,
|
||||
params.input_spatial_lengths_,
|
||||
params.filter_spatial_lengths_,
|
||||
output_spatial_lengths);
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
auto verify_f = [&](const auto& ref_conv) {
|
||||
auto ref_invoker = ref_conv.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi,
|
||||
wei_k_c_y_x_host_result,
|
||||
out_n_k_ho_wo,
|
||||
params.conv_filter_strides_,
|
||||
params.conv_filter_dilations_,
|
||||
params.input_left_pads_,
|
||||
params.input_right_pads_,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
wei_device_buf.FromDevice(wei_k_c_y_x_device_result.mData.data());
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "out: ", out_n_k_ho_wo.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "in : ", in_n_c_hi_wi.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "wei_device(after): ", wei_k_c_y_x_device_result.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "wei_host : ", wei_k_c_y_x_host_result.mData, ",")
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
return ck::utils::check_err(wei_k_c_y_x_device_result.mData,
|
||||
wei_k_c_y_x_host_result.mData)
|
||||
? 0
|
||||
: 1;
|
||||
};
|
||||
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 3: {
|
||||
auto ref_conv = ReferenceConvBwdWeightInstance<3>();
|
||||
return verify_f(ref_conv);
|
||||
}
|
||||
case 2: {
|
||||
auto ref_conv = ReferenceConvBwdWeightInstance<2>();
|
||||
return verify_f(ref_conv);
|
||||
}
|
||||
case 1: {
|
||||
auto ref_conv = ReferenceConvBwdWeightInstance<1>();
|
||||
return verify_f(ref_conv);
|
||||
}
|
||||
default: {
|
||||
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
@@ -0,0 +1,427 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "conv_util.hpp"
|
||||
#include "config.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "tensor_layout.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "device_unary_elementwise.hpp"
|
||||
#include "device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp"
|
||||
#include "reference_conv_backward_weight.hpp"
|
||||
|
||||
using InDataType = ck::bhalf_t;
|
||||
using WeiDataType = ck::bhalf_t;
|
||||
using OutDataType = ck::bhalf_t;
|
||||
using AccDataType = float;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using UnaryTypeConvert = ck::tensor_operation::element_wise::UnaryTypeConvert<ck::bhalf_t, float>;
|
||||
|
||||
using DeviceUnaryElementwiseTypeConvertInstance = ck::tensor_operation::device::
|
||||
DeviceUnaryElementwise<AccDataType, WeiDataType, UnaryTypeConvert, 1, 4>;
|
||||
|
||||
static constexpr auto ConvBwdWeightDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default;
|
||||
|
||||
using DeviceConvBwdWeightBasePtr =
|
||||
ck::tensor_operation::device::DeviceConvBwdWeightPtr<InElementOp, WeiElementOp, OutElementOp>;
|
||||
|
||||
// clang-format off
|
||||
template <ck::index_t NumDimSpatial>
|
||||
using DeviceConvndBwdWeightInstance_bf16_splitk = ck::tensor_operation::device::
|
||||
DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
|
||||
InDataType, // InDataType
|
||||
AccDataType, // WeiDataType
|
||||
OutDataType, // OutDataType
|
||||
AccDataType, // AccDataType
|
||||
InElementOp, // InElementwiseOperation
|
||||
WeiElementOp, // WeiElementwiseOperation
|
||||
OutElementOp, // OutElementwiseOperation
|
||||
ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization
|
||||
NumDimSpatial, // NumDimSpatial
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
4, // K0PerBlock
|
||||
8, // K1
|
||||
32, // MPerXdl
|
||||
32, // NPerXdl
|
||||
2, // MXdlPerWave
|
||||
2, // NXdlPerWave
|
||||
S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
2, // ABlockTransferDstScalarPerVector_K1
|
||||
true, // ABlockLdsAddExtraM
|
||||
S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1
|
||||
S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
2, // BBlockTransferDstScalarPerVector_K1
|
||||
true, // BBlockLdsAddExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
4>; // CBlockTransferScalarPerVector_NWaveNPerXdl
|
||||
// clang-format on
|
||||
|
||||
template <ck::index_t NumDimSpatial>
|
||||
using ReferenceConvBwdWeightInstance =
|
||||
ck::tensor_operation::host::ReferenceConvBwdWeight<InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
NumDimSpatial>;
|
||||
|
||||
template <typename HostTensorB, typename HostTensorA, typename Functor>
|
||||
void host_elementwise(HostTensorB& B,
|
||||
const HostTensorA& A,
|
||||
const std::vector<std::size_t>& shape,
|
||||
Functor functor)
|
||||
{
|
||||
size_t tensor_size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>{});
|
||||
std::cout << __LINE__ << ":" << tensor_size << ", " << A.mData[0] << std::endl;
|
||||
for(std::size_t n = 0; n < tensor_size; ++n)
|
||||
{
|
||||
B.mData[n] = functor(A.mData[n]);
|
||||
}
|
||||
}
|
||||
|
||||
void print_use_msg()
|
||||
{
|
||||
std::cout << "arg1: verification (0=no, 1=yes)\n"
|
||||
<< "arg2: initialization (0=no init, 1=random value, 2= init to 1 )\n"
|
||||
<< "arg3: time kernel (0=n0, 1=yes)\n"
|
||||
<< "arg4: is show log (0=no, 1=yes)\n"
|
||||
<< "arg5: split-k : in this example split-k must be larger than 1\n"
|
||||
<< "arg6: N spatial dimensions (default 2)\n"
|
||||
<< "Following arguments (depending on number of spatial dims):\n"
|
||||
<< " N, K, C, \n"
|
||||
<< " <filter spatial dimensions>, (ie Y, X for 2D)\n"
|
||||
<< " <input image spatial dimensions>, (ie Hi, Wi for 2D)\n"
|
||||
<< " <strides>, (ie Sy, Sx for 2D)\n"
|
||||
<< " <dilations>, (ie Dy, Dx for 2D)\n"
|
||||
<< " <left padding>, (ie LeftPy, LeftPx for 2D)\n"
|
||||
<< " <right padding>, (ie RightPy, RightPx for 2D)\n"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[])
|
||||
{
|
||||
// (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right)
|
||||
ck::utils::conv::ConvParams params;
|
||||
int arg_idx = 7;
|
||||
|
||||
params.num_dim_spatial_ = num_dim_spatial;
|
||||
params.N_ = std::stoi(argv[arg_idx++]);
|
||||
params.K_ = std::stoi(argv[arg_idx++]);
|
||||
params.C_ = std::stoi(argv[arg_idx++]);
|
||||
|
||||
params.filter_spatial_lengths_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_spatial_lengths_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.conv_filter_strides_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.conv_filter_dilations_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_left_pads_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_left_pads_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
params.input_right_pads_.resize(num_dim_spatial);
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
params.input_right_pads_[i] = std::stoi(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
DeviceConvBwdWeightBasePtr get_conv_instance(int num_dim_spatial)
|
||||
{
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 3: {
|
||||
return std::make_unique<DeviceConvndBwdWeightInstance_bf16_splitk<3>>();
|
||||
}
|
||||
case 2: {
|
||||
return std::make_unique<DeviceConvndBwdWeightInstance_bf16_splitk<2>>();
|
||||
}
|
||||
case 1: {
|
||||
return std::make_unique<DeviceConvndBwdWeightInstance_bf16_splitk<1>>();
|
||||
}
|
||||
default: {
|
||||
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
int num_dim_spatial = 2;
|
||||
int do_log = 0;
|
||||
int split_k = 2;
|
||||
|
||||
ck::utils::conv::ConvParams params;
|
||||
params.C_ = 128;
|
||||
|
||||
if(argc == 6)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
do_log = std::stoi(argv[4]);
|
||||
split_k = std::stoi(argv[5]);
|
||||
}
|
||||
else if(argc > 6)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
do_log = std::stoi(argv[4]);
|
||||
split_k = std::stoi(argv[5]);
|
||||
num_dim_spatial = std::stoi(argv[6]);
|
||||
// check args number
|
||||
int conv_args = 3 + num_dim_spatial * 6;
|
||||
int cmdline_nargs = conv_args + 7;
|
||||
if(cmdline_nargs != argc)
|
||||
{
|
||||
print_use_msg();
|
||||
exit(1);
|
||||
}
|
||||
|
||||
params = parse_conv_params(num_dim_spatial, argv);
|
||||
}
|
||||
else if(argc != 1)
|
||||
{
|
||||
print_use_msg();
|
||||
exit(1);
|
||||
}
|
||||
|
||||
if(split_k <= 1)
|
||||
{
|
||||
print_use_msg();
|
||||
exit(1);
|
||||
}
|
||||
|
||||
std::vector<std::size_t> input_dims{static_cast<std::size_t>(params.N_),
|
||||
static_cast<std::size_t>(params.C_)};
|
||||
input_dims.insert(std::end(input_dims),
|
||||
std::begin(params.input_spatial_lengths_),
|
||||
std::end(params.input_spatial_lengths_));
|
||||
|
||||
std::vector<std::size_t> filter_dims{static_cast<std::size_t>(params.K_),
|
||||
static_cast<std::size_t>(params.C_)};
|
||||
filter_dims.insert(std::end(filter_dims),
|
||||
std::begin(params.filter_spatial_lengths_),
|
||||
std::end(params.filter_spatial_lengths_));
|
||||
|
||||
const std::vector<ck::index_t>& output_spatial_lengths = params.GetOutputSpatialLengths();
|
||||
std::vector<std::size_t> output_dims{static_cast<std::size_t>(params.N_),
|
||||
static_cast<std::size_t>(params.K_)};
|
||||
output_dims.insert(std::end(output_dims),
|
||||
std::begin(output_spatial_lengths),
|
||||
std::end(output_spatial_lengths));
|
||||
|
||||
Tensor<InDataType> in_n_c_hi_wi(
|
||||
ck::utils::conv::get_input_host_tensor_descriptor(input_dims, num_dim_spatial));
|
||||
Tensor<WeiDataType> wei_k_c_y_x_host_result(
|
||||
ck::utils::conv::get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial));
|
||||
Tensor<WeiDataType> wei_k_c_y_x_device_result(
|
||||
ck::utils::conv::get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial));
|
||||
Tensor<OutDataType> out_n_k_ho_wo(
|
||||
ck::utils::conv::get_output_host_tensor_descriptor(output_dims, num_dim_spatial));
|
||||
|
||||
std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl;
|
||||
std::cout << "wei_k_c_y_x: " << wei_k_c_y_x_device_result.mDesc << std::endl;
|
||||
std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl;
|
||||
|
||||
std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl;
|
||||
std::cout << "wei_k_c_y_x: " << wei_k_c_y_x_host_result.mDesc << std::endl;
|
||||
std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-2, 2});
|
||||
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2<InDataType>{-2, 2});
|
||||
break;
|
||||
default:
|
||||
out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1});
|
||||
in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1<InDataType>{1});
|
||||
}
|
||||
|
||||
DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem wei_device_buf(sizeof(WeiDataType) *
|
||||
wei_k_c_y_x_device_result.mDesc.GetElementSpace());
|
||||
DeviceMem out_device_buf(sizeof(OutDataType) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
||||
|
||||
in_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
|
||||
out_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
// reset input to zero
|
||||
wei_device_buf.SetZero();
|
||||
|
||||
// do GEMM
|
||||
auto conv = get_conv_instance(num_dim_spatial);
|
||||
auto invoker = conv->MakeInvokerPointer();
|
||||
auto argument =
|
||||
conv->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
params.N_,
|
||||
params.K_,
|
||||
params.C_,
|
||||
params.input_spatial_lengths_,
|
||||
params.filter_spatial_lengths_,
|
||||
output_spatial_lengths,
|
||||
params.conv_filter_strides_,
|
||||
params.conv_filter_dilations_,
|
||||
params.input_left_pads_,
|
||||
params.input_right_pads_,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{},
|
||||
split_k);
|
||||
|
||||
// alloc work space
|
||||
size_t bwd_weight_workspace_size = conv->GetWorkSpaceSize(argument.get());
|
||||
if(bwd_weight_workspace_size <= 0)
|
||||
{
|
||||
print_use_msg();
|
||||
exit(1);
|
||||
}
|
||||
|
||||
float conv_ave_time = 0.f;
|
||||
|
||||
DeviceMem wei_work_space_device_buf(bwd_weight_workspace_size);
|
||||
wei_work_space_device_buf.SetZero();
|
||||
conv->SetWorkSpacePointer(argument.get(), wei_work_space_device_buf.GetDeviceBuffer());
|
||||
|
||||
if(!conv->IsSupportedArgument(argument.get()))
|
||||
{
|
||||
std::cout << "wrong! device_conv with the specified compilation parameters does "
|
||||
"not support this Conv problem"
|
||||
<< std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
conv_ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = ck::utils::conv::get_flops(
|
||||
params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths);
|
||||
std::size_t num_btype = ck::utils::conv::get_btype<InDataType, WeiDataType, OutDataType>(
|
||||
params.N_,
|
||||
params.C_,
|
||||
params.K_,
|
||||
params.input_spatial_lengths_,
|
||||
params.filter_spatial_lengths_,
|
||||
output_spatial_lengths);
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / conv_ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / conv_ave_time;
|
||||
|
||||
std::cout << "Perf: conv: " << conv_ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s" << std::endl;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
auto verify_f = [&](const auto& ref_conv) {
|
||||
auto ref_invoker = ref_conv.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi,
|
||||
wei_k_c_y_x_host_result,
|
||||
out_n_k_ho_wo,
|
||||
params.conv_filter_strides_,
|
||||
params.conv_filter_dilations_,
|
||||
params.input_left_pads_,
|
||||
params.input_right_pads_,
|
||||
InElementOp{},
|
||||
WeiElementOp{},
|
||||
OutElementOp{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
wei_device_buf.FromDevice(wei_k_c_y_x_device_result.mData.data());
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "out: ", out_n_k_ho_wo.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "in : ", in_n_c_hi_wi.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "wei_device(after): ", wei_k_c_y_x_device_result.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "wei_host : ", wei_k_c_y_x_host_result.mData, ",")
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
return ck::utils::check_err(wei_k_c_y_x_device_result.mData,
|
||||
wei_k_c_y_x_host_result.mData)
|
||||
? 0
|
||||
: 1;
|
||||
};
|
||||
|
||||
switch(num_dim_spatial)
|
||||
{
|
||||
case 3: {
|
||||
auto ref_conv = ReferenceConvBwdWeightInstance<3>();
|
||||
verify_f(ref_conv);
|
||||
break;
|
||||
}
|
||||
case 2: {
|
||||
auto ref_conv = ReferenceConvBwdWeightInstance<2>();
|
||||
verify_f(ref_conv);
|
||||
break;
|
||||
}
|
||||
case 1: {
|
||||
auto ref_conv = ReferenceConvBwdWeightInstance<1>();
|
||||
verify_f(ref_conv);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
2
example/21_gemm_layernorm/CMakeLists.txt
Normal file
2
example/21_gemm_layernorm/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_fp16 gemm_bias_relu_add_layernorm_xdl_fp16.cpp)
|
||||
add_example_executable(example_gemm_layernorm_xdl_fp16 gemm_layernorm_xdl_fp16.cpp)
|
||||
@@ -0,0 +1,425 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_5ary_elementwise.hpp"
|
||||
#include "device_gemm_bias_add_reduce_xdl_cshuffle.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "reference_gemm.hpp"
|
||||
#include "gemm_specialization.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using CDataType = F16;
|
||||
using C0DataType = F32;
|
||||
using C1DataType = F16;
|
||||
using GemmAccDataType = F32;
|
||||
using ReduceAccDataType = F32;
|
||||
using DDataType = F32;
|
||||
using DPtrsGlobal = ck::Tuple<DDataType*, DDataType*>;
|
||||
using GammaDataType = F16;
|
||||
using BetaDataType = F16;
|
||||
using LayerNormOutDataType = F16;
|
||||
using NormalizeComputeDataType = F32;
|
||||
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::Relu;
|
||||
using C1ElementOp = PassThrough;
|
||||
using ReduceSumOp = ck::reduce::Add;
|
||||
using DxsReduceOp = ck::Tuple<ReduceSumOp, ReduceSumOp>;
|
||||
|
||||
using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide;
|
||||
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
|
||||
using DxsOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
|
||||
|
||||
using DxsGlobalMemOp =
|
||||
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
|
||||
ck::InMemoryDataOperationEnum::AtomicAdd>;
|
||||
|
||||
static constexpr auto GemmSpecialization =
|
||||
ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmBiasAddReduceInstance = ck::tensor_operation::device::DeviceGemmBiasAddReduce_Xdl_CShuffle
|
||||
//######| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| C1| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
|
||||
//######| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
|
||||
//######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< Row, Col, Row, F16, F16, F16, F32, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, C1ElementOp, DxsReduceOp, DxsInElementOps, DxsOutElementOps, DxsGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
GemmAccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
PassThrough>;
|
||||
|
||||
using NormalizeFunctor = ck::tensor_operation::element_wise::Normalize;
|
||||
|
||||
// A:x, B:E[x], C:E[x^2], D:Gamma, E:Beta , F:y
|
||||
using DeviceNormalizeInstance =
|
||||
ck::tensor_operation::device::Device5AryElementwise<CDataType,
|
||||
DDataType,
|
||||
DDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
LayerNormOutDataType,
|
||||
NormalizeComputeDataType,
|
||||
NormalizeFunctor,
|
||||
2,
|
||||
8,
|
||||
8, // scalarPerVector: gemm_out
|
||||
1, // scalarPerVector: reduce_mean
|
||||
1, // scalarPerVector: reduce_mean_square
|
||||
8, // scalarPerVector: Gamma
|
||||
8, // scalarPerVector: Beta
|
||||
8>; // scalarPerVector: LayerNorm_out
|
||||
|
||||
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({len}),
|
||||
std::vector<std::size_t>({stride}));
|
||||
};
|
||||
|
||||
auto f_host_tensor_descriptor2d =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({stride, 1}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({1, stride}));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename CDataType,
|
||||
typename DDataType,
|
||||
typename AccDataType,
|
||||
typename C0DataType,
|
||||
typename C1DataType,
|
||||
typename A_functor,
|
||||
typename B_functor,
|
||||
typename C_functor,
|
||||
typename C1_functor>
|
||||
void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
|
||||
const Tensor<ADataType>& a_m_k,
|
||||
const Tensor<ADataType>& b_k_n,
|
||||
const Tensor<C0DataType>& bias_n,
|
||||
const Tensor<C1DataType>& c1_m_n,
|
||||
const Tensor<GammaDataType>& gamma_n,
|
||||
const Tensor<GammaDataType>& beta_n,
|
||||
A_functor a_element_op,
|
||||
B_functor b_element_op,
|
||||
C_functor c_element_op,
|
||||
C1_functor c1_element_op,
|
||||
int M,
|
||||
int N)
|
||||
{
|
||||
|
||||
int StrideC = N;
|
||||
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
|
||||
Tensor<DDataType> mean_m(f_host_tensor_descriptor1d(M, 1));
|
||||
Tensor<DDataType> meanSquare_m(f_host_tensor_descriptor1d(M, 1));
|
||||
auto averageOpInst = UnaryDivElementOp{N};
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument =
|
||||
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
// c = activation(c + bias) + c1_functor(c1)
|
||||
for(int m = 0; m < M; ++m)
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
AccDataType acc =
|
||||
static_cast<AccDataType>(c_m_n(m, n)) + static_cast<AccDataType>(bias_n(n));
|
||||
|
||||
AccDataType c1 = static_cast<AccDataType>(c1_m_n(m, n));
|
||||
|
||||
c_element_op(acc, acc);
|
||||
c1_element_op(c1, c1);
|
||||
acc += c1;
|
||||
c_m_n(m, n) = static_cast<CDataType>(acc);
|
||||
}
|
||||
|
||||
// reduce_mean and reduce_square_mean
|
||||
auto reduceSumOpInst = ReduceSumOp{};
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
auto mean_acc = reduceSumOpInst.GetIdentityValue<AccDataType>();
|
||||
auto square_mean_acc = reduceSumOpInst.GetIdentityValue<AccDataType>();
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
AccDataType c_val = ck::type_convert<AccDataType>(c_m_n(m, n));
|
||||
AccDataType square_c_val = 0;
|
||||
UnarySquareElementOp{}(square_c_val, c_val);
|
||||
|
||||
reduceSumOpInst(mean_acc, c_val);
|
||||
reduceSumOpInst(square_mean_acc, square_c_val);
|
||||
}
|
||||
|
||||
averageOpInst(mean_acc, mean_acc);
|
||||
averageOpInst(square_mean_acc, square_mean_acc);
|
||||
mean_m(m) = ck::type_convert<DDataType>(mean_acc);
|
||||
meanSquare_m(m) = ck::type_convert<DDataType>(square_mean_acc);
|
||||
}
|
||||
|
||||
// LayerNorm
|
||||
auto layerNormInst = NormalizeFunctor{};
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
AccDataType out_acc = 0;
|
||||
layerNormInst(out_acc,
|
||||
static_cast<AccDataType>(c_m_n(m, n)),
|
||||
static_cast<AccDataType>(mean_m(m)),
|
||||
static_cast<AccDataType>(meanSquare_m(m)),
|
||||
static_cast<AccDataType>(gamma_n(n)),
|
||||
static_cast<AccDataType>(beta_n(n)));
|
||||
out_m_n(m, n) = static_cast<DDataType>(out_acc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename C0DataType,
|
||||
typename C1DataType,
|
||||
typename DDataType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename NormalizeDataType>
|
||||
void DumpGemmLayerNormPerf(float gemm_reduce_time, float normalize_time, int M, int N, int K)
|
||||
{
|
||||
std::size_t gemm_flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N;
|
||||
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
|
||||
sizeof(CDataType) * M * N + sizeof(C0DataType) * M * N +
|
||||
sizeof(C1DataType) * M * N + sizeof(DDataType) * M +
|
||||
sizeof(DDataType) * M;
|
||||
|
||||
std::size_t normalize_num_byte = sizeof(CDataType) * M * N + sizeof(DDataType) * M +
|
||||
sizeof(DDataType) * M + sizeof(GammaDataType) * N +
|
||||
sizeof(BetaDataType) * N + sizeof(NormalizeDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time;
|
||||
float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time;
|
||||
float normalize_gb_per_sec = normalize_num_byte / 1.E6 / normalize_time;
|
||||
|
||||
std::cout << "gemm + reduce_mean + reduce_square_mean Perf: " << gemm_reduce_time << " ms, "
|
||||
<< tflops << " TFlops, " << gemm_gb_per_sec << " GB/s, " << std::endl;
|
||||
|
||||
std::cout << "5-ary elementwise Perf: " << normalize_time << " ms, " << normalize_gb_per_sec
|
||||
<< " GB/s, " << std::endl;
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
// GEMM shape
|
||||
ck::index_t M = 1024;
|
||||
ck::index_t N = 1024;
|
||||
ck::index_t K = 1024;
|
||||
|
||||
ck::index_t StrideA = 1024;
|
||||
ck::index_t StrideB = 1024;
|
||||
ck::index_t StrideC = 1024;
|
||||
ck::index_t StrideC1 = 1024;
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{}));
|
||||
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
|
||||
Tensor<C0DataType> bias_n(f_host_tensor_descriptor1d(N, 1));
|
||||
Tensor<C1DataType> c1_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
|
||||
Tensor<DDataType> reduceMean_m(f_host_tensor_descriptor1d(M, 1));
|
||||
Tensor<DDataType> reduceMeanSquare_m(f_host_tensor_descriptor1d(M, 1));
|
||||
Tensor<GammaDataType> gamma_n(f_host_tensor_descriptor1d(N, 1));
|
||||
Tensor<BetaDataType> beta_n(f_host_tensor_descriptor1d(N, 1));
|
||||
Tensor<LayerNormOutDataType> layerNorm_m_n(
|
||||
f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
|
||||
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1, 1});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 1});
|
||||
bias_n.GenerateTensorValue(GeneratorTensor_3<C0DataType>{-1, 1});
|
||||
c1_m_n.GenerateTensorValue(GeneratorTensor_3<C1DataType>{-5, 5});
|
||||
gamma_n.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{-1, 1});
|
||||
beta_n.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{-1, 1});
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
|
||||
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace());
|
||||
DeviceMem bias_device_buf(sizeof(C0DataType) * bias_n.mDesc.GetElementSpace());
|
||||
DeviceMem c1_device_buf(sizeof(C1DataType) * c1_m_n.mDesc.GetElementSpace());
|
||||
DeviceMem reduceMean_device_buf(sizeof(DDataType) * reduceMean_m.mDesc.GetElementSpace());
|
||||
DeviceMem reduceMeanSquare_device_buf(sizeof(DDataType) *
|
||||
reduceMeanSquare_m.mDesc.GetElementSpace());
|
||||
DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpace());
|
||||
DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpace());
|
||||
DeviceMem layerNorm_device_buf(sizeof(LayerNormOutDataType) *
|
||||
layerNorm_m_n.mDesc.GetElementSpace());
|
||||
|
||||
a_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_device_buf.ToDevice(b_k_n.mData.data());
|
||||
bias_device_buf.ToDevice(bias_n.mData.data());
|
||||
c1_device_buf.ToDevice(c1_m_n.mData.data());
|
||||
gamma_device_buf.ToDevice(gamma_n.mData.data());
|
||||
beta_device_buf.ToDevice(beta_n.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
auto c1_element_op = C1ElementOp{};
|
||||
auto dxs_global =
|
||||
ck::make_tuple(static_cast<DDataType*>(reduceMean_device_buf.GetDeviceBuffer()),
|
||||
static_cast<DDataType*>(reduceMeanSquare_device_buf.GetDeviceBuffer()));
|
||||
|
||||
auto dxs_in_element_op = DxsInElementOps{};
|
||||
auto dxs_out_element_op = DxsOutElementOps{N, N};
|
||||
|
||||
// Prepare GEMM, reduce_mean, reduce_mean_square
|
||||
auto gemmReduce = DeviceGemmBiasAddReduceInstance{};
|
||||
auto gemmReduce_invoker = gemmReduce.MakeInvoker();
|
||||
auto gemmReduce_argument =
|
||||
gemmReduce.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<C0DataType*>(bias_device_buf.GetDeviceBuffer()),
|
||||
static_cast<C1DataType*>(c1_device_buf.GetDeviceBuffer()),
|
||||
dxs_global,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
StrideC1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
c1_element_op,
|
||||
dxs_in_element_op,
|
||||
dxs_out_element_op);
|
||||
|
||||
if(!gemmReduce.IsSupportedArgument(gemmReduce_argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
reduceMean_device_buf.SetZero();
|
||||
reduceMeanSquare_device_buf.SetZero();
|
||||
|
||||
// Prepare LayerNorm
|
||||
auto normalize = DeviceNormalizeInstance{};
|
||||
auto normalize_invoker = normalize.MakeInvoker();
|
||||
auto normalize_argument = normalize.MakeArgument(
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<DDataType*>(reduceMean_device_buf.GetDeviceBuffer()),
|
||||
static_cast<DDataType*>(reduceMeanSquare_device_buf.GetDeviceBuffer()),
|
||||
static_cast<GammaDataType*>(gamma_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BetaDataType*>(beta_device_buf.GetDeviceBuffer()),
|
||||
static_cast<LayerNormOutDataType*>(layerNorm_device_buf.GetDeviceBuffer()),
|
||||
{M, N},
|
||||
{StrideC, 1},
|
||||
{1, 0},
|
||||
{1, 0},
|
||||
{0, 1},
|
||||
{0, 1},
|
||||
{StrideC, 1},
|
||||
NormalizeFunctor{});
|
||||
|
||||
if(!normalize.IsSupportedArgument(normalize_argument))
|
||||
{
|
||||
throw std::runtime_error("The runtime parameters seems not supported by the "
|
||||
"Device5AryElementwise instance, exiting!");
|
||||
}
|
||||
|
||||
// run kernel
|
||||
gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, false});
|
||||
normalize_invoker.Run(normalize_argument, StreamConfig{nullptr, false});
|
||||
|
||||
bool pass = true;
|
||||
{
|
||||
// verification
|
||||
Tensor<LayerNormOutDataType> host_layerNorm_m_n(
|
||||
f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
|
||||
|
||||
host_gemm_layernorm<CDataType, DDataType, ReduceAccDataType>(host_layerNorm_m_n,
|
||||
a_m_k,
|
||||
b_k_n,
|
||||
bias_n,
|
||||
c1_m_n,
|
||||
gamma_n,
|
||||
beta_n,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
c1_element_op,
|
||||
M,
|
||||
N);
|
||||
|
||||
layerNorm_device_buf.FromDevice(layerNorm_m_n.mData.data());
|
||||
pass &= ck::utils::check_err(layerNorm_m_n.mData,
|
||||
host_layerNorm_m_n.mData,
|
||||
"Error: Incorrect results layerNorm_m_n",
|
||||
1e-2,
|
||||
1e-2);
|
||||
}
|
||||
|
||||
{
|
||||
// evaluate kernel perf
|
||||
bool time_kernel = true;
|
||||
|
||||
float gemm_reduce_mean_reduce_square_mean_ave_time =
|
||||
gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, time_kernel});
|
||||
float normalize_ave_time =
|
||||
normalize_invoker.Run(normalize_argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
if(time_kernel)
|
||||
DumpGemmLayerNormPerf<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
C0DataType,
|
||||
C1DataType,
|
||||
DDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
LayerNormOutDataType>(
|
||||
gemm_reduce_mean_reduce_square_mean_ave_time, normalize_ave_time, M, N, K);
|
||||
}
|
||||
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
379
example/21_gemm_layernorm/gemm_layernorm_xdl_fp16.cpp
Normal file
379
example/21_gemm_layernorm/gemm_layernorm_xdl_fp16.cpp
Normal file
@@ -0,0 +1,379 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_5ary_elementwise.hpp"
|
||||
#include "device_gemm_reduce_xdl_cshuffle.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "reference_gemm.hpp"
|
||||
#include "gemm_specialization.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using CDataType = F16;
|
||||
using GemmAccDataType = F32;
|
||||
using ReduceAccDataType = F32;
|
||||
using DDataType = F32;
|
||||
using DPtrsGlobal = ck::Tuple<DDataType*, DDataType*>;
|
||||
using GammaDataType = F16;
|
||||
using BetaDataType = F16;
|
||||
using LayerNormOutDataType = F16;
|
||||
using NormalizeComputeDataType = F32;
|
||||
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using ReduceSumOp = ck::reduce::Add;
|
||||
using DxsReduceOp = ck::Tuple<ReduceSumOp, ReduceSumOp>;
|
||||
|
||||
using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide;
|
||||
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
|
||||
using DxsOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
|
||||
|
||||
using DxsGlobalMemOp =
|
||||
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
|
||||
ck::InMemoryDataOperationEnum::AtomicAdd>;
|
||||
|
||||
static constexpr auto GemmSpecialization =
|
||||
ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle
|
||||
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
|
||||
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOps, DxsOutElementOps, DxsGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
GemmAccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
using NormalizeFunctor = ck::tensor_operation::element_wise::Normalize;
|
||||
|
||||
// A:x, B:E[x], C:E[x^2], D:Gamma, E:Beta , F:y
|
||||
using DeviceNormalizeInstance =
|
||||
ck::tensor_operation::device::Device5AryElementwise<CDataType,
|
||||
DDataType,
|
||||
DDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
LayerNormOutDataType,
|
||||
NormalizeComputeDataType,
|
||||
NormalizeFunctor,
|
||||
2,
|
||||
8,
|
||||
8, // scalarPerVector: gemm_out
|
||||
1, // scalarPerVector: reduce_mean
|
||||
1, // scalarPerVector: reduce_mean_square
|
||||
8, // scalarPerVector: Gamma
|
||||
8, // scalarPerVector: Beta
|
||||
8>; // scalarPerVector: LayerNorm_out
|
||||
|
||||
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({len}),
|
||||
std::vector<std::size_t>({stride}));
|
||||
};
|
||||
|
||||
auto f_host_tensor_descriptor2d =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({stride, 1}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({1, stride}));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename CDataType,
|
||||
typename DDataType,
|
||||
typename A_functor,
|
||||
typename B_functor,
|
||||
typename C_functor>
|
||||
void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
|
||||
const Tensor<ADataType>& a_m_k,
|
||||
const Tensor<ADataType>& b_k_n,
|
||||
const Tensor<GammaDataType>& gamma_n,
|
||||
const Tensor<GammaDataType>& beta_n,
|
||||
A_functor a_element_op,
|
||||
B_functor b_element_op,
|
||||
C_functor c_element_op,
|
||||
int M,
|
||||
int N)
|
||||
{
|
||||
using out_type = ck::remove_reference_t<decltype(out_m_n(0, 0))>;
|
||||
|
||||
int StrideC = N;
|
||||
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
|
||||
Tensor<DDataType> mean_m(f_host_tensor_descriptor1d(M, 1));
|
||||
Tensor<DDataType> meanSquare_m(f_host_tensor_descriptor1d(M, 1));
|
||||
auto averageOpInst = UnaryDivElementOp{N};
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument =
|
||||
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, c_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
// reduce_mean and reduce_square_mean
|
||||
auto reduceSumOpInst = ReduceSumOp{};
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
auto mean_acc = reduceSumOpInst.GetIdentityValue<ReduceAccDataType>();
|
||||
auto square_mean_acc = reduceSumOpInst.GetIdentityValue<ReduceAccDataType>();
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
auto c_val = ck::type_convert<ReduceAccDataType>(c_m_n(m, n));
|
||||
auto square_c_val = reduceSumOpInst.GetIdentityValue<ReduceAccDataType>();
|
||||
|
||||
UnarySquareElementOp{}(square_c_val, c_val);
|
||||
|
||||
reduceSumOpInst(mean_acc, c_val);
|
||||
reduceSumOpInst(square_mean_acc, square_c_val);
|
||||
}
|
||||
|
||||
averageOpInst(mean_acc, mean_acc);
|
||||
averageOpInst(square_mean_acc, square_mean_acc);
|
||||
mean_m(m) = ck::type_convert<DDataType>(mean_acc);
|
||||
meanSquare_m(m) = ck::type_convert<DDataType>(square_mean_acc);
|
||||
}
|
||||
|
||||
// LayerNorm
|
||||
auto layerNormInst = NormalizeFunctor{};
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
float out_f32 = 0;
|
||||
layerNormInst(out_f32,
|
||||
static_cast<float>(c_m_n(m, n)),
|
||||
static_cast<float>(mean_m(m)),
|
||||
static_cast<float>(meanSquare_m(m)),
|
||||
static_cast<float>(gamma_n(n)),
|
||||
static_cast<float>(beta_n(n)));
|
||||
out_m_n(m, n) = static_cast<out_type>(out_f32);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename DDataType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename NormalizeDataType>
|
||||
void DumpGemmLayerNormPerf(float gemm_reduce_time, float normalize_time, int M, int N, int K)
|
||||
{
|
||||
std::size_t gemm_flop = std::size_t(2) * M * N * K;
|
||||
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
|
||||
sizeof(CDataType) * M * N + sizeof(DDataType) * M +
|
||||
sizeof(DDataType) * M;
|
||||
|
||||
std::size_t normalize_num_btye = sizeof(CDataType) * M * N + sizeof(DDataType) * M +
|
||||
sizeof(DDataType) * M + sizeof(GammaDataType) * N +
|
||||
sizeof(BetaDataType) * N + sizeof(NormalizeDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time;
|
||||
float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time;
|
||||
float normalize_gb_per_sec = normalize_num_btye / 1.E6 / normalize_time;
|
||||
|
||||
std::cout << "gemm + reduce_mean + reduce_square_mean Perf: " << gemm_reduce_time << " ms, "
|
||||
<< tflops << " TFlops, " << gemm_gb_per_sec << " GB/s, " << std::endl;
|
||||
|
||||
std::cout << "5-ary elementwise Perf: " << normalize_time << " ms, " << normalize_gb_per_sec
|
||||
<< " GB/s, " << std::endl;
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
// GEMM shape
|
||||
ck::index_t M = 1024;
|
||||
ck::index_t N = 1024;
|
||||
ck::index_t K = 1024;
|
||||
|
||||
ck::index_t StrideA = 1024;
|
||||
ck::index_t StrideB = 1024;
|
||||
ck::index_t StrideC = 1024;
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{}));
|
||||
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
|
||||
Tensor<DDataType> reduceMean_m(f_host_tensor_descriptor1d(M, 1));
|
||||
Tensor<DDataType> reduceMeanSquare_m(f_host_tensor_descriptor1d(M, 1));
|
||||
Tensor<GammaDataType> gamma_n(f_host_tensor_descriptor1d(N, 1));
|
||||
Tensor<BetaDataType> beta_n(f_host_tensor_descriptor1d(N, 1));
|
||||
Tensor<LayerNormOutDataType> layerNorm_m_n(
|
||||
f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
|
||||
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1, 1});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 1});
|
||||
gamma_n.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{-1, 1});
|
||||
beta_n.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{-1, 1});
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
|
||||
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace());
|
||||
DeviceMem reduceMean_device_buf(sizeof(DDataType) * reduceMean_m.mDesc.GetElementSpace());
|
||||
DeviceMem reduceMeanSquare_device_buf(sizeof(DDataType) *
|
||||
reduceMeanSquare_m.mDesc.GetElementSpace());
|
||||
DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpace());
|
||||
DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpace());
|
||||
DeviceMem layerNorm_device_buf(sizeof(LayerNormOutDataType) *
|
||||
layerNorm_m_n.mDesc.GetElementSpace());
|
||||
|
||||
a_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_device_buf.ToDevice(b_k_n.mData.data());
|
||||
gamma_device_buf.ToDevice(gamma_n.mData.data());
|
||||
beta_device_buf.ToDevice(beta_n.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
auto dxs_global =
|
||||
ck::make_tuple(static_cast<DDataType*>(reduceMean_device_buf.GetDeviceBuffer()),
|
||||
static_cast<DDataType*>(reduceMeanSquare_device_buf.GetDeviceBuffer()));
|
||||
|
||||
auto dxs_in_element_op = DxsInElementOps{};
|
||||
auto dxs_out_element_op = DxsOutElementOps{N, N};
|
||||
|
||||
// Prepare GEMM, reduce_mean, reduce_mean_square
|
||||
auto gemmReduce = DeviceGemmReduceInstance{};
|
||||
auto gemmReduce_invoker = gemmReduce.MakeInvoker();
|
||||
auto gemmReduce_argument =
|
||||
gemmReduce.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
dxs_global,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
dxs_in_element_op,
|
||||
dxs_out_element_op);
|
||||
|
||||
if(!gemmReduce.IsSupportedArgument(gemmReduce_argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
reduceMean_device_buf.SetZero();
|
||||
reduceMeanSquare_device_buf.SetZero();
|
||||
|
||||
// Prepare LayerNorm
|
||||
auto normalize = DeviceNormalizeInstance{};
|
||||
auto normalize_invoker = normalize.MakeInvoker();
|
||||
auto normalize_argument = normalize.MakeArgument(
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<DDataType*>(reduceMean_device_buf.GetDeviceBuffer()),
|
||||
static_cast<DDataType*>(reduceMeanSquare_device_buf.GetDeviceBuffer()),
|
||||
static_cast<GammaDataType*>(gamma_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BetaDataType*>(beta_device_buf.GetDeviceBuffer()),
|
||||
static_cast<LayerNormOutDataType*>(layerNorm_device_buf.GetDeviceBuffer()),
|
||||
{M, N},
|
||||
{StrideC, 1},
|
||||
{1, 0},
|
||||
{1, 0},
|
||||
{0, 1},
|
||||
{0, 1},
|
||||
{StrideC, 1},
|
||||
NormalizeFunctor{});
|
||||
|
||||
if(!normalize.IsSupportedArgument(normalize_argument))
|
||||
{
|
||||
throw std::runtime_error("The runtime parameters seems not supported by the "
|
||||
"Device5AryElementwise instance, exiting!");
|
||||
}
|
||||
|
||||
// run kernel
|
||||
gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, false});
|
||||
normalize_invoker.Run(normalize_argument, StreamConfig{nullptr, false});
|
||||
|
||||
bool pass = true;
|
||||
{
|
||||
// verification
|
||||
Tensor<LayerNormOutDataType> host_layerNorm_m_n(
|
||||
f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
|
||||
|
||||
host_gemm_layernorm<CDataType, DDataType>(host_layerNorm_m_n,
|
||||
a_m_k,
|
||||
b_k_n,
|
||||
gamma_n,
|
||||
beta_n,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
M,
|
||||
N);
|
||||
|
||||
layerNorm_device_buf.FromDevice(layerNorm_m_n.mData.data());
|
||||
pass &= ck::utils::check_err(layerNorm_m_n.mData,
|
||||
host_layerNorm_m_n.mData,
|
||||
"Error: Incorrect results d1",
|
||||
1e-3,
|
||||
1e-3);
|
||||
}
|
||||
|
||||
{
|
||||
// evaluate kernel perf
|
||||
bool time_kernel = true;
|
||||
|
||||
float gemm_reduce_mean_reduce_square_mean_ave_time =
|
||||
gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, time_kernel});
|
||||
float normalize_ave_time =
|
||||
normalize_invoker.Run(normalize_argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
if(time_kernel)
|
||||
DumpGemmLayerNormPerf<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
DDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
LayerNormOutDataType>(
|
||||
gemm_reduce_mean_reduce_square_mean_ave_time, normalize_ave_time, M, N, K);
|
||||
}
|
||||
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
1
example/22_cgemm/CMakeLists.txt
Normal file
1
example/22_cgemm/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_example_executable(example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp)
|
||||
302
example/22_cgemm/cgemm_xdl_fp16.cpp
Normal file
302
example/22_cgemm/cgemm_xdl_fp16.cpp
Normal file
@@ -0,0 +1,302 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2022 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_cgemm_4gemm_xdl_cshuffle.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "reference_cgemm.hpp"
|
||||
#include "gemm_specialization.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using CDataType = F16;
|
||||
using AccDataType = F32;
|
||||
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// clang-format off
|
||||
using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
<ALayout, // typename ALayout
|
||||
BLayout, // typename BLayout
|
||||
CLayout, // typename CLayout
|
||||
ADataType, // typename ADataType
|
||||
BDataType, // typename BDataType
|
||||
CDataType, // typename CDataType
|
||||
AccDataType, // typename GemmAccDataType
|
||||
CDataType, // typename CShuffleDataType
|
||||
PassThrough, // typename AElementwiseOperation
|
||||
PassThrough, // typename BElementwiseOperation
|
||||
PassThrough, // typename CElementwiseOperation
|
||||
GemmDefault, // GemmSpecialization GemmSpec
|
||||
1, // index_t NumGemmKPrefetchStage
|
||||
256, // index_t BlockSize
|
||||
256, // index_t MPerBlock
|
||||
128, // index_t NPerBlock
|
||||
32, // index_t KPerBlock
|
||||
8, // index_t AK1
|
||||
8, // index_t BK1
|
||||
32, // index_t MPerXDL
|
||||
32, // index_t NPerXDL
|
||||
4, // index_t MXdlPerWave
|
||||
2, // index_t NXdlPerWave
|
||||
S<4, 64, 1>, // typename ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder
|
||||
2, // index_t ABlockTransferSrcVectorDim
|
||||
8, // index_t ABlockTransferSrcScalarPerVector
|
||||
8, // index_t ABlockTransferDstScalarPerVector_AK1
|
||||
1, // index_t ABlockLdsExtraM
|
||||
S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder
|
||||
2, // index_t BBlockTransferSrcVectorDim
|
||||
8, // index_t BBlockTransferSrcScalarPerVector
|
||||
8, // index_t BBlockTransferDstScalarPerVector_BK1
|
||||
1, // index_t BBlockLdsExtraN
|
||||
1, // index_t CShuffleMXdlPerWavePerShuffle
|
||||
1, // index_t CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
// clang-format on
|
||||
|
||||
using ReferenceCGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceCGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
// CGEMM shape
|
||||
ck::index_t M = 3840;
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
|
||||
ck::index_t StrideA = 4096;
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideC = 4096;
|
||||
|
||||
if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 10)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
|
||||
StrideA = std::stoi(argv[7]);
|
||||
StrideB = std::stoi(argv[8]);
|
||||
StrideC = std::stoi(argv[9]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: run kernel # of times (>1)\n");
|
||||
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({stride, 1}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
|
||||
std::vector<std::size_t>({1, stride}));
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<ADataType> a_m_k_real(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<ADataType> a_m_k_imag(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n_real(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<BDataType> b_k_n_imag(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<CDataType> c_m_n_real_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_imag_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
std::cout << "a_m_k_real: " << a_m_k_real.mDesc << std::endl;
|
||||
std::cout << "a_m_k_imag: " << a_m_k_imag.mDesc << std::endl;
|
||||
std::cout << "b_k_n_real: " << b_k_n_real.mDesc << std::endl;
|
||||
std::cout << "b_k_n_imag: " << b_k_n_imag.mDesc << std::endl;
|
||||
std::cout << "c_m_n_real: " << c_m_n_real_device_result.mDesc << std::endl;
|
||||
std::cout << "c_m_n_imag: " << c_m_n_imag_device_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_m_k_real.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
a_m_k_imag.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b_k_n_real.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
||||
b_k_n_imag.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
||||
break;
|
||||
default:
|
||||
a_m_k_real.GenerateTensorValue(GeneratorTensor_3<ADataType>{-0.5, 0.5});
|
||||
a_m_k_imag.GenerateTensorValue(GeneratorTensor_3<ADataType>{-0.5, 0.5});
|
||||
b_k_n_real.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
b_k_n_imag.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
}
|
||||
|
||||
auto cgemm = DeviceCGemmInstance{};
|
||||
|
||||
DeviceMem a_m_k_real_device_buf(sizeof(ADataType) * a_m_k_real.mDesc.GetElementSpace());
|
||||
DeviceMem a_m_k_imag_device_buf(sizeof(ADataType) * a_m_k_imag.mDesc.GetElementSpace());
|
||||
DeviceMem b_k_n_real_device_buf(sizeof(BDataType) * b_k_n_real.mDesc.GetElementSpace());
|
||||
DeviceMem b_k_n_imag_device_buf(sizeof(BDataType) * b_k_n_imag.mDesc.GetElementSpace());
|
||||
DeviceMem c_m_n_real_device_buf(sizeof(CDataType) *
|
||||
c_m_n_real_device_result.mDesc.GetElementSpace());
|
||||
DeviceMem c_m_n_imag_device_buf(sizeof(CDataType) *
|
||||
c_m_n_imag_device_result.mDesc.GetElementSpace());
|
||||
DeviceMem workspace_device_buf(cgemm.GetWorkspaceSize(M, N, K, StrideA, StrideB, StrideC));
|
||||
|
||||
a_m_k_real_device_buf.ToDevice(a_m_k_real.mData.data());
|
||||
a_m_k_imag_device_buf.ToDevice(a_m_k_imag.mData.data());
|
||||
b_k_n_real_device_buf.ToDevice(b_k_n_real.mData.data());
|
||||
b_k_n_imag_device_buf.ToDevice(b_k_n_imag.mData.data());
|
||||
|
||||
auto a_element_op = PassThrough{};
|
||||
auto b_element_op = PassThrough{};
|
||||
auto c_element_op = PassThrough{};
|
||||
|
||||
// do GEMM
|
||||
auto invoker = cgemm.MakeInvoker();
|
||||
auto argument =
|
||||
cgemm.MakeArgument(static_cast<ADataType*>(a_m_k_real_device_buf.GetDeviceBuffer()),
|
||||
static_cast<ADataType*>(a_m_k_imag_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_k_n_real_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_k_n_imag_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_m_n_real_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_m_n_imag_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(workspace_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
if(!cgemm.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_cgemm with the specified compilation parameters does "
|
||||
"not support this CGEMM problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(8) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
std::size_t(2) *
|
||||
(sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N);
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< cgemm.GetTypeString() << std::endl;
|
||||
|
||||
c_m_n_real_device_buf.FromDevice(c_m_n_real_device_result.mData.data());
|
||||
c_m_n_imag_device_buf.FromDevice(c_m_n_imag_device_result.mData.data());
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
Tensor<CDataType> c_m_n_real_host_result(
|
||||
f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_imag_host_result(
|
||||
f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
auto ref_cgemm = ReferenceCGemmInstance{};
|
||||
auto ref_invoker = ref_cgemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_cgemm.MakeArgument(a_m_k_real,
|
||||
a_m_k_imag,
|
||||
b_k_n_real,
|
||||
b_k_n_imag,
|
||||
c_m_n_real_host_result,
|
||||
c_m_n_imag_host_result,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
ck::utils::check_err(c_m_n_real_device_result.mData,
|
||||
c_m_n_real_host_result.mData,
|
||||
"Verification error: incorrect results in real part!",
|
||||
1e-2f,
|
||||
1e-1f);
|
||||
ck::utils::check_err(c_m_n_imag_device_result.mData,
|
||||
c_m_n_imag_host_result.mData,
|
||||
"Verification error: incorrect results in imaginary part!",
|
||||
1e-2f,
|
||||
1e-1f);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
1
example/23_softmax/CMakeLists.txt
Normal file
1
example/23_softmax/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_example_executable(example_softmax_blockwise softmax_blockwise.cpp)
|
||||
18
example/23_softmax/README.md
Normal file
18
example/23_softmax/README.md
Normal file
@@ -0,0 +1,18 @@
|
||||
# Instructions for ```example_softmax_blockwise```
|
||||
|
||||
## Run ```example_softmax_blockwise```
|
||||
```bash
|
||||
# -D <xxx> : input 3-d tensor lengths
|
||||
# -v <x> : verification (0=no, 1=yes)
|
||||
#arg1: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)
|
||||
#arg2: time kernel (0=no, 1=yes)
|
||||
example_softmax_blockwise -D 4,128,2048 -v 1 1 1
|
||||
```
|
||||
|
||||
Result
|
||||
```
|
||||
launch_and_time_kernel: grid_dim {64, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up 1 time
|
||||
Start running 10 times...
|
||||
Perf: 0.0242877 ms, 259.039 GB/s, DeviceReduceSoftmax<256,M_C8_S1,K_C32_S8,InSrcVectorDim_1_InSrcVectorSize_8_OutDstVectorSize_8>
|
||||
```
|
||||
255
example/23_softmax/softmax_blockwise.cpp
Normal file
255
example/23_softmax/softmax_blockwise.cpp
Normal file
@@ -0,0 +1,255 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <getopt.h>
|
||||
|
||||
#include "check_err.hpp"
|
||||
#include "config.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_base.hpp"
|
||||
#include "device_softmax.hpp"
|
||||
#include "host_common_util.hpp"
|
||||
#include "reference_softmax.hpp"
|
||||
|
||||
#include "reduction_enums.hpp"
|
||||
#include "reduction_operator_mapping.hpp"
|
||||
|
||||
using namespace ck;
|
||||
using namespace ck::tensor_operation::device;
|
||||
|
||||
using InDataType = ck::half_t;
|
||||
using OutDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
|
||||
constexpr int Rank = 3;
|
||||
constexpr int NumReduceDim = 1;
|
||||
|
||||
using DeviceInstance = DeviceSoftmax<InDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
256, // BlockSize
|
||||
8, // ClusterM
|
||||
32, // ClusterK
|
||||
1, // SliceM
|
||||
8, // SliceK
|
||||
1, // SrcVecDim (0=M, 1=K)
|
||||
8, // SrcScalarPerVector
|
||||
8>; // OutScalarPerVector
|
||||
|
||||
static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'},
|
||||
{"verify", required_argument, nullptr, 'v'},
|
||||
{"help", no_argument, nullptr, '?'},
|
||||
{nullptr, 0, nullptr, 0}};
|
||||
|
||||
class SimpleAppArgs
|
||||
{
|
||||
private:
|
||||
int option_index = 0;
|
||||
|
||||
public:
|
||||
std::vector<size_t> inLengths = {8, 128, 2048};
|
||||
std::vector<AccDataType> scales = {2.0f, 2.0f};
|
||||
|
||||
bool do_verification = true;
|
||||
int init_method = 2;
|
||||
bool time_kernel = true;
|
||||
|
||||
public:
|
||||
void show_usage(const char* cmd)
|
||||
{
|
||||
std::cout << "Usage of " << cmd << std::endl;
|
||||
std::cout << "--inLengths or -D, comma separated list of input tensor dimension lengths"
|
||||
<< std::endl;
|
||||
std::cout << "--verify or -v, 1/0 to indicate whether to verify the reduction result by "
|
||||
"comparing with the host-based reduction"
|
||||
<< std::endl;
|
||||
std::cout << "Arg1 -- init method (0=no init, 1=single integer value, 2=scope integer "
|
||||
"value, 3=decimal value)"
|
||||
<< std::endl;
|
||||
std::cout << "Arg2 -- time kernel (0=no, 1=yes)" << std::endl;
|
||||
};
|
||||
|
||||
int processArgs(int argc, char* argv[])
|
||||
{
|
||||
using ck::host_common::getTypeValuesFromString;
|
||||
|
||||
int ch;
|
||||
|
||||
while(1)
|
||||
{
|
||||
ch = getopt_long(argc, argv, "D:v:l:", long_options, &option_index);
|
||||
if(ch == -1)
|
||||
break;
|
||||
switch(ch)
|
||||
{
|
||||
case 'D':
|
||||
if(!optarg)
|
||||
throw std::runtime_error("Invalid option format!");
|
||||
|
||||
inLengths = getTypeValuesFromString<size_t>(optarg);
|
||||
break;
|
||||
case 'v':
|
||||
if(!optarg)
|
||||
throw std::runtime_error("Invalid option format!");
|
||||
|
||||
do_verification = static_cast<bool>(std::atoi(optarg));
|
||||
break;
|
||||
case '?':
|
||||
if(std::string(long_options[option_index].name) == "help")
|
||||
{
|
||||
show_usage(argv[0]);
|
||||
return (-1);
|
||||
};
|
||||
break;
|
||||
default: show_usage(argv[0]); return (-1);
|
||||
};
|
||||
};
|
||||
|
||||
if(optind + 2 > argc)
|
||||
throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!");
|
||||
|
||||
init_method = std::atoi(argv[optind++]);
|
||||
time_kernel = static_cast<bool>(std::atoi(argv[optind]));
|
||||
|
||||
if(scales.empty())
|
||||
{
|
||||
scales.push_back(1.0f);
|
||||
scales.push_back(0.0f);
|
||||
};
|
||||
|
||||
return (0);
|
||||
};
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
// Example: batched gemm C[G, M, N] applies max/sum reduction along N internally
|
||||
const std::vector<int> invariantDims{0, 1};
|
||||
const std::vector<int> reduceDims{2};
|
||||
|
||||
SimpleAppArgs args;
|
||||
|
||||
if(argc > 1)
|
||||
{
|
||||
if(args.processArgs(argc, argv) < 0)
|
||||
return (-1);
|
||||
};
|
||||
|
||||
Tensor<InDataType> in(args.inLengths);
|
||||
Tensor<OutDataType> out_ref(args.inLengths);
|
||||
Tensor<OutDataType> out(args.inLengths);
|
||||
|
||||
auto inStrides = in.mDesc.GetStrides();
|
||||
auto outStrides = out.mDesc.GetStrides();
|
||||
|
||||
AccDataType alpha = args.scales[0];
|
||||
AccDataType beta = args.scales[1];
|
||||
|
||||
std::size_t num_thread = 1;
|
||||
|
||||
if(args.do_verification)
|
||||
{
|
||||
switch(args.init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
in.GenerateTensorValue(GeneratorTensor_1<InDataType>{1}, num_thread);
|
||||
if(beta != 0.0f)
|
||||
out_ref.GenerateTensorValue(GeneratorTensor_1<OutDataType>{1}, num_thread);
|
||||
break;
|
||||
case 2:
|
||||
in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5}, num_thread);
|
||||
if(beta != 0.0f)
|
||||
out_ref.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5}, num_thread);
|
||||
break;
|
||||
default:
|
||||
in.GenerateTensorValue(GeneratorTensor_3<InDataType>{-5.0, 5.0}, num_thread);
|
||||
if(beta != 0.0f)
|
||||
out_ref.GenerateTensorValue(GeneratorTensor_3<OutDataType>{-5.0, 5.0}, num_thread);
|
||||
}
|
||||
|
||||
if(beta != 0.0f)
|
||||
for(size_t i = 0; i < out_ref.mDesc.GetElementSpace(); i++)
|
||||
out.mData[i] = out_ref.mData[i];
|
||||
};
|
||||
// std::cout << "beta = " << beta << std::endl;
|
||||
// LogRangeAsType<float>(std::cout << "tensor in: " , in.mData, ",") << std::endl;
|
||||
// LogRangeAsType<float>(std::cout << "tensor prior out: " , out.mData, ",") << std::endl;
|
||||
|
||||
// these buffers are usually provided by the user application
|
||||
DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpace());
|
||||
DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpace());
|
||||
|
||||
in_dev.ToDevice(in.mData.data());
|
||||
|
||||
if(beta != 0.0f)
|
||||
out_dev.ToDevice(out.mData.data());
|
||||
|
||||
if(args.do_verification)
|
||||
{
|
||||
using ReferenceInstance =
|
||||
tensor_operation::host::ReferenceSoftmax<InDataType, OutDataType, AccDataType>;
|
||||
ReferenceInstance ref;
|
||||
auto ref_arg = ref.MakeArgument(in, out_ref, alpha, beta, Rank, reduceDims);
|
||||
auto invoker = ref.MakeInvoker();
|
||||
invoker.Run(ref_arg);
|
||||
// LogRangeAsType<float>(std::cout << "tensor out_ref: ", out_ref.mData, ",") << std::endl;
|
||||
};
|
||||
|
||||
std::vector<ck::index_t> i_inLengths;
|
||||
std::vector<ck::index_t> i_inStrides;
|
||||
|
||||
i_inLengths.assign(args.inLengths.begin(), args.inLengths.end());
|
||||
i_inStrides.assign(inStrides.begin(), inStrides.end());
|
||||
|
||||
auto device_instance = DeviceInstance{};
|
||||
|
||||
auto argument_ptr = device_instance.MakeArgumentPointer(i_inLengths,
|
||||
i_inStrides,
|
||||
reduceDims,
|
||||
alpha,
|
||||
beta,
|
||||
in_dev.GetDeviceBuffer(),
|
||||
out_dev.GetDeviceBuffer());
|
||||
|
||||
if(!device_instance.IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
std::cout
|
||||
<< "The runtime parameters seems not supported by the DeviceReduce instance, exiting!"
|
||||
<< std::endl;
|
||||
return 1;
|
||||
};
|
||||
|
||||
std::string instance_name = device_instance.GetTypeString();
|
||||
|
||||
auto invoker_ptr = device_instance.MakeInvokerPointer();
|
||||
|
||||
bool pass = true;
|
||||
if(args.do_verification)
|
||||
{
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
|
||||
out_dev.FromDevice(out.mData.data());
|
||||
// LogRangeAsType<float>(std::cout << "tensor out: " , out.mData, ",") << std::endl;
|
||||
pass = pass && ck::utils::check_err(out.mData, out_ref.mData);
|
||||
};
|
||||
|
||||
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, args.time_kernel});
|
||||
|
||||
std::size_t num_bytes =
|
||||
in.mDesc.GetElementSize() * sizeof(InDataType) +
|
||||
(beta == 0.0f ? 1 : 2) * out.mDesc.GetElementSize() * sizeof(OutDataType);
|
||||
|
||||
float gb_per_sec = num_bytes / 1.E6 / avg_time;
|
||||
|
||||
std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s, " << instance_name
|
||||
<< std::endl;
|
||||
|
||||
return (pass ? 0 : 1);
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
include_directories(BEFORE
|
||||
${PROJECT_SOURCE_DIR}/include/ck
|
||||
${PROJECT_SOURCE_DIR}/include/ck/utility
|
||||
${PROJECT_SOURCE_DIR}/include/ck/host_utility
|
||||
${PROJECT_SOURCE_DIR}/include/ck/tensor_description
|
||||
${PROJECT_SOURCE_DIR}/include/ck/tensor
|
||||
${PROJECT_SOURCE_DIR}/include/ck/problem_transform
|
||||
@@ -19,17 +20,26 @@ include_directories(BEFORE
|
||||
|
||||
add_custom_target(examples)
|
||||
|
||||
function(add_example_executable EXAMPLE_NAME)
|
||||
function(add_example_executable EXAMPLE_NAME FILE_NAME)
|
||||
message("adding example ${EXAMPLE_NAME}")
|
||||
add_executable(${EXAMPLE_NAME} ${ARGN})
|
||||
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
|
||||
target_link_libraries(${EXAMPLE_NAME} PRIVATE host_tensor)
|
||||
add_test(NAME ${EXAMPLE_NAME} COMMAND $<TARGET_FILE:${EXAMPLE_NAME}> ${ARGN})
|
||||
add_dependencies(examples ${EXAMPLE_NAME})
|
||||
add_dependencies(check ${EXAMPLE_NAME})
|
||||
endfunction(add_example_executable EXAMPLE_NAME)
|
||||
|
||||
function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
|
||||
message("adding example ${EXAMPLE_NAME}")
|
||||
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
|
||||
target_link_libraries(${EXAMPLE_NAME} PRIVATE host_tensor)
|
||||
add_dependencies(examples ${EXAMPLE_NAME})
|
||||
endfunction(add_example_executable EXAMPLE_NAME)
|
||||
endfunction(add_example_executable_no_testing EXAMPLE_NAME)
|
||||
|
||||
add_subdirectory(01_gemm)
|
||||
add_subdirectory(02_gemm_alpha_beta)
|
||||
add_subdirectory(03_gemm_bias_relu)
|
||||
add_subdirectory(04_gemm_bias_relu_add)
|
||||
add_subdirectory(04_gemm_add_add_fastgelu)
|
||||
add_subdirectory(06_conv2d_fwd_bias_relu)
|
||||
add_subdirectory(07_conv2d_fwd_bias_relu_add)
|
||||
add_subdirectory(09_convnd_fwd)
|
||||
@@ -38,7 +48,12 @@ add_subdirectory(11_conv2d_bwd_weight)
|
||||
add_subdirectory(12_reduce)
|
||||
add_subdirectory(13_pool2d_fwd)
|
||||
add_subdirectory(14_gemm_xdl_requant_relu_requant)
|
||||
add_subdirectory(17_convnd_bwd_data_xdl)
|
||||
add_subdirectory(15_grouped_gemm)
|
||||
add_subdirectory(16_gemm_reduce)
|
||||
add_subdirectory(17_convnd_bwd_data_xdl)
|
||||
add_subdirectory(18_batched_gemm_reduce)
|
||||
add_subdirectory(19_binary_elementwise)
|
||||
add_subdirectory(20_convnd_bwd_weight_xdl)
|
||||
add_subdirectory(21_gemm_layernorm)
|
||||
add_subdirectory(22_cgemm)
|
||||
add_subdirectory(23_softmax)
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#ifndef CK_CONFIG_AMD_HPP
|
||||
#define CK_CONFIG_AMD_HPP
|
||||
|
||||
@@ -76,6 +79,12 @@
|
||||
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
|
||||
#endif
|
||||
|
||||
#if defined(__gfx90a__) // for GPU code
|
||||
#define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1
|
||||
#else
|
||||
#define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0
|
||||
#endif
|
||||
|
||||
// inline asm
|
||||
#define CK_USE_AMD_INLINE_ASM 1
|
||||
|
||||
@@ -91,10 +100,11 @@
|
||||
// experimental feature: static tensor descriptor
|
||||
#define CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR 0
|
||||
|
||||
// experimental feature: buffer load/store/atomic-add OOB trick
|
||||
// experimental feature: buffer load/store/atomic-add/ OOB trick
|
||||
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
|
||||
#define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1
|
||||
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1
|
||||
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK 1
|
||||
|
||||
// experimental feature: in-regsiter sub-dword transpose
|
||||
#define CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE 1
|
||||
@@ -109,6 +119,10 @@
|
||||
// experimental feature: use __builtin_memcpy instead of union to do bit_cast
|
||||
#define CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST 1
|
||||
|
||||
// experimental feature: optimize for inter-wave scheduling policy
|
||||
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING 0
|
||||
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS 1
|
||||
|
||||
// hack: have underlying assumption that need to be satsified, otherwise it's a bug
|
||||
// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be
|
||||
// thread-invariant, otherwise it's a bug
|
||||
@@ -128,19 +142,29 @@
|
||||
// tuning parameter
|
||||
#define CK_WORKAROUND_SWDEV_325164 1
|
||||
|
||||
// workaround for verification failure ConvNd forward
|
||||
// https://github.com/ROCmSoftwarePlatform/composable_kernel/issues/135
|
||||
#define CK_WORKAROUND_GITHUB_135 1
|
||||
|
||||
namespace ck {
|
||||
|
||||
enum struct InMemoryDataOperationEnum
|
||||
{
|
||||
Set,
|
||||
AtomicAdd,
|
||||
AtomicMax,
|
||||
Add
|
||||
};
|
||||
|
||||
template <InMemoryDataOperationEnum... Is>
|
||||
struct InMemoryDataOperationEnumSequence
|
||||
{
|
||||
static constexpr int mSize = sizeof...(Is);
|
||||
|
||||
__host__ __device__ static constexpr InMemoryDataOperationEnum At(int I)
|
||||
{
|
||||
// the last dummy element is to prevent compiler complain about empty array, when mSize = 0
|
||||
const InMemoryDataOperationEnum mData[mSize + 1] = {Is..., InMemoryDataOperationEnum::Set};
|
||||
return mData[I];
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: no longer needed, remove this
|
||||
enum struct ActivTypeEnum
|
||||
{
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
// "_PACKAGE_" to avoid name contentions: the macros like
|
||||
// HIP_VERSION_MAJOR are defined in HIP_VERSION.h.
|
||||
// clang-format off
|
||||
#define CK_HIP_PACKAGE_VERSION_MAJOR @CK_HIP_VERSION_MAJOR@
|
||||
#define CK_HIP_PACKAGE_VERSION_MINOR @CK_HIP_VERSION_MINOR@
|
||||
#define CK_HIP_PACKAGE_VERSION_PATCH @CK_HIP_VERSION_PATCH@
|
||||
// clang-format on
|
||||
|
||||
#ifndef CK_HIP_PACKAGE_VERSION_MAJOR
|
||||
#define CK_HIP_PACKAGE_VERSION_MAJOR 0
|
||||
#endif
|
||||
#ifndef CK_HIP_PACKAGE_VERSION_MINOR
|
||||
#define CK_HIP_PACKAGE_VERSION_MINOR 0
|
||||
#endif
|
||||
#ifndef CK_HIP_PACKAGE_VERSION_PATCH
|
||||
#define CK_HIP_PACKAGE_VERSION_PATCH 0
|
||||
#endif
|
||||
// 3 decimal digits for major and minor, 6 digits for patch number.
|
||||
// Max number is 999,999,999999 == 0xE8,D4A5,0FFF that fits into 64-bit math.
|
||||
#if CK_HIP_PACKAGE_VERSION_MAJOR > 999 || CK_HIP_PACKAGE_VERSION_MAJOR > 999 || \
|
||||
CK_HIP_PACKAGE_VERSION_PATCH > 999999
|
||||
#error "Too big HIP version number(s)"
|
||||
#endif
|
||||
#define CK_HIP_PACKAGE_VERSION_FLAT \
|
||||
((CK_HIP_PACKAGE_VERSION_MAJOR * 1000ULL + CK_HIP_PACKAGE_VERSION_MINOR) * 1000000 + \
|
||||
CK_HIP_PACKAGE_VERSION_PATCH)
|
||||
50
include/ck/host_utility/device_prop.hpp
Normal file
50
include/ck/host_utility/device_prop.hpp
Normal file
@@ -0,0 +1,50 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <map>
|
||||
|
||||
namespace ck {
|
||||
|
||||
inline std::string get_device_name()
|
||||
{
|
||||
hipDeviceProp_t props{};
|
||||
int device;
|
||||
auto status = hipGetDevice(&device);
|
||||
if(status != hipSuccess)
|
||||
{
|
||||
return std::string();
|
||||
}
|
||||
|
||||
status = hipGetDeviceProperties(&props, device);
|
||||
if(status != hipSuccess)
|
||||
{
|
||||
return std::string();
|
||||
}
|
||||
const std::string raw_name(props.gcnArchName);
|
||||
|
||||
// https://github.com/ROCmSoftwarePlatform/MIOpen/blob/8498875aef84878e04c1eabefdf6571514891086/src/target_properties.cpp#L40
|
||||
static std::map<std::string, std::string> device_name_map = {
|
||||
{"Ellesmere", "gfx803"},
|
||||
{"Baffin", "gfx803"},
|
||||
{"RacerX", "gfx803"},
|
||||
{"Polaris10", "gfx803"},
|
||||
{"Polaris11", "gfx803"},
|
||||
{"Tonga", "gfx803"},
|
||||
{"Fiji", "gfx803"},
|
||||
{"gfx800", "gfx803"},
|
||||
{"gfx802", "gfx803"},
|
||||
{"gfx804", "gfx803"},
|
||||
{"Vega10", "gfx900"},
|
||||
{"gfx901", "gfx900"},
|
||||
{"10.3.0 Sienna_Cichlid 18", "gfx1030"},
|
||||
};
|
||||
|
||||
const auto name = raw_name.substr(0, raw_name.find(':')); // str.substr(0, npos) returns str.
|
||||
|
||||
auto match = device_name_map.find(name);
|
||||
if(match != device_name_map.end())
|
||||
return match->second;
|
||||
return name;
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
3
include/ck/options.hpp
Normal file
3
include/ck/options.hpp
Normal file
@@ -0,0 +1,3 @@
|
||||
#pragma once
|
||||
|
||||
#define CK_TIME_KERNEL 1
|
||||
10
include/ck/stream_config.hpp
Normal file
10
include/ck/stream_config.hpp
Normal file
@@ -0,0 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
struct StreamConfig
|
||||
{
|
||||
hipStream_t stream_id_ = nullptr;
|
||||
bool time_kernel_ = false;
|
||||
};
|
||||
@@ -136,7 +136,11 @@ struct TensorAdaptor
|
||||
using ElementSize = remove_cv_t<decltype(InitializeElementSize(Transforms{}))>;
|
||||
|
||||
public:
|
||||
#if 0 // workaround compiler complaint about constexpr
|
||||
__host__ __device__ constexpr TensorAdaptor() = default;
|
||||
#else
|
||||
__host__ __device__ constexpr TensorAdaptor() : transforms_{}, element_size_{} {}
|
||||
#endif
|
||||
|
||||
__host__ __device__ constexpr TensorAdaptor(const Transforms& transforms)
|
||||
: transforms_{transforms}, element_size_{InitializeElementSize(transforms)}
|
||||
|
||||
@@ -111,7 +111,14 @@ struct TensorDescriptor
|
||||
using ElementSize = remove_cv_t<decltype(InitializeElementSize(Transforms{}))>;
|
||||
|
||||
public:
|
||||
#if 0 // workaround compiler complaint about constexpr
|
||||
__host__ __device__ constexpr TensorDescriptor() = default;
|
||||
#else
|
||||
__host__ __device__ constexpr TensorDescriptor()
|
||||
: transforms_{}, element_size_{}, element_space_size_{}
|
||||
{
|
||||
}
|
||||
#endif
|
||||
|
||||
__host__ __device__ constexpr TensorDescriptor(const Transforms& transforms,
|
||||
ElementSpaceSize element_space_size)
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
#ifndef CK_TENSOR_DESCRIPTOR_HELPER_HPP
|
||||
#define CK_TENSOR_DESCRIPTOR_HELPER_HPP
|
||||
|
||||
#pragma once
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "multi_index_transform_helper.hpp"
|
||||
@@ -35,6 +33,12 @@ __host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengt
|
||||
}
|
||||
#endif
|
||||
|
||||
// Lengths..., Strides... could be:
|
||||
// 1) index_t, which is known at run-time, or
|
||||
// 2) Number<>, which is known at compile-time
|
||||
// element_space_size could be:
|
||||
// 1) long_index_t, or
|
||||
// 2) LongNumber<>
|
||||
template <typename... Lengths,
|
||||
typename... Strides,
|
||||
typename enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
|
||||
@@ -68,10 +72,10 @@ __host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple<Leng
|
||||
}
|
||||
};
|
||||
|
||||
const auto element_space_size = f(f, Number<0>{}, Number<1>{});
|
||||
const auto element_space_size = f(f, Number<0>{}, LongNumber<1>{});
|
||||
#else
|
||||
const auto element_space_size =
|
||||
calculate_element_space_size_impl(lengths, strides, Number<0>{}, Number<1>{});
|
||||
calculate_element_space_size_impl(lengths, strides, Number<0>{}, LongNumber<1>{});
|
||||
#endif
|
||||
|
||||
return TensorDescriptor<remove_cv_t<decltype(transforms)>,
|
||||
@@ -82,9 +86,12 @@ __host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple<Leng
|
||||
element_space_size};
|
||||
}
|
||||
|
||||
// Lengths... can be:
|
||||
// 1) index_t, which is known at run-time
|
||||
// Lengths... could be:
|
||||
// 1) index_t, which is known at run-time, or
|
||||
// 2) Number<>, which is known at compile-time
|
||||
// element_space_size could be:
|
||||
// 1) long_index_t, or
|
||||
// 2) LongNumber<>
|
||||
template <typename... Lengths>
|
||||
__host__ __device__ constexpr auto
|
||||
make_naive_tensor_descriptor_packed(const Tuple<Lengths...>& lengths)
|
||||
@@ -100,7 +107,7 @@ make_naive_tensor_descriptor_packed(const Tuple<Lengths...>& lengths)
|
||||
|
||||
constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
|
||||
|
||||
const auto element_space_size = container_reduce(lengths, math::multiplies{}, Number<1>{});
|
||||
const auto element_space_size = container_reduce(lengths, math::multiplies{}, LongNumber<1>{});
|
||||
|
||||
return TensorDescriptor<remove_cv_t<decltype(transforms)>,
|
||||
remove_cv_t<decltype(low_dim_hidden_idss)>,
|
||||
@@ -110,6 +117,12 @@ make_naive_tensor_descriptor_packed(const Tuple<Lengths...>& lengths)
|
||||
element_space_size};
|
||||
}
|
||||
|
||||
// Lengths... could be:
|
||||
// 1) index_t, which is known at run-time, or
|
||||
// 2) Number<>, which is known at compile-time
|
||||
// align could be:
|
||||
// 1) index_t, or
|
||||
// 2) Number<>
|
||||
template <typename... Lengths, typename Align>
|
||||
__host__ __device__ constexpr auto
|
||||
make_naive_tensor_descriptor_aligned(const Tuple<Lengths...>& lengths, Align align)
|
||||
@@ -146,4 +159,3 @@ make_naive_tensor_descriptor_aligned(const Tuple<Lengths...>& lengths, Align ali
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R3_HPP
|
||||
#define CK_BLOCKWISE_GEMM_DLOPS_V2R3_HPP
|
||||
|
||||
#pragma once
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_adaptor.hpp"
|
||||
#include "threadwise_tensor_slice_transfer_v2.hpp"
|
||||
#include "threadwise_contraction_dlops.hpp"
|
||||
#include "threadwise_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "threadwise_contraction_dl.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -41,7 +39,7 @@ template <index_t BlockSize,
|
||||
typename enable_if<ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
|
||||
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
|
||||
struct BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
|
||||
{
|
||||
using AIndex = MultiIndex<3>;
|
||||
using BIndex = MultiIndex<3>;
|
||||
@@ -148,7 +146,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
|
||||
MakeBBlockDescriptor_BK0_BN0_BN1_BK1(BBlockDesc_BK0_BN_BK1{});
|
||||
|
||||
public:
|
||||
__device__ BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2()
|
||||
__device__ BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2()
|
||||
: c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
|
||||
get_thread_local_1d_id())},
|
||||
a_thread_copy_{
|
||||
@@ -175,6 +173,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
|
||||
"wrong!");
|
||||
|
||||
// TODO: remove this restriction
|
||||
static_assert(BM0 == 2, "wrong");
|
||||
static_assert(BM0 == 2 && BN0 == 2, "wrong");
|
||||
}
|
||||
|
||||
@@ -226,7 +225,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
|
||||
b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize());
|
||||
|
||||
constexpr auto threadwise_contraction =
|
||||
ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1<
|
||||
ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1<
|
||||
FloatA,
|
||||
FloatB,
|
||||
FloatC,
|
||||
@@ -407,4 +406,3 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -3,10 +3,26 @@
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "xdlops_gemm.hpp"
|
||||
#include "tensor_adaptor.hpp"
|
||||
#include "thread_group.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename ThreadGroup,
|
||||
enum struct LoopScheduler
|
||||
{
|
||||
Default,
|
||||
Interwave,
|
||||
};
|
||||
|
||||
constexpr LoopScheduler make_default_loop_scheduler()
|
||||
{
|
||||
#if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
|
||||
return LoopScheduler::Interwave;
|
||||
#else
|
||||
return LoopScheduler::Default;
|
||||
#endif // if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename AK0MK1BlockDesc,
|
||||
@@ -23,6 +39,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
static constexpr index_t WaveSize = get_warp_size();
|
||||
|
||||
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
|
||||
@@ -53,7 +71,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
|
||||
__device__ static auto GetWaveIdx()
|
||||
{
|
||||
const index_t thread_id = ThreadGroup::GetThreadId();
|
||||
const index_t thread_id = ThisThreadBlock::GetThreadId();
|
||||
|
||||
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
|
||||
@@ -120,8 +138,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
BK0NK1BlockDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
static_assert(ThreadGroup::GetNumOfThread() == MWaves * NWaves * WaveSize,
|
||||
"ThreadGroup::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
|
||||
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
|
||||
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
|
||||
|
||||
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
|
||||
"wrong!");
|
||||
@@ -299,7 +317,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
protected:
|
||||
// A[M0, M1, M2, KPerThread]
|
||||
static constexpr auto a_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{}));
|
||||
@@ -336,4 +354,232 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()};
|
||||
};
|
||||
|
||||
// Note: To facilitate the inter-wave loop scheduler, we need to explicitly set the macro
|
||||
// CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=1 as a few intrinsics are not yet available in
|
||||
// the latest ROCm release. For unsupported compilers, inter-wave loop scheduler falls back to the
|
||||
// default loop scheduler which is given by the macro CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=0
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename AK0MK1BlockDesc,
|
||||
typename BK0NK1BlockDesc,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS>
|
||||
struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
: public BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
AK0MK1BlockDesc,
|
||||
BK0NK1BlockDesc,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
{
|
||||
using Base = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
AK0MK1BlockDesc,
|
||||
BK0NK1BlockDesc,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
|
||||
#if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
|
||||
using Base::a_block_desc_m0_m1_m2_k;
|
||||
using Base::A_K1;
|
||||
using Base::b_block_desc_n0_n1_n2_k;
|
||||
using Base::B_K1;
|
||||
using Base::c_thread_buf_;
|
||||
using Base::c_thread_desc_;
|
||||
using Base::CalculateAThreadOriginDataIndex;
|
||||
using Base::CalculateBThreadOriginDataIndex;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::KPerThread;
|
||||
using Base::xdlops_gemm;
|
||||
|
||||
static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack);
|
||||
|
||||
// 2-wave optimized blockwise gemm
|
||||
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
|
||||
__device__ void Run(const ABlockBuffer& a_block_buf,
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
static_for<0, KPerThread, KPerInnerLoop>{}([&](auto k) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
// read A
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, k),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
// read B
|
||||
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
|
||||
make_tuple(n0, I0, I0, k),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(n0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
// NOTE: Synchronize threads in a workgroup at the start of each MAC cluster, but except
|
||||
// the first, as we can shorten non-MAC cluster a bit and there's no observable negative
|
||||
// impact. The desired effect is waves in a workgroup executing MAC in sync. This avoids
|
||||
// some out-of-sync waves hijacking MAC resource from other workgroups and reducing the
|
||||
// chance of latency hiding by waiting for the rest of the workgroup at the eventual
|
||||
// sync point.
|
||||
if constexpr(k.value != 0 || KPerInnerLoop == KPerThread)
|
||||
{
|
||||
asm volatile("s_barrier" ::);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<FloatAB, KPack> a_thread_vec;
|
||||
vector_type<FloatAB, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto i) {
|
||||
a_thread_vec.template AsType<FloatAB>()(i) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, 0, 0, k_ + i))>{}];
|
||||
b_thread_vec.template AsType<FloatAB>()(i) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, 0, 0, k_ + i))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
// The block_sync_lds() here performs double duty:
|
||||
// A) safeguard against data hazard because barrier from blockwise_gemm is
|
||||
// moved here B) reduce VMEM FIFO congestion by applying small delays to
|
||||
// different wavefronts It is performed near the end of MAC cluster to
|
||||
// minimize lgkmcnt penalty
|
||||
if constexpr(k.value == KPerThread - KPerInnerLoop &&
|
||||
k_.value == KPerInnerLoop - KPack && m0.value == MRepeat - 1 &&
|
||||
n0.value == NRepeat - 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
block_sync_lds();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
// TODO: insert setprio in more precise manner since we
|
||||
// could have more than >1 MFMA instructions in single call
|
||||
xdlops_gemm.template Run(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
__builtin_amdgcn_s_setprio(1);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
__builtin_amdgcn_s_setprio(0);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
});
|
||||
}
|
||||
|
||||
protected:
|
||||
// A[M0, M1, M2, KPerInnerLoop]
|
||||
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{}, I1, I1, Number<KPerInnerLoop>{}));
|
||||
|
||||
// B[N0, N1, N2, KPerInnerLoop]
|
||||
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<NRepeat>{}, I1, I1, Number<KPerInnerLoop>{}));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_block_desc_m0_m1_m2_k),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, KPerInnerLoop>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
A_K1,
|
||||
A_K1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_block_desc_n0_n1_n2_k),
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<1, 1, 1, KPerInnerLoop>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
B_K1,
|
||||
B_K1>;
|
||||
|
||||
AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()};
|
||||
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()};
|
||||
|
||||
#endif // #if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
|
||||
};
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename AK0MK1BlockDesc,
|
||||
typename BK0NK1BlockDesc,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
LoopScheduler LoopSched>
|
||||
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
|
||||
{
|
||||
if constexpr(LoopSched == LoopScheduler::Default)
|
||||
{
|
||||
return BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
AK0MK1BlockDesc,
|
||||
BK0NK1BlockDesc,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
else if constexpr(LoopSched == LoopScheduler::Interwave)
|
||||
{
|
||||
return BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
AK0MK1BlockDesc,
|
||||
BK0NK1BlockDesc,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -1,169 +0,0 @@
|
||||
#pragma once
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "cluster_descriptor.hpp"
|
||||
#include "threadwise_tensor_slice_transfer_v3r1.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// this version does following things to avoid scratch memory issue
|
||||
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
|
||||
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
|
||||
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
|
||||
template <index_t BlockSize,
|
||||
typename SrcElementwiseOperation,
|
||||
typename DstElementwiseOperation,
|
||||
InMemoryDataOperationEnum DstInMemOp,
|
||||
typename BlockSliceLengths,
|
||||
typename ThreadClusterLengths,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename SrcDimAccessOrder,
|
||||
typename DstDimAccessOrder,
|
||||
index_t SrcVectorDim,
|
||||
index_t DstVectorDim,
|
||||
index_t SrcScalarPerVector,
|
||||
index_t DstScalarPerVector,
|
||||
index_t SrcScalarStrideInVector,
|
||||
index_t DstScalarStrideInVector,
|
||||
bool ThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool ThreadTransferDstResetCoordinateAfterRun,
|
||||
index_t NumThreadScratch = 1>
|
||||
struct BlockwiseTensorSliceTransfer_v4r1
|
||||
{
|
||||
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
|
||||
|
||||
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
__device__ constexpr BlockwiseTensorSliceTransfer_v4r1(
|
||||
const SrcDesc& src_desc,
|
||||
const Index& src_block_slice_origin,
|
||||
const SrcElementwiseOperation& src_element_op,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_block_slice_origin,
|
||||
const DstElementwiseOperation& dst_element_op)
|
||||
: threadwise_transfer_(src_desc,
|
||||
make_zero_multi_index<nDim>(),
|
||||
src_element_op,
|
||||
dst_desc,
|
||||
make_zero_multi_index<nDim>(),
|
||||
dst_element_op)
|
||||
|
||||
{
|
||||
static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() &&
|
||||
nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
|
||||
nDim == ThreadClusterLengths::Size() &&
|
||||
nDim == ThreadClusterArrangeOrder::Size() &&
|
||||
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
|
||||
"wrong! nDim not consistent");
|
||||
|
||||
static_assert(
|
||||
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
|
||||
"wrong! threads should be mapped to cover entire slicing window");
|
||||
|
||||
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
|
||||
"wrong! BlockSize too small");
|
||||
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(get_thread_local_1d_id()));
|
||||
|
||||
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
|
||||
|
||||
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
|
||||
src_block_slice_origin + thread_data_idx_begin);
|
||||
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
|
||||
dst_block_slice_origin + thread_data_idx_begin);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, index_t ThreadScratchId = 0>
|
||||
__device__ void RunRead(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunRead(src_desc, src_buf, thread_scratch_id);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DstBuffer, index_t ThreadScratchId = 0>
|
||||
__device__ void RunWrite(const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunWrite(dst_desc, dst_buf, thread_scratch_id);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, typename DstBuffer, index_t ThreadScratchId>
|
||||
__device__ void Run(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf,
|
||||
Number<ThreadScratchId> thread_scratch_id)
|
||||
{
|
||||
RunRead(src_desc, src_buf, thread_scratch_id);
|
||||
RunWrite(dst_desc, dst_buf, thread_scratch_id);
|
||||
}
|
||||
|
||||
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr auto thread_cluster_desc_ =
|
||||
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadwiseTransfer =
|
||||
ThreadwiseTensorSliceTransfer_v3r1<decltype(thread_slice_lengths),
|
||||
SrcElementwiseOperation,
|
||||
DstElementwiseOperation,
|
||||
DstInMemOp,
|
||||
SrcData,
|
||||
DstData,
|
||||
SrcDesc,
|
||||
DstDesc,
|
||||
SrcDimAccessOrder,
|
||||
DstDimAccessOrder,
|
||||
SrcVectorDim,
|
||||
DstVectorDim,
|
||||
SrcScalarPerVector,
|
||||
DstScalarPerVector,
|
||||
SrcScalarStrideInVector,
|
||||
DstScalarStrideInVector,
|
||||
ThreadTransferSrcResetCoordinateAfterRun,
|
||||
ThreadTransferDstResetCoordinateAfterRun,
|
||||
NumThreadScratch>;
|
||||
|
||||
ThreadwiseTransfer threadwise_transfer_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -75,14 +75,13 @@ struct BlockwiseTensorSliceTransfer_v5r1
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, typename SrcStepHacks>
|
||||
__device__ void
|
||||
RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
|
||||
template <typename SrcBuffer>
|
||||
__device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks);
|
||||
threadwise_transfer_.RunRead(src_desc, src_buf);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,133 +0,0 @@
|
||||
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R1_HPP
|
||||
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R1_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "cluster_descriptor.hpp"
|
||||
#include "threadwise_tensor_slice_transfer_v6r1.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// this version does following things to avoid scratch memory issue
|
||||
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
|
||||
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
|
||||
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
|
||||
template <index_t BlockSize,
|
||||
typename ElementwiseOperation,
|
||||
InMemoryDataOperationEnum DstInMemOp,
|
||||
typename BlockSliceLengths,
|
||||
typename ThreadClusterLengths,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename DimAccessOrder,
|
||||
index_t VectorDim,
|
||||
index_t ScalarPerVector,
|
||||
bool ThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool ThreadTransferDstResetCoordinateAfterRun>
|
||||
struct BlockwiseTensorSliceTransfer_v6r1
|
||||
{
|
||||
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
|
||||
|
||||
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
__device__ constexpr BlockwiseTensorSliceTransfer_v6r1(const SrcDesc& src_desc,
|
||||
const Index& src_block_slice_origin,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_block_slice_origin,
|
||||
const ElementwiseOperation& element_op)
|
||||
: threadwise_transfer_(src_desc,
|
||||
make_zero_multi_index<nDim>(),
|
||||
dst_desc,
|
||||
make_zero_multi_index<nDim>(),
|
||||
element_op)
|
||||
|
||||
{
|
||||
static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() &&
|
||||
nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
|
||||
nDim == ThreadClusterLengths::Size() &&
|
||||
nDim == ThreadClusterArrangeOrder::Size() &&
|
||||
nDim == DimAccessOrder::Size(),
|
||||
"wrong! nDim not consistent");
|
||||
|
||||
static_assert(
|
||||
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
|
||||
"wrong! threads should be mapped to cover entire slicing window");
|
||||
|
||||
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
|
||||
"wrong! BlockSize too small");
|
||||
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(get_thread_local_1d_id()));
|
||||
|
||||
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
|
||||
|
||||
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
|
||||
src_block_slice_origin + thread_data_idx_begin);
|
||||
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
|
||||
dst_block_slice_origin + thread_data_idx_begin);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, typename DstBuffer>
|
||||
__device__ void Run(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.Run(src_desc, src_buf, dst_desc, dst_buf);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr auto thread_cluster_desc_ =
|
||||
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadwiseTransfer =
|
||||
ThreadwiseTensorSliceTransfer_v6r1<SrcData,
|
||||
DstData,
|
||||
SrcDesc,
|
||||
DstDesc,
|
||||
ElementwiseOperation,
|
||||
decltype(thread_slice_lengths),
|
||||
DimAccessOrder,
|
||||
VectorDim,
|
||||
ScalarPerVector,
|
||||
DstInMemOp,
|
||||
ThreadTransferSrcResetCoordinateAfterRun,
|
||||
ThreadTransferDstResetCoordinateAfterRun>;
|
||||
|
||||
ThreadwiseTransfer threadwise_transfer_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -45,7 +45,9 @@ template <typename AccDataType,
|
||||
typename ThreadClusterLengths_M_K,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename OpReduce,
|
||||
bool PropagateNan>
|
||||
bool PropagateNan,
|
||||
typename Accumulation =
|
||||
detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>>
|
||||
struct PartitionedBlockwiseReduction
|
||||
{
|
||||
static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1),
|
||||
@@ -62,8 +64,6 @@ struct PartitionedBlockwiseReduction
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>;
|
||||
|
||||
template <typename BufferType>
|
||||
__device__ static void Reduce(BufferType& work_buffer, AccDataType& in_out_value)
|
||||
{
|
||||
@@ -113,13 +113,16 @@ struct PartitionedBlockwiseReduction
|
||||
// 3) in_out_value/in_out_index is the input data in vgpr from each thread
|
||||
// 4) in_out_value/in_out_index is the over-written reduced output in vgpr for each thread
|
||||
// clang-format on
|
||||
template <typename AccDataType,
|
||||
typename IndexDataType,
|
||||
index_t BlockSize,
|
||||
typename ThreadClusterLengths_M_K,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename OpReduce,
|
||||
bool PropagateNan>
|
||||
template <
|
||||
typename AccDataType,
|
||||
typename IndexDataType,
|
||||
index_t BlockSize,
|
||||
typename ThreadClusterLengths_M_K,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename OpReduce,
|
||||
bool PropagateNan,
|
||||
typename Accumulation =
|
||||
detail::AccumulateWithIndexAndNanCheck<PropagateNan, OpReduce, AccDataType, IndexDataType>>
|
||||
struct PartitionedBlockwiseReductionWithIndex
|
||||
{
|
||||
static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1),
|
||||
@@ -136,9 +139,6 @@ struct PartitionedBlockwiseReductionWithIndex
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using Accumulation =
|
||||
detail::AccumulateWithIndexAndNanCheck<PropagateNan, OpReduce, AccDataType, IndexDataType>;
|
||||
|
||||
// This interface accumulates on both data values and indices
|
||||
template <typename BufferType, typename IdxBufferType>
|
||||
__device__ static void Reduce(BufferType& work_val_buffer,
|
||||
|
||||
@@ -14,7 +14,7 @@ namespace ck {
|
||||
template <typename ThreadGroup,
|
||||
typename ElementwiseOperation,
|
||||
InMemoryDataOperationEnum DstInMemOp,
|
||||
typename BlockSliceLengths,
|
||||
typename SliceLengths,
|
||||
typename ThreadClusterLengths,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename SrcData,
|
||||
@@ -30,7 +30,7 @@ struct ThreadGroupTensorSliceTransfer_v6r1
|
||||
{
|
||||
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
|
||||
|
||||
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
|
||||
static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
@@ -54,7 +54,7 @@ struct ThreadGroupTensorSliceTransfer_v6r1
|
||||
"wrong! nDim not consistent");
|
||||
|
||||
static_assert(
|
||||
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
|
||||
is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
|
||||
"wrong! threads should be mapped to cover entire slicing window");
|
||||
|
||||
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R2_HPP
|
||||
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R2_HPP
|
||||
|
||||
#pragma once
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
@@ -13,10 +11,10 @@ namespace ck {
|
||||
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
|
||||
// 2. It does not keep reference to tensor descriptor
|
||||
// 3. Run() does not construct new tensor coordinate
|
||||
template <index_t BlockSize,
|
||||
template <typename ThreadGroup,
|
||||
typename ElementwiseOperation,
|
||||
InMemoryDataOperationEnum DstInMemOp,
|
||||
typename BlockSliceLengths,
|
||||
typename SliceLengths,
|
||||
typename ThreadClusterLengths,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename Src0Data,
|
||||
@@ -31,21 +29,21 @@ template <index_t BlockSize,
|
||||
bool ThreadTransferSrc0ResetCoordinateAfterRun,
|
||||
bool ThreadTransferSrc1ResetCoordinateAfterRun,
|
||||
bool ThreadTransferDstResetCoordinateAfterRun>
|
||||
struct BlockwiseTensorSliceTransfer_v6r2
|
||||
struct ThreadGroupTensorSliceTransfer_v6r2
|
||||
{
|
||||
static constexpr index_t nDim = remove_reference_t<Src0Desc>::GetNumOfDimension();
|
||||
|
||||
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
|
||||
static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
__device__ constexpr BlockwiseTensorSliceTransfer_v6r2(const Src0Desc& src0_desc,
|
||||
const Index& src0_block_slice_origin,
|
||||
const Src1Desc& src1_desc,
|
||||
const Index& src1_block_slice_origin,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_block_slice_origin,
|
||||
const ElementwiseOperation& element_op)
|
||||
__device__ constexpr ThreadGroupTensorSliceTransfer_v6r2(const Src0Desc& src0_desc,
|
||||
const Index& src0_block_slice_origin,
|
||||
const Src1Desc& src1_desc,
|
||||
const Index& src1_block_slice_origin,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_block_slice_origin,
|
||||
const ElementwiseOperation& element_op)
|
||||
: threadwise_transfer_(src0_desc,
|
||||
make_zero_multi_index<nDim>(),
|
||||
src1_desc,
|
||||
@@ -64,17 +62,17 @@ struct BlockwiseTensorSliceTransfer_v6r2
|
||||
"wrong! nDim not consistent");
|
||||
|
||||
static_assert(
|
||||
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
|
||||
is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
|
||||
"wrong! threads should be mapped to cover entire slicing window");
|
||||
|
||||
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
|
||||
"wrong! BlockSize too small");
|
||||
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
|
||||
"wrong! ThreadGroup::GetNumOfThread() too small");
|
||||
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(get_thread_local_1d_id()));
|
||||
make_multi_index(ThreadGroup::GetThreadId()));
|
||||
|
||||
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
|
||||
|
||||
@@ -95,8 +93,8 @@ struct BlockwiseTensorSliceTransfer_v6r2
|
||||
const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.Run(src0_desc, src0_buf, src1_desc, src1_buf, dst_desc, dst_buf);
|
||||
}
|
||||
@@ -104,8 +102,8 @@ struct BlockwiseTensorSliceTransfer_v6r2
|
||||
|
||||
__device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step);
|
||||
}
|
||||
@@ -113,8 +111,8 @@ struct BlockwiseTensorSliceTransfer_v6r2
|
||||
|
||||
__device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step);
|
||||
}
|
||||
@@ -122,8 +120,8 @@ struct BlockwiseTensorSliceTransfer_v6r2
|
||||
|
||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
|
||||
}
|
||||
@@ -154,4 +152,3 @@ struct BlockwiseTensorSliceTransfer_v6r2
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,6 +1,4 @@
|
||||
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R3_HPP
|
||||
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R3_HPP
|
||||
|
||||
#pragma once
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
@@ -13,10 +11,10 @@ namespace ck {
|
||||
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
|
||||
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
|
||||
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
|
||||
template <index_t BlockSize,
|
||||
template <typename ThreadGroup,
|
||||
typename ElementwiseOperation,
|
||||
InMemoryDataOperationEnum DstInMemOp,
|
||||
typename BlockSliceLengths,
|
||||
typename SliceLengths,
|
||||
typename ThreadClusterLengths,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename Src0Data,
|
||||
@@ -34,23 +32,23 @@ template <index_t BlockSize,
|
||||
bool ThreadTransferSrc1ResetCoordinateAfterRun,
|
||||
bool ThreadTransferSrc2ResetCoordinateAfterRun,
|
||||
bool ThreadTransferDstResetCoordinateAfterRun>
|
||||
struct BlockwiseTensorSliceTransfer_v6r3
|
||||
struct ThreadGroupTensorSliceTransfer_v6r3
|
||||
{
|
||||
static constexpr index_t nDim = remove_reference_t<Src0Desc>::GetNumOfDimension();
|
||||
|
||||
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
|
||||
static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
__device__ constexpr BlockwiseTensorSliceTransfer_v6r3(const Src0Desc& src0_desc,
|
||||
const Index& src0_block_slice_origin,
|
||||
const Src1Desc& src1_desc,
|
||||
const Index& src1_block_slice_origin,
|
||||
const Src2Desc& src2_desc,
|
||||
const Index& src2_block_slice_origin,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_block_slice_origin,
|
||||
const ElementwiseOperation& element_op)
|
||||
__device__ constexpr ThreadGroupTensorSliceTransfer_v6r3(const Src0Desc& src0_desc,
|
||||
const Index& src0_block_slice_origin,
|
||||
const Src1Desc& src1_desc,
|
||||
const Index& src1_block_slice_origin,
|
||||
const Src2Desc& src2_desc,
|
||||
const Index& src2_block_slice_origin,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_block_slice_origin,
|
||||
const ElementwiseOperation& element_op)
|
||||
: threadwise_transfer_(src0_desc,
|
||||
make_zero_multi_index<nDim>(),
|
||||
src1_desc,
|
||||
@@ -72,14 +70,14 @@ struct BlockwiseTensorSliceTransfer_v6r3
|
||||
"wrong! nDim not consistent");
|
||||
|
||||
static_assert(
|
||||
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
|
||||
is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
|
||||
"wrong! threads should be mapped to cover entire slicing window");
|
||||
|
||||
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
|
||||
"wrong! BlockSize too small");
|
||||
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
|
||||
"wrong! ThreadGroup::GetNumOfThread() too small");
|
||||
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(get_thread_local_1d_id()));
|
||||
@@ -107,8 +105,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
|
||||
const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.Run(
|
||||
src0_desc, src0_buf, src1_desc, src1_buf, src2_desc, src2_buf, dst_desc, dst_buf);
|
||||
@@ -117,8 +115,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
|
||||
|
||||
__device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step);
|
||||
}
|
||||
@@ -126,8 +124,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
|
||||
|
||||
__device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step);
|
||||
}
|
||||
@@ -135,8 +133,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
|
||||
|
||||
__device__ void MoveSrc2SliceWindow(const Src2Desc& src2_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrc2SliceWindow(src2_desc, step);
|
||||
}
|
||||
@@ -144,8 +142,8 @@ struct BlockwiseTensorSliceTransfer_v6r3
|
||||
|
||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
|
||||
}
|
||||
@@ -179,4 +177,3 @@ struct BlockwiseTensorSliceTransfer_v6r3
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,169 @@
|
||||
#pragma once
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "cluster_descriptor.hpp"
|
||||
#include "threadwise_tensor_slice_transfer_v7.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Thread-group level multi-source, multi-destination tensor slice data movement
|
||||
// Assume:
|
||||
// 1. All sources and destinations are DynamicBuffer
|
||||
// 2. Same VectorDim and ScalerPerVector for all sources and destinations
|
||||
// 3. DstInMemOps are per destination tensor
|
||||
// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor
|
||||
// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor
|
||||
//
|
||||
// Does following things to avoid scratch memory issue
|
||||
// 1. Pass tensor descritpors by reference (or tuple of references)
|
||||
// 2. Does not keep reference to tensor descriptor
|
||||
// 3. Does not construct new tensor coordinate when call Run()
|
||||
template <typename ThreadGroup,
|
||||
typename SrcDatas,
|
||||
typename DstDatas,
|
||||
typename SrcDescs,
|
||||
typename DstDescs,
|
||||
typename ElementwiseOperation,
|
||||
typename DstInMemOps, // Sequence<InMemoryDataOperationEnum ...>
|
||||
typename SliceLengths,
|
||||
typename ThreadClusterLengths,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename DimAccessOrder,
|
||||
index_t VectorDim,
|
||||
index_t ScalarPerVector,
|
||||
typename ThreadTransferSrcResetCoordinateAfterRunFlags,
|
||||
typename ThreadTransferDstResetCoordinateAfterRunFlags>
|
||||
struct ThreadGroupTensorSliceTransfer_v7
|
||||
{
|
||||
static constexpr index_t nDim =
|
||||
remove_cvref_t<tuple_element_t<0, SrcDescs>>::GetNumOfDimension();
|
||||
|
||||
static constexpr index_t nSrc = remove_cvref_t<SrcDescs>::Size();
|
||||
static constexpr index_t nDst = remove_cvref_t<DstDescs>::Size();
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
|
||||
|
||||
__device__ constexpr ThreadGroupTensorSliceTransfer_v7(
|
||||
const SrcDescs& src_descs,
|
||||
const StaticallyIndexedArray<Index, nSrc>& src_block_slice_origins,
|
||||
const DstDescs& dst_descs,
|
||||
const StaticallyIndexedArray<Index, nDst>& dst_block_slice_origins,
|
||||
const ElementwiseOperation& element_op)
|
||||
: threadwise_transfer_(src_descs,
|
||||
StaticallyIndexedArray<Index, nSrc>{},
|
||||
dst_descs,
|
||||
StaticallyIndexedArray<Index, nDst>{},
|
||||
element_op)
|
||||
{
|
||||
static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() &&
|
||||
nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() &&
|
||||
nDst == DstDatas::Size() && nDst == DstDescs::Size() &&
|
||||
nDst == ThreadTransferDstResetCoordinateAfterRunFlags::Size(),
|
||||
"wrong!");
|
||||
|
||||
static_for<0, nSrc, 1>{}([&](auto i) {
|
||||
static_assert(
|
||||
nDim == remove_cvref_t<tuple_element_t<i.value, SrcDescs>>::GetNumOfDimension(),
|
||||
"wrong!");
|
||||
});
|
||||
|
||||
static_for<0, nDst, 1>{}([&](auto i) {
|
||||
static_assert(
|
||||
nDim == remove_cvref_t<tuple_element_t<i.value, DstDescs>>::GetNumOfDimension(),
|
||||
"wrong!");
|
||||
});
|
||||
|
||||
static_assert(nDim == ThreadClusterLengths::Size() &&
|
||||
nDim == ThreadClusterArrangeOrder::Size() &&
|
||||
nDim == DimAccessOrder::Size(),
|
||||
"wrong! nDim not consistent");
|
||||
|
||||
static_assert(
|
||||
is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
|
||||
"wrong! threads should be mapped to cover entire slicing window");
|
||||
|
||||
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
|
||||
"wrong! ThreadGroup::GetNumOfThread() too small");
|
||||
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(get_thread_local_1d_id()));
|
||||
|
||||
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
|
||||
|
||||
const auto src_thread_slice_origins = generate_tuple(
|
||||
[&](auto i) { return src_block_slice_origins[i] + thread_data_idx_begin; },
|
||||
Number<nSrc>{});
|
||||
|
||||
const auto dst_thread_slice_origins = generate_tuple(
|
||||
[&](auto i) { return dst_block_slice_origins[i] + thread_data_idx_begin; },
|
||||
Number<nDst>{});
|
||||
|
||||
threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins);
|
||||
threadwise_transfer_.SetDstSliceOrigins(dst_descs, dst_thread_slice_origins);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffers, typename DstBuffers>
|
||||
__device__ void Run(const SrcDescs& src_descs,
|
||||
const SrcBuffers& src_bufs,
|
||||
const DstDescs& dst_descs,
|
||||
DstBuffers dst_bufs)
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.Run(src_descs, src_bufs, dst_descs, dst_bufs);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t ISrc>
|
||||
__device__ void
|
||||
MoveSrcSliceWindow(const SrcDescs& src_descs, Number<ISrc> iSrc, const Index& step)
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrcSliceWindow(src_descs, iSrc, step);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t IDst>
|
||||
__device__ void
|
||||
MoveDstSliceWindow(const DstDescs& dst_descs, Number<IDst> iDst, const Index& step)
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveDstSliceWindow(dst_descs, iDst, step);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr auto thread_cluster_desc_ =
|
||||
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadwiseTransfer =
|
||||
ThreadwiseTensorSliceTransfer_v7<SrcDatas,
|
||||
DstDatas,
|
||||
SrcDescs,
|
||||
DstDescs,
|
||||
ElementwiseOperation,
|
||||
DstInMemOps,
|
||||
decltype(thread_slice_lengths),
|
||||
DimAccessOrder,
|
||||
VectorDim,
|
||||
ScalarPerVector,
|
||||
ThreadTransferSrcResetCoordinateAfterRunFlags,
|
||||
ThreadTransferDstResetCoordinateAfterRunFlags>;
|
||||
|
||||
ThreadwiseTransfer threadwise_transfer_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,17 @@
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
enum struct ConvolutionBackwardWeightSpecialization
|
||||
{
|
||||
Default,
|
||||
Filter1x1Stride1Pad0,
|
||||
Filter1x1Pad0,
|
||||
OddC,
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,332 @@
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "device.hpp"
|
||||
#include "device_base.hpp"
|
||||
#include "gridwise_5ary_Elementwise_1d.hpp"
|
||||
#include "tensor_layout.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename DDataType,
|
||||
typename EDataType,
|
||||
typename FDataType,
|
||||
typename ComputeDataType,
|
||||
typename ElementwiseFunctor,
|
||||
index_t NDim,
|
||||
index_t MPerThread,
|
||||
index_t AScalarPerVector,
|
||||
index_t BScalarPerVector,
|
||||
index_t CScalarPerVector,
|
||||
index_t DScalarPerVector,
|
||||
index_t EScalarPerVector,
|
||||
index_t FScalarPerVector>
|
||||
struct Device5AryElementwise : public BaseOperator
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
template <typename Desc_M>
|
||||
static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize)
|
||||
{
|
||||
const auto m = desc_m.GetLength(I0);
|
||||
const index_t loop_step = gridSize * blockSize * MPerThread;
|
||||
const auto pad = math::integer_least_multiple(m, loop_step) - m;
|
||||
const auto desc_m_pad =
|
||||
transform_tensor_descriptor(desc_m,
|
||||
make_tuple(make_right_pad_transform(m, pad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return desc_m_pad;
|
||||
}
|
||||
|
||||
static auto MakeDescriptor_M(const std::vector<index_t>& lengths,
|
||||
const std::vector<index_t>& stride,
|
||||
index_t gridSize,
|
||||
index_t blockSize)
|
||||
{
|
||||
auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<NDim>{});
|
||||
auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<NDim>{});
|
||||
|
||||
// nd desc - [s0, s1, s2, ...]
|
||||
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
|
||||
|
||||
// merge nd to 1d desc - [s0 * s1 * ...]
|
||||
if constexpr(NDim > 1)
|
||||
{
|
||||
const auto desc_m = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_merge_transform(tupleOfShape)),
|
||||
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<NDim>{})),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return PadDescriptor_M_1d(desc_m, gridSize, blockSize);
|
||||
}
|
||||
else
|
||||
return PadDescriptor_M_1d(desc, gridSize, blockSize);
|
||||
}
|
||||
|
||||
using AGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
|
||||
using BGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
|
||||
using CGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
|
||||
using DGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
|
||||
using EGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
|
||||
using FGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
|
||||
|
||||
using Gridwise5AryEltwise = Gridwise5AryElementwise_1D<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
DDataType,
|
||||
EDataType,
|
||||
FDataType,
|
||||
ComputeDataType,
|
||||
AGridDesc_M,
|
||||
BGridDesc_M,
|
||||
CGridDesc_M,
|
||||
DGridDesc_M,
|
||||
EGridDesc_M,
|
||||
FGridDesc_M,
|
||||
ElementwiseFunctor,
|
||||
MPerThread,
|
||||
AScalarPerVector,
|
||||
BScalarPerVector,
|
||||
CScalarPerVector,
|
||||
DScalarPerVector,
|
||||
EScalarPerVector,
|
||||
FScalarPerVector>;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
const CDataType* p_c,
|
||||
const DDataType* p_d,
|
||||
const EDataType* p_e,
|
||||
FDataType* p_f,
|
||||
const std::vector<index_t>& lengths,
|
||||
const std::vector<index_t>& a_strides,
|
||||
const std::vector<index_t>& b_strides,
|
||||
const std::vector<index_t>& c_strides,
|
||||
const std::vector<index_t>& d_strides,
|
||||
const std::vector<index_t>& e_strides,
|
||||
const std::vector<index_t>& f_strides,
|
||||
ElementwiseFunctor functor)
|
||||
: p_a_(p_a),
|
||||
p_b_(p_b),
|
||||
p_c_(p_c),
|
||||
p_d_(p_d),
|
||||
p_e_(p_e),
|
||||
p_f_(p_f),
|
||||
lengths_(lengths),
|
||||
a_strides_(a_strides),
|
||||
b_strides_(b_strides),
|
||||
c_strides_(c_strides),
|
||||
d_strides_(d_strides),
|
||||
e_strides_(e_strides),
|
||||
f_strides_(f_strides),
|
||||
functor_(functor),
|
||||
blockSize_(256),
|
||||
gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future
|
||||
{
|
||||
a_grid_desc_m_ = MakeDescriptor_M(lengths, a_strides, gridSize_, blockSize_);
|
||||
b_grid_desc_m_ = MakeDescriptor_M(lengths, b_strides, gridSize_, blockSize_);
|
||||
c_grid_desc_m_ = MakeDescriptor_M(lengths, c_strides, gridSize_, blockSize_);
|
||||
d_grid_desc_m_ = MakeDescriptor_M(lengths, d_strides, gridSize_, blockSize_);
|
||||
e_grid_desc_m_ = MakeDescriptor_M(lengths, e_strides, gridSize_, blockSize_);
|
||||
f_grid_desc_m_ = MakeDescriptor_M(lengths, f_strides, gridSize_, blockSize_);
|
||||
}
|
||||
|
||||
const ADataType* p_a_;
|
||||
const BDataType* p_b_;
|
||||
const CDataType* p_c_;
|
||||
const DDataType* p_d_;
|
||||
const EDataType* p_e_;
|
||||
FDataType* p_f_;
|
||||
std::vector<index_t> lengths_;
|
||||
AGridDesc_M a_grid_desc_m_;
|
||||
BGridDesc_M b_grid_desc_m_;
|
||||
CGridDesc_M c_grid_desc_m_;
|
||||
DGridDesc_M d_grid_desc_m_;
|
||||
EGridDesc_M e_grid_desc_m_;
|
||||
FGridDesc_M f_grid_desc_m_;
|
||||
std::vector<index_t> a_strides_;
|
||||
std::vector<index_t> b_strides_;
|
||||
std::vector<index_t> c_strides_;
|
||||
std::vector<index_t> d_strides_;
|
||||
std::vector<index_t> e_strides_;
|
||||
std::vector<index_t> f_strides_;
|
||||
ElementwiseFunctor functor_;
|
||||
index_t blockSize_;
|
||||
index_t gridSize_;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const auto kernel = kernel_5ary_elementwise_1d<Gridwise5AryEltwise,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
DDataType,
|
||||
EDataType,
|
||||
FDataType,
|
||||
AGridDesc_M,
|
||||
BGridDesc_M,
|
||||
CGridDesc_M,
|
||||
DGridDesc_M,
|
||||
EGridDesc_M,
|
||||
FGridDesc_M,
|
||||
ElementwiseFunctor>;
|
||||
|
||||
float elapsed_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(arg.gridSize_),
|
||||
dim3(arg.blockSize_),
|
||||
0,
|
||||
arg.p_a_,
|
||||
arg.p_b_,
|
||||
arg.p_c_,
|
||||
arg.p_d_,
|
||||
arg.p_e_,
|
||||
arg.p_f_,
|
||||
arg.a_grid_desc_m_,
|
||||
arg.b_grid_desc_m_,
|
||||
arg.c_grid_desc_m_,
|
||||
arg.d_grid_desc_m_,
|
||||
arg.e_grid_desc_m_,
|
||||
arg.f_grid_desc_m_,
|
||||
arg.functor_);
|
||||
return elapsed_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument& p_arg) { return IsSupportedArgument(&p_arg); }
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
if(pArg == nullptr)
|
||||
return false;
|
||||
|
||||
if(pArg->lengths_.size() != NDim)
|
||||
return false;
|
||||
|
||||
if(pArg->lengths_.back() % MPerThread != 0)
|
||||
return false;
|
||||
|
||||
auto IsScalarPerVectorValid = [](bool isLastDimensionCoalesced, int scalarPerVector) {
|
||||
bool ret = true;
|
||||
|
||||
if(!isLastDimensionCoalesced)
|
||||
ret = scalarPerVector == 1;
|
||||
else
|
||||
ret = MPerThread % scalarPerVector == 0;
|
||||
|
||||
return ret;
|
||||
};
|
||||
|
||||
if(!IsScalarPerVectorValid(pArg->a_strides_.back() == 1, AScalarPerVector))
|
||||
return false;
|
||||
|
||||
if(!IsScalarPerVectorValid(pArg->b_strides_.back() == 1, BScalarPerVector))
|
||||
return false;
|
||||
|
||||
if(!IsScalarPerVectorValid(pArg->c_strides_.back() == 1, CScalarPerVector))
|
||||
return false;
|
||||
|
||||
if(!IsScalarPerVectorValid(pArg->d_strides_.back() == 1, DScalarPerVector))
|
||||
return false;
|
||||
|
||||
if(!IsScalarPerVectorValid(pArg->e_strides_.back() == 1, EScalarPerVector))
|
||||
return false;
|
||||
|
||||
if(!IsScalarPerVectorValid(pArg->f_strides_.back() == 1, FScalarPerVector))
|
||||
return false;
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
const CDataType* p_c,
|
||||
const DDataType* p_d,
|
||||
const EDataType* p_e,
|
||||
FDataType* p_f,
|
||||
std::vector<index_t> lengths,
|
||||
std::vector<index_t> a_strides,
|
||||
std::vector<index_t> b_strides,
|
||||
std::vector<index_t> c_strides,
|
||||
std::vector<index_t> d_strides,
|
||||
std::vector<index_t> e_strides,
|
||||
std::vector<index_t> f_strides,
|
||||
ElementwiseFunctor functor)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_c,
|
||||
p_d,
|
||||
p_e,
|
||||
p_f,
|
||||
lengths,
|
||||
a_strides,
|
||||
b_strides,
|
||||
c_strides,
|
||||
d_strides,
|
||||
e_strides,
|
||||
f_strides,
|
||||
functor};
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
const void* p_c,
|
||||
const void* p_d,
|
||||
const void* p_e,
|
||||
void* p_f,
|
||||
std::vector<index_t> lengths,
|
||||
std::vector<index_t> a_strides,
|
||||
std::vector<index_t> b_strides,
|
||||
std::vector<index_t> c_strides,
|
||||
std::vector<index_t> d_strides,
|
||||
std::vector<index_t> e_strides,
|
||||
std::vector<index_t> f_strides,
|
||||
ElementwiseFunctor functor)
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<const CDataType*>(p_c),
|
||||
static_cast<const DDataType*>(p_d),
|
||||
static_cast<const EDataType*>(p_e),
|
||||
static_cast<FDataType*>(p_f),
|
||||
lengths,
|
||||
a_strides,
|
||||
b_strides,
|
||||
c_strides,
|
||||
d_strides,
|
||||
e_strides,
|
||||
f_strides,
|
||||
functor);
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); }
|
||||
}; // namespace device
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,8 +1,9 @@
|
||||
#ifndef DEVICE_BASE_HPP
|
||||
#define DEVICE_BASE_HPP
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "stream_config.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
@@ -14,6 +15,8 @@ struct BaseArgument
|
||||
BaseArgument& operator=(const BaseArgument&) = default;
|
||||
|
||||
virtual ~BaseArgument() {}
|
||||
|
||||
void* p_workspace_ = nullptr;
|
||||
};
|
||||
|
||||
struct BaseInvoker
|
||||
@@ -22,7 +25,10 @@ struct BaseInvoker
|
||||
BaseInvoker(const BaseInvoker&) = default;
|
||||
BaseInvoker& operator=(const BaseInvoker&) = default;
|
||||
|
||||
virtual float Run(const BaseArgument*, int = 1) = 0;
|
||||
virtual float Run(const BaseArgument*, const StreamConfig& = StreamConfig{})
|
||||
{
|
||||
return float{0};
|
||||
}
|
||||
|
||||
virtual ~BaseInvoker() {}
|
||||
};
|
||||
@@ -33,8 +39,16 @@ struct BaseOperator
|
||||
BaseOperator(const BaseOperator&) = default;
|
||||
BaseOperator& operator=(const BaseOperator&) = default;
|
||||
|
||||
virtual bool IsSupportedArgument(const BaseArgument*) = 0;
|
||||
virtual std::string GetTypeString() const = 0;
|
||||
virtual bool IsSupportedArgument(const BaseArgument*) { return false; }
|
||||
virtual std::string GetTypeString() const { return ""; }
|
||||
|
||||
virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; }
|
||||
|
||||
virtual void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const
|
||||
{
|
||||
assert(p_arg);
|
||||
p_arg->p_workspace_ = p_workspace;
|
||||
}
|
||||
|
||||
virtual ~BaseOperator() {}
|
||||
};
|
||||
@@ -42,4 +56,3 @@ struct BaseOperator
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -17,11 +17,12 @@ namespace device {
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename FloatD,
|
||||
typename DPtrsGlobal,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename D1ElementwiseOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsReduceAccElementwiseOperation,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
@@ -37,13 +38,13 @@ __global__ void
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
FloatD* __restrict__ p_d0_grid,
|
||||
FloatD* __restrict__ p_d1_grid,
|
||||
DPtrsGlobal p_ds_grid,
|
||||
const index_t batch_count,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const D1ElementwiseOperation d1_element_op,
|
||||
const DxsInElementwiseOperation dxs_in_element_op,
|
||||
const DxsReduceAccElementwiseOperation dxs_out_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
@@ -64,23 +65,24 @@ __global__ void
|
||||
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetCBasePtr(g_idx)));
|
||||
|
||||
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetD0BasePtr(g_idx)));
|
||||
const long_index_t d1_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetD1BasePtr(g_idx)));
|
||||
static_for<0, p_ds_grid.Size(), 1>{}([&](auto In) {
|
||||
const long_index_t d_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_base_ptr_of_batch_.GetDBasePtr(g_idx, In)));
|
||||
p_ds_grid(In) = p_ds_grid(In) + d_batch_offset;
|
||||
});
|
||||
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainK0BlockLoop>(p_a_grid + a_batch_offset,
|
||||
p_b_grid + b_batch_offset,
|
||||
p_c_grid + c_batch_offset,
|
||||
p_d0_grid + d0_batch_offset,
|
||||
p_d1_grid + d1_batch_offset,
|
||||
p_ds_grid,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
d1_element_op,
|
||||
dxs_in_element_op,
|
||||
dxs_out_element_op,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
@@ -90,13 +92,13 @@ __global__ void
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
ignore = p_c_grid;
|
||||
ignore = p_d0_grid;
|
||||
ignore = p_d1_grid;
|
||||
ignore = p_ds_grid;
|
||||
ignore = batch_count;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = c_element_op;
|
||||
ignore = d1_element_op;
|
||||
ignore = dxs_in_element_op;
|
||||
ignore = dxs_out_element_op;
|
||||
ignore = a_grid_desc_ak0_m_ak1;
|
||||
ignore = b_grid_desc_bk0_n_bk1;
|
||||
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
@@ -106,6 +108,9 @@ __global__ void
|
||||
#endif // end of if defined (defined(__gfx908__) || defined(__gfx90a__))
|
||||
}
|
||||
|
||||
// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
|
||||
// version currently has compiler issues with register spill which further causes validation
|
||||
// failures.
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
@@ -115,13 +120,14 @@ template <typename ALayout,
|
||||
typename GemmAccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename ReduceAccDataType,
|
||||
typename DDataType,
|
||||
typename DPtrsGlobal,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename D0ReduceOperation,
|
||||
typename D1ReduceOperation,
|
||||
typename D1ElementwiseOperation,
|
||||
typename DxsReduceOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsReduceAccElementwiseOperation,
|
||||
typename DGlobalMemoryDataOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
index_t BlockSize,
|
||||
@@ -154,11 +160,14 @@ template <typename ALayout,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
|
||||
index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
|
||||
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock>
|
||||
struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
D1ElementwiseOperation>
|
||||
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
struct DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
: public DeviceGemmReduce<AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsReduceAccElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle;
|
||||
|
||||
@@ -461,56 +470,16 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
|
||||
using DGridDesc_M = decltype(MakeDGridDescriptor_M(1));
|
||||
|
||||
static constexpr auto MakeBlock2CTileMap(index_t batch_count,
|
||||
const CGridDesc_M_N& c_grid_desc_m_n,
|
||||
index_t M01,
|
||||
index_t N01)
|
||||
{
|
||||
const auto M = c_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = c_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
constexpr auto M1 = Number<MPerBlock>{};
|
||||
constexpr auto N1 = Number<NPerBlock>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
const auto M00 = M0 / M01;
|
||||
const auto N00 = N0 / N01;
|
||||
|
||||
const auto g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_insert_transform(batch_count),
|
||||
make_unmerge_transform(make_tuple(M00, M01)),
|
||||
make_unmerge_transform(make_tuple(N00, N01))),
|
||||
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
|
||||
|
||||
const auto globalblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(batch_count, M00, N00, M01, N01))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto globalblockid_to_m0_n0_block_cluster_adaptor =
|
||||
chain_tensor_adaptors(g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
|
||||
globalblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
|
||||
|
||||
return globalblockid_to_m0_n0_block_cluster_adaptor;
|
||||
}
|
||||
|
||||
struct ComputeBasePtrOfStridedBatch
|
||||
{
|
||||
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
index_t BatchStrideC,
|
||||
index_t BatchStrideD0,
|
||||
index_t BatchStrideD1)
|
||||
index_t BatchStrideD)
|
||||
: BatchStrideA_(BatchStrideA),
|
||||
BatchStrideB_(BatchStrideB),
|
||||
BatchStrideC_(BatchStrideC),
|
||||
BatchStrideD0_(BatchStrideD0),
|
||||
BatchStrideD1_(BatchStrideD1)
|
||||
BatchStrideD_(BatchStrideD)
|
||||
{
|
||||
}
|
||||
|
||||
@@ -529,22 +498,20 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideC_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx) const
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr long_index_t GetDBasePtr(index_t g_idx,
|
||||
Number<I> reduction_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideD0_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetD1BasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideD1_);
|
||||
// TODO - Support sequence of StrideD in MakeArgument()
|
||||
(void)reduction_idx;
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideD_);
|
||||
}
|
||||
|
||||
private:
|
||||
index_t BatchStrideA_;
|
||||
index_t BatchStrideB_;
|
||||
index_t BatchStrideC_;
|
||||
index_t BatchStrideD0_;
|
||||
index_t BatchStrideD1_;
|
||||
index_t BatchStrideD_;
|
||||
};
|
||||
|
||||
// GridwiseGemm
|
||||
@@ -554,15 +521,15 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
ReduceAccDataType,
|
||||
DDataType,
|
||||
DPtrsGlobal,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
D0ReduceOperation,
|
||||
D1ReduceOperation,
|
||||
D1ElementwiseOperation,
|
||||
DxsReduceOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsReduceAccElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
DGlobalMemoryDataOperation,
|
||||
AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
CGridDesc_M_N,
|
||||
@@ -600,9 +567,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
CReduceThreadClusterLengths_MPerBlock_NPerBlock,
|
||||
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
|
||||
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock>;
|
||||
|
||||
using Block2CTileMap = decltype(MakeBlock2CTileMap(1, CGridDesc_M_N{}, 1, 1));
|
||||
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
|
||||
LoopSched>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
@@ -610,8 +576,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
Argument(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
DDataType* p_d0_grid,
|
||||
DDataType* p_d1_grid,
|
||||
DPtrsGlobal p_ds_grid,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
@@ -621,13 +586,13 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
D1ElementwiseOperation d1_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op,
|
||||
index_t BatchCount)
|
||||
: p_a_grid_{p_a_grid},
|
||||
p_b_grid_{p_b_grid},
|
||||
p_c_grid_{p_c_grid},
|
||||
p_d0_grid_{p_d0_grid},
|
||||
p_d1_grid_{p_d1_grid},
|
||||
p_ds_grid_{p_ds_grid},
|
||||
BatchCount_(BatchCount),
|
||||
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
|
||||
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
|
||||
@@ -635,19 +600,22 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(MRaw)},
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
d_grid_desc_mblock_mperblock_{},
|
||||
compute_base_ptr_of_batch_{a_grid_desc_ak0_m_ak1_.GetElementSpaceSize(),
|
||||
b_grid_desc_bk0_n_bk1_.GetElementSpaceSize(),
|
||||
c_grid_desc_m_n_.GetElementSpaceSize(),
|
||||
d_grid_desc_m_.GetElementSpaceSize(),
|
||||
d_grid_desc_m_.GetElementSpaceSize()},
|
||||
block_2_ctile_map_{},
|
||||
compute_base_ptr_of_batch_{
|
||||
type_convert<index_t>(a_grid_desc_ak0_m_ak1_.GetElementSpaceSize()),
|
||||
type_convert<index_t>(b_grid_desc_bk0_n_bk1_.GetElementSpaceSize()),
|
||||
type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize()),
|
||||
type_convert<index_t>(d_grid_desc_m_.GetElementSpaceSize())},
|
||||
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op},
|
||||
d1_element_op_{d1_element_op}
|
||||
dxs_in_element_op_{dxs_in_element_op},
|
||||
dxs_out_element_op_{dxs_out_element_op}
|
||||
{
|
||||
if(GridwiseGemm::CheckValidity(
|
||||
a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_))
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
|
||||
b_grid_desc_bk0_n_bk1_,
|
||||
c_grid_desc_m_n_,
|
||||
block_2_ctile_map_))
|
||||
{
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
@@ -655,8 +623,6 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
|
||||
d_grid_desc_mblock_mperblock_ =
|
||||
GridwiseGemm::MakeDGridDescriptor_MBlock_MPerBlock(d_grid_desc_m_);
|
||||
|
||||
block_2_ctile_map_ = MakeBlock2CTileMap(BatchCount, c_grid_desc_m_n_, 1, 1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -664,8 +630,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
DDataType* p_d0_grid_;
|
||||
DDataType* p_d1_grid_;
|
||||
DPtrsGlobal p_ds_grid_;
|
||||
index_t BatchCount_;
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
@@ -675,11 +640,12 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock_;
|
||||
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
|
||||
Block2CTileMap block_2_ctile_map_;
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
D1ElementwiseOperation d1_element_op_;
|
||||
DxsInElementwiseOperation dxs_in_element_op_;
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -687,7 +653,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, int /* nrepeat */ = 1)
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
#if 0
|
||||
{
|
||||
@@ -711,57 +677,63 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
}
|
||||
#endif
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(
|
||||
arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_))
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
}
|
||||
|
||||
const index_t grid_size =
|
||||
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_;
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_;
|
||||
|
||||
const auto K =
|
||||
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
|
||||
|
||||
float elapsed_time = 0.0f;
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
const auto kernel = kernel_batched_gemm_reduce_xdl_cshuffle_v1<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
DDataType,
|
||||
DPtrsGlobal,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
D1ElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsReduceAccElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock,
|
||||
ComputeBasePtrOfStridedBatch,
|
||||
remove_reference_t<Block2CTileMap>,
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap,
|
||||
true>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_d0_grid_,
|
||||
arg.p_d1_grid_,
|
||||
arg.BatchCount_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.d1_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.d_grid_desc_mblock_mperblock_,
|
||||
arg.compute_base_ptr_of_batch_,
|
||||
arg.block_2_ctile_map_);
|
||||
elapsed_time =
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.BatchCount_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.dxs_in_element_op_,
|
||||
arg.dxs_out_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.d_grid_desc_mblock_mperblock_,
|
||||
arg.compute_base_ptr_of_batch_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -769,48 +741,52 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
DDataType,
|
||||
DPtrsGlobal,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
D1ElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsReduceAccElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock,
|
||||
ComputeBasePtrOfStridedBatch,
|
||||
remove_reference_t<Block2CTileMap>,
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap,
|
||||
false>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_d0_grid_,
|
||||
arg.p_d1_grid_,
|
||||
arg.BatchCount_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.d1_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.d_grid_desc_mblock_mperblock_,
|
||||
arg.compute_base_ptr_of_batch_,
|
||||
arg.block_2_ctile_map_);
|
||||
elapsed_time =
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.BatchCount_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.dxs_in_element_op_,
|
||||
arg.dxs_out_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.d_grid_desc_mblock_mperblock_,
|
||||
arg.compute_base_ptr_of_batch_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
return 0;
|
||||
return elapsed_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -822,8 +798,10 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
return GridwiseGemm::CheckValidity(
|
||||
arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_);
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
@@ -843,8 +821,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
CDataType* p_c,
|
||||
DDataType* p_d0,
|
||||
DDataType* p_d1,
|
||||
DPtrsGlobal p_dxs,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
@@ -854,14 +831,14 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
D1ElementwiseOperation d1_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op,
|
||||
index_t BatchCount)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_c,
|
||||
p_d0,
|
||||
p_d1,
|
||||
p_dxs,
|
||||
MRaw,
|
||||
NRaw,
|
||||
KRaw,
|
||||
@@ -871,35 +848,37 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
d1_element_op,
|
||||
dxs_in_element_op,
|
||||
dxs_out_element_op,
|
||||
BatchCount};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
void* p_d0,
|
||||
void* p_d1,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
D1ElementwiseOperation d1_element_op,
|
||||
index_t BatchCount) override
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
void* p_dxs,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op,
|
||||
index_t BatchCount) override
|
||||
{
|
||||
DPtrsGlobal dxs_tuple = *(static_cast<DPtrsGlobal*>(p_dxs));
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<CDataType*>(p_c),
|
||||
static_cast<DDataType*>(p_d0),
|
||||
static_cast<DDataType*>(p_d1),
|
||||
dxs_tuple,
|
||||
MRaw,
|
||||
NRaw,
|
||||
KRaw,
|
||||
@@ -909,7 +888,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
d1_element_op,
|
||||
dxs_in_element_op,
|
||||
dxs_out_element_op,
|
||||
BatchCount);
|
||||
}
|
||||
|
||||
|
||||
@@ -243,44 +243,6 @@ struct DeviceBatchedGemmXdl
|
||||
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
|
||||
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
|
||||
|
||||
static constexpr auto MakeBlock2CTileMap(index_t batch_count,
|
||||
const CGridDesc_M_N& c_grid_desc_m_n,
|
||||
index_t M01,
|
||||
index_t N01)
|
||||
{
|
||||
const auto M = c_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = c_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
constexpr auto M1 = Number<MPerBlock>{};
|
||||
constexpr auto N1 = Number<NPerBlock>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
const auto M00 = M0 / M01;
|
||||
const auto N00 = N0 / N01;
|
||||
|
||||
const auto g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_insert_transform(batch_count),
|
||||
make_unmerge_transform(make_tuple(M00, M01)),
|
||||
make_unmerge_transform(make_tuple(N00, N01))),
|
||||
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}));
|
||||
|
||||
const auto globalblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(batch_count, M00, N00, M01, N01))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto globalblockid_to_m0_n0_block_cluster_adaptor =
|
||||
chain_tensor_adaptors(g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
|
||||
globalblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
|
||||
|
||||
return globalblockid_to_m0_n0_block_cluster_adaptor;
|
||||
}
|
||||
|
||||
struct ComputePtrOffsetOfStridedBatch
|
||||
{
|
||||
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
|
||||
@@ -354,7 +316,7 @@ struct DeviceBatchedGemmXdl
|
||||
|
||||
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
|
||||
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
|
||||
using Block2CTileMap = decltype(MakeBlock2CTileMap(1, CGridDesc_M_N{}, 1, 1));
|
||||
using Block2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
@@ -384,23 +346,25 @@ struct DeviceBatchedGemmXdl
|
||||
DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)},
|
||||
c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)},
|
||||
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
|
||||
compute_ptr_offset_of_batch_{a_grid_desc_k0_m_k1_.GetElementSpaceSize(),
|
||||
b_grid_desc_k0_n_k1_.GetElementSpaceSize(),
|
||||
c_grid_desc_m_n_.GetElementSpaceSize()},
|
||||
block_2_ctile_map_{},
|
||||
compute_ptr_offset_of_batch_{
|
||||
type_convert<index_t>(a_grid_desc_k0_m_k1_.GetElementSpaceSize()),
|
||||
type_convert<index_t>(b_grid_desc_k0_n_k1_.GetElementSpaceSize()),
|
||||
type_convert<index_t>(c_grid_desc_m_n_.GetElementSpaceSize())},
|
||||
block_2_ctile_map_{
|
||||
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)},
|
||||
M01_{M01},
|
||||
N01_{N01},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op}
|
||||
{
|
||||
if(GridwiseGemm::CheckValidity(
|
||||
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
|
||||
b_grid_desc_k0_n_k1_,
|
||||
c_grid_desc_m_n_,
|
||||
block_2_ctile_map_))
|
||||
{
|
||||
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
|
||||
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
|
||||
|
||||
block_2_ctile_map_ = MakeBlock2CTileMap(BatchCount, c_grid_desc_m_n_, M01, N01);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -427,7 +391,7 @@ struct DeviceBatchedGemmXdl
|
||||
{
|
||||
using Argument = DeviceBatchedGemmXdl::Argument;
|
||||
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
@@ -445,15 +409,14 @@ struct DeviceBatchedGemmXdl
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.M01_,
|
||||
arg.N01_))
|
||||
arg.block_2_ctile_map_))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseBatchedGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
|
||||
}
|
||||
|
||||
const index_t grid_size =
|
||||
GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_;
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_;
|
||||
|
||||
const auto K =
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
@@ -476,8 +439,8 @@ struct DeviceBatchedGemmXdl
|
||||
remove_reference_t<Block2CTileMap>,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
ave_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
@@ -510,8 +473,8 @@ struct DeviceBatchedGemmXdl
|
||||
remove_reference_t<Block2CTileMap>,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
ave_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
@@ -533,9 +496,10 @@ struct DeviceBatchedGemmXdl
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -550,8 +514,7 @@ struct DeviceBatchedGemmXdl
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.M01_,
|
||||
arg.N01_);
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
|
||||
@@ -0,0 +1,234 @@
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "device.hpp"
|
||||
#include "device_base.hpp"
|
||||
#include "gridwise_binary_elementwise_1d.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename ComputeDataType,
|
||||
typename ElementwiseFunctor,
|
||||
index_t NDim,
|
||||
index_t MPerThread,
|
||||
index_t AScalarPerVector,
|
||||
index_t BScalarPerVector,
|
||||
index_t CScalarPerVector>
|
||||
struct DeviceBinaryElementwise : public BaseOperator
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
template <typename Desc_M>
|
||||
static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize)
|
||||
{
|
||||
const auto M = desc_m.GetLength(I0);
|
||||
const index_t loop_step = gridSize * blockSize * MPerThread;
|
||||
const auto pad = math::integer_least_multiple(M, loop_step) - M;
|
||||
const auto desc_m_pad =
|
||||
transform_tensor_descriptor(desc_m,
|
||||
make_tuple(make_right_pad_transform(M, pad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return desc_m_pad;
|
||||
}
|
||||
|
||||
static auto MakeDescriptor_M(const std::vector<index_t>& lengths,
|
||||
const std::vector<index_t>& strides,
|
||||
index_t gridSize,
|
||||
index_t blockSize)
|
||||
{
|
||||
auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<NDim>{});
|
||||
auto tupleOfStride = generate_tuple([&](auto I) { return strides[I]; }, Number<NDim>{});
|
||||
|
||||
// nd desc - [s0, s1, s2, ...]
|
||||
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
|
||||
|
||||
// merge nd to 1d desc - [s0 * s1 * ...]
|
||||
if constexpr(NDim > 1)
|
||||
{
|
||||
const auto desc_m = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_merge_transform(tupleOfShape)),
|
||||
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<NDim>{})),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return PadDescriptor_M_1d(desc_m, gridSize, blockSize);
|
||||
}
|
||||
else
|
||||
return PadDescriptor_M_1d(desc, gridSize, blockSize);
|
||||
}
|
||||
|
||||
using AGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
|
||||
using BGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
|
||||
using CGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
|
||||
using GridwiseBinEltwise = GridwiseBinaryElementwise_1D<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
ComputeDataType,
|
||||
AGridDesc_M,
|
||||
BGridDesc_M,
|
||||
CGridDesc_M,
|
||||
ElementwiseFunctor,
|
||||
MPerThread,
|
||||
AScalarPerVector,
|
||||
BScalarPerVector,
|
||||
CScalarPerVector>;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
CDataType* p_c,
|
||||
const std::vector<index_t>& lengths,
|
||||
const std::vector<index_t>& a_strides,
|
||||
const std::vector<index_t>& b_strides,
|
||||
const std::vector<index_t>& c_strides,
|
||||
ElementwiseFunctor functor)
|
||||
: p_a_(p_a),
|
||||
p_b_(p_b),
|
||||
p_c_(p_c),
|
||||
lengths_(lengths),
|
||||
a_strides_(a_strides),
|
||||
b_strides_(b_strides),
|
||||
c_strides_(c_strides),
|
||||
functor_(functor),
|
||||
blockSize_(256),
|
||||
gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future
|
||||
{
|
||||
a_grid_desc_m_ = MakeDescriptor_M(lengths, a_strides, gridSize_, blockSize_);
|
||||
b_grid_desc_m_ = MakeDescriptor_M(lengths, b_strides, gridSize_, blockSize_);
|
||||
c_grid_desc_m_ = MakeDescriptor_M(lengths, c_strides, gridSize_, blockSize_);
|
||||
}
|
||||
|
||||
const ADataType* p_a_;
|
||||
const BDataType* p_b_;
|
||||
CDataType* p_c_;
|
||||
std::vector<int> lengths_;
|
||||
AGridDesc_M a_grid_desc_m_;
|
||||
BGridDesc_M b_grid_desc_m_;
|
||||
CGridDesc_M c_grid_desc_m_;
|
||||
std::vector<index_t> a_strides_;
|
||||
std::vector<index_t> b_strides_;
|
||||
std::vector<index_t> c_strides_;
|
||||
ElementwiseFunctor functor_;
|
||||
index_t blockSize_;
|
||||
index_t gridSize_;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const auto kernel = kernel_binary_elementwise_1d<GridwiseBinEltwise,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AGridDesc_M,
|
||||
BGridDesc_M,
|
||||
CGridDesc_M,
|
||||
ElementwiseFunctor>;
|
||||
|
||||
float elapsed_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(arg.gridSize_),
|
||||
dim3(arg.blockSize_),
|
||||
0,
|
||||
arg.p_a_,
|
||||
arg.p_b_,
|
||||
arg.p_c_,
|
||||
arg.a_grid_desc_m_,
|
||||
arg.b_grid_desc_m_,
|
||||
arg.c_grid_desc_m_,
|
||||
arg.functor_);
|
||||
return elapsed_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
if(pArg == nullptr)
|
||||
return false;
|
||||
|
||||
if(pArg->lengths_.size() != NDim)
|
||||
return false;
|
||||
|
||||
if(pArg->lengths_.back() % MPerThread != 0)
|
||||
return false;
|
||||
|
||||
auto IsScalarPerVectorValid = [](bool isLastDimensionCoalesced, int scalarPerVector) {
|
||||
bool ret = true;
|
||||
|
||||
if(!isLastDimensionCoalesced)
|
||||
ret = scalarPerVector == 1;
|
||||
else
|
||||
ret = MPerThread % scalarPerVector == 0;
|
||||
|
||||
return ret;
|
||||
};
|
||||
|
||||
if(!IsScalarPerVectorValid(pArg->a_strides_.back() == 1, AScalarPerVector))
|
||||
return false;
|
||||
|
||||
if(!IsScalarPerVectorValid(pArg->b_strides_.back() == 1, BScalarPerVector))
|
||||
return false;
|
||||
|
||||
if(!IsScalarPerVectorValid(pArg->c_strides_.back() == 1, CScalarPerVector))
|
||||
return false;
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
std::vector<index_t> lengths,
|
||||
std::vector<index_t> a_strides,
|
||||
std::vector<index_t> b_strides,
|
||||
std::vector<index_t> c_strides,
|
||||
ElementwiseFunctor functor)
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<CDataType*>(p_c),
|
||||
lengths,
|
||||
a_strides,
|
||||
b_strides,
|
||||
c_strides,
|
||||
functor);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() { return std::make_unique<Invoker>(); }
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceBinaryElementwise"
|
||||
<< "<"
|
||||
<< "MPerThread = " << MPerThread
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
73
include/ck/tensor_operation/gpu/device/device_cgemm.hpp
Normal file
73
include/ck/tensor_operation/gpu/device/device_cgemm.hpp
Normal file
@@ -0,0 +1,73 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2022 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#pragma once
|
||||
#include "device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceCGemm : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a_real,
|
||||
const void* p_a_imag,
|
||||
const void* p_b_real,
|
||||
const void* p_b_imag,
|
||||
void* p_c_real,
|
||||
void* p_c_imag,
|
||||
void* p_workspace,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
ck::index_t KBatch = 1) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
virtual std::size_t GetWorkspaceSize(index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC) = 0;
|
||||
};
|
||||
|
||||
template <typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
using DeviceCGemmPtr = std::unique_ptr<
|
||||
DeviceCGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,972 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2022 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "device.hpp"
|
||||
#include "device_gemm.hpp"
|
||||
#include "device_cgemm.hpp"
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_layout.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_xdl_cshuffle_v1.hpp"
|
||||
#include "binary_element_wise_operation.hpp"
|
||||
#include "gridwise_binary_elementwise_1d.hpp"
|
||||
#include "tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename GemmAccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MXdlPerWave,
|
||||
index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
enable_if_t<
|
||||
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
|
||||
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
|
||||
is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
|
||||
bool> = false>
|
||||
struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
: public DeviceCGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceCGemm_4Gemm_Xdl_CShuffle;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
static constexpr auto MPerThread = Number<4>{};
|
||||
static constexpr auto AScalarPerVector = Number<4>{};
|
||||
static constexpr auto BScalarPerVector = Number<4>{};
|
||||
static constexpr auto CScalarPerVector = Number<4>{};
|
||||
|
||||
template <typename Desc_M>
|
||||
static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize)
|
||||
{
|
||||
const auto M = desc_m.GetLength(I0);
|
||||
const index_t loop_step = gridSize * blockSize * MPerThread;
|
||||
const auto pad = math::integer_least_multiple(M, loop_step) - M;
|
||||
const auto desc_m_pad =
|
||||
transform_tensor_descriptor(desc_m,
|
||||
make_tuple(make_right_pad_transform(M, pad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
return desc_m_pad;
|
||||
}
|
||||
|
||||
static auto MakeDescriptor_M(const std::vector<index_t>& lengths,
|
||||
const std::vector<index_t>& strides,
|
||||
index_t gridSize,
|
||||
index_t blockSize)
|
||||
{
|
||||
auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<2>{});
|
||||
auto tupleOfStride = generate_tuple([&](auto I) { return strides[I]; }, Number<2>{});
|
||||
|
||||
// nd desc - [s0, s1, s2, ...]
|
||||
const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
|
||||
const auto desc_m = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_merge_transform(tupleOfShape)),
|
||||
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<2>{})),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return PadDescriptor_M_1d(desc_m, gridSize, blockSize);
|
||||
}
|
||||
|
||||
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw = [&]() {
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(StrideA, I1));
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(I1, StrideA));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
|
||||
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
|
||||
|
||||
const auto MPad = M - MRaw;
|
||||
const auto KPad = K - KRaw;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad both M and K
|
||||
assert(K % AK1 == 0);
|
||||
|
||||
const auto AK0 = K / AK1;
|
||||
|
||||
const auto a_grid_desc_m_k =
|
||||
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad),
|
||||
make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_tensor_descriptor(a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
|
||||
GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
// pad M, but not K
|
||||
assert(KRaw % AK1 == 0);
|
||||
|
||||
const auto AK0 = KRaw / AK1;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_right_pad_transform(MRaw, MPad)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding)
|
||||
{
|
||||
// pad K, but not M
|
||||
assert(K % AK1 == 0);
|
||||
|
||||
const auto AK0 = K / AK1;
|
||||
|
||||
const auto a_grid_desc_m_k = transform_tensor_descriptor(
|
||||
a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_tensor_descriptor(a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(MRaw)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
else
|
||||
{
|
||||
// not pad M or K
|
||||
assert(KRaw % AK1 == 0);
|
||||
|
||||
const auto AK0 = KRaw / AK1;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(MRaw)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
}
|
||||
|
||||
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
{
|
||||
const auto b_grid_desc_nraw_kraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(I1, StrideB));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(StrideB, I1));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
|
||||
const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock;
|
||||
|
||||
const auto NPad = N - NRaw;
|
||||
const auto KPad = K - KRaw;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad both N and K
|
||||
assert(K % BK1 == 0);
|
||||
|
||||
const auto BK0 = K / BK1;
|
||||
|
||||
const auto b_grid_desc_n_k =
|
||||
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_right_pad_transform(NRaw, NPad),
|
||||
make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_tensor_descriptor(b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
|
||||
GemmSpec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
// pad N, but not K
|
||||
assert(KRaw % BK1 == 0);
|
||||
|
||||
const auto BK0 = KRaw / BK1;
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_right_pad_transform(NRaw, NPad)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
|
||||
GemmSpec == GemmSpecialization::MKPadding)
|
||||
{
|
||||
// pad K, but not N
|
||||
assert(K % BK1 == 0);
|
||||
|
||||
const auto BK0 = K / BK1;
|
||||
|
||||
const auto b_grid_desc_n_k = transform_tensor_descriptor(
|
||||
b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_tensor_descriptor(b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(NRaw)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
else
|
||||
{
|
||||
// not pad N or K
|
||||
assert(KRaw % BK1 == 0);
|
||||
|
||||
const auto BK0 = KRaw / BK1;
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(NRaw)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
}
|
||||
}
|
||||
|
||||
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
|
||||
{
|
||||
const auto c_grid_desc_mraw_nraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
|
||||
make_tuple(StrideC, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
|
||||
make_tuple(I1, StrideC));
|
||||
}
|
||||
}();
|
||||
|
||||
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
|
||||
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
|
||||
|
||||
const auto MPad = M - MRaw;
|
||||
const auto NPad = N - NRaw;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
// pad M and N
|
||||
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad),
|
||||
make_right_pad_transform(NRaw, NPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
|
||||
GemmSpec == GemmSpecialization::MKPadding)
|
||||
{
|
||||
// pad M, but not N
|
||||
return transform_tensor_descriptor(
|
||||
c_grid_desc_mraw_nraw,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding)
|
||||
{
|
||||
// pad N, but not M
|
||||
return transform_tensor_descriptor(
|
||||
c_grid_desc_mraw_nraw,
|
||||
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// not pad M or N
|
||||
return c_grid_desc_mraw_nraw;
|
||||
}
|
||||
}
|
||||
|
||||
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
|
||||
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
|
||||
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
|
||||
using CGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
GemmAccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
CGridDesc_M_N,
|
||||
NumGemmKPrefetchStage,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const ADataType* p_a_grid_real,
|
||||
const ADataType* p_a_grid_imag,
|
||||
const BDataType* p_b_grid_real,
|
||||
const BDataType* p_b_grid_imag,
|
||||
CDataType* p_c_grid_real,
|
||||
CDataType* p_c_grid_imag,
|
||||
CDataType* p_workspace,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
: p_a_grid_real_{p_a_grid_real},
|
||||
p_a_grid_imag_{p_a_grid_imag},
|
||||
p_b_grid_real_{p_b_grid_real},
|
||||
p_b_grid_imag_{p_b_grid_imag},
|
||||
p_c_grid_real_{p_c_grid_real},
|
||||
p_c_grid_imag_{p_c_grid_imag},
|
||||
p_aux_grid_{p_workspace},
|
||||
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
|
||||
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
|
||||
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op}
|
||||
{
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
|
||||
b_grid_desc_bk0_n_bk1_,
|
||||
c_grid_desc_m_n_,
|
||||
block_2_ctile_map_))
|
||||
{
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n_);
|
||||
}
|
||||
|
||||
const index_t grid_size = block_2_ctile_map_.CalculateGridSize(c_grid_desc_m_n_);
|
||||
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
|
||||
{
|
||||
c_grid_desc_m_ =
|
||||
DeviceOp::MakeDescriptor_M({MRaw, NRaw}, {StrideC, I1}, grid_size, BlockSize);
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
|
||||
{
|
||||
c_grid_desc_m_ =
|
||||
DeviceOp::MakeDescriptor_M({MRaw, NRaw}, {I1, StrideC}, grid_size, BlockSize);
|
||||
}
|
||||
|
||||
p_aux_2_grid_ = p_workspace + c_grid_desc_m_n_.GetElementSpaceSize();
|
||||
}
|
||||
|
||||
// private:
|
||||
const ADataType* p_a_grid_real_;
|
||||
const ADataType* p_a_grid_imag_;
|
||||
const BDataType* p_b_grid_real_;
|
||||
const BDataType* p_b_grid_imag_;
|
||||
CDataType* p_c_grid_real_;
|
||||
CDataType* p_c_grid_imag_;
|
||||
CDataType* p_aux_grid_;
|
||||
CDataType* p_aux_2_grid_;
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
CGridDesc_M c_grid_desc_m_;
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
}
|
||||
|
||||
const index_t grid_size =
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
|
||||
|
||||
const auto K =
|
||||
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
using Add = ck::tensor_operation::element_wise::Add;
|
||||
using Subtract = ck::tensor_operation::element_wise::Subtract;
|
||||
using GridwiseBinAdd = GridwiseBinaryElementwise_1D<CDataType,
|
||||
CDataType,
|
||||
CDataType,
|
||||
CDataType,
|
||||
CGridDesc_M,
|
||||
CGridDesc_M,
|
||||
CGridDesc_M,
|
||||
Add,
|
||||
MPerThread,
|
||||
AScalarPerVector,
|
||||
BScalarPerVector,
|
||||
CScalarPerVector>;
|
||||
using GridwiseBinSubtract = GridwiseBinaryElementwise_1D<CDataType,
|
||||
CDataType,
|
||||
CDataType,
|
||||
CDataType,
|
||||
CGridDesc_M,
|
||||
CGridDesc_M,
|
||||
CGridDesc_M,
|
||||
Subtract,
|
||||
MPerThread,
|
||||
AScalarPerVector,
|
||||
BScalarPerVector,
|
||||
CScalarPerVector>;
|
||||
const auto add_kernel = kernel_binary_elementwise_1d<GridwiseBinAdd,
|
||||
CDataType,
|
||||
CDataType,
|
||||
CDataType,
|
||||
CGridDesc_M,
|
||||
CGridDesc_M,
|
||||
CGridDesc_M,
|
||||
Add>;
|
||||
const auto subtract_kernel = kernel_binary_elementwise_1d<GridwiseBinSubtract,
|
||||
CDataType,
|
||||
CDataType,
|
||||
CDataType,
|
||||
CGridDesc_M,
|
||||
CGridDesc_M,
|
||||
CGridDesc_M,
|
||||
Subtract>;
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v1<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap,
|
||||
true>;
|
||||
|
||||
ave_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_real_,
|
||||
arg.p_b_grid_real_,
|
||||
arg.p_aux_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
ave_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_imag_,
|
||||
arg.p_b_grid_imag_,
|
||||
arg.p_aux_2_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
// c_real = aux - aux_2
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
subtract_kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_aux_grid_,
|
||||
arg.p_aux_2_grid_,
|
||||
arg.p_c_grid_real_,
|
||||
arg.c_grid_desc_m_,
|
||||
arg.c_grid_desc_m_,
|
||||
arg.c_grid_desc_m_,
|
||||
Subtract{});
|
||||
|
||||
ave_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_real_,
|
||||
arg.p_b_grid_imag_,
|
||||
arg.p_aux_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
ave_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_imag_,
|
||||
arg.p_b_grid_real_,
|
||||
arg.p_aux_2_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
// c_imag = aux + aux_2
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
add_kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_aux_grid_,
|
||||
arg.p_aux_2_grid_,
|
||||
arg.p_c_grid_imag_,
|
||||
arg.c_grid_desc_m_,
|
||||
arg.c_grid_desc_m_,
|
||||
arg.c_grid_desc_m_,
|
||||
Add{});
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v1<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap,
|
||||
false>;
|
||||
|
||||
ave_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_real_,
|
||||
arg.p_b_grid_real_,
|
||||
arg.p_aux_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
ave_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_imag_,
|
||||
arg.p_b_grid_imag_,
|
||||
arg.p_aux_2_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
// c_real = aux - aux_2
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
subtract_kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_aux_grid_,
|
||||
arg.p_aux_2_grid_,
|
||||
arg.p_c_grid_real_,
|
||||
arg.c_grid_desc_m_,
|
||||
arg.c_grid_desc_m_,
|
||||
arg.c_grid_desc_m_,
|
||||
Subtract{});
|
||||
|
||||
ave_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_real_,
|
||||
arg.p_b_grid_imag_,
|
||||
arg.p_aux_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
ave_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_imag_,
|
||||
arg.p_b_grid_real_,
|
||||
arg.p_aux_2_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_);
|
||||
|
||||
// c_imag = aux + aux_2
|
||||
ave_time += launch_and_time_kernel(stream_config,
|
||||
add_kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_aux_grid_,
|
||||
arg.p_aux_2_grid_,
|
||||
arg.p_c_grid_imag_,
|
||||
arg.c_grid_desc_m_,
|
||||
arg.c_grid_desc_m_,
|
||||
arg.c_grid_desc_m_,
|
||||
Add{});
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a_real,
|
||||
const ADataType* p_a_imag,
|
||||
const BDataType* p_b_real,
|
||||
const BDataType* p_b_imag,
|
||||
CDataType* p_c_real,
|
||||
CDataType* p_c_imag,
|
||||
CDataType* p_workspace,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
return Argument{p_a_real,
|
||||
p_a_imag,
|
||||
p_b_real,
|
||||
p_b_imag,
|
||||
p_c_real,
|
||||
p_c_imag,
|
||||
p_workspace,
|
||||
MRaw,
|
||||
NRaw,
|
||||
KRaw,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a_real,
|
||||
const void* p_a_imag,
|
||||
const void* p_b_real,
|
||||
const void* p_b_imag,
|
||||
void* p_c_real,
|
||||
void* p_c_imag,
|
||||
void* p_workspace,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
index_t /* KBatch */ = 1) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a_real),
|
||||
static_cast<const ADataType*>(p_a_imag),
|
||||
static_cast<const BDataType*>(p_b_real),
|
||||
static_cast<const BDataType*>(p_b_imag),
|
||||
static_cast<CDataType*>(p_c_real),
|
||||
static_cast<CDataType*>(p_c_imag),
|
||||
static_cast<CDataType*>(p_workspace),
|
||||
MRaw,
|
||||
NRaw,
|
||||
KRaw,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceCGemm_4Gemm_Xdl_CShuffle"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< AK1 << ", "
|
||||
<< BK1
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceSize(index_t MRaw,
|
||||
index_t NRaw,
|
||||
[[maybe_unused]] index_t KRaw,
|
||||
[[maybe_unused]] index_t StrideA,
|
||||
[[maybe_unused]] index_t StrideB,
|
||||
index_t StrideC) override
|
||||
{
|
||||
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC);
|
||||
|
||||
return 2 * sizeof(CDataType) * c_grid_desc_m_n.GetElementSpaceSize();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -11,7 +11,7 @@
|
||||
#include "tensor_layout.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_xdlops_v2r4r2.hpp"
|
||||
#include "gridwise_gemm_xdlops_bwd_weight.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -81,6 +81,22 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
static constexpr auto GemmK1Number = K1Number;
|
||||
|
||||
static constexpr auto N1Number = K1Number;
|
||||
|
||||
// Bytes per 32 lds bank: 32 * 4 bytes
|
||||
static constexpr auto BankLength = 128;
|
||||
static constexpr auto ElePerBank = BankLength / sizeof(ADataType);
|
||||
|
||||
// M1 & M0
|
||||
static constexpr auto ABlockLdsM1PerBlock = ElePerBank / K1;
|
||||
static constexpr auto ABlockLdsM0PerBlock = MPerBlock / ABlockLdsM1PerBlock;
|
||||
static constexpr auto ABlockLdsM1Padding = 4;
|
||||
|
||||
// N1 & N0
|
||||
static constexpr auto BBlockLdsN1PerBlock = ElePerBank / K1;
|
||||
static constexpr auto BBlockLdsN0PerBlock = NPerBlock / BBlockLdsN1PerBlock;
|
||||
static constexpr auto BBlockLdsN1Padding = 4;
|
||||
|
||||
static auto
|
||||
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N,
|
||||
ck::index_t K,
|
||||
@@ -125,27 +141,51 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
const index_t GemmK0 =
|
||||
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
|
||||
K0PerBlock;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
|
||||
const auto out_gemmktotal_gemmm_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
|
||||
const auto in_n_hi_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
|
||||
|
||||
// A: output tensor
|
||||
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmktotal_gemmm_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
const index_t N0 = N / N1Number;
|
||||
const index_t GemmK0Total = N0 * Ho * Wo;
|
||||
|
||||
const index_t GemmK0S =
|
||||
math::integer_divide_ceil(GemmK0Total, K0PerBlock * GemmKBatch) * K0PerBlock;
|
||||
const index_t GemmK0Pad = GemmKBatch * GemmK0S;
|
||||
const auto out_n_ho_wo_k_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Ho * Wo, K));
|
||||
|
||||
const auto out_n0_ho_wo_k_n1_grid_desc =
|
||||
transform_tensor_descriptor(out_n_ho_wo_k_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(N0, N1Number)),
|
||||
make_pass_through_transform(Ho * Wo),
|
||||
make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto out_gemmk0total_gemmm_gemmk1_grid_desc =
|
||||
transform_tensor_descriptor(out_n0_ho_wo_k_n1_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N0, Ho * Wo)),
|
||||
make_pass_through_transform(K),
|
||||
make_pass_through_transform(N1Number)),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto out_gemmk0pad_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmk0total_gemmm_gemmk1_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmK0Total, GemmK0Pad - GemmK0Total),
|
||||
make_pass_through_transform(GemmM),
|
||||
make_pass_through_transform(N1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
out_gemmk0pad_gemmm_gemmk1_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0)),
|
||||
make_pass_through_transform(GemmM),
|
||||
make_pass_through_transform(N1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// B: input tensor
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
@@ -167,26 +207,50 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmktotal_gemmn_grid_desc =
|
||||
const auto in_n0_y_ho_x_wo_c_n1_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
make_tuple(make_unmerge_transform(make_tuple(N0, N1Number)),
|
||||
make_pass_through_transform(Y),
|
||||
make_pass_through_transform(Ho),
|
||||
make_pass_through_transform(X),
|
||||
make_pass_through_transform(Wo),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<0, 6>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}));
|
||||
|
||||
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmktotal_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
const auto in_gemmk0total_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_n0_y_ho_x_wo_c_n1_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N0, Ho, Wo)),
|
||||
make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_pass_through_transform(N1Number)),
|
||||
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}, Sequence<6>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto in_gemmk0pad_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmk0total_gemmn_gemmk1_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmK0Total, GemmK0Pad - GemmK0Total),
|
||||
make_pass_through_transform(GemmN),
|
||||
make_pass_through_transform(N1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
in_gemmk0pad_gemmn_gemmk1_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0)),
|
||||
make_pass_through_transform(GemmN),
|
||||
make_pass_through_transform(N1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// C: weight tensor
|
||||
const auto wei_gemmm_gemmn_grid_desc =
|
||||
@@ -205,7 +269,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
|
||||
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight<
|
||||
BlockSize,
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
@@ -233,6 +297,9 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
ABlockLdsAddExtraM,
|
||||
ABlockLdsM1PerBlock,
|
||||
ABlockLdsM0PerBlock,
|
||||
ABlockLdsM1Padding,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
@@ -241,12 +308,17 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockLdsAddExtraN,
|
||||
BBlockLdsN1PerBlock,
|
||||
BBlockLdsN0PerBlock,
|
||||
BBlockLdsN1Padding,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CBlockTransferScalarPerVector_NWaveNPerXdl,
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
true,
|
||||
true>;
|
||||
|
||||
using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
|
||||
using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight<
|
||||
BlockSize,
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
@@ -274,6 +346,9 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
ABlockLdsAddExtraM,
|
||||
ABlockLdsM1PerBlock,
|
||||
ABlockLdsM0PerBlock,
|
||||
ABlockLdsM1Padding,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
@@ -282,10 +357,15 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockLdsAddExtraN,
|
||||
BBlockLdsN1PerBlock,
|
||||
BBlockLdsN0PerBlock,
|
||||
BBlockLdsN1Padding,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CBlockTransferScalarPerVector_NWaveNPerXdl,
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
true,
|
||||
true>;
|
||||
// Argument
|
||||
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}));
|
||||
@@ -353,17 +433,16 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
b_grid_desc_kbatch_k0_n_k1_ = descs[I1];
|
||||
c_grid_desc_m_n_ = descs[I2];
|
||||
|
||||
block_2_ctile_map_ =
|
||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
|
||||
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_,
|
||||
b_grid_desc_kbatch_k0_n_k1_,
|
||||
c_grid_desc_m_n_,
|
||||
M01_,
|
||||
N01_))
|
||||
block_2_ctile_map_))
|
||||
{
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_);
|
||||
|
||||
block_2_ctile_map_ =
|
||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -415,20 +494,21 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
ShowInfo(arg);
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.M01_,
|
||||
arg.N01_))
|
||||
arg.block_2_ctile_map_))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting");
|
||||
"wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight has invalid setting");
|
||||
}
|
||||
const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0);
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_, kbatch);
|
||||
const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0);
|
||||
const index_t grid_size =
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
|
||||
|
||||
const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
|
||||
|
||||
@@ -437,56 +517,35 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
float ave_time = 0;
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
if(nrepeat > 0)
|
||||
{
|
||||
ave_time =
|
||||
launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
hipGetErrorString(hipMemset(
|
||||
arg.p_c_grid_,
|
||||
0,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
|
||||
sizeof(CDataType)));
|
||||
|
||||
if(kbatch > 1 || nrepeat <= 0)
|
||||
{
|
||||
hipGetErrorString(hipMemset(
|
||||
arg.p_c_grid_,
|
||||
0,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
|
||||
sizeof(CDataType)));
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
ave_time =
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
};
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
{
|
||||
if(kbatch == 1)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r4r2<
|
||||
const auto kernel = kernel_gemm_xdlops_bwd_weight<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
@@ -503,7 +562,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r4r2<
|
||||
const auto kernel = kernel_gemm_xdlops_bwd_weight<
|
||||
GridwiseGemmAtomicAdd,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
@@ -523,7 +582,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
{
|
||||
if(kbatch == 1)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r4r2<
|
||||
const auto kernel = kernel_gemm_xdlops_bwd_weight<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
@@ -540,7 +599,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v2r4r2<
|
||||
const auto kernel = kernel_gemm_xdlops_bwd_weight<
|
||||
GridwiseGemmAtomicAdd,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
@@ -560,9 +619,10 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -582,6 +642,12 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
return false;
|
||||
}
|
||||
|
||||
// unmerge N to N0 and N1, where N1 equals to K1
|
||||
if(!(arg.Conv_N_ % K1 == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// vector store C matrix into global memory
|
||||
if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0))
|
||||
{
|
||||
@@ -592,8 +658,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.M01_,
|
||||
arg.N01_);
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user