mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
fix merge from upstream
This commit is contained in:
213
Jenkinsfile
vendored
213
Jenkinsfile
vendored
@@ -1,5 +1,5 @@
|
||||
def rocmnode(name) {
|
||||
return '(rocmtest || miopen) && ' + name
|
||||
return '(rocmtest || miopen) && (' + name + ')'
|
||||
}
|
||||
|
||||
def show_node_info() {
|
||||
@@ -7,6 +7,7 @@ def show_node_info() {
|
||||
echo "NODE_NAME = \$NODE_NAME"
|
||||
lsb_release -sd
|
||||
uname -r
|
||||
cat /sys/module/amdgpu/version
|
||||
ls /opt/ -la
|
||||
"""
|
||||
}
|
||||
@@ -33,7 +34,11 @@ def runShell(String command){
|
||||
|
||||
def getDockerImageName(){
|
||||
def img
|
||||
if (params.ROCMVERSION != "6.0.1"){
|
||||
if (params.USE_CUSTOM_DOCKER != ""){
|
||||
img = "${params.USE_CUSTOM_DOCKER}"
|
||||
}
|
||||
else{
|
||||
if (params.ROCMVERSION != "6.1"){
|
||||
if (params.COMPILER_VERSION == "") {
|
||||
img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}"
|
||||
}
|
||||
@@ -61,6 +66,7 @@ def getDockerImageName(){
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return img
|
||||
}
|
||||
|
||||
@@ -98,7 +104,7 @@ def getDockerImage(Map conf=[:]){
|
||||
env.DOCKER_BUILDKIT=1
|
||||
def prefixpath = conf.get("prefixpath", "/opt/rocm")
|
||||
def no_cache = conf.get("no_cache", false)
|
||||
def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
|
||||
def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${prefixpath} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
|
||||
if(no_cache)
|
||||
{
|
||||
dockerArgs = dockerArgs + " --no-cache "
|
||||
@@ -111,7 +117,9 @@ def getDockerImage(Map conf=[:]){
|
||||
{
|
||||
echo "Pulling down image: ${image}"
|
||||
retimage = docker.image("${image}")
|
||||
retimage.pull()
|
||||
withDockerRegistry([ credentialsId: "docker_test_cred", url: "" ]) {
|
||||
retimage.pull()
|
||||
}
|
||||
}
|
||||
catch(Exception ex)
|
||||
{
|
||||
@@ -126,7 +134,7 @@ def buildDocker(install_prefix){
|
||||
checkout scm
|
||||
def image_name = getDockerImageName()
|
||||
echo "Building Docker for ${image_name}"
|
||||
def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${install_prefix} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
|
||||
def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${install_prefix} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
|
||||
|
||||
echo "Build Args: ${dockerArgs}"
|
||||
try{
|
||||
@@ -134,7 +142,9 @@ def buildDocker(install_prefix){
|
||||
//force building the new docker if that parameter is true
|
||||
echo "Building image: ${image_name}"
|
||||
retimage = docker.build("${image_name}", dockerArgs + ' .')
|
||||
retimage.push()
|
||||
withDockerRegistry([ credentialsId: "docker_test_cred", url: "" ]) {
|
||||
retimage.push()
|
||||
}
|
||||
sh 'docker images -q -f dangling=true | xargs --no-run-if-empty docker rmi'
|
||||
}
|
||||
else{
|
||||
@@ -146,7 +156,9 @@ def buildDocker(install_prefix){
|
||||
catch(Exception ex){
|
||||
echo "Unable to locate image: ${image_name}. Building image now"
|
||||
retimage = docker.build("${image_name}", dockerArgs + ' .')
|
||||
retimage.push()
|
||||
withDockerRegistry([ credentialsId: "docker_test_cred", url: "" ]) {
|
||||
retimage.push()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -254,18 +266,24 @@ def cmake_build(Map conf=[:]){
|
||||
""")
|
||||
sh cmd3
|
||||
}
|
||||
|
||||
def setup_cmd = conf.get("setup_cmd", "${cmake_envs} cmake ${setup_args} .. ")
|
||||
// reduce parallelism when compiling, clang uses too much memory
|
||||
def nt = nthreads()
|
||||
def build_cmd = conf.get("build_cmd", "${build_envs} dumb-init make -j${nt} ${config_targets}")
|
||||
def cmd
|
||||
def execute_cmd = conf.get("execute_cmd", "")
|
||||
|
||||
def cmd = conf.get("cmd", """
|
||||
if(!setup_args.contains("NO_CK_BUILD")){
|
||||
def setup_cmd = conf.get("setup_cmd", "${cmake_envs} cmake ${setup_args} .. ")
|
||||
def build_cmd = conf.get("build_cmd", "${build_envs} dumb-init make -j${nt} ${config_targets}")
|
||||
cmd = conf.get("cmd", """
|
||||
${setup_cmd}
|
||||
${build_cmd}
|
||||
${execute_cmd}
|
||||
""")
|
||||
}
|
||||
else{
|
||||
cmd = conf.get("cmd", """
|
||||
${execute_cmd}
|
||||
""")
|
||||
}
|
||||
|
||||
echo cmd
|
||||
|
||||
@@ -293,7 +311,7 @@ def buildHipClangJob(Map conf=[:]){
|
||||
if (conf.get("enforce_xnack_on", false)) {
|
||||
dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
|
||||
}
|
||||
def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
|
||||
def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
|
||||
if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){
|
||||
dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' "
|
||||
}
|
||||
@@ -303,7 +321,7 @@ def buildHipClangJob(Map conf=[:]){
|
||||
def retimage
|
||||
(retimage, image) = getDockerImage(conf)
|
||||
|
||||
gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel-internal') {
|
||||
gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
|
||||
withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') {
|
||||
timeout(time: 48, unit: 'HOURS')
|
||||
{
|
||||
@@ -349,20 +367,17 @@ def runCKProfiler(Map conf=[:]){
|
||||
dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
|
||||
}
|
||||
def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
|
||||
if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){
|
||||
dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' "
|
||||
}
|
||||
|
||||
def variant = env.STAGE_NAME
|
||||
def retimage
|
||||
|
||||
gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel-internal') {
|
||||
gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
|
||||
try {
|
||||
(retimage, image) = getDockerImage(conf)
|
||||
withDockerContainer(image: image, args: dockerOpts) {
|
||||
timeout(time: 5, unit: 'MINUTES'){
|
||||
sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo | tee clinfo.log'
|
||||
if ( runShell('grep -n "Number of devices:.*. 0" clinfo.log') ){
|
||||
sh 'rocminfo | tee rocminfo.log'
|
||||
if ( !runShell('grep -n "gfx" rocminfo.log') ){
|
||||
throw new Exception ("GPU not found")
|
||||
}
|
||||
else{
|
||||
@@ -375,20 +390,6 @@ def runCKProfiler(Map conf=[:]){
|
||||
echo "The job was cancelled or aborted"
|
||||
throw e
|
||||
}
|
||||
catch(Exception ex) {
|
||||
retimage = docker.build("${image}", dockerArgs + " --no-cache .")
|
||||
withDockerContainer(image: image, args: dockerOpts) {
|
||||
timeout(time: 5, unit: 'MINUTES'){
|
||||
sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo | tee clinfo.log'
|
||||
if ( runShell('grep -n "Number of devices:.*. 0" clinfo.log') ){
|
||||
throw new Exception ("GPU not found")
|
||||
}
|
||||
else{
|
||||
echo "GPU is OK"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') {
|
||||
timeout(time: 24, unit: 'HOURS')
|
||||
@@ -404,7 +405,7 @@ def runCKProfiler(Map conf=[:]){
|
||||
|
||||
dir("script"){
|
||||
if (params.RUN_FULL_QA){
|
||||
sh "./run_full_performance_tests.sh 1 QA_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}"
|
||||
sh "./run_full_performance_tests.sh 0 QA_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}"
|
||||
archiveArtifacts "perf_gemm.log"
|
||||
archiveArtifacts "perf_resnet50_N256.log"
|
||||
archiveArtifacts "perf_resnet50_N4.log"
|
||||
@@ -414,9 +415,9 @@ def runCKProfiler(Map conf=[:]){
|
||||
archiveArtifacts "perf_conv_bwd_data.log"
|
||||
archiveArtifacts "perf_gemm_bilinear.log"
|
||||
archiveArtifacts "perf_reduction.log"
|
||||
archiveArtifacts "perf_splitK_gemm_verify.log"
|
||||
archiveArtifacts "perf_splitK_gemm.log"
|
||||
archiveArtifacts "perf_onnx_gemm.log"
|
||||
archiveArtifacts "perf_mixed_gemm.log"
|
||||
// stash perf files to master
|
||||
stash name: "perf_gemm.log"
|
||||
stash name: "perf_resnet50_N256.log"
|
||||
@@ -429,6 +430,7 @@ def runCKProfiler(Map conf=[:]){
|
||||
stash name: "perf_reduction.log"
|
||||
stash name: "perf_splitK_gemm.log"
|
||||
stash name: "perf_onnx_gemm.log"
|
||||
stash name: "perf_mixed_gemm.log"
|
||||
//we will process results on the master node
|
||||
}
|
||||
else{
|
||||
@@ -469,6 +471,7 @@ def Build_CK(Map conf=[:]){
|
||||
show_node_info()
|
||||
|
||||
env.HSA_ENABLE_SDMA=0
|
||||
env.DOCKER_BUILDKIT=1
|
||||
checkout scm
|
||||
|
||||
def image = getDockerImageName()
|
||||
@@ -483,26 +486,25 @@ def Build_CK(Map conf=[:]){
|
||||
if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){
|
||||
dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' "
|
||||
}
|
||||
def video_id = sh(returnStdout: true, script: 'getent group video | cut -d: -f3')
|
||||
def render_id = sh(returnStdout: true, script: 'getent group render | cut -d: -f3')
|
||||
dockerOpts = dockerOpts + " --group-add=${video_id} --group-add=${render_id} "
|
||||
echo "Docker flags: ${dockerOpts}"
|
||||
|
||||
def variant = env.STAGE_NAME
|
||||
def retimage
|
||||
def navi_node = 0
|
||||
|
||||
gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel-internal') {
|
||||
gitStatusWrapper(credentialsId: "${env.status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
|
||||
try {
|
||||
(retimage, image) = getDockerImage(conf)
|
||||
withDockerContainer(image: image, args: dockerOpts) {
|
||||
timeout(time: 5, unit: 'MINUTES'){
|
||||
sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo | tee clinfo.log'
|
||||
if ( runShell('grep -n "Number of devices:.*. 0" clinfo.log') ){
|
||||
sh 'rocminfo | tee rocminfo.log'
|
||||
if ( !runShell('grep -n "gfx" rocminfo.log') ){
|
||||
throw new Exception ("GPU not found")
|
||||
}
|
||||
else{
|
||||
echo "GPU is OK"
|
||||
}
|
||||
if ( runShell('grep -n "gfx1030" clinfo.log') || runShell('grep -n "gfx1101" clinfo.log') ){
|
||||
navi_node = 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -510,43 +512,38 @@ def Build_CK(Map conf=[:]){
|
||||
echo "The job was cancelled or aborted"
|
||||
throw e
|
||||
}
|
||||
catch(Exception ex) {
|
||||
retimage = docker.build("${image}", dockerArgs + " --no-cache .")
|
||||
withDockerContainer(image: image, args: dockerOpts) {
|
||||
timeout(time: 5, unit: 'MINUTES'){
|
||||
sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo |tee clinfo.log'
|
||||
if ( runShell('grep -n "Number of devices:.*. 0" clinfo.log') ){
|
||||
throw new Exception ("GPU not found")
|
||||
}
|
||||
else{
|
||||
echo "GPU is OK"
|
||||
}
|
||||
if ( runShell('grep -n "gfx1030" clinfo.log') || runShell('grep -n "gfx1101" clinfo.log') ){
|
||||
navi_node = 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') {
|
||||
timeout(time: 24, unit: 'HOURS')
|
||||
{
|
||||
//check whether running on Navi or MI300 node
|
||||
def navi_node = 0
|
||||
def mi300_node = 0
|
||||
sh 'rocminfo | tee rocminfo.log'
|
||||
if ( runShell('grep -n "gfx1030" rocminfo.log') || runShell('grep -n "gfx1101" rocminfo.log') ){
|
||||
navi_node = 1
|
||||
echo "This is a Navi node"
|
||||
}
|
||||
if ( runShell('grep -n "gfx942" rocminfo.log') ){
|
||||
mi300_node = 1
|
||||
echo "This is MI300 node"
|
||||
}
|
||||
cmake_build(conf)
|
||||
dir("build"){
|
||||
//run tests and examples
|
||||
sh 'make -j check'
|
||||
if (navi_node == 0 ){
|
||||
if (params.RUN_PERFORMANCE_TESTS && navi_node == 0 && mi300_node == 0 ){
|
||||
//we only need the ckProfiler to run the performance tests, so we pack and stash it
|
||||
//do not stash profiler on Navi nodes
|
||||
//do not stash profiler on Navi or MI300 nodes
|
||||
sh 'tar -zcvf ckProfiler.tar.gz bin/ckProfiler'
|
||||
stash "ckProfiler.tar.gz"
|
||||
stash name: "ckProfiler.tar.gz"
|
||||
}
|
||||
if (params.RUN_FULL_QA){
|
||||
// build deb packages
|
||||
if (params.RUN_FULL_QA && mi300_node == 0 ){
|
||||
// build deb packages for all MI100/200/300 targets and prepare to export
|
||||
sh 'make -j package'
|
||||
archiveArtifacts artifacts: 'composablekernel-ckprofiler_*.deb'
|
||||
archiveArtifacts artifacts: 'composablekernel-tests_*.deb'
|
||||
sh 'mv composablekernel-ckprofiler_*.deb ckprofiler_0.2.0_amd64.deb'
|
||||
stash "ckprofiler_0.2.0_amd64.deb"
|
||||
stash name: "ckprofiler_0.2.0_amd64.deb"
|
||||
}
|
||||
}
|
||||
if (params.hipTensor_test && navi_node == 0 ){
|
||||
@@ -606,7 +603,7 @@ def process_results(Map conf=[:]){
|
||||
def variant = env.STAGE_NAME
|
||||
def retimage
|
||||
|
||||
gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel-internal') {
|
||||
gitStatusWrapper(credentialsId: "${env.status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
|
||||
try {
|
||||
(retimage, image) = getDockerImage(conf)
|
||||
}
|
||||
@@ -622,6 +619,8 @@ def process_results(Map conf=[:]){
|
||||
dir("script"){
|
||||
if (params.RUN_FULL_QA){
|
||||
// unstash perf files to master
|
||||
unstash "ckprofiler_0.2.0_amd64.deb"
|
||||
sh "sshpass -p ${env.ck_deb_pw} scp -o StrictHostKeyChecking=no ckprofiler_0.2.0_amd64.deb ${env.ck_deb_user}@${env.ck_deb_ip}:/var/www/html/composable_kernel/"
|
||||
unstash "perf_gemm.log"
|
||||
unstash "perf_resnet50_N256.log"
|
||||
unstash "perf_resnet50_N4.log"
|
||||
@@ -633,9 +632,8 @@ def process_results(Map conf=[:]){
|
||||
unstash "perf_reduction.log"
|
||||
unstash "perf_splitK_gemm.log"
|
||||
unstash "perf_onnx_gemm.log"
|
||||
unstash "perf_mixed_gemm.log"
|
||||
sh "./process_qa_data.sh"
|
||||
unstash "ckprofiler_0.2.0_amd64.deb"
|
||||
sh "sshpass -p ${env.ck_deb_pw} scp -o StrictHostKeyChecking=no ckprofiler_0.2.0_amd64.deb ${env.ck_deb_user}@${env.ck_deb_ip}:/var/www/html/composable_kernel/"
|
||||
}
|
||||
else{
|
||||
// unstash perf files to master
|
||||
@@ -647,16 +645,28 @@ def process_results(Map conf=[:]){
|
||||
}
|
||||
}
|
||||
catch(e){
|
||||
echo "throwing error exception while processing performance test results"
|
||||
echo "Throwing error exception while processing performance test results"
|
||||
echo 'Exception occurred: ' + e.toString()
|
||||
throw e
|
||||
}
|
||||
finally{
|
||||
echo "Finished processing performance test results"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version
|
||||
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=6.0;COMPILER_VERSION=
|
||||
0 21 * * * % ROCMVERSION=6.0;COMPILER_VERSION=;COMPILER_COMMIT=
|
||||
0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-staging;COMPILER_COMMIT=;USE_SCCACHE=false
|
||||
0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;COMPILER_COMMIT=;USE_SCCACHE=false''' : ""
|
||||
|
||||
pipeline {
|
||||
agent none
|
||||
triggers {
|
||||
parameterizedCron(CRON_SETTINGS)
|
||||
}
|
||||
options {
|
||||
parallelsAlwaysFailFast()
|
||||
}
|
||||
@@ -665,6 +675,10 @@ pipeline {
|
||||
name: "BUILD_DOCKER",
|
||||
defaultValue: false,
|
||||
description: "Force building docker image (default: false), set to true if docker image needs to be updated.")
|
||||
string(
|
||||
name: 'USE_CUSTOM_DOCKER',
|
||||
defaultValue: '',
|
||||
description: 'If you want to use a custom docker image, please specify it here (default: leave blank).')
|
||||
string(
|
||||
name: 'ROCMVERSION',
|
||||
defaultValue: '6.0',
|
||||
@@ -707,8 +721,12 @@ pipeline {
|
||||
description: "Run the cppcheck static analysis (default: OFF)")
|
||||
booleanParam(
|
||||
name: "RUN_PERFORMANCE_TESTS",
|
||||
defaultValue: false,
|
||||
description: "Run the performance tests (default: OFF)")
|
||||
defaultValue: true,
|
||||
description: "Run the performance tests (default: ON)")
|
||||
booleanParam(
|
||||
name: "RUN_CODEGEN_TESTS",
|
||||
defaultValue: true,
|
||||
description: "Run the codegen tests (default: ON)")
|
||||
}
|
||||
environment{
|
||||
dbuser = "${dbuser}"
|
||||
@@ -787,7 +805,34 @@ pipeline {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stage("Run Codegen Tests")
|
||||
{
|
||||
parallel
|
||||
{
|
||||
stage("Run Codegen Tests on MI100/MI200")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { params.RUN_CODEGEN_TESTS.toBoolean() }
|
||||
}
|
||||
options { retry(2) }
|
||||
agent{ label rocmnode("gfx908 || gfx90a")}
|
||||
environment{
|
||||
setup_args = "NO_CK_BUILD"
|
||||
execute_args = """ cd ../codegen && rm -rf build && mkdir build && cd build && \
|
||||
cmake -D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-D CMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D GPU_TARGETS="gfx908;gfx90a" \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j check"""
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
stage("Build CK and run Tests")
|
||||
{
|
||||
parallel
|
||||
@@ -815,6 +860,26 @@ pipeline {
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
stage("Build CK and run Tests on MI300")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { params.RUN_FULL_QA.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx942") }
|
||||
environment{
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx942" -DCMAKE_CXX_FLAGS=" -O3 " """
|
||||
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
|
||||
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
|
||||
-DGPU_TARGETS="gfx942" \
|
||||
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
|
||||
}
|
||||
steps{
|
||||
Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
stage("Build CK and run Tests on MI100/MI200")
|
||||
{
|
||||
when {
|
||||
|
||||
@@ -52,7 +52,7 @@ importlib-metadata==6.8.0
|
||||
# via
|
||||
# sphinx
|
||||
# sphinxcontrib-bibtex
|
||||
importlib-resources==6.1.1
|
||||
importlib-resources==6.1.0
|
||||
# via rocm-docs-core
|
||||
jinja2==3.1.2
|
||||
# via
|
||||
@@ -96,9 +96,7 @@ pygments==2.15.0
|
||||
# pydata-sphinx-theme
|
||||
# sphinx
|
||||
pyjwt[crypto]==2.6.0
|
||||
# via
|
||||
# pygithub
|
||||
# pyjwt
|
||||
# via pygithub
|
||||
pynacl==1.5.0
|
||||
# via pygithub
|
||||
pytz==2023.3.post1
|
||||
@@ -113,7 +111,7 @@ requests==2.31.0
|
||||
# via
|
||||
# pygithub
|
||||
# sphinx
|
||||
rocm-docs-core==0.33.2
|
||||
rocm-docs-core==0.37.1
|
||||
# via -r requirements.in
|
||||
six==1.16.0
|
||||
# via
|
||||
|
||||
@@ -27,7 +27,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)
|
||||
|
||||
add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16)
|
||||
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102")
|
||||
if(GPU_TARGETS MATCHES "gfx11")
|
||||
add_custom_target(example_gemm_wmma)
|
||||
add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp)
|
||||
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
|
||||
@@ -53,13 +53,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64)
|
||||
|
||||
add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp)
|
||||
|
||||
add_example_executable(example_gemm_xdl_fp8 gemm_xdl_fp8.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8)
|
||||
|
||||
add_example_executable(example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8)
|
||||
|
||||
list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
@@ -72,5 +66,12 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
add_example_executable(example_gemm_xdl_fp8 gemm_xdl_fp8.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8)
|
||||
|
||||
add_example_executable(example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8)
|
||||
|
||||
add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8)
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
list(APPEND gpu_list1 gfx1100 gfx1101 gfx1102)
|
||||
list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list1 AND target EQUAL 0)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
@@ -33,4 +33,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_lds_direct_load_fp32)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
endforeach()
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
@@ -6,6 +6,8 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp8 convnd_fwd_xdl_fp8.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_bf8 convnd_fwd_xdl_bf8.cpp)
|
||||
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
|
||||
add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp)
|
||||
set(target 1)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# dlops
|
||||
add_example_executable(example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp)
|
||||
# xdlops
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
list(APPEND gpu_list2 gfx1100 gfx1101 gfx1102)
|
||||
|
||||
set(target 0)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
@@ -13,6 +13,6 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx95" AND NOT GPU_TARGETS MATCHES "gfx1")
|
||||
if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1")
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp)
|
||||
endif()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
@@ -10,6 +10,9 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16)
|
||||
|
||||
add_example_executable(example_splitK_gemm_xdl_fp16_fp8 splitK_gemm_xdl_fp16_fp8.cpp)
|
||||
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16_fp8)
|
||||
|
||||
add_example_executable(example_splitK_gemm_xdl_lds_direct_load_fp16 splitK_gemm_xdl_lds_direct_load_fp16.cpp)
|
||||
add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_lds_direct_load_fp16)
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
list(APPEND gpu_list2 gfx908 gfx90a)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
@@ -13,6 +13,6 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx95" AND NOT GPU_TARGETS MATCHES "gfx1")
|
||||
if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1")
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp)
|
||||
endif()
|
||||
|
||||
@@ -5,6 +5,6 @@ add_example_executable(example_elementwise_permute_4D_fp16_row elementwise_permu
|
||||
add_example_executable(example_elementwise_permute_4D_fp32_col elementwise_permute_4D_fp32_col.cpp)
|
||||
add_example_executable(example_elementwise_permute_4D_fp16_col elementwise_permute_4D_fp16_col.cpp)
|
||||
add_example_executable(example_elementwise_permute elementwise_permute.cpp)
|
||||
if((NOT GPU_TARGETS MATCHES "gfx940") AND (NOT GPU_TARGETS MATCHES "gfx941") AND (NOT GPU_TARGETS MATCHES "gfx942") AND (NOT GPU_TARGETS MATCHES "gfx950"))
|
||||
if((NOT GPU_TARGETS MATCHES "gfx940") AND (NOT GPU_TARGETS MATCHES "gfx941") AND (NOT GPU_TARGETS MATCHES "gfx942"))
|
||||
add_example_executable(example_elementwise_permute_3d elementwise_permute_3d.cpp)
|
||||
endif()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list2 AND target EQUAL 0)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list2 AND target EQUAL 0)
|
||||
|
||||
@@ -45,7 +45,7 @@
|
||||
#endif
|
||||
|
||||
// define general macros for various architectures
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__)
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#define __gfx94__
|
||||
#endif
|
||||
#if defined(__gfx1010__) || defined(__gfx1011__) || defined(__gfx1012__)
|
||||
|
||||
@@ -55,15 +55,14 @@ inline bool is_xdl_supported()
|
||||
{
|
||||
return ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
|
||||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
|
||||
ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950";
|
||||
ck::get_device_name() == "gfx942";
|
||||
}
|
||||
|
||||
inline bool is_lds_direct_load_supported()
|
||||
{
|
||||
// Check if direct loads from global memory to LDS are supported.
|
||||
return ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940" ||
|
||||
ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942" ||
|
||||
ck::get_device_name() == "gfx950";
|
||||
ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942";
|
||||
}
|
||||
|
||||
inline bool is_navi1_supported()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -37,7 +37,9 @@ template <index_t BlockSize,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack>
|
||||
index_t KPack,
|
||||
typename ComputeTypeA = FloatA,
|
||||
typename ComputeTypeB = FloatB>
|
||||
struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -59,7 +61,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
|
||||
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
|
||||
|
||||
static constexpr auto xdlops_gemm = XdlopsGemm<FloatA, MPerXDL, NPerXDL, KPack, FloatB>{};
|
||||
static constexpr auto xdlops_gemm =
|
||||
XdlopsGemm<ComputeTypeA, MPerXDL, NPerXDL, KPack, ComputeTypeB>{};
|
||||
|
||||
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
|
||||
|
||||
@@ -295,9 +298,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
@@ -319,20 +322,20 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
b_thread_buf);
|
||||
|
||||
static_for<0, KPerThread, KPack>{}([&](auto k) {
|
||||
vector_type<FloatA, KPack> a_thread_vec;
|
||||
vector_type<FloatB, KPack> b_thread_vec;
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto i) {
|
||||
a_thread_vec.template AsType<FloatA>()(i) = a_thread_buf
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(i) = a_thread_buf
|
||||
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
|
||||
b_thread_vec.template AsType<FloatB>()(i) = b_thread_buf
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(i) = b_thread_buf
|
||||
[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<FloatA, xdlops_gemm.K1PerXdlops>::type;
|
||||
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<FloatB, xdlops_gemm.K1PerXdlops>::type;
|
||||
typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
@@ -360,7 +363,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
|
||||
FloatA,
|
||||
ComputeTypeA,
|
||||
decltype(a_block_desc_m0_m1_m2_k),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, KPerThread>,
|
||||
@@ -370,7 +373,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
A_K1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB,
|
||||
FloatB,
|
||||
ComputeTypeB,
|
||||
decltype(b_block_desc_n0_n1_n2_k),
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<1, 1, 1, KPerThread>,
|
||||
@@ -398,6 +401,8 @@ template <index_t BlockSize,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
typename ComputeTypeA = FloatA,
|
||||
typename ComputeTypeB = FloatB,
|
||||
index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS>
|
||||
struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
: public BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
@@ -410,7 +415,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
KPack,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>
|
||||
{
|
||||
using Base = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
FloatA,
|
||||
@@ -422,7 +429,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
KPack,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
|
||||
#if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
|
||||
using Base::a_block_desc_m0_m1_m2_k;
|
||||
@@ -446,9 +455,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
static_for<0, KPerThread, KPerInnerLoop>{}([&](auto k) {
|
||||
@@ -485,22 +494,22 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<FloatA, KPack> a_thread_vec;
|
||||
vector_type<FloatB, KPack> b_thread_vec;
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto i) {
|
||||
a_thread_vec.template AsType<FloatA>()(i) =
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(i) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, 0, 0, k_ + i))>{}];
|
||||
b_thread_vec.template AsType<FloatB>()(i) =
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(i) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, 0, 0, k_ + i))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<FloatA, xdlops_gemm.K1PerXdlops>::type;
|
||||
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<FloatB, xdlops_gemm.K1PerXdlops>::type;
|
||||
typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
@@ -550,7 +559,7 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
make_tuple(Number<NRepeat>{}, I1, I1, Number<KPerInnerLoop>{}));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
|
||||
FloatA,
|
||||
ComputeTypeA,
|
||||
decltype(a_block_desc_m0_m1_m2_k),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, KPerInnerLoop>,
|
||||
@@ -560,7 +569,7 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
A_K1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB,
|
||||
FloatB,
|
||||
ComputeTypeB,
|
||||
decltype(b_block_desc_n0_n1_n2_k),
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<1, 1, 1, KPerInnerLoop>,
|
||||
@@ -586,7 +595,9 @@ template <index_t BlockSize,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
LoopScheduler LoopSched>
|
||||
LoopScheduler LoopSched,
|
||||
typename ComputeTypeA = FloatA,
|
||||
typename ComputeTypeB = FloatB>
|
||||
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
|
||||
{
|
||||
if constexpr(LoopSched == LoopScheduler::Default)
|
||||
@@ -601,7 +612,9 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
KPack,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>{};
|
||||
}
|
||||
else if constexpr(LoopSched == LoopScheduler::Interwave)
|
||||
{
|
||||
@@ -615,7 +628,9 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
KPack,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>{};
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -0,0 +1,223 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_description/cluster_descriptor.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
/**
|
||||
* @brief Blockwise data transfer with dequantization
|
||||
*
|
||||
* RunRead would load low-precision data and scale data.
|
||||
* RunWrite would process dequantization process.
|
||||
* Assume Scale is identical along K-dimension
|
||||
*
|
||||
* This version does following things to avoid scratch memory issue
|
||||
* 1. Use StaticallyIndexedArray instead of C array for thread buffer
|
||||
* 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
|
||||
* 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
|
||||
*
|
||||
*/
|
||||
template <typename ThreadGroup,
|
||||
typename SrcElementwiseOperation,
|
||||
typename ScaleElementwiseOperation,
|
||||
typename DstElementwiseOperation,
|
||||
InMemoryDataOperationEnum DstInMemOp,
|
||||
typename BlockSliceLengths,
|
||||
typename BlockScaleSliceLengths,
|
||||
typename ThreadClusterLengths,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename SrcData,
|
||||
typename ScaleData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename ScaleDesc,
|
||||
typename DstDesc,
|
||||
typename SrcDimAccessOrder,
|
||||
typename DstDimAccessOrder,
|
||||
index_t SrcVectorDim,
|
||||
index_t DstVectorDim,
|
||||
index_t SrcScalarPerVector,
|
||||
index_t ScaleScalarPerVector,
|
||||
index_t DstScalarPerVector,
|
||||
index_t SrcScalarStrideInVector,
|
||||
index_t ScaleScalarStrideInVector,
|
||||
index_t DstScalarStrideInVector,
|
||||
bool ThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool ThreadTransferDstResetCoordinateAfterRun,
|
||||
index_t NumThreadScratch = 1>
|
||||
struct ThreadGroupTensorSliceTransfer_v4r1_dequant
|
||||
{
|
||||
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
|
||||
|
||||
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
|
||||
static constexpr auto scale_thread_slice_lengths =
|
||||
BlockScaleSliceLengths{} / ThreadClusterLengths{};
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
__device__ constexpr ThreadGroupTensorSliceTransfer_v4r1_dequant(
|
||||
const SrcDesc& src_desc,
|
||||
const Index& src_block_slice_origin,
|
||||
const SrcElementwiseOperation& src_element_op,
|
||||
const ScaleDesc& scale_desc,
|
||||
const Index& scale_block_slice_origin,
|
||||
const ScaleElementwiseOperation& scale_element_op,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_block_slice_origin,
|
||||
const DstElementwiseOperation& dst_element_op)
|
||||
: threadwise_transfer_(src_desc,
|
||||
make_zero_multi_index<nDim>(),
|
||||
src_element_op,
|
||||
scale_desc,
|
||||
make_zero_multi_index<nDim>(),
|
||||
scale_element_op,
|
||||
dst_desc,
|
||||
make_zero_multi_index<nDim>(),
|
||||
dst_element_op)
|
||||
|
||||
{
|
||||
static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() &&
|
||||
nDim == remove_cvref_t<ScaleDesc>::GetNumOfDimension() &&
|
||||
nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
|
||||
nDim == ThreadClusterLengths::Size() &&
|
||||
nDim == ThreadClusterArrangeOrder::Size() &&
|
||||
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
|
||||
"wrong! nDim not consistent");
|
||||
|
||||
static_assert(
|
||||
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{} &&
|
||||
is_same<BlockScaleSliceLengths,
|
||||
decltype(scale_thread_slice_lengths * ThreadClusterLengths{})>{},
|
||||
"wrong! threads should be mapped to cover entire slicing window");
|
||||
|
||||
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
|
||||
"wrong! ThreadGroup::GetNumOfThread() too small");
|
||||
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(ThreadGroup::GetThreadId()));
|
||||
|
||||
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
|
||||
|
||||
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
|
||||
src_block_slice_origin + thread_data_idx_begin);
|
||||
threadwise_transfer_.SetScaleSliceOrigin(
|
||||
scale_desc, scale_block_slice_origin + thread_data_idx_begin);
|
||||
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
|
||||
dst_block_slice_origin + thread_data_idx_begin);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, index_t ThreadScratchId = 0>
|
||||
__device__ void RunRead(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunRead(src_desc, src_buf, thread_scratch_id);
|
||||
}
|
||||
}
|
||||
|
||||
// With the assumption, scale scratch is always one
|
||||
template <typename ScaleBuffer>
|
||||
__device__ void RunScaleRead(const ScaleDesc& scale_desc, const ScaleBuffer& scale_buf)
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunScaleRead(scale_desc, scale_buf);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DstBuffer, index_t ThreadScratchId = 0>
|
||||
__device__ void RunWrite(const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunWrite(dst_desc, dst_buf, thread_scratch_id);
|
||||
}
|
||||
}
|
||||
|
||||
// We don't prefer use this API directly
|
||||
/*
|
||||
template <typename SrcBuffer, typename DstBuffer, index_t ThreadScratchId>
|
||||
__device__ void Run(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf,
|
||||
Number<ThreadScratchId> thread_scratch_id)
|
||||
{
|
||||
RunRead(src_desc, src_buf, thread_scratch_id);
|
||||
RunWrite(dst_desc, dst_buf, thread_scratch_id);
|
||||
}
|
||||
*/
|
||||
|
||||
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
|
||||
}
|
||||
}
|
||||
|
||||
// With the assumption, scale buffer don't need move slice window method
|
||||
|
||||
private:
|
||||
static constexpr auto thread_cluster_desc_ =
|
||||
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadwiseTransfer =
|
||||
ThreadwiseTensorSliceTransfer_v3r1_dequant<decltype(thread_slice_lengths),
|
||||
decltype(scale_thread_slice_lengths),
|
||||
SrcElementwiseOperation,
|
||||
ScaleElementwiseOperation,
|
||||
DstElementwiseOperation,
|
||||
DstInMemOp,
|
||||
SrcData,
|
||||
ScaleData,
|
||||
DstData,
|
||||
SrcDesc,
|
||||
ScaleDesc,
|
||||
DstDesc,
|
||||
SrcDimAccessOrder,
|
||||
DstDimAccessOrder,
|
||||
SrcVectorDim,
|
||||
DstVectorDim,
|
||||
SrcScalarPerVector,
|
||||
ScaleScalarPerVector,
|
||||
DstScalarPerVector,
|
||||
SrcScalarStrideInVector,
|
||||
ScaleScalarStrideInVector,
|
||||
DstScalarStrideInVector,
|
||||
ThreadTransferSrcResetCoordinateAfterRun,
|
||||
ThreadTransferDstResetCoordinateAfterRun,
|
||||
NumThreadScratch>;
|
||||
|
||||
ThreadwiseTransfer threadwise_transfer_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,193 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_description/cluster_descriptor.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
/**
|
||||
* @brief Blockwise data transfer
|
||||
*
|
||||
* This version does following things to avoid scratch memory issue
|
||||
* 1. Use StaticallyIndexedArray instead of C array for thread buffer
|
||||
* 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
|
||||
* 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
|
||||
*
|
||||
*/
|
||||
template <typename ThreadGroup,
|
||||
typename ElementwiseOperation,
|
||||
typename DstInMemOps, // Sequence
|
||||
typename BlockSliceLengths,
|
||||
typename ThreadClusterLengths,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename SrcDatas,
|
||||
typename DstDatas,
|
||||
typename SrcDescs,
|
||||
typename DstDescs,
|
||||
typename SrcDimAccessOrder,
|
||||
typename DstDimAccessOrder,
|
||||
index_t SrcVectorDim,
|
||||
index_t DstVectorDim,
|
||||
typename SrcsScalarPerVector, // Sequence
|
||||
typename DstsScalarPerVector, // Sequence
|
||||
typename SrcsScalarStrideInVector, // Sequence
|
||||
typename DstsScalarStrideInVector, // Sequence
|
||||
typename ThreadTransferSrcsResetCoordinateAfterRun, // Sequence
|
||||
typename ThreadTransferDstsResetCoordinateAfterRun, // Sequence
|
||||
index_t NumThreadScratch = 1>
|
||||
struct ThreadGroupTensorSliceTransfer_v4r2
|
||||
{
|
||||
static constexpr index_t nDim =
|
||||
remove_reference_t<tuple_element_t<0, SrcDescs>>::GetNumOfDimension();
|
||||
static constexpr index_t nSrc = SrcDescs::Size();
|
||||
static constexpr index_t nDst = DstDescs::Size();
|
||||
|
||||
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
__device__ constexpr ThreadGroupTensorSliceTransfer_v4r2(
|
||||
const SrcDescs& src_descs,
|
||||
const StaticallyIndexedArray<Index, nSrc>& src_block_slice_origins,
|
||||
const DstDescs& dst_descs,
|
||||
const StaticallyIndexedArray<Index, nDst>& dst_block_slice_origins,
|
||||
const ElementwiseOperation& element_op)
|
||||
: threadwise_transfer_(src_descs,
|
||||
StaticallyIndexedArray<Index, nSrc>{},
|
||||
dst_descs,
|
||||
StaticallyIndexedArray<Index, nDst>{},
|
||||
element_op)
|
||||
|
||||
{
|
||||
static_assert(nDim == ThreadClusterLengths::Size() &&
|
||||
nDim == ThreadClusterArrangeOrder::Size() &&
|
||||
nDim == SrcDimAccessOrder::Size() && nDim == SrcDimAccessOrder::Size(),
|
||||
"wrong! nDim not consistent");
|
||||
|
||||
static_for<0, nSrc, 1>{}([&](auto src_i) {
|
||||
static_assert(nDim ==
|
||||
remove_cvref_t<tuple_element_t<src_i, SrcDescs>>::GetNumOfDimension(),
|
||||
"wrong! nDim not consistent");
|
||||
});
|
||||
|
||||
static_for<0, nDst, 1>{}([&](auto dst_i) {
|
||||
static_assert(nDim ==
|
||||
remove_cvref_t<tuple_element_t<dst_i, DstDescs>>::GetNumOfDimension(),
|
||||
"wrong! nDim not consistent");
|
||||
});
|
||||
|
||||
static_assert(
|
||||
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
|
||||
"wrong! threads should be mapped to cover entire slicing window");
|
||||
|
||||
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
|
||||
"wrong! ThreadGroup::GetNumOfThread() too small");
|
||||
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(ThreadGroup::GetThreadId()));
|
||||
|
||||
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
|
||||
|
||||
const auto src_thread_slice_origins = generate_tuple(
|
||||
[&](auto i) { return src_block_slice_origins[i] + thread_data_idx_begin; },
|
||||
Number<nSrc>{});
|
||||
|
||||
const auto dst_thread_slice_origins = generate_tuple(
|
||||
[&](auto i) { return dst_block_slice_origins[i] + thread_data_idx_begin; },
|
||||
Number<nDst>{});
|
||||
|
||||
threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins);
|
||||
threadwise_transfer_.SetDstSliceOrigins(dst_descs, dst_thread_slice_origins);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffers, index_t ThreadScratchId = 0>
|
||||
__device__ void RunRead(const SrcDescs& src_descs,
|
||||
const SrcBuffers& src_bufs,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DstBuffers, index_t ThreadScratchId = 0>
|
||||
__device__ void RunWrite(const DstDescs& dst_descs,
|
||||
DstBuffers& dst_bufs,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunWrite(dst_descs, dst_bufs, thread_scratch_id);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, typename DstBuffer, index_t ThreadScratchId>
|
||||
__device__ void Run(const SrcDescs& src_descs,
|
||||
const SrcBuffer& src_bufs,
|
||||
const DstDescs& dst_descs,
|
||||
DstBuffer& dst_bufs,
|
||||
Number<ThreadScratchId> thread_scratch_id)
|
||||
{
|
||||
RunRead(src_descs, src_bufs, thread_scratch_id);
|
||||
RunWrite(dst_descs, dst_bufs, thread_scratch_id);
|
||||
}
|
||||
|
||||
__device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, const Index& step)
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrcSliceWindow(src_descs, step);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void MoveDstSliceWindow(const DstDescs& dst_descs, const Index& step)
|
||||
{
|
||||
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
|
||||
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveDstSliceWindow(dst_descs, step);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr auto thread_cluster_desc_ =
|
||||
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadwiseTransfer =
|
||||
ThreadwiseTensorSliceTransfer_v3r2<decltype(thread_slice_lengths),
|
||||
ElementwiseOperation,
|
||||
DstInMemOps,
|
||||
SrcDatas,
|
||||
DstDatas,
|
||||
SrcDescs,
|
||||
DstDescs,
|
||||
SrcDimAccessOrder,
|
||||
DstDimAccessOrder,
|
||||
SrcVectorDim,
|
||||
DstVectorDim,
|
||||
SrcsScalarPerVector,
|
||||
DstsScalarPerVector,
|
||||
SrcsScalarStrideInVector,
|
||||
DstsScalarStrideInVector,
|
||||
ThreadTransferSrcsResetCoordinateAfterRun,
|
||||
ThreadTransferDstsResetCoordinateAfterRun,
|
||||
NumThreadScratch>;
|
||||
|
||||
ThreadwiseTransfer threadwise_transfer_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,46 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// Dequantization of input tensor could not be decoupled from gridwisegemm pipeline
|
||||
// As input tensor thread buffer declared inside blockwise-gemm pipeline.
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceGemm_dequantB : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
const void* p_scale,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -62,10 +62,10 @@ template <index_t NumDimG,
|
||||
index_t NumDimK,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
@@ -73,13 +73,14 @@ template <index_t NumDimG,
|
||||
TensorSpecialization ASpec,
|
||||
TensorSpecialization BSpec,
|
||||
TensorSpecialization DESpec,
|
||||
ck::index_t NumPrefetch,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerWMMA,
|
||||
ck::index_t NPerWMMA,
|
||||
ck::index_t MPerWmma,
|
||||
ck::index_t NPerWmma,
|
||||
ck::index_t MRepeat,
|
||||
ck::index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
@@ -100,7 +101,6 @@ template <index_t NumDimG,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
ck::index_t NumPrefetch = 1,
|
||||
ck::LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
|
||||
struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
@@ -123,15 +123,32 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
// K1 = Max Vector Access Pixels
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
|
||||
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
|
||||
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
|
||||
|
||||
static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true;
|
||||
static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true;
|
||||
|
||||
// If true, LDS is used unconditionally
|
||||
static constexpr auto AEnableLds_manu = false;
|
||||
static constexpr auto BEnableLds_manu = false;
|
||||
|
||||
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
|
||||
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock* K1};
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
// Assume: A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
|
||||
static auto MakeAGridDescriptor_M_K(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
|
||||
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
|
||||
static auto MakeAGridDescriptor(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
|
||||
const std::vector<index_t>& a_gs_ms_ks_strides_vec)
|
||||
{
|
||||
assert(a_gs_ms_ks_lengths_vec.size() == NumDimG + NumDimM + NumDimK &&
|
||||
a_gs_ms_ks_strides_vec.size() == NumDimG + NumDimM + NumDimK);
|
||||
@@ -158,36 +175,72 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
// lengths for K0, K1, ...
|
||||
const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds);
|
||||
|
||||
if constexpr(ASpec == TensorSpecialization::Packed)
|
||||
const auto a_grid_desc_m_k = [&]() {
|
||||
if constexpr(ASpec == TensorSpecialization::Packed)
|
||||
{
|
||||
auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{});
|
||||
auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{});
|
||||
const auto a_grid_desc_mraw_kraw = make_naive_tensor_descriptor(
|
||||
make_tuple(M, K),
|
||||
make_tuple(a_ms_ks_strides[Number<NumDimM - 1>{}],
|
||||
a_ms_ks_strides[Number<NumDimM + NumDimK - 1>{}]));
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
}
|
||||
else
|
||||
{
|
||||
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
|
||||
const auto a_grid_desc_ms_ks =
|
||||
make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides);
|
||||
|
||||
// transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
|
||||
const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor(
|
||||
a_grid_desc_ms_ks,
|
||||
make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)),
|
||||
make_tuple(mDimIds, kDimIds),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
}
|
||||
}();
|
||||
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
assert(K % K1 == 0);
|
||||
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{});
|
||||
auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{});
|
||||
const auto a_grid_desc_mraw_kraw = make_naive_tensor_descriptor(
|
||||
make_tuple(M, K),
|
||||
make_tuple(a_ms_ks_strides[Number<NumDimM - 1>{}],
|
||||
a_ms_ks_strides[Number<NumDimM + NumDimK - 1>{}]));
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
|
||||
const auto a_grid_desc_ms_ks =
|
||||
make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides);
|
||||
constexpr auto A_KRow = 2;
|
||||
constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number;
|
||||
const auto A_KWmma = K / WmmaK;
|
||||
|
||||
// transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
|
||||
const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor(
|
||||
a_grid_desc_ms_ks,
|
||||
make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)),
|
||||
make_tuple(mDimIds, kDimIds),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
const auto M0 = M / MPerBlock;
|
||||
// 0 1 0 1 2 3 4 5 6
|
||||
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
A_KWmma, Number<A_K0PerWmma>{}, Number<A_KRow>{}, K1Number)),
|
||||
make_unmerge_transform(
|
||||
make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
|
||||
}
|
||||
}
|
||||
|
||||
// Assume: B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
|
||||
static auto MakeBGridDescriptor_N_K(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
|
||||
const std::vector<index_t>& b_gs_ns_ks_strides_vec)
|
||||
static auto MakeBGridDescriptor(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
|
||||
const std::vector<index_t>& b_gs_ns_ks_strides_vec)
|
||||
{
|
||||
assert(b_gs_ns_ks_lengths_vec.size() == NumDimG + NumDimN + NumDimK &&
|
||||
b_gs_ns_ks_strides_vec.size() == NumDimG + NumDimN + NumDimK);
|
||||
@@ -214,30 +267,66 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
// lengths for N0, N1, ...
|
||||
const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds);
|
||||
|
||||
if constexpr(BSpec == TensorSpecialization::Packed)
|
||||
const auto b_grid_desc_n_k = [&]() {
|
||||
if constexpr(BSpec == TensorSpecialization::Packed)
|
||||
{
|
||||
auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{});
|
||||
auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{});
|
||||
const auto b_grid_desc_nraw_kraw = make_naive_tensor_descriptor(
|
||||
make_tuple(N, K),
|
||||
make_tuple(b_ns_ks_strides[Number<NumDimN - 1>{}],
|
||||
b_ns_ks_strides[Number<NumDimN + NumDimK - 1>{}]));
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
}
|
||||
else
|
||||
{
|
||||
// naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...]
|
||||
const auto b_grid_desc_ns_ks =
|
||||
make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides);
|
||||
|
||||
// transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...]
|
||||
const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor(
|
||||
b_grid_desc_ns_ks,
|
||||
make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)),
|
||||
make_tuple(nDimIds, kDimIds),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
}
|
||||
}();
|
||||
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
assert(K % K1 == 0);
|
||||
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{});
|
||||
auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{});
|
||||
const auto b_grid_desc_nraw_kraw = make_naive_tensor_descriptor(
|
||||
make_tuple(N, K),
|
||||
make_tuple(b_ns_ks_strides[Number<NumDimN - 1>{}],
|
||||
b_ns_ks_strides[Number<NumDimN + NumDimK - 1>{}]));
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...]
|
||||
const auto b_grid_desc_ns_ks =
|
||||
make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides);
|
||||
constexpr auto B_KRow = 2;
|
||||
constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number;
|
||||
const auto B_KWmma = K / WmmaK;
|
||||
|
||||
// transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...]
|
||||
const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor(
|
||||
b_grid_desc_ns_ks,
|
||||
make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)),
|
||||
make_tuple(nDimIds, kDimIds),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
const auto N0 = N / NPerBlock;
|
||||
// 0 1 0 1 2 3 4 5 6
|
||||
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
B_KWmma, Number<B_K0PerWmma>{}, Number<B_KRow>{}, K1Number)),
|
||||
make_unmerge_transform(
|
||||
make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -393,8 +482,6 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
}
|
||||
|
||||
// Gridwise descriptor, mapping to whole given provblem.
|
||||
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K({}, {}));
|
||||
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K({}, {}));
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
|
||||
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
|
||||
|
||||
@@ -449,45 +536,11 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
EGridDesc_G_M_N e_grid_desc_g_m_n_;
|
||||
};
|
||||
|
||||
// A desc for source in blockwise copy
|
||||
template <typename AGridDesc_M_K>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeAGridDescriptor_K0_M_K1(const AGridDesc_M_K& a_grid_desc_m_k)
|
||||
{
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
|
||||
const auto AK0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, K1)), make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
// B desc for source in blockwise copy
|
||||
template <typename BGridDesc_N_K>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeBGridDescriptor_K0_N_K1(const BGridDesc_N_K& b_grid_desc_n_k)
|
||||
{
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
|
||||
const auto BK0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, K1)), make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
using AGridDesc_K0_M_K1 = decltype(DeviceOp::MakeAGridDescriptor_K0_M_K1(AGridDesc_M_K{}));
|
||||
using BGridDesc_K0_N_K1 = decltype(DeviceOp::MakeBGridDescriptor_K0_N_K1(BGridDesc_N_K{}));
|
||||
using AGridDesc = decltype(DeviceOp::MakeAGridDescriptor({}, {}));
|
||||
using BGridDesc = decltype(DeviceOp::MakeBGridDescriptor({}, {}));
|
||||
|
||||
// GridwiseOp
|
||||
using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle<
|
||||
using GridwiseOp = GridwiseGemmMultipleD_Wmma<
|
||||
// DataType Family
|
||||
ADataType,
|
||||
BDataType,
|
||||
@@ -496,8 +549,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
DsDataType,
|
||||
EDataType,
|
||||
// InMemory Data Descriptor
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
AGridDesc,
|
||||
BGridDesc,
|
||||
DsGridDesc_M_N,
|
||||
EGridDesc_M_N,
|
||||
// ElementwiseOp Family
|
||||
@@ -508,9 +561,9 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
// Tiling Family
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
MPerWMMA,
|
||||
NPerWMMA,
|
||||
KPerBlock,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
@@ -523,6 +576,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
AEnableLds,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
@@ -531,6 +585,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BEnableLds,
|
||||
BBlockLdsAddExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
@@ -564,16 +619,14 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
|
||||
p_ds_grid_{},
|
||||
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
|
||||
a_grid_desc_m_k_{},
|
||||
b_grid_desc_n_k_{},
|
||||
a_grid_desc_{},
|
||||
b_grid_desc_{},
|
||||
ds_grid_desc_m_n_{},
|
||||
e_grid_desc_m_n_{},
|
||||
ds_grid_desc_g_m_n_{
|
||||
DeviceOp::MakeDsGridDescriptor_G_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides)},
|
||||
e_grid_desc_g_m_n_{
|
||||
DeviceOp::MakeEGridDescriptor_G_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides)},
|
||||
a_grid_desc_k0_m_k1_{},
|
||||
b_grid_desc_k0_n_k1_{},
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock{},
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock{},
|
||||
block_2_ctile_map_{},
|
||||
@@ -600,10 +653,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
|
||||
});
|
||||
|
||||
a_grid_desc_m_k_ =
|
||||
DeviceOp::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
|
||||
b_grid_desc_n_k_ =
|
||||
DeviceOp::MakeBGridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides);
|
||||
a_grid_desc_ = DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
|
||||
b_grid_desc_ = DeviceOp::MakeBGridDescriptor(b_gs_ns_ks_lengths, b_gs_ns_ks_strides);
|
||||
|
||||
ds_grid_desc_m_n_ =
|
||||
DeviceOp::MakeDsGridDescriptor_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides);
|
||||
@@ -611,9 +662,6 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
e_grid_desc_m_n_ =
|
||||
DeviceOp::MakeEGridDescriptor_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
|
||||
|
||||
a_grid_desc_k0_m_k1_ = DeviceOp::MakeAGridDescriptor_K0_M_K1(a_grid_desc_m_k_);
|
||||
b_grid_desc_k0_n_k1_ = DeviceOp::MakeBGridDescriptor_K0_N_K1(b_grid_desc_n_k_);
|
||||
|
||||
block_2_ctile_map_ = GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01);
|
||||
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
@@ -644,16 +692,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
EDataType* p_e_grid_;
|
||||
|
||||
// Tensor Descriptors
|
||||
AGridDesc_M_K a_grid_desc_m_k_;
|
||||
BGridDesc_N_K b_grid_desc_n_k_;
|
||||
AGridDesc a_grid_desc_;
|
||||
BGridDesc b_grid_desc_;
|
||||
DsGridDesc_M_N ds_grid_desc_m_n_;
|
||||
EGridDesc_M_N e_grid_desc_m_n_;
|
||||
DsGridDesc_G_M_N ds_grid_desc_g_m_n_;
|
||||
EGridDesc_G_M_N e_grid_desc_g_m_n_;
|
||||
|
||||
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
|
||||
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
|
||||
|
||||
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
@@ -686,6 +731,11 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
|
||||
// Batch Offset
|
||||
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
|
||||
|
||||
// for checking vector load/store
|
||||
// index_t MRaw_;
|
||||
// index_t NRaw_;
|
||||
// index_t KRaw_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -700,8 +750,17 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
const index_t grid_size =
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * G;
|
||||
|
||||
const auto K =
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
const auto K = [&]() {
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I2);
|
||||
}
|
||||
else
|
||||
{
|
||||
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) *
|
||||
arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6);
|
||||
}
|
||||
}();
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
constexpr bool has_main_loop = has_main_k_block_loop.value;
|
||||
@@ -712,8 +771,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
BDataType,
|
||||
typename GridwiseOp::DsGridPointer,
|
||||
EDataType,
|
||||
DeviceOp::AGridDesc_K0_M_K1,
|
||||
DeviceOp::BGridDesc_K0_N_K1,
|
||||
DeviceOp::AGridDesc,
|
||||
DeviceOp::BGridDesc,
|
||||
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
AElementwiseOperation,
|
||||
@@ -733,8 +792,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
arg.p_ds_grid_,
|
||||
arg.p_e_grid_,
|
||||
G,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.a_grid_desc_,
|
||||
arg.b_grid_desc_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op_,
|
||||
@@ -774,6 +833,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
printf("DeviceOp: Arch check failure\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -782,12 +842,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!GridwiseOp::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
if(!GridwiseOp::CheckValidity(arg.a_grid_desc_,
|
||||
arg.b_grid_desc_,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_))
|
||||
{
|
||||
printf("GridwiseOp: Validity check failure\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -800,16 +861,18 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
if constexpr(ABlockTransferSrcVectorDim == 1)
|
||||
{
|
||||
if(!(arg.a_mz_stride_ == 1 &&
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0))
|
||||
arg.a_grid_desc_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
printf("DeviceOp: Vector Access A-m check failure\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(!(arg.a_kz_stride_ == 1 &&
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0))
|
||||
arg.a_grid_desc_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
printf("DeviceOp: Vector Access A-k check failure\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -818,16 +881,18 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
if constexpr(BBlockTransferSrcVectorDim == 1)
|
||||
{
|
||||
if(!(arg.b_nz_stride_ == 1 &&
|
||||
arg.b_grid_desc_k0_n_k1_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0))
|
||||
arg.b_grid_desc_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
printf("DeviceOp: Vector Access B-n check failure\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(!(arg.b_kz_stride_ == 1 &&
|
||||
arg.b_grid_desc_k0_n_k1_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0))
|
||||
arg.b_grid_desc_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
printf("DeviceOp: Vector Access B-k check failure\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -841,6 +906,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
CDEShuffleBlockTransferScalarPerVector_NPerBlock ==
|
||||
0))
|
||||
{
|
||||
printf("DeviceOp: Vector Access D-n check failure\n");
|
||||
valid_d_access = false;
|
||||
}
|
||||
});
|
||||
@@ -857,6 +923,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
0) ||
|
||||
CDEShuffleBlockTransferScalarPerVector_NPerBlock == 1))
|
||||
{
|
||||
printf("DeviceOp: Vector Access E-n check failure\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -967,14 +1034,18 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< K0PerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< K1 << ", "
|
||||
<< MPerWMMA << ", "
|
||||
<< NPerWMMA << ", "
|
||||
<< MPerWmma << ", "
|
||||
<< NPerWmma << ", "
|
||||
<< MRepeat << ", "
|
||||
<< NRepeat
|
||||
<< ">"
|
||||
<< " NumPrefetch: "
|
||||
<< " AEnableLds: "
|
||||
<< AEnableLds << ", "
|
||||
<< "BEnableLds: "
|
||||
<< BEnableLds << ", "
|
||||
<< "NumPrefetch: "
|
||||
<< NumPrefetch << ", "
|
||||
<< "LoopScheduler: "
|
||||
<< LoopSchedToString[LoopSched] << ", "
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -602,7 +602,9 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!ck::is_lds_direct_load_supported() && std::is_same<ADataType, double>::value)
|
||||
if(ck::get_device_name() != "gfx90a" && ck::get_device_name() != "gfx940" &&
|
||||
ck::get_device_name() != "gfx941" && ck::get_device_name() != "gfx942" &&
|
||||
std::is_same<ADataType, double>::value)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -294,7 +294,7 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple,
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
if((ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
|
||||
ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"))
|
||||
ck::get_device_name() == "gfx942"))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,422 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/math.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/stream_utility.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename InDataTypeTuple,
|
||||
typename OutDataTypeTuple,
|
||||
typename ElementwiseOperation,
|
||||
index_t NumDim,
|
||||
index_t BlockSize,
|
||||
index_t M0PerBlock,
|
||||
index_t M1PerBlock,
|
||||
index_t M0PerThread,
|
||||
index_t M1PerThread,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename InScalarPerVectorSeq,
|
||||
typename OutScalarPerVectorSeq>
|
||||
struct DeviceElementwiseImpl
|
||||
: public DeviceElementwise<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim>
|
||||
{
|
||||
static constexpr int NumInput = InDataTypeTuple::Size();
|
||||
static constexpr int NumOutput = OutDataTypeTuple::Size();
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
|
||||
NumOutput == OutScalarPerVectorSeq::Size(),
|
||||
"Tuple size is inconsistent with the number of in/out!");
|
||||
|
||||
static auto GenerateInDataTypePointerTuple()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
|
||||
|
||||
return static_cast<const DataType*>(nullptr);
|
||||
},
|
||||
Number<NumInput>{});
|
||||
};
|
||||
|
||||
static auto GenerateOutDataTypePointerTuple()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
|
||||
|
||||
return static_cast<DataType*>(nullptr);
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
};
|
||||
|
||||
using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple());
|
||||
using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple());
|
||||
|
||||
static index_t GetLowestStrideDim(const std::array<index_t, NumDim>& strides)
|
||||
{
|
||||
index_t most_continous_dim = NumDim - 1;
|
||||
index_t most_continous_dim_stride = strides[most_continous_dim];
|
||||
for(index_t dim = 0; dim < NumDim; dim++)
|
||||
{
|
||||
if(strides[dim] < most_continous_dim_stride)
|
||||
{
|
||||
most_continous_dim_stride = strides[dim];
|
||||
most_continous_dim = dim;
|
||||
}
|
||||
}
|
||||
return most_continous_dim;
|
||||
}
|
||||
|
||||
template <typename InOutDescriptor>
|
||||
static auto PadInputOutputDescriptor(const InOutDescriptor& desc)
|
||||
{
|
||||
const auto M0 = desc.GetLength(I0);
|
||||
const auto M1 = desc.GetLength(I1);
|
||||
const auto pad_M0 = math::integer_divide_ceil(M0, M0PerThread) * M0PerThread - M0;
|
||||
const auto pad_M1 = math::integer_divide_ceil(M1, M1PerThread) * M1PerThread - M1;
|
||||
|
||||
const auto padded_desc = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_right_pad_transform(M0, pad_M0), make_right_pad_transform(M1, pad_M1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return padded_desc;
|
||||
}
|
||||
|
||||
static auto GenerateBatchDimsLenghtsTuple(const std::array<index_t, NumDim>& lengths,
|
||||
const index_t M0_dim,
|
||||
const index_t M1_dim)
|
||||
{
|
||||
// Generate batch dims, they will be merged to M0
|
||||
// Add one more dim than needed in case that M0 is equal to M1
|
||||
// If M0 is equal to M1, then will be one more batch dim
|
||||
std::array<index_t, NumDim - 1> batch_dims;
|
||||
index_t batch_dim = 0;
|
||||
for(index_t i = 0; i < NumDim; i++)
|
||||
{
|
||||
if(i != M0_dim && i != M1_dim)
|
||||
{
|
||||
batch_dims[batch_dim] = lengths[i];
|
||||
batch_dim++;
|
||||
}
|
||||
}
|
||||
// Add dummy dim if M0_dim is not equal to M1_dim
|
||||
if(M0_dim != M1_dim && NumDim >= 2)
|
||||
batch_dims[NumDim - 2] = 1;
|
||||
return generate_tuple([&](auto I) { return batch_dims[I]; }, Number<NumDim - 1>{});
|
||||
}
|
||||
|
||||
static auto MakeDescriptor(const std::array<index_t, NumDim>& lengths,
|
||||
const std::array<index_t, NumDim>& in_strides,
|
||||
const std::array<index_t, NumDim>& out_strides,
|
||||
const std::array<index_t, NumDim>& desc_strides)
|
||||
{
|
||||
const auto M0_dim = GetLowestStrideDim(out_strides);
|
||||
const auto M1_dim = GetLowestStrideDim(in_strides);
|
||||
|
||||
// If M0_dim is equal to M1_dim, then make M0_dim dummy
|
||||
const auto M0 = M0_dim == M1_dim ? I1 : lengths[M0_dim];
|
||||
const auto M1 = lengths[M1_dim];
|
||||
const auto M0_stride = M0_dim == M1_dim ? I1 : desc_strides[M0_dim];
|
||||
const auto M1_stride = desc_strides[M1_dim];
|
||||
|
||||
const auto batch_dims_lenghts = GenerateBatchDimsLenghtsTuple(lengths, M0_dim, M1_dim);
|
||||
const auto batch_dims_strides = GenerateBatchDimsLenghtsTuple(desc_strides, M0_dim, M1_dim);
|
||||
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
concat_tuple(batch_dims_lenghts, make_tuple(M0), make_tuple(M1)),
|
||||
concat_tuple(batch_dims_strides, make_tuple(M0_stride), make_tuple(M1_stride)));
|
||||
// Merged batch dims with M0
|
||||
const auto transforms =
|
||||
make_tuple(make_merge_transform(concat_tuple(batch_dims_lenghts, make_tuple(M0))),
|
||||
make_pass_through_transform(M1));
|
||||
using BatchElemsSequence =
|
||||
typename arithmetic_sequence_gen<0, decltype(batch_dims_lenghts)::Size() + 1, 1>::type;
|
||||
const auto lower_dims = make_tuple(BatchElemsSequence{}, Sequence<NumDim>{});
|
||||
const auto upper_dims = make_tuple(Sequence<0>{}, Sequence<1>{});
|
||||
// desc: (merged_dims + M0, M1)
|
||||
auto merged_desc = transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims);
|
||||
return PadInputOutputDescriptor(merged_desc);
|
||||
}
|
||||
|
||||
template <index_t NumTensors>
|
||||
static auto GenerateInOutGridDescTuple()
|
||||
{
|
||||
std::array<index_t, NumDim> ones;
|
||||
for(index_t d = 0; d < NumDim; d++)
|
||||
{
|
||||
ones[d] = 1;
|
||||
}
|
||||
|
||||
return generate_tuple([&](auto) { return MakeDescriptor(ones, ones, ones, ones); },
|
||||
Number<NumTensors>{});
|
||||
};
|
||||
|
||||
using InGridDescTuple = decltype(GenerateInOutGridDescTuple<NumInput>());
|
||||
using OutGridDescTuple = decltype(GenerateInOutGridDescTuple<NumOutput>());
|
||||
|
||||
using Block2TileMap = BlockToCTileMap_M00_N0_M01Adapt<M0PerBlock, M1PerBlock>;
|
||||
|
||||
using GridwiseElementwiseOp = GridwiseElementwise<InGridDescTuple,
|
||||
OutGridDescTuple,
|
||||
InDataTypePointerTuple,
|
||||
OutDataTypePointerTuple,
|
||||
Block2TileMap,
|
||||
ElementwiseOperation,
|
||||
BlockSize,
|
||||
M0PerBlock,
|
||||
M1PerBlock,
|
||||
M0PerThread,
|
||||
M1PerThread,
|
||||
ThreadClusterArrangeOrder,
|
||||
InScalarPerVectorSeq,
|
||||
OutScalarPerVectorSeq,
|
||||
false>;
|
||||
|
||||
using GridwiseElementwiseOpSameInOutVectorDim = GridwiseElementwise<InGridDescTuple,
|
||||
OutGridDescTuple,
|
||||
InDataTypePointerTuple,
|
||||
OutDataTypePointerTuple,
|
||||
Block2TileMap,
|
||||
ElementwiseOperation,
|
||||
BlockSize,
|
||||
M0PerBlock,
|
||||
M1PerBlock,
|
||||
M0PerThread,
|
||||
M1PerThread,
|
||||
ThreadClusterArrangeOrder,
|
||||
InScalarPerVectorSeq,
|
||||
OutScalarPerVectorSeq,
|
||||
true>;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const std::array<index_t, NumDim> lengths,
|
||||
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
|
||||
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
|
||||
const std::array<const void*, NumInput> in_dev_buffers,
|
||||
const std::array<void*, NumOutput> out_dev_buffers,
|
||||
ElementwiseOperation elementwise_op)
|
||||
|
||||
: lengths_(lengths),
|
||||
inStridesArray_(inStridesArray),
|
||||
outStridesArray_(outStridesArray),
|
||||
elementwise_op_(elementwise_op)
|
||||
{
|
||||
in_dev_buffers_ = generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
|
||||
return static_cast<const DataType*>(in_dev_buffers[I.value]);
|
||||
},
|
||||
Number<NumInput>{});
|
||||
|
||||
out_dev_buffers_ = generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
|
||||
return static_cast<DataType*>(out_dev_buffers[I.value]);
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
}
|
||||
|
||||
InDataTypePointerTuple in_dev_buffers_;
|
||||
OutDataTypePointerTuple out_dev_buffers_;
|
||||
|
||||
std::array<index_t, NumDim> lengths_;
|
||||
std::array<std::array<index_t, NumDim>, NumInput> inStridesArray_;
|
||||
std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray_;
|
||||
|
||||
ElementwiseOperation elementwise_op_;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
auto in_grid_desc_tuple = generate_tuple(
|
||||
[&](auto src_i) {
|
||||
// Use Strides from first tensor to assert that M0 dim and
|
||||
// M1 dim are the same for each tensor.
|
||||
return MakeDescriptor(arg.lengths_,
|
||||
arg.inStridesArray_[I0],
|
||||
arg.outStridesArray_[I0],
|
||||
arg.inStridesArray_[src_i]);
|
||||
},
|
||||
Number<NumInput>{});
|
||||
|
||||
auto out_grid_desc_tuple = generate_tuple(
|
||||
[&](auto dst_i) {
|
||||
return MakeDescriptor(arg.lengths_,
|
||||
arg.inStridesArray_[I0],
|
||||
arg.outStridesArray_[I0],
|
||||
arg.outStridesArray_[dst_i]);
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
|
||||
const index_t M0 = in_grid_desc_tuple.At(I0).GetLength(Number<I0>{});
|
||||
const index_t M1 = in_grid_desc_tuple.At(I0).GetLength(Number<I1>{});
|
||||
|
||||
const auto block_2_tile_map = Block2TileMap(M0, M1);
|
||||
const index_t grid_size = block_2_tile_map.CalculateGridSize(M0, M1);
|
||||
|
||||
const bool in_out_same_vector_dim = GetLowestStrideDim(arg.inStridesArray_[I0]) ==
|
||||
GetLowestStrideDim(arg.outStridesArray_[I0]);
|
||||
|
||||
const auto kernel = in_out_same_vector_dim
|
||||
? kernel_elementwise<GridwiseElementwiseOpSameInOutVectorDim,
|
||||
InGridDescTuple,
|
||||
OutGridDescTuple,
|
||||
InDataTypePointerTuple,
|
||||
OutDataTypePointerTuple,
|
||||
Block2TileMap,
|
||||
ElementwiseOperation>
|
||||
: kernel_elementwise<GridwiseElementwiseOp,
|
||||
InGridDescTuple,
|
||||
OutGridDescTuple,
|
||||
InDataTypePointerTuple,
|
||||
OutDataTypePointerTuple,
|
||||
Block2TileMap,
|
||||
ElementwiseOperation>;
|
||||
|
||||
float elapsed_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
in_grid_desc_tuple,
|
||||
out_grid_desc_tuple,
|
||||
arg.in_dev_buffers_,
|
||||
arg.out_dev_buffers_,
|
||||
block_2_tile_map,
|
||||
arg.elementwise_op_);
|
||||
return elapsed_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
const index_t M0_dim = GetLowestStrideDim(arg.inStridesArray_[I0]);
|
||||
const index_t M1_dim = GetLowestStrideDim(arg.outStridesArray_[I0]);
|
||||
|
||||
auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths,
|
||||
const std::array<index_t, NumDim>& strides,
|
||||
index_t scalarPerVector,
|
||||
index_t M_dim) {
|
||||
if(scalarPerVector == 1)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
if(strides[M_dim] == 1 && lengths[M_dim] % scalarPerVector == 0)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
bool is_valid = true;
|
||||
static_for<0, NumInput, 1>{}([&](auto I) {
|
||||
static_assert(M0PerThread % InScalarPerVectorSeq::At(I) == 0 &&
|
||||
M1PerThread % InScalarPerVectorSeq::At(I) == 0);
|
||||
is_valid &= IsScalarPerVectorValid(
|
||||
arg.lengths_, arg.inStridesArray_[I.value], InScalarPerVectorSeq::At(I), M0_dim);
|
||||
});
|
||||
|
||||
static_for<0, NumOutput, 1>{}([&](auto I) {
|
||||
static_assert(M0PerThread % OutScalarPerVectorSeq::At(I) == 0 &&
|
||||
M1PerThread % OutScalarPerVectorSeq::At(I) == 0);
|
||||
is_valid &= IsScalarPerVectorValid(
|
||||
arg.lengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I), M1_dim);
|
||||
});
|
||||
|
||||
return is_valid;
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto
|
||||
MakeArgument(const std::array<index_t, NumDim> lengths,
|
||||
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
|
||||
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
|
||||
const std::array<const void*, NumInput> in_dev_buffers,
|
||||
const std::array<void*, NumOutput> out_dev_buffers,
|
||||
ElementwiseOperation elementwise_op)
|
||||
{
|
||||
return Argument{lengths,
|
||||
inStridesArray,
|
||||
outStridesArray,
|
||||
in_dev_buffers,
|
||||
out_dev_buffers,
|
||||
elementwise_op};
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
|
||||
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
|
||||
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
|
||||
const std::array<const void*, NumInput> in_dev_buffers,
|
||||
const std::array<void*, NumOutput> out_dev_buffers,
|
||||
ElementwiseOperation elementwise_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(lengths,
|
||||
inStridesArray,
|
||||
outStridesArray,
|
||||
in_dev_buffers,
|
||||
out_dev_buffers,
|
||||
elementwise_op);
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
};
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceElementwiseImpl<";
|
||||
str << NumDim << ", ";
|
||||
str << BlockSize << ", ";
|
||||
str << M0PerBlock << ", ";
|
||||
str << M1PerBlock << ", ";
|
||||
str << M0PerThread << ", ";
|
||||
str << M1PerThread << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -322,6 +322,19 @@ struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple,
|
||||
{
|
||||
return std::make_unique<Invoker>();
|
||||
};
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceElementwiseNormalizationImpl<";
|
||||
str << NumDim << ", ";
|
||||
str << MPerThread << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
}; // namespace device
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -0,0 +1,714 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_dequantB.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// 1. DequantB(K, N) = int2fp(B(K, N)) * scale(1, N)
|
||||
// 2. C(M, N) = A(M, K) * DequantB(K, N)
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ScaleDataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::index_t NumPrefetch,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerWmma,
|
||||
ck::index_t NPerWmma,
|
||||
ck::index_t MRepeat,
|
||||
ck::index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool ABlockLdsAddExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BBlockLdsAddExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
ck::LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
ck::PipelineVersion PipelineVer = ck::PipelineVersion::weight_only>
|
||||
struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
// K1 = Max Vector Access Pixels
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
|
||||
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
|
||||
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
|
||||
|
||||
static constexpr auto AEnableLds_auto =
|
||||
(NWaves == 1 && is_same<tensor_layout::gemm::RowMajor, ALayout>::value) ? false : true;
|
||||
static constexpr auto BEnableLds_auto =
|
||||
(MWaves == 1 && is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) ? false : true;
|
||||
|
||||
// If true, LDS is used unconditionally
|
||||
// LDS bypass feature not implemented for dequantization pipeline.
|
||||
static constexpr auto AEnableLds_manu = true;
|
||||
static constexpr auto BEnableLds_manu = true;
|
||||
|
||||
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
|
||||
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
using DeviceOp = DeviceFpAintBGemm_Wmma_CShuffle;
|
||||
|
||||
// Describe how data read from Global memory
|
||||
static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
{
|
||||
const auto a_grid_desc_m_k = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw =
|
||||
make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(StrideA, I1));
|
||||
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw =
|
||||
make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(I1, StrideA));
|
||||
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
}
|
||||
}();
|
||||
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
assert(K % K1 == 0);
|
||||
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto A_KRow = 2;
|
||||
constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number;
|
||||
const auto A_KWmma = K / WmmaK;
|
||||
|
||||
const auto M0 = M / MPerBlock;
|
||||
// 0 1 0 1 2 3 4 5 6
|
||||
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
A_KWmma, Number<A_K0PerWmma>{}, Number<A_KRow>{}, K1Number)),
|
||||
make_unmerge_transform(
|
||||
make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
|
||||
}
|
||||
}
|
||||
|
||||
static auto MakeBGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
{
|
||||
const auto b_grid_desc_n_k = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
const auto b_grid_desc_nraw_kraw =
|
||||
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(I1, StrideB));
|
||||
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
const auto b_grid_desc_nraw_kraw =
|
||||
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(StrideB, I1));
|
||||
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
}
|
||||
}();
|
||||
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
assert(K % K1 == 0);
|
||||
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto B_KRow = 2;
|
||||
constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number;
|
||||
const auto B_KWmma = K / WmmaK;
|
||||
|
||||
const auto N0 = N / NPerBlock;
|
||||
// 0 1 0 1 2 3 4 5 6
|
||||
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
B_KWmma, Number<B_K0PerWmma>{}, Number<B_KRow>{}, K1Number)),
|
||||
make_unmerge_transform(
|
||||
make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
|
||||
}
|
||||
}
|
||||
|
||||
static auto MakeScaleGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB = 0)
|
||||
{
|
||||
// assume Scale is [1, N]
|
||||
const auto scale_grid_desc_n_k = [&]() {
|
||||
const auto scale_grid_desc_nraw_kraw =
|
||||
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(I1, StrideB));
|
||||
|
||||
return matrix_padder.PadBDescriptor_N_K(scale_grid_desc_nraw_kraw);
|
||||
}();
|
||||
|
||||
const auto N = scale_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = scale_grid_desc_n_k.GetLength(I1);
|
||||
// When K = 1, it might be scale tensor.
|
||||
assert(K % K1 == 0 && K != 1);
|
||||
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
scale_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, 1)), // Reduce K1 = 1
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto B_KRow = 2;
|
||||
constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number;
|
||||
const auto B_KWmma = K / WmmaK;
|
||||
|
||||
const auto N0 = N / NPerBlock;
|
||||
// 0 1 0 1 2 3 4 5 6
|
||||
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
|
||||
return transform_tensor_descriptor(
|
||||
scale_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
B_KWmma, Number<B_K0PerWmma>{}, Number<B_KRow>{}, K1Number)),
|
||||
make_unmerge_transform(
|
||||
make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
|
||||
}
|
||||
}
|
||||
|
||||
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
|
||||
{
|
||||
const auto c_grid_desc_mraw_nraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
|
||||
make_tuple(StrideC, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
|
||||
make_tuple(I1, StrideC));
|
||||
}
|
||||
}();
|
||||
|
||||
return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
|
||||
}
|
||||
|
||||
// Gridwise descriptor, mapping to whole given provblem.
|
||||
using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1));
|
||||
using BGridDesc = decltype(MakeBGridDescriptor(1, 1, 1));
|
||||
using ScaleGridDesc = decltype(MakeScaleGridDescriptor(1, 1, 0));
|
||||
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseFpAintBGemm_Wmma<
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ScaleDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc,
|
||||
BGridDesc,
|
||||
ScaleGridDesc,
|
||||
CGridDesc_M_N,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
AEnableLds,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BEnableLds,
|
||||
BBlockLdsAddExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
NumPrefetch,
|
||||
LoopSched,
|
||||
PipelineVer>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const ADataType* p_a_grid,
|
||||
const BDataType* p_b_grid,
|
||||
const ScaleDataType* p_scale_grid,
|
||||
CDataType* p_c_grid,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t M01,
|
||||
index_t N01,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
: p_a_grid_{p_a_grid},
|
||||
p_b_grid_{p_b_grid},
|
||||
p_scale_grid_{p_scale_grid},
|
||||
p_c_grid_{p_c_grid},
|
||||
a_grid_desc_{},
|
||||
b_grid_desc_{},
|
||||
scale_grid_desc_{},
|
||||
c_grid_desc_m_n_{},
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock{},
|
||||
block_2_ctile_map_{},
|
||||
M01_{M01},
|
||||
N01_{N01},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op},
|
||||
MRaw_{M},
|
||||
NRaw_{N},
|
||||
KRaw_{K}
|
||||
{
|
||||
a_grid_desc_ = DeviceOp::MakeAGridDescriptor(M, K, StrideA);
|
||||
b_grid_desc_ = DeviceOp::MakeBGridDescriptor(K, N, StrideB);
|
||||
scale_grid_desc_ = DeviceOp::MakeScaleGridDescriptor(K, N, 0);
|
||||
c_grid_desc_m_n_ = DeviceOp::MakeCGridDescriptor_M_N(M, N, StrideC);
|
||||
|
||||
block_2_ctile_map_ =
|
||||
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
|
||||
|
||||
if(GridwiseGemm::CheckValidity(
|
||||
a_grid_desc_, b_grid_desc_, c_grid_desc_m_n_, block_2_ctile_map_))
|
||||
{
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n_);
|
||||
}
|
||||
}
|
||||
|
||||
// private:
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
const ScaleDataType* p_scale_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
AGridDesc a_grid_desc_;
|
||||
BGridDesc b_grid_desc_;
|
||||
ScaleGridDesc scale_grid_desc_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
|
||||
index_t M01_;
|
||||
index_t N01_;
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
// for checking vector load/store
|
||||
index_t MRaw_;
|
||||
index_t NRaw_;
|
||||
index_t KRaw_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_,
|
||||
arg.b_grid_desc_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemm_k0mk1_k0nk1_m0nm1_wmma_v1r1 has invalid setting");
|
||||
}
|
||||
|
||||
const index_t grid_size =
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
|
||||
|
||||
const auto K = [&]() {
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I2);
|
||||
}
|
||||
else
|
||||
{
|
||||
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) *
|
||||
arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6);
|
||||
}
|
||||
}();
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
const auto kernel = kernel_fpAintB_gemm_wmma<
|
||||
GridwiseGemm,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ScaleDataType,
|
||||
CDataType,
|
||||
remove_reference_t<DeviceOp::AGridDesc>,
|
||||
remove_reference_t<DeviceOp::BGridDesc>,
|
||||
remove_reference_t<DeviceOp::ScaleGridDesc>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
|
||||
has_main_k_block_loop>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_scale_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_,
|
||||
arg.b_grid_desc_,
|
||||
arg.scale_grid_desc_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
};
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, false>{});
|
||||
}
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::is_navi3_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> ||
|
||||
is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
printf("DeviceOp err: AccDataType");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("DeviceOp err: Arch");
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector load/store
|
||||
{
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
// check vector load of A
|
||||
if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
|
||||
{
|
||||
if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
|
||||
{
|
||||
// FIXME: not rigorous
|
||||
if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector laod of B
|
||||
if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
|
||||
{
|
||||
if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
|
||||
{
|
||||
// FIXME: not rigorous
|
||||
if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector store of C
|
||||
// only support RowMajor for now
|
||||
if constexpr(is_same_v<CLayout, Row>)
|
||||
{
|
||||
if(arg.NRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(
|
||||
arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_m_n_, arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
const ScaleDataType* p_scale,
|
||||
CDataType* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_scale,
|
||||
p_c,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
1,
|
||||
1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
const void* p_scale,
|
||||
void* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<const ScaleDataType*>(p_scale),
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
1,
|
||||
1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<LoopScheduler, std::string> LoopSchedToString{
|
||||
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<PipelineVersion, std::string> PipelineVersionToString{
|
||||
{PipelineVersion::v1, "v1"},
|
||||
{PipelineVersion::v2, "v2"},
|
||||
{PipelineVersion::weight_only, "weight_only"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceFpAintBGemm_Wmma_CShuffle"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< K1 << ", "
|
||||
<< MPerWmma << ", "
|
||||
<< NPerWmma << ", "
|
||||
<< MRepeat << ", "
|
||||
<< NRepeat
|
||||
<< ">"
|
||||
<< " AEnableLds: "
|
||||
<< AEnableLds << ", "
|
||||
<< "BEnableLds: "
|
||||
<< BEnableLds << ", "
|
||||
<< "NumPrefetch: "
|
||||
<< NumPrefetch << ", "
|
||||
<< "LoopScheduler: "
|
||||
<< LoopSchedToString[LoopSched] << ", "
|
||||
<< "PipelineVersion: "
|
||||
<< PipelineVersionToString[PipelineVer];
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -27,21 +28,22 @@ template <typename ALayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::index_t NumPrefetch,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerWMMA,
|
||||
ck::index_t NPerWMMA,
|
||||
ck::index_t MPerWmma,
|
||||
ck::index_t NPerWmma,
|
||||
ck::index_t MRepeat,
|
||||
ck::index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
@@ -62,7 +64,6 @@ template <typename ALayout,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
ck::index_t NumPrefetch = 1,
|
||||
ck::LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
|
||||
struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
@@ -83,68 +84,139 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
// K1 = Max Vector Access Pixels
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock* K1};
|
||||
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
|
||||
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
|
||||
|
||||
static auto MakeAGridDescriptor_K0_M_K1(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
static constexpr auto AEnableLds_auto =
|
||||
(NWaves == 1 && is_same<tensor_layout::gemm::RowMajor, ALayout>::value) ? false : true;
|
||||
static constexpr auto BEnableLds_auto =
|
||||
(MWaves == 1 && is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) ? false : true;
|
||||
|
||||
// If true, LDS is used unconditionally
|
||||
static constexpr auto AEnableLds_manu = false;
|
||||
static constexpr auto BEnableLds_manu = false;
|
||||
|
||||
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
|
||||
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
// Describe how data read from Global memory
|
||||
static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw = [&]() {
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
const auto a_grid_desc_m_k = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(StrideA, I1));
|
||||
const auto a_grid_desc_mraw_kraw =
|
||||
make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(StrideA, I1));
|
||||
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(I1, StrideA));
|
||||
const auto a_grid_desc_mraw_kraw =
|
||||
make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(I1, StrideA));
|
||||
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
}
|
||||
}();
|
||||
|
||||
const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
assert(K % K1 == 0);
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto A_KRow = 2;
|
||||
constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number;
|
||||
const auto A_KWmma = K / WmmaK;
|
||||
|
||||
const auto M0 = M / MPerBlock;
|
||||
// 0 1 0 1 2 3 4 5 6
|
||||
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
A_KWmma, Number<A_K0PerWmma>{}, Number<A_KRow>{}, K1Number)),
|
||||
make_unmerge_transform(
|
||||
make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
|
||||
}
|
||||
}
|
||||
|
||||
static auto MakeBGridDescriptor_K0_N_K1(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
static auto MakeBGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
{
|
||||
const auto b_grid_desc_nraw_kraw = [&]() {
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
const auto b_grid_desc_n_k = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(I1, StrideB));
|
||||
const auto b_grid_desc_nraw_kraw =
|
||||
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(I1, StrideB));
|
||||
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(StrideB, I1));
|
||||
const auto b_grid_desc_nraw_kraw =
|
||||
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(StrideB, I1));
|
||||
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
}
|
||||
}();
|
||||
|
||||
const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
assert(K % K1 == 0);
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto B_KRow = 2;
|
||||
constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number;
|
||||
const auto B_KWmma = K / WmmaK;
|
||||
|
||||
const auto N0 = N / NPerBlock;
|
||||
// 0 1 0 1 2 3 4 5 6
|
||||
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
B_KWmma, Number<B_K0PerWmma>{}, Number<B_KRow>{}, K1Number)),
|
||||
make_unmerge_transform(
|
||||
make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ELayout_>
|
||||
@@ -180,13 +252,13 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
}
|
||||
|
||||
// Gridwise descriptor, mapping to whole given provblem.
|
||||
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
|
||||
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
|
||||
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
|
||||
using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1));
|
||||
using BGridDesc = decltype(MakeBGridDescriptor(1, 1, 1));
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
|
||||
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
|
||||
|
||||
// GridwiseOp
|
||||
using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle<
|
||||
using GridwiseOp = GridwiseGemmMultipleD_Wmma<
|
||||
// DataType Family
|
||||
ADataType,
|
||||
BDataType,
|
||||
@@ -195,8 +267,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
// InMemory Data Descriptor
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
AGridDesc,
|
||||
BGridDesc,
|
||||
DsGridDesc_M_N,
|
||||
EGridDesc_M_N,
|
||||
// ElementwiseOp Family
|
||||
@@ -207,9 +279,9 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
// Tiling Family
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
MPerWMMA,
|
||||
NPerWMMA,
|
||||
KPerBlock,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
@@ -222,6 +294,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
AEnableLds,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
@@ -230,6 +303,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BEnableLds,
|
||||
BBlockLdsAddExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
@@ -262,8 +336,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
|
||||
p_ds_grid_{},
|
||||
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
|
||||
a_grid_desc_k0_m_k1_{},
|
||||
b_grid_desc_k0_n_k1_{},
|
||||
a_grid_desc{},
|
||||
b_grid_desc{},
|
||||
ds_grid_desc_m_n_{},
|
||||
e_grid_desc_m_n_{},
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock{},
|
||||
@@ -278,8 +352,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
NRaw_{N},
|
||||
KRaw_{K}
|
||||
{
|
||||
a_grid_desc_k0_m_k1_ = DeviceOp::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
|
||||
b_grid_desc_k0_n_k1_ = DeviceOp::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
|
||||
a_grid_desc = DeviceOp::MakeAGridDescriptor(M, K, StrideA);
|
||||
b_grid_desc = DeviceOp::MakeBGridDescriptor(K, N, StrideB);
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
@@ -295,8 +369,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
|
||||
block_2_ctile_map_ = GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01);
|
||||
|
||||
if(GridwiseOp::CheckValidity(a_grid_desc_k0_m_k1_,
|
||||
b_grid_desc_k0_n_k1_,
|
||||
if(GridwiseOp::CheckValidity(a_grid_desc,
|
||||
b_grid_desc,
|
||||
ds_grid_desc_m_n_,
|
||||
e_grid_desc_m_n_,
|
||||
block_2_ctile_map_))
|
||||
@@ -318,8 +392,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
EDataType* p_e_grid_;
|
||||
|
||||
// Tensor Descriptors
|
||||
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
|
||||
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
|
||||
AGridDesc a_grid_desc;
|
||||
BGridDesc b_grid_desc;
|
||||
DsGridDesc_M_N ds_grid_desc_m_n_;
|
||||
EGridDesc_M_N e_grid_desc_m_n_;
|
||||
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
@@ -352,24 +426,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
#if 0
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
|
||||
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
|
||||
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0)
|
||||
<< ", " << arg.c_grid_desc_m_n_.GetLength(I1) << ", "
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I2) << "}" << std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
if(!GridwiseOp::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
if(!GridwiseOp::CheckValidity(arg.a_grid_desc,
|
||||
arg.b_grid_desc,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_))
|
||||
@@ -381,91 +439,64 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
const index_t grid_size =
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
|
||||
|
||||
const auto K =
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
const auto K = [&]() {
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
return arg.a_grid_desc.GetLength(I0) * arg.a_grid_desc.GetLength(I2);
|
||||
}
|
||||
else
|
||||
{
|
||||
return arg.a_grid_desc.GetLength(I0) * arg.a_grid_desc.GetLength(I3) *
|
||||
arg.a_grid_desc.GetLength(I4) * arg.a_grid_desc.GetLength(I6);
|
||||
}
|
||||
}();
|
||||
|
||||
float ave_time = 0;
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
const auto kernel = kernel_gemm_mupltipe_d_wmma_cshuffle<
|
||||
GridwiseOp,
|
||||
ADataType,
|
||||
BDataType,
|
||||
typename GridwiseOp::DsGridPointer,
|
||||
EDataType,
|
||||
remove_reference_t<typename DeviceOp::AGridDesc>,
|
||||
remove_reference_t<typename DeviceOp::BGridDesc>,
|
||||
remove_reference_t<
|
||||
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
remove_reference_t<
|
||||
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseOp::DefaultBlock2CTileMap>,
|
||||
has_main_k_block_loop>; // Last Option is W/O
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.p_e_grid_,
|
||||
arg.a_grid_desc,
|
||||
arg.b_grid_desc,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
};
|
||||
|
||||
if(GridwiseOp::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
const auto kernel = kernel_gemm_mupltipe_d_wmma_cshuffle<
|
||||
GridwiseOp,
|
||||
ADataType,
|
||||
BDataType,
|
||||
typename GridwiseOp::DsGridPointer,
|
||||
EDataType,
|
||||
remove_reference_t<typename DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<typename DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
remove_reference_t<
|
||||
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseOp::DefaultBlock2CTileMap>,
|
||||
true>; // Last Option is W/O
|
||||
|
||||
ave_time =
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.p_e_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
return launch_kernel(integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_mupltipe_d_wmma_cshuffle<
|
||||
GridwiseOp,
|
||||
ADataType,
|
||||
BDataType,
|
||||
typename GridwiseOp::DsGridPointer,
|
||||
EDataType,
|
||||
remove_reference_t<typename DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<typename DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
remove_reference_t<
|
||||
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseOp::DefaultBlock2CTileMap>,
|
||||
false>;
|
||||
|
||||
ave_time =
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.p_e_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
return launch_kernel(integral_constant<bool, false>{});
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
@@ -575,8 +606,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
}
|
||||
}
|
||||
|
||||
return GridwiseOp::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
return GridwiseOp::CheckValidity(arg.a_grid_desc,
|
||||
arg.b_grid_desc,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_);
|
||||
@@ -681,14 +712,18 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< K0PerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< K1 << ", "
|
||||
<< MPerWMMA << ", "
|
||||
<< NPerWMMA << ", "
|
||||
<< MPerWmma << ", "
|
||||
<< NPerWmma << ", "
|
||||
<< MRepeat << ", "
|
||||
<< NRepeat
|
||||
<< ">"
|
||||
<< " NumPrefetch: "
|
||||
<< " AEnableLds: "
|
||||
<< AEnableLds << ", "
|
||||
<< "BEnableLds: "
|
||||
<< BEnableLds << ", "
|
||||
<< "NumPrefetch: "
|
||||
<< NumPrefetch << ", "
|
||||
<< "LoopScheduler: "
|
||||
<< LoopSchedToString[LoopSched] << ", "
|
||||
|
||||
@@ -498,6 +498,86 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_)
|
||||
{
|
||||
// check vector load/store
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
// check vector load of A
|
||||
if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
|
||||
{
|
||||
if(KRaw_ % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
|
||||
{
|
||||
// FIXME: not rigorous
|
||||
if(MRaw_ % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
// check vector laod of B
|
||||
if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
|
||||
{
|
||||
if(KRaw_ % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
|
||||
{
|
||||
// FIXME: not rigorous
|
||||
if(NRaw_ % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector load of Ds
|
||||
// only support RowMajor for now
|
||||
bool all_valid = true;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
if constexpr(!is_same_v<DLayout, Row>)
|
||||
{
|
||||
all_valid = false;
|
||||
}
|
||||
});
|
||||
|
||||
if(!all_valid)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector store of E
|
||||
// only support RowMajor for now
|
||||
if constexpr(is_same_v<ELayout, Row>)
|
||||
{
|
||||
if(NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_xdl_supported())
|
||||
@@ -505,87 +585,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector load/store
|
||||
{
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
// check vector load of A
|
||||
if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
|
||||
{
|
||||
if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
|
||||
{
|
||||
// FIXME: not rigorous
|
||||
if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector laod of B
|
||||
if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
|
||||
{
|
||||
if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
|
||||
{
|
||||
// FIXME: not rigorous
|
||||
if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector load of Ds
|
||||
// only support RowMajor for now
|
||||
bool all_valid = true;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
if constexpr(!is_same_v<DLayout, Row>)
|
||||
{
|
||||
all_valid = false;
|
||||
}
|
||||
});
|
||||
|
||||
if(!all_valid)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector store of E
|
||||
// only support RowMajor for now
|
||||
if constexpr(is_same_v<ELayout, Row>)
|
||||
{
|
||||
if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
|
||||
return IsSupported(arg.MRaw_, arg.NRaw_, arg.KRaw_) and
|
||||
GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
|
||||
arg.b_grid_desc_n_k_,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
@@ -708,6 +709,178 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
|
||||
return str.str();
|
||||
}
|
||||
|
||||
template <class ADesc, class BDesc, class DsDesc, class EDesc>
|
||||
struct Descriptor
|
||||
{
|
||||
static constexpr auto ds_tuple()
|
||||
{
|
||||
return transform_tuples(
|
||||
[&](auto d) constexpr { return DeviceOp::matrix_padder.PadCDescriptor_M_N(d); },
|
||||
DsDesc{});
|
||||
}
|
||||
using AGridDesc_M_K =
|
||||
remove_cvref_t<decltype(DeviceOp::matrix_padder.PadADescriptor_M_K(ADesc{}))>;
|
||||
using BGridDesc_N_K =
|
||||
remove_cvref_t<decltype(DeviceOp::matrix_padder.PadBDescriptor_N_K(BDesc{}))>;
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(ds_tuple())>;
|
||||
using EGridDesc_M_N =
|
||||
remove_cvref_t<decltype(DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{}))>;
|
||||
using AGridDesc_AK0_M_AK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
|
||||
DeviceOp::matrix_padder.PadADescriptor_M_K(ADesc{})))>;
|
||||
using BGridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
|
||||
DeviceOp::matrix_padder.PadBDescriptor_N_K(BDesc{})))>;
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
|
||||
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_tuple()))>;
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
|
||||
decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{})))>;
|
||||
using Block2ETileMap = remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(
|
||||
DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{})))>;
|
||||
|
||||
// tensor descriptors for problem definiton
|
||||
AGridDesc_M_K a_grid_desc_m_k;
|
||||
BGridDesc_N_K b_grid_desc_n_k;
|
||||
DsGridDesc_M_N ds_grid_desc_m_n;
|
||||
EGridDesc_M_N e_grid_desc_m_n;
|
||||
|
||||
// tensor descriptors for block/thread-wise copy
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1;
|
||||
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
|
||||
// block-to-e-tile map
|
||||
Block2ETileMap block_2_etile_map;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOperation a_element_op;
|
||||
BElementwiseOperation b_element_op;
|
||||
CDEElementwiseOperation cde_element_op;
|
||||
|
||||
// for checking vector load/store
|
||||
index_t MRaw;
|
||||
index_t NRaw;
|
||||
index_t KRaw;
|
||||
|
||||
bool has_main_k_block_loop = true;
|
||||
|
||||
constexpr Descriptor(ADesc a,
|
||||
BDesc b,
|
||||
DsDesc ds,
|
||||
EDesc e,
|
||||
AElementwiseOperation a_element_op_,
|
||||
BElementwiseOperation b_element_op_,
|
||||
CDEElementwiseOperation cde_element_op_)
|
||||
: a_grid_desc_m_k{DeviceOp::matrix_padder.PadADescriptor_M_K(a)},
|
||||
b_grid_desc_n_k{DeviceOp::matrix_padder.PadBDescriptor_N_K(b)},
|
||||
ds_grid_desc_m_n{transform_tuples(
|
||||
[&](auto d) constexpr { return DeviceOp::matrix_padder.PadCDescriptor_M_N(d); },
|
||||
ds)},
|
||||
e_grid_desc_m_n{DeviceOp::matrix_padder.PadCDescriptor_M_N(e)},
|
||||
a_grid_desc_ak0_m_ak1{
|
||||
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k)},
|
||||
b_grid_desc_bk0_n_bk1{
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k)},
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock{
|
||||
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
transform_tuples(
|
||||
[&](auto d) constexpr {
|
||||
return DeviceOp::matrix_padder.PadCDescriptor_M_N(d);
|
||||
},
|
||||
ds))},
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock{
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n)},
|
||||
block_2_etile_map{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n)},
|
||||
has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop(
|
||||
a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))},
|
||||
a_element_op{a_element_op_},
|
||||
b_element_op{b_element_op_},
|
||||
cde_element_op{cde_element_op_},
|
||||
MRaw{e.GetLength(I0)},
|
||||
NRaw{e.GetLength(I1)},
|
||||
KRaw{a.GetLength(I1)}
|
||||
{
|
||||
}
|
||||
|
||||
constexpr bool IsValid() const
|
||||
{
|
||||
return GridwiseGemm::CheckValidity(a_grid_desc_m_k,
|
||||
b_grid_desc_n_k,
|
||||
ds_grid_desc_m_n,
|
||||
e_grid_desc_m_n,
|
||||
block_2_etile_map) and
|
||||
IsSupported(MRaw, NRaw, KRaw);
|
||||
}
|
||||
|
||||
constexpr index_t GetBlockSize() const { return BlockSize; }
|
||||
|
||||
constexpr index_t GetGridSize() const
|
||||
{
|
||||
return block_2_etile_map.CalculateGridSize(e_grid_desc_m_n);
|
||||
}
|
||||
};
|
||||
|
||||
template <class ADesc, class BDesc, class DsDesc, class EDesc>
|
||||
static constexpr auto
|
||||
make_descriptor(ADesc a,
|
||||
BDesc b,
|
||||
DsDesc ds,
|
||||
EDesc e,
|
||||
AElementwiseOperation a_element_op = AElementwiseOperation{},
|
||||
BElementwiseOperation b_element_op = BElementwiseOperation{},
|
||||
CDEElementwiseOperation cde_element_op = CDEElementwiseOperation{})
|
||||
{
|
||||
return Descriptor<ADesc, BDesc, DsDesc, EDesc>(
|
||||
a, b, ds, e, a_element_op, b_element_op, cde_element_op);
|
||||
}
|
||||
|
||||
template <class Desc, class DsPointer>
|
||||
__device__ static void Run(const Desc& desc,
|
||||
const ADataType* __restrict__ p_a_grid,
|
||||
const BDataType* __restrict__ p_b_grid,
|
||||
DsPointer p_ds_grid,
|
||||
EDataType* __restrict__ p_e_grid)
|
||||
{
|
||||
__shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
assert(desc.IsValid());
|
||||
if(desc.has_main_k_block_loop)
|
||||
{
|
||||
GridwiseGemm::template Run<true>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared_block,
|
||||
desc.a_element_op,
|
||||
desc.b_element_op,
|
||||
desc.cde_element_op,
|
||||
desc.a_grid_desc_ak0_m_ak1,
|
||||
desc.b_grid_desc_bk0_n_bk1,
|
||||
desc.ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
desc.e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
desc.block_2_etile_map);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseGemm::template Run<false>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared_block,
|
||||
desc.a_element_op,
|
||||
desc.b_element_op,
|
||||
desc.cde_element_op,
|
||||
desc.a_grid_desc_ak0_m_ak1,
|
||||
desc.b_grid_desc_bk0_n_bk1,
|
||||
desc.ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
desc.e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
desc.block_2_etile_map);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -33,13 +34,14 @@ template <typename ALayout,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::index_t NumPrefetch,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerWMMA,
|
||||
ck::index_t NPerWMMA,
|
||||
ck::index_t MPerWmma,
|
||||
ck::index_t NPerWmma,
|
||||
ck::index_t MRepeat,
|
||||
ck::index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
@@ -60,7 +62,6 @@ template <typename ALayout,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
ck::index_t NumPrefetch = 1,
|
||||
ck::LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
|
||||
struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
@@ -76,68 +77,138 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
// K1 = Max Vector Access Pixels
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock* K1};
|
||||
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
|
||||
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
|
||||
|
||||
static auto MakeAGridDescriptor_K0_M_K1(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
static constexpr auto AEnableLds_auto =
|
||||
(NWaves == 1 && is_same<tensor_layout::gemm::RowMajor, ALayout>::value) ? false : true;
|
||||
static constexpr auto BEnableLds_auto =
|
||||
(MWaves == 1 && is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) ? false : true;
|
||||
|
||||
// If true, LDS is used unconditionally
|
||||
static constexpr auto AEnableLds_manu = false;
|
||||
static constexpr auto BEnableLds_manu = false;
|
||||
|
||||
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
|
||||
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
// Describe how data read from Global memory
|
||||
static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw = [&]() {
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
const auto a_grid_desc_m_k = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(StrideA, I1));
|
||||
const auto a_grid_desc_mraw_kraw =
|
||||
make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(StrideA, I1));
|
||||
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(I1, StrideA));
|
||||
const auto a_grid_desc_mraw_kraw =
|
||||
make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(I1, StrideA));
|
||||
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
}
|
||||
}();
|
||||
|
||||
const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
assert(K % K1 == 0);
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto A_KRow = 2;
|
||||
constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number;
|
||||
const auto A_KWmma = K / WmmaK;
|
||||
|
||||
const auto M0 = M / MPerBlock;
|
||||
// 0 1 0 1 2 3 4 5 6
|
||||
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
A_KWmma, Number<A_K0PerWmma>{}, Number<A_KRow>{}, K1Number)),
|
||||
make_unmerge_transform(
|
||||
make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
|
||||
}
|
||||
}
|
||||
|
||||
static auto MakeBGridDescriptor_K0_N_K1(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
static auto MakeBGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
{
|
||||
const auto b_grid_desc_nraw_kraw = [&]() {
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
const auto b_grid_desc_n_k = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(I1, StrideB));
|
||||
const auto b_grid_desc_nraw_kraw =
|
||||
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(I1, StrideB));
|
||||
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(StrideB, I1));
|
||||
const auto b_grid_desc_nraw_kraw =
|
||||
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(StrideB, I1));
|
||||
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
}
|
||||
}();
|
||||
|
||||
const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
assert(K % K1 == 0);
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto B_KRow = 2;
|
||||
constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number;
|
||||
const auto B_KWmma = K / WmmaK;
|
||||
|
||||
const auto N0 = N / NPerBlock;
|
||||
// 0 1 0 1 2 3 4 5 6
|
||||
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
B_KWmma, Number<B_K0PerWmma>{}, Number<B_KRow>{}, K1Number)),
|
||||
make_unmerge_transform(
|
||||
make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
|
||||
}
|
||||
}
|
||||
|
||||
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
|
||||
@@ -159,56 +230,58 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
}
|
||||
|
||||
// Gridwise descriptor, mapping to whole given provblem.
|
||||
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
|
||||
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
|
||||
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
|
||||
using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1));
|
||||
using BGridDesc = decltype(MakeBGridDescriptor(1, 1, 1));
|
||||
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_wmma<
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
MPerWMMA,
|
||||
NPerWMMA,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockLdsAddExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
NumPrefetch,
|
||||
LoopSched,
|
||||
PipelineVer>;
|
||||
using GridwiseGemm =
|
||||
GridwiseGemm_Wmma<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc,
|
||||
BGridDesc,
|
||||
CGridDesc_M_N,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
AEnableLds,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BEnableLds,
|
||||
BBlockLdsAddExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
NumPrefetch,
|
||||
LoopSched,
|
||||
PipelineVer>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
@@ -230,7 +303,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
: p_a_grid_{p_a_grid},
|
||||
p_b_grid_{p_b_grid},
|
||||
p_c_grid_{p_c_grid},
|
||||
a_grid_desc_k0_m_k1_{},
|
||||
a_grid_desc_{},
|
||||
b_grid_desc_k0_n_k1_{},
|
||||
c_grid_desc_m_n_{},
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock{},
|
||||
@@ -244,19 +317,15 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
NRaw_{N},
|
||||
KRaw_{K}
|
||||
{
|
||||
a_grid_desc_k0_m_k1_ =
|
||||
DeviceGemmWmma_CShuffle::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
|
||||
b_grid_desc_k0_n_k1_ =
|
||||
DeviceGemmWmma_CShuffle::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
|
||||
c_grid_desc_m_n_ = DeviceGemmWmma_CShuffle::MakeCGridDescriptor_M_N(M, N, StrideC);
|
||||
a_grid_desc_ = DeviceGemmWmma_CShuffle::MakeAGridDescriptor(M, K, StrideA);
|
||||
b_grid_desc_k0_n_k1_ = DeviceGemmWmma_CShuffle::MakeBGridDescriptor(K, N, StrideB);
|
||||
c_grid_desc_m_n_ = DeviceGemmWmma_CShuffle::MakeCGridDescriptor_M_N(M, N, StrideC);
|
||||
|
||||
block_2_ctile_map_ =
|
||||
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
|
||||
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
|
||||
b_grid_desc_k0_n_k1_,
|
||||
c_grid_desc_m_n_,
|
||||
block_2_ctile_map_))
|
||||
if(GridwiseGemm::CheckValidity(
|
||||
a_grid_desc_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, block_2_ctile_map_))
|
||||
{
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
@@ -268,8 +337,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
|
||||
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
|
||||
AGridDesc a_grid_desc_;
|
||||
BGridDesc b_grid_desc_k0_n_k1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
@@ -292,23 +361,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
#if 0
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
|
||||
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
|
||||
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0)
|
||||
<< ", " << arg.c_grid_desc_m_n_.GetLength(I1) << ", "
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I2) << "}" << std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_))
|
||||
@@ -320,79 +373,58 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
const index_t grid_size =
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
|
||||
|
||||
const auto K =
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
const auto K = [&]() {
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I2);
|
||||
}
|
||||
else
|
||||
{
|
||||
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) *
|
||||
arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6);
|
||||
}
|
||||
}();
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
const auto kernel = kernel_gemm_wmma<
|
||||
GridwiseGemm,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc>,
|
||||
remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
|
||||
has_main_k_block_loop>;
|
||||
|
||||
float ave_time = 0;
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
};
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
const auto kernel = kernel_gemm_wmma<
|
||||
GridwiseGemm,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
|
||||
true>; // Last Option is W/O
|
||||
|
||||
ave_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
return launch_kernel(integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_wmma<
|
||||
GridwiseGemm,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
remove_reference_t<DeviceGemmWmma_CShuffle::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceGemmWmma_CShuffle::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
return launch_kernel(integral_constant<bool, false>{});
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
@@ -413,13 +445,16 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
{
|
||||
if(ck::is_navi3_supported())
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> ||
|
||||
is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
printf("DeviceOp err: AccDataType");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("DeviceOp err: Arch");
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -485,7 +520,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
}
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_);
|
||||
@@ -581,14 +616,18 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< K0PerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< K1 << ", "
|
||||
<< MPerWMMA << ", "
|
||||
<< NPerWMMA << ", "
|
||||
<< MPerWmma << ", "
|
||||
<< NPerWmma << ", "
|
||||
<< MRepeat << ", "
|
||||
<< NRepeat
|
||||
<< ">"
|
||||
<< " NumPrefetch: "
|
||||
<< " AEnableLds: "
|
||||
<< AEnableLds << ", "
|
||||
<< "BEnableLds: "
|
||||
<< BEnableLds << ", "
|
||||
<< "NumPrefetch: "
|
||||
<< NumPrefetch << ", "
|
||||
<< "LoopScheduler: "
|
||||
<< LoopSchedToString[LoopSched] << ", "
|
||||
|
||||
@@ -60,7 +60,9 @@ template <typename ADataType,
|
||||
index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
|
||||
typename ComputeType = CDataType,
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
typename LDSTypeA = ComputeType,
|
||||
typename LDSTypeB = ComputeType>
|
||||
|
||||
struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
BLayout,
|
||||
@@ -81,6 +83,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
// TODO: should be exposed as Tparams.
|
||||
static constexpr index_t NumGemmKPrefetchStage = 1;
|
||||
|
||||
using ComputeTypeA = ComputeType;
|
||||
using ComputeTypeB = ComputeType;
|
||||
|
||||
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
|
||||
BlockSize,
|
||||
ADataType,
|
||||
@@ -125,7 +130,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
LoopSched,
|
||||
PipelineVer,
|
||||
ComputeType>;
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
LDSTypeA,
|
||||
LDSTypeB>;
|
||||
|
||||
struct Argument : public GridwiseGemm::Argument
|
||||
{
|
||||
|
||||
@@ -196,7 +196,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
using EGridDesc_M_N = remove_cvref_t<tuple_element_t<3, ABDsEGridDesc>>;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle<
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_Wmma<
|
||||
// DataType Family
|
||||
ADataType,
|
||||
BDataType,
|
||||
@@ -217,7 +217,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
// Tiling Family
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
KPerBlock,
|
||||
MPerWMMA,
|
||||
NPerWMMA,
|
||||
K1,
|
||||
@@ -232,6 +232,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
true,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
@@ -240,6 +241,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
true,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
|
||||
@@ -393,12 +393,14 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
|
||||
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
|
||||
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle<
|
||||
using CShuffleDataType = AccDataType;
|
||||
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_Wmma<
|
||||
// DataType Family
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CShuffleDataType,
|
||||
Tuple<>,
|
||||
CDataType,
|
||||
// InMemory Data Descriptor
|
||||
@@ -414,7 +416,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
// Tiling Family
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
KPerBlock,
|
||||
MPerWMMA,
|
||||
NPerWMMA,
|
||||
K1,
|
||||
@@ -429,6 +431,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false,
|
||||
true,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
@@ -437,6 +440,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false,
|
||||
true,
|
||||
BBlockLdsAddExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
|
||||
@@ -52,22 +52,23 @@ template <index_t NDimSpatial,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
ConvolutionForwardSpecialization ConvForwardSpecialization,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerWMMA,
|
||||
ck::index_t NPerWMMA,
|
||||
ck::index_t MPerWmma,
|
||||
ck::index_t NPerWmma,
|
||||
ck::index_t MRepeat,
|
||||
ck::index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
@@ -88,7 +89,6 @@ template <index_t NDimSpatial,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
index_t NumGemmKPrefetchStage = 1,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
|
||||
struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
@@ -109,11 +109,31 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr index_t KPerBlock = K0PerBlock * K1;
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
// K1 = Max Vector Access Pixels
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
|
||||
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
|
||||
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
static constexpr auto WmmaK = 16;
|
||||
|
||||
static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true;
|
||||
static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true;
|
||||
|
||||
// If true, LDS is used unconditionally
|
||||
static constexpr auto AEnableLds_manu = true;
|
||||
static constexpr auto BEnableLds_manu = true;
|
||||
|
||||
static constexpr auto AEnableLds =
|
||||
AEnableLds_auto || AEnableLds_manu || (NumGemmKPrefetchStage > 1);
|
||||
static constexpr auto BEnableLds =
|
||||
BEnableLds_auto || BEnableLds_manu || (NumGemmKPrefetchStage > 1);
|
||||
|
||||
static constexpr auto conv_to_gemm_transformer =
|
||||
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
|
||||
@@ -122,17 +142,16 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
template <typename ALay>
|
||||
static auto
|
||||
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads)
|
||||
static auto MakeAGridDescriptor(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads)
|
||||
{
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths,
|
||||
@@ -149,13 +168,44 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
const auto in_gemmm_gemmk_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
|
||||
|
||||
return in_gemmm_gemmk_desc;
|
||||
const auto M = in_gemmm_gemmk_desc.GetLength(I0);
|
||||
const auto K = in_gemmm_gemmk_desc.GetLength(I1);
|
||||
assert(K % K1 == 0);
|
||||
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmm_gemmk_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto A_KRow = 2;
|
||||
constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number;
|
||||
const auto A_KWmma = K / WmmaK;
|
||||
|
||||
const auto M0 = M / MPerBlock;
|
||||
// 0 1 0 1 2 3 4 5 6
|
||||
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
|
||||
return transform_tensor_descriptor(
|
||||
in_gemmm_gemmk_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
A_KWmma, Number<A_K0PerWmma>{}, Number<A_KRow>{}, K1Number)),
|
||||
make_unmerge_transform(
|
||||
make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename BLay>
|
||||
static auto
|
||||
MakeBGridDescriptor_N_K(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
|
||||
static auto MakeBGridDescriptor(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
|
||||
{
|
||||
const auto wei_gemmnraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths,
|
||||
@@ -164,7 +214,39 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
const auto wei_gemmn_gemmk_desc =
|
||||
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
|
||||
|
||||
return wei_gemmn_gemmk_desc;
|
||||
const auto N = wei_gemmn_gemmk_desc.GetLength(I0);
|
||||
const auto K = wei_gemmn_gemmk_desc.GetLength(I1);
|
||||
assert(K % K1 == 0);
|
||||
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
const index_t K0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
wei_gemmn_gemmk_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto B_KRow = 2;
|
||||
constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number;
|
||||
const auto B_KWmma = K / WmmaK;
|
||||
|
||||
const auto N0 = N / NPerBlock;
|
||||
// 0 1 0 1 2 3 4 5 6
|
||||
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
|
||||
return transform_tensor_descriptor(
|
||||
wei_gemmn_gemmk_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
B_KWmma, Number<B_K0PerWmma>{}, Number<B_KRow>{}, K1Number)),
|
||||
make_unmerge_transform(
|
||||
make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ELay>
|
||||
@@ -197,53 +279,14 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
}
|
||||
|
||||
// desc for problem definition
|
||||
using AGridDesc_M_K = remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(
|
||||
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
|
||||
using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>;
|
||||
using AGridDesc =
|
||||
decltype(DeviceOp::MakeAGridDescriptor<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {}));
|
||||
using BGridDesc = decltype(DeviceOp::MakeBGridDescriptor<BLayout>({}, {}));
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
|
||||
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
|
||||
|
||||
// A desc for source in blockwise copy
|
||||
template <typename AGridDesc_M_K>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
|
||||
{
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
|
||||
const auto AK1 = K1;
|
||||
const auto AK0 = K / AK1;
|
||||
|
||||
return transform_tensor_descriptor(a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
// B desc for source in blockwise copy
|
||||
template <typename BGridDesc_N_K>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
|
||||
{
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
|
||||
const auto BK1 = K1;
|
||||
const auto BK0 = K / BK1;
|
||||
|
||||
return transform_tensor_descriptor(b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
using AGridDesc_AK0_M_AK1 = decltype(DeviceOp::MakeAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}));
|
||||
using BGridDesc_BK0_N_BK1 = decltype(DeviceOp::MakeBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}));
|
||||
|
||||
// GridwiseOp
|
||||
using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle<
|
||||
using GridwiseOp = GridwiseGemmMultipleD_Wmma<
|
||||
// DataType Family
|
||||
ADataType,
|
||||
BDataType,
|
||||
@@ -252,8 +295,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
DsDataType,
|
||||
EDataType,
|
||||
// InMemory Data Descriptor
|
||||
AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
AGridDesc,
|
||||
BGridDesc,
|
||||
DsGridDesc_M_N,
|
||||
EGridDesc_M_N,
|
||||
// ElementwiseOp Family
|
||||
@@ -264,9 +307,9 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
// Tiling Family
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
MPerWMMA,
|
||||
NPerWMMA,
|
||||
KPerBlock,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
@@ -279,6 +322,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
AEnableLds,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
@@ -287,6 +331,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BEnableLds,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
@@ -327,23 +372,21 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
p_ds_grid_{},
|
||||
p_e_grid_{static_cast<EDataType*>(p_e)},
|
||||
num_group_{a_g_n_c_wis_lengths[0]},
|
||||
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K<ALayout>(a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads)},
|
||||
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K<BLayout>(b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides)},
|
||||
ds_grid_desc_m_n_{},
|
||||
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides)},
|
||||
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
|
||||
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
|
||||
a_grid_desc_{DeviceOp::MakeAGridDescriptor<ALayout>(a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads)},
|
||||
b_grid_desc_{
|
||||
DeviceOp::MakeBGridDescriptor<BLayout>(b_g_k_c_xs_lengths, b_g_k_c_xs_strides)},
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
block_2_etile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01)},
|
||||
@@ -395,8 +438,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
|
||||
void Print() const
|
||||
{
|
||||
std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
|
||||
std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
|
||||
std::cout << "A[M, K]: " << a_grid_desc_ << std::endl;
|
||||
std::cout << "B[N, K]: " << b_grid_desc_ << std::endl;
|
||||
static_for<0, NumDTensor, 1>{}(
|
||||
[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
|
||||
std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
|
||||
@@ -411,14 +454,12 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
|
||||
// tensor descriptors for problem definiton
|
||||
index_t num_group_;
|
||||
AGridDesc_M_K a_grid_desc_m_k_;
|
||||
BGridDesc_N_K b_grid_desc_n_k_;
|
||||
DsGridDesc_M_N ds_grid_desc_m_n_;
|
||||
EGridDesc_M_N e_grid_desc_m_n_;
|
||||
|
||||
// tensor descriptors for block/thread-wise copy
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
AGridDesc a_grid_desc_;
|
||||
BGridDesc b_grid_desc_;
|
||||
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
@@ -465,8 +506,17 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
const index_t grid_size =
|
||||
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.num_group_;
|
||||
|
||||
const auto K =
|
||||
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
|
||||
const auto K = [&]() {
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I2);
|
||||
}
|
||||
else
|
||||
{
|
||||
return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) *
|
||||
arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6);
|
||||
}
|
||||
}();
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
constexpr bool has_main_loop = has_main_k_block_loop.value;
|
||||
@@ -480,8 +530,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
DeviceOp::AGridDesc,
|
||||
DeviceOp::BGridDesc,
|
||||
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
remove_reference_t<typename GridwiseOp::DefaultBlock2CTileMap>,
|
||||
@@ -501,8 +551,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.a_g_n_c_wis_lengths_[0], // Group count
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.a_grid_desc_,
|
||||
arg.b_grid_desc_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_etile_map_,
|
||||
@@ -670,8 +720,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
}
|
||||
|
||||
// check Gridwise GEMM
|
||||
return GridwiseOp::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
return GridwiseOp::CheckValidity(arg.a_grid_desc_,
|
||||
arg.b_grid_desc_,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_etile_map_);
|
||||
@@ -790,9 +840,19 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
<< KPerBlock << ", "
|
||||
<< getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
|
||||
<< K1 << ", "
|
||||
<< MPerWmma << ", "
|
||||
<< NPerWmma << ", "
|
||||
<< MRepeat << ", "
|
||||
<< NRepeat
|
||||
<< ">"
|
||||
<< " AEnableLds: "
|
||||
<< AEnableLds << ", "
|
||||
<< "BEnableLds: "
|
||||
<< BEnableLds << ", "
|
||||
<< "ABlockTransferSrcScalarPerVector: "
|
||||
<< ABlockTransferSrcScalarPerVector << ", "
|
||||
<< BBlockTransferSrcScalarPerVector
|
||||
<< ">";
|
||||
<< "BBlockTransferSrcScalarPerVector: "
|
||||
<< BBlockTransferSrcScalarPerVector;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -650,22 +650,9 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
|
||||
constexpr auto AtomicAdd = InMemoryDataOperationEnum::AtomicAdd;
|
||||
constexpr auto Set = InMemoryDataOperationEnum::Set;
|
||||
|
||||
if(arg.k_batch_ > 1)
|
||||
{
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
ave_time =
|
||||
launch_kernel(integral_constant<bool, true>{},
|
||||
integral_constant<InMemoryDataOperationEnum, AtomicAdd>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time =
|
||||
launch_kernel(integral_constant<bool, false>{},
|
||||
integral_constant<InMemoryDataOperationEnum, AtomicAdd>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
// For bf16 datatype only kbatch = 1 scenario is supported. This condition is enforced
|
||||
// in IsSupportedArgument function
|
||||
if constexpr(std::is_same<ADataType, ck::bhalf_t>::value)
|
||||
{
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
@@ -678,6 +665,39 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
|
||||
integral_constant<InMemoryDataOperationEnum, Set>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(arg.k_batch_ > 1)
|
||||
{
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
ave_time = launch_kernel(
|
||||
integral_constant<bool, true>{},
|
||||
integral_constant<InMemoryDataOperationEnum, AtomicAdd>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = launch_kernel(
|
||||
integral_constant<bool, false>{},
|
||||
integral_constant<InMemoryDataOperationEnum, AtomicAdd>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
ave_time =
|
||||
launch_kernel(integral_constant<bool, true>{},
|
||||
integral_constant<InMemoryDataOperationEnum, Set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time =
|
||||
launch_kernel(integral_constant<bool, false>{},
|
||||
integral_constant<InMemoryDataOperationEnum, Set>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
@@ -718,6 +738,13 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
|
||||
}
|
||||
}
|
||||
|
||||
// For bf16 datatype only kbatch = 1 is supported since there is no AtomicAdd
|
||||
// instruction that supports bf16 and we cannot use splitk because of that
|
||||
if constexpr(std::is_same<ADataType, ck::bhalf_t>::value)
|
||||
{
|
||||
supported = supported & (arg.k_batch_ == 1);
|
||||
}
|
||||
|
||||
return supported;
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -53,7 +53,10 @@ struct MaskOutUpperTrianglePredicate
|
||||
template <typename MaskOutPredicate>
|
||||
struct C0MatrixMask_impl
|
||||
{
|
||||
C0MatrixMask_impl(index_t NRaw) : NRaw_(NRaw), predicate_(MaskOutPredicate{}) {}
|
||||
__host__ __device__ C0MatrixMask_impl(index_t NRaw)
|
||||
: NRaw_(NRaw), predicate_(MaskOutPredicate{})
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr bool IsNOutOfBound(/*index_t m, */ index_t n) const
|
||||
{
|
||||
|
||||
@@ -165,7 +165,7 @@ struct Subtract
|
||||
|
||||
struct Bilinear
|
||||
{
|
||||
Bilinear(float alpha, float beta) : alpha_(alpha), beta_(beta){};
|
||||
Bilinear(float alpha = 1.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
|
||||
|
||||
template <typename Y, typename X0, typename X1>
|
||||
__host__ __device__ constexpr void operator()(Y&, const X0&, const X1&) const;
|
||||
@@ -184,6 +184,14 @@ struct Bilinear
|
||||
y = alpha_ * x0 + beta_ * x1;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<int8_t, int8_t, int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
|
||||
{
|
||||
y = type_convert<int8_t>(alpha_ * type_convert<float>(x0) +
|
||||
beta_ * type_convert<float>(x1));
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t, half_t, half_t>(half_t& y, const half_t& x0, const half_t& x1) const
|
||||
@@ -221,7 +229,8 @@ struct Bilinear
|
||||
__host__ __device__ constexpr void operator()<std::int8_t, std::int32_t, std::int8_t>(
|
||||
std::int8_t& y, const std::int32_t& x0, const std::int8_t& x1) const
|
||||
{
|
||||
y = type_convert<std::int8_t>(x0 + ck::type_convert<std::int32_t>(x1));
|
||||
y = type_convert<int8_t>(alpha_ * type_convert<float>(x0) +
|
||||
beta_ * type_convert<float>(x1));
|
||||
};
|
||||
|
||||
float alpha_;
|
||||
|
||||
@@ -21,50 +21,11 @@ struct PassThroughPack2
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ void operator()(Y& y, const X& x) const;
|
||||
|
||||
__host__ __device__ constexpr void operator()(ck::f8x2_t& y, const ck::half2_t& x) const
|
||||
{
|
||||
// fake conversion
|
||||
uint16_t t = ck::bit_cast<uint32_t>(x);
|
||||
y = ck::bit_cast<ck::f8x2_t>(t);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::f8x2_t& x) const
|
||||
{
|
||||
auto t = type_convert<float2_t>(x);
|
||||
y = type_convert<half2_t>(t);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::half2_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void operator()(ck::f8x2_t& y, const ck::f8x2_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void operator()(ck::float2_t& y, const ck::float2_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void operator()(ck::int8x2_t& y, const ck::int8x2_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void operator()(ck::bhalf2_t& y, const ck::bhalf2_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void operator()(ck::double2_t& y, const ck::double2_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
constexpr const static bool is_pack2_invocable = true;
|
||||
};
|
||||
|
||||
struct PassThrough
|
||||
@@ -162,6 +123,12 @@ struct PassThrough
|
||||
y = type_convert<bhalf_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<uint8_t, uint8_t>(uint8_t& y, const uint8_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<int8_t, int32_t>(int8_t& y, const int32_t& x) const
|
||||
{
|
||||
@@ -343,6 +310,12 @@ struct Scale
|
||||
y = scale_ * x;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
|
||||
{
|
||||
y = ck::type_convert<int8_t>(scale_ * ck::type_convert<float>(x));
|
||||
};
|
||||
|
||||
float scale_;
|
||||
};
|
||||
|
||||
@@ -702,6 +675,76 @@ struct Elu
|
||||
const float alpha_;
|
||||
};
|
||||
|
||||
// support fastconvert of int8 to fp16
|
||||
|
||||
template <typename InputDataType, typename OutputDataType, index_t RegPackNumber>
|
||||
struct FastNumericArrayConverter
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FastNumericArrayConverter<uint8_t, ck::half_t, 4>
|
||||
{
|
||||
using InputArray = vector_type<uint8_t, 4>;
|
||||
using OutputArray = vector_type<ck::half_t, 4>;
|
||||
|
||||
__device__ static OutputArray convert(InputArray const& Input)
|
||||
{
|
||||
OutputArray Output;
|
||||
|
||||
uint32_t* half_2 = reinterpret_cast<uint32_t*>(&Output);
|
||||
uint32_t const uint8_4 = reinterpret_cast<uint32_t const&>(Input);
|
||||
|
||||
static constexpr uint32_t byte_selector_01 = 0x05010500;
|
||||
static constexpr uint32_t byte_selector_23 = 0x05030502;
|
||||
static constexpr uint32_t fp16_adder = 0x64646464;
|
||||
half_2[0] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_01);
|
||||
half_2[1] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_23);
|
||||
|
||||
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]"
|
||||
: "=v"(half_2[0])
|
||||
: "v"(half_2[0]), "s"(I8s_TO_F16s_MAGIC_NUM));
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]"
|
||||
: "=v"(half_2[1])
|
||||
: "v"(half_2[1]), "s"(I8s_TO_F16s_MAGIC_NUM));
|
||||
|
||||
return Output;
|
||||
}
|
||||
|
||||
__device__ OutputArray operator()(InputArray const& Input) { return convert(Input); }
|
||||
};
|
||||
|
||||
template <index_t N>
|
||||
struct FastNumericArrayConverter<uint8_t, ck::half_t, N>
|
||||
{
|
||||
static constexpr int VEC_WIDTH = 4;
|
||||
static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");
|
||||
|
||||
using InputArray = vector_type<uint8_t, N>;
|
||||
using OutputArray = vector_type<ck::half_t, N>;
|
||||
|
||||
__device__ static OutputArray convert(InputArray const& Input)
|
||||
{
|
||||
FastNumericArrayConverter<uint8_t, ck::half_t, 4> converter;
|
||||
|
||||
OutputArray Output;
|
||||
|
||||
using Vec_InputArray = vector_type<uint8_t, 4>;
|
||||
using Vec_OutputArray = vector_type<ck::half_t, 4>;
|
||||
|
||||
Vec_OutputArray* half_4_ptr = reinterpret_cast<Vec_OutputArray*>(&Output);
|
||||
Vec_InputArray const* uint8_4_ptr = reinterpret_cast<Vec_InputArray const*>(&Input);
|
||||
|
||||
static_for<0, N / VEC_WIDTH, 1>{}(
|
||||
[&](auto i) { half_4_ptr[i] = converter(uint8_4_ptr[i]); });
|
||||
|
||||
return Output;
|
||||
}
|
||||
|
||||
__device__ OutputArray operator()(InputArray const& Input) { return convert(Input); }
|
||||
};
|
||||
|
||||
} // namespace element_wise
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
@@ -24,10 +24,10 @@ struct BlockToCTileMap_M00_N0_M01
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
__host__ __device__ BlockToCTileMap_M00_N0_M01() = default;
|
||||
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01() = default;
|
||||
|
||||
__host__ __device__ BlockToCTileMap_M00_N0_M01(const CGridDesc_M_N& c_grid_desc_m_n,
|
||||
index_t M01 = 1)
|
||||
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01(const CGridDesc_M_N& c_grid_desc_m_n,
|
||||
index_t M01 = 1)
|
||||
: M01_(M01), underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01))
|
||||
{
|
||||
}
|
||||
@@ -51,8 +51,8 @@ struct BlockToCTileMap_M00_N0_M01
|
||||
}
|
||||
|
||||
template <typename CTileIdx, typename CTileDim>
|
||||
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
|
||||
const CTileDim& c_tile_dim) const
|
||||
__host__ __device__ constexpr bool ValidCTileIndex(const CTileIdx& c_tile_idx,
|
||||
const CTileDim& c_tile_dim) const
|
||||
{
|
||||
if constexpr(DeviceCTileIndexCheck)
|
||||
return DefaultValidCTileIndex(c_tile_idx, c_tile_dim);
|
||||
@@ -60,7 +60,7 @@ struct BlockToCTileMap_M00_N0_M01
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
__host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
{
|
||||
if constexpr(DeviceCTileIndexCheck)
|
||||
return true; // validity check moved to kernel
|
||||
@@ -120,18 +120,19 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt() = default;
|
||||
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt() = default;
|
||||
|
||||
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const BlockToCTileMap_M00_N0_M01Adapt&) =
|
||||
default;
|
||||
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(BlockToCTileMap_M00_N0_M01Adapt&&) =
|
||||
default;
|
||||
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt&
|
||||
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(
|
||||
const BlockToCTileMap_M00_N0_M01Adapt&) = default;
|
||||
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(
|
||||
BlockToCTileMap_M00_N0_M01Adapt&&) = default;
|
||||
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt&
|
||||
operator=(const BlockToCTileMap_M00_N0_M01Adapt&) = default;
|
||||
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt&
|
||||
__host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt&
|
||||
operator=(BlockToCTileMap_M00_N0_M01Adapt&&) = default;
|
||||
|
||||
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8)
|
||||
__host__
|
||||
__device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8)
|
||||
: M_(M), N_(N), M01_(M01)
|
||||
{
|
||||
#if 0
|
||||
@@ -142,8 +143,9 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
|
||||
index_t M01 = 8)
|
||||
__host__
|
||||
__device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
|
||||
index_t M01 = 8)
|
||||
: BlockToCTileMap_M00_N0_M01Adapt(
|
||||
c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01)
|
||||
{
|
||||
@@ -164,7 +166,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
|
||||
__host__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
|
||||
{
|
||||
return true;
|
||||
}
|
||||
@@ -237,8 +239,8 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
|
||||
}
|
||||
|
||||
template <typename CTileIdx, typename CTileDim>
|
||||
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
|
||||
const CTileDim& /* c_tile_dim */) const
|
||||
__host__ __device__ constexpr bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
|
||||
const CTileDim& /* c_tile_dim */) const
|
||||
{
|
||||
return true; // always valid provided that user gets grid size from CalculateGridSize()
|
||||
}
|
||||
@@ -616,7 +618,10 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt
|
||||
return true; // always valid provided that user gets grid size from CalculateGridSize()
|
||||
}
|
||||
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; }
|
||||
__host__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
index_t M01_;
|
||||
@@ -674,7 +679,7 @@ struct BlockToCTileMap_M00_N00_M01_N01
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
__host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
{
|
||||
if constexpr(DeviceCTileIndexCheck)
|
||||
return true; // validity check moved to kernel
|
||||
@@ -786,7 +791,7 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
__host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
{
|
||||
if constexpr(DeviceCTileIndexCheck)
|
||||
return true; // validity check moved to kernel
|
||||
@@ -910,7 +915,7 @@ struct OffsettedBlockToCTileMap
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
__host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
{
|
||||
return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
|
||||
}
|
||||
@@ -967,7 +972,7 @@ struct BlockToCTileMap_3DGrid_KSplit
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
|
||||
__host__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -116,7 +116,7 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemm0KPrefetchStage>;
|
||||
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemm0KPrefetchStage, true, true>;
|
||||
|
||||
// ck::Tuple<const D0DataType1*, const D0DataType2*, ...>
|
||||
static constexpr auto MakeD0sGridPointer()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,169 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_description/cluster_descriptor.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r2.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor/static_tensor.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseElementwiseFunctor,
|
||||
typename InGridDescTuple,
|
||||
typename OutGridDescTuple,
|
||||
typename InDataTypePointerTuple,
|
||||
typename OutDataTypePointerTuple,
|
||||
typename Block2TileMap,
|
||||
typename ElementwiseOperation>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_elementwise(const InGridDescTuple in_grid_desc_tuple,
|
||||
const OutGridDescTuple out_grid_desc_tuple,
|
||||
const InDataTypePointerTuple p_in_global_tuple,
|
||||
const OutDataTypePointerTuple p_out_global_tuple,
|
||||
const Block2TileMap block_2_tile_map,
|
||||
const ElementwiseOperation elementwise_op)
|
||||
{
|
||||
GridwiseElementwiseFunctor::Run(in_grid_desc_tuple,
|
||||
out_grid_desc_tuple,
|
||||
p_in_global_tuple,
|
||||
p_out_global_tuple,
|
||||
block_2_tile_map,
|
||||
elementwise_op);
|
||||
}
|
||||
|
||||
template <typename InGridDescTuple,
|
||||
typename OutGridDescTuple,
|
||||
typename InDataTypePointerTuple,
|
||||
typename OutDataTypePointerTuple,
|
||||
typename Block2TileMap,
|
||||
typename ElementwiseOperation,
|
||||
index_t BlockSize,
|
||||
index_t M0PerBlock,
|
||||
index_t M1PerBlock,
|
||||
index_t M0PerThread,
|
||||
index_t M1PerThread,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename InScalarPerVectorSeq,
|
||||
typename OutScalarPerVectorSeq,
|
||||
bool InOutSameVectorDim>
|
||||
struct GridwiseElementwise
|
||||
{
|
||||
static constexpr index_t NumInput = InDataTypePointerTuple::Size();
|
||||
static constexpr index_t NumOutput = OutDataTypePointerTuple::Size();
|
||||
|
||||
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
|
||||
NumOutput == OutScalarPerVectorSeq::Size() &&
|
||||
NumInput == InGridDescTuple::Size() && NumOutput == OutGridDescTuple::Size(),
|
||||
"Tuple size is inconsistent with the number of in/out!");
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
using PassThroughOp = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
__device__ static void Run(const InGridDescTuple& in_grid_desc_tuple,
|
||||
const OutGridDescTuple& out_grid_desc_tuple,
|
||||
const InDataTypePointerTuple& p_in_global_tuple,
|
||||
const OutDataTypePointerTuple& p_out_global_tuple,
|
||||
const Block2TileMap& block_2_tile_map,
|
||||
const ElementwiseOperation& elementwise_op)
|
||||
{
|
||||
|
||||
constexpr auto src_datas = generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
|
||||
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
|
||||
|
||||
return DataType{};
|
||||
},
|
||||
Number<NumInput>{});
|
||||
|
||||
constexpr auto dst_datas = generate_tuple(
|
||||
[&](auto I) {
|
||||
using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
|
||||
using DataType = remove_pointer_t<DataTypePointer>;
|
||||
|
||||
return DataType{};
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
|
||||
const auto in_global_buf_tuple = generate_tuple(
|
||||
[&](auto I) {
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_global_tuple[I], in_grid_desc_tuple[I].GetElementSpaceSize());
|
||||
},
|
||||
Number<NumInput>{});
|
||||
|
||||
auto out_global_buf_tuple = generate_tuple(
|
||||
[&](auto I) {
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_out_global_tuple[I], out_grid_desc_tuple[I].GetElementSpaceSize());
|
||||
},
|
||||
Number<NumOutput>{});
|
||||
|
||||
const auto block_work_idx =
|
||||
block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
const index_t m0_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * M0PerBlock);
|
||||
const index_t m1_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * M1PerBlock);
|
||||
const auto thread_grid_offset =
|
||||
make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid);
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
// If src and dst have same vector dim, then:
|
||||
// M0 dim - for src and dst vector load/store
|
||||
// else:
|
||||
// M0 dim - for dst vector load
|
||||
// M1 dim - for src vector store
|
||||
using SrcDimAccessOrder = Sequence<0, 1>;
|
||||
using DstDimAccessOrder =
|
||||
std::conditional_t<InOutSameVectorDim, Sequence<0, 1>, Sequence<1, 0>>;
|
||||
using SrcVectorDim = Number<1>;
|
||||
using DstVectorDim = std::conditional_t<InOutSameVectorDim, Number<1>, Number<0>>;
|
||||
|
||||
using ThreadClusterLengths =
|
||||
Sequence<Number<M0PerBlock / M0PerThread>{}, Number<M1PerBlock / M1PerThread>{}>;
|
||||
|
||||
auto global_to_global_transfer = ThreadGroupTensorSliceTransfer_v4r2<
|
||||
ThisThreadBlock,
|
||||
ElementwiseOperation,
|
||||
uniform_sequence_gen_t<NumOutput, static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
|
||||
Sequence<M0PerBlock, M1PerBlock>,
|
||||
ThreadClusterLengths,
|
||||
ThreadClusterArrangeOrder,
|
||||
decltype(src_datas),
|
||||
decltype(dst_datas),
|
||||
InGridDescTuple,
|
||||
OutGridDescTuple,
|
||||
SrcDimAccessOrder,
|
||||
DstDimAccessOrder,
|
||||
SrcVectorDim{},
|
||||
DstVectorDim{},
|
||||
InScalarPerVectorSeq,
|
||||
OutScalarPerVectorSeq,
|
||||
uniform_sequence_gen_t<NumInput, 1>,
|
||||
uniform_sequence_gen_t<NumOutput, 1>,
|
||||
uniform_sequence_gen_t<NumInput, false>,
|
||||
uniform_sequence_gen_t<NumOutput, false>>{in_grid_desc_tuple,
|
||||
thread_grid_offset,
|
||||
out_grid_desc_tuple,
|
||||
thread_grid_offset,
|
||||
elementwise_op};
|
||||
global_to_global_transfer.Run(
|
||||
in_grid_desc_tuple, in_global_buf_tuple, out_grid_desc_tuple, out_global_buf_tuple, I0);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
1046
include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
Normal file
1046
include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -264,7 +264,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
const BGridDesc_N_K& b_grid_desc_n_k,
|
||||
const DsGridDesc_M_N& ds_grid_desc_m_n,
|
||||
const EGridDesc_M_N& e_grid_desc_m_n,
|
||||
const Block2ETileMap& block_2_etile_map)
|
||||
const Block2ETileMap&)
|
||||
{
|
||||
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
|
||||
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
|
||||
@@ -310,10 +310,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
}
|
||||
|
||||
// check block-to-E-tile
|
||||
if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
// if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n))
|
||||
//{
|
||||
// return false;
|
||||
//}
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
// check tensor size: cannot be larger than 2GB each
|
||||
|
||||
@@ -17,18 +17,21 @@ enum struct PipelineVersion
|
||||
v2,
|
||||
// v3 is only used in the Stream-K implementation.
|
||||
v4,
|
||||
weight_only,
|
||||
};
|
||||
|
||||
template <PipelineVersion PipelineVer,
|
||||
index_t NumPrefetch = 1,
|
||||
LoopScheduler LoopSched = LoopScheduler::Default>
|
||||
LoopScheduler LoopSched = LoopScheduler::Default,
|
||||
bool AEnableLds = true,
|
||||
bool BEnableLds = true>
|
||||
constexpr auto GridwiseGemmPipeline_Selector()
|
||||
{
|
||||
if constexpr(PipelineVer == PipelineVersion::v1)
|
||||
{
|
||||
if constexpr(LoopSched == LoopScheduler::Default)
|
||||
{
|
||||
return GridwiseGemmPipeline_v1<NumPrefetch>{};
|
||||
return GridwiseGemmPipeline_v1<NumPrefetch, AEnableLds, BEnableLds>{};
|
||||
}
|
||||
else if constexpr(LoopSched == LoopScheduler::Interwave)
|
||||
{
|
||||
@@ -43,6 +46,10 @@ constexpr auto GridwiseGemmPipeline_Selector()
|
||||
{
|
||||
return GridwiseGemmPipeline_v4<NumPrefetch>{};
|
||||
}
|
||||
else if constexpr(PipelineVer == PipelineVersion::weight_only)
|
||||
{
|
||||
return GridwiseGemmPipeline_v1_WeightOnly<NumPrefetch, AEnableLds, BEnableLds>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "GridwiseGemmPipeline configuration is not available" << std::endl;
|
||||
|
||||
@@ -9,12 +9,12 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t NumPrefetch>
|
||||
template <index_t NumPrefetch, bool AEnableLds, bool BEnableLds>
|
||||
struct GridwiseGemmPipeline_v1;
|
||||
|
||||
// 1-stage prefetch
|
||||
template <>
|
||||
struct GridwiseGemmPipeline_v1<1>
|
||||
struct GridwiseGemmPipeline_v1<1, true, true>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -108,7 +108,7 @@ struct GridwiseGemmPipeline_v1<1>
|
||||
|
||||
// 2-stage prefetch
|
||||
template <>
|
||||
struct GridwiseGemmPipeline_v1<2>
|
||||
struct GridwiseGemmPipeline_v1<2, true, true>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -254,6 +254,406 @@ struct GridwiseGemmPipeline_v1<2>
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GridwiseGemmPipeline_v1<1, false, true>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
|
||||
{
|
||||
return num_loop > 1;
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename BlockwiseGemm,
|
||||
typename CThreadBuffer>
|
||||
__device__ static void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
const BlockwiseGemm& blockwise_gemm,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop)
|
||||
{
|
||||
constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
|
||||
auto a_block_buf_switch = a_block_buf;
|
||||
|
||||
// preload data into LDS
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
a_blockwise_copy.Run(
|
||||
a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
|
||||
do
|
||||
{
|
||||
a_blockwise_copy.Run(
|
||||
a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf_switch);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
a_block_buf = a_block_buf_switch;
|
||||
++i;
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GridwiseGemmPipeline_v1<1, true, false>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
|
||||
{
|
||||
return num_loop > 1;
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename BlockwiseGemm,
|
||||
typename CThreadBuffer>
|
||||
__device__ static void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
const BlockwiseGemm& blockwise_gemm,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop)
|
||||
{
|
||||
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
|
||||
auto b_block_buf_switch = b_block_buf;
|
||||
|
||||
// preload data into LDS
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.Run(
|
||||
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
|
||||
do
|
||||
{
|
||||
b_blockwise_copy.Run(
|
||||
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf_switch);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
|
||||
b_block_buf = b_block_buf_switch;
|
||||
++i;
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GridwiseGemmPipeline_v1<1, false, false>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
|
||||
{
|
||||
return num_loop > 1;
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename BlockwiseGemm,
|
||||
typename CThreadBuffer>
|
||||
__device__ static void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
const BlockwiseGemm& blockwise_gemm,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop)
|
||||
{
|
||||
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
|
||||
constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
|
||||
auto b_block_buf_switch = b_block_buf;
|
||||
auto a_block_buf_switch = a_block_buf;
|
||||
|
||||
// preload data into LDS
|
||||
a_blockwise_copy.Run(
|
||||
a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf);
|
||||
b_blockwise_copy.Run(
|
||||
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
|
||||
do
|
||||
{
|
||||
a_blockwise_copy.Run(
|
||||
a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf_switch);
|
||||
b_blockwise_copy.Run(
|
||||
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf_switch);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
a_block_buf = a_block_buf_switch;
|
||||
b_block_buf = b_block_buf_switch;
|
||||
++i;
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t NumPrefetch, bool AEnableLds, bool BEnableLds>
|
||||
struct GridwiseGemmPipeline_v1_WeightOnly;
|
||||
|
||||
template <>
|
||||
struct GridwiseGemmPipeline_v1_WeightOnly<1, true, true>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
|
||||
{
|
||||
return num_loop > 1;
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename ScaleGridDesc,
|
||||
typename ScaleGridBuffer,
|
||||
typename BlockwiseGemm,
|
||||
typename CThreadBuffer>
|
||||
__device__ static void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
const ScaleGridDesc& scale_grid_desc,
|
||||
const ScaleGridBuffer& scale_grid_buf,
|
||||
const BlockwiseGemm& blockwise_gemm,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop)
|
||||
{
|
||||
// Global Prefetch Stage 1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
// Scale read once
|
||||
b_blockwise_copy.RunScaleRead(scale_grid_desc, scale_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
// Dequantization fused in blockwise_copy
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
|
||||
do
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
++i;
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t NumPrefetch>
|
||||
struct GridwiseGemmPipelineInterwave_v1;
|
||||
|
||||
@@ -349,7 +749,7 @@ struct GridwiseGemmPipelineInterwave_v1<1>
|
||||
|
||||
// Note: 2 stage prefetch not optimized for inter-wave loop scheduler
|
||||
template <>
|
||||
struct GridwiseGemmPipelineInterwave_v1<2> : public GridwiseGemmPipeline_v1<2>
|
||||
struct GridwiseGemmPipelineInterwave_v1<2> : public GridwiseGemmPipeline_v1<2, true, true>
|
||||
{
|
||||
};
|
||||
|
||||
@@ -359,7 +759,7 @@ constexpr auto GridwiseGemmPipeline_v1_Selector()
|
||||
{
|
||||
if constexpr(LoopSched == LoopScheduler::Default)
|
||||
{
|
||||
return GridwiseGemmPipeline_v1<NumPrefetch>{};
|
||||
return GridwiseGemmPipeline_v1<NumPrefetch, true, true>{};
|
||||
}
|
||||
else if constexpr(LoopSched == LoopScheduler::Interwave)
|
||||
{
|
||||
|
||||
@@ -93,7 +93,7 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>;
|
||||
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage, true, true>;
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
{
|
||||
|
||||
@@ -18,11 +18,11 @@
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename AGridDesc_K0_M_K1,
|
||||
typename BGridDesc_K0_N_K1,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AGridDesc,
|
||||
typename BGridDesc,
|
||||
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
@@ -33,31 +33,27 @@ __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_wmma(
|
||||
const FloatA* __restrict__ p_a_grid,
|
||||
const FloatB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
// const
|
||||
// CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup
|
||||
// c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
kernel_gemm_wmma(const ADataType* __restrict__ p_a_grid,
|
||||
const BDataType* __restrict__ p_b_grid,
|
||||
CDataType* __restrict__ p_c_grid,
|
||||
const AGridDesc a_grid_desc,
|
||||
const BGridDesc b_grid_desc,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
__shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared,
|
||||
a_grid_desc_k0_m_k1,
|
||||
b_grid_desc_k0_n_k1,
|
||||
a_grid_desc,
|
||||
b_grid_desc,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
@@ -67,8 +63,8 @@ __global__ void
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
ignore = p_c_grid;
|
||||
ignore = a_grid_desc_k0_m_k1;
|
||||
ignore = b_grid_desc_k0_n_k1;
|
||||
ignore = a_grid_desc;
|
||||
ignore = b_grid_desc;
|
||||
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
@@ -78,21 +74,21 @@ __global__ void
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatAcc,
|
||||
typename FloatCShuffle,
|
||||
typename FloatC,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename CDataType,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename AGridDesc_K0_M_K1,
|
||||
typename BGridDesc_K0_N_K1,
|
||||
typename AGridDesc,
|
||||
typename BGridDesc,
|
||||
typename CGridDesc_M_N,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t K0PerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
index_t K1Value,
|
||||
@@ -105,6 +101,7 @@ template <index_t BlockSize,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool AEnableLds,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
@@ -113,6 +110,7 @@ template <index_t BlockSize,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool BEnableLds,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
@@ -121,7 +119,7 @@ template <index_t BlockSize,
|
||||
index_t NumGemmKPrefetchStage = 1,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1>
|
||||
struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
struct GridwiseGemm_Wmma
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -132,103 +130,277 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
static constexpr auto I7 = Number<7>{};
|
||||
|
||||
// K1 should be Number<...>
|
||||
// FIX ME: To be deprecated
|
||||
static constexpr auto K1 = Number<K1Value>{};
|
||||
|
||||
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
|
||||
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using GridwiseGemmPipe = remove_cvref_t<
|
||||
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
|
||||
using GridwiseGemmPipe =
|
||||
remove_cvref_t<decltype(GridwiseGemmPipeline_Selector<PipelineVer,
|
||||
NumGemmKPrefetchStage,
|
||||
LoopSched,
|
||||
AEnableLds,
|
||||
BEnableLds>())>;
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
|
||||
// Describe how data store to (LDS/VGPR) buffer from Global memory
|
||||
__host__ __device__ static constexpr auto MakeABlockDescriptor()
|
||||
{
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_block_desc_k0perblock_mperblock_k1 = [&]() {
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
constexpr auto a_block_desc = [&]() {
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
|
||||
// K0->M->K1 Per Block
|
||||
constexpr auto K0PerBlock = KPerBlock / K1;
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
|
||||
constexpr auto K0PerWmma = WmmaK / 2 / K1;
|
||||
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KWmmaPerblock>{},
|
||||
Number<MRepeat>{},
|
||||
I1,
|
||||
Number<K0PerWmma>{},
|
||||
I1,
|
||||
I1,
|
||||
K1),
|
||||
make_tuple(Number<MRepeat>{} * Number<K0PerWmma>{} * K1,
|
||||
Number<K0PerWmma>{} * K1,
|
||||
Number<K0PerWmma>{} * K1,
|
||||
K1,
|
||||
K1,
|
||||
K1,
|
||||
I1));
|
||||
}
|
||||
}();
|
||||
|
||||
return a_block_desc_k0perblock_mperblock_k1;
|
||||
return a_block_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
|
||||
__host__ __device__ static constexpr auto MakeBBlockDescriptor()
|
||||
{
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_block_desc_k0perblock_nperblock_k1 = [&]() {
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
constexpr auto b_block_desc = [&]() {
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
|
||||
// K0->N->K1 Per Block
|
||||
constexpr auto K0PerBlock = KPerBlock / K1;
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(Number<NPerBlock + 1>{} * K1, K1, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
|
||||
constexpr auto K0PerWmma = WmmaK / 2 / K1;
|
||||
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KWmmaPerblock>{},
|
||||
Number<NRepeat>{},
|
||||
I1,
|
||||
Number<K0PerWmma>{},
|
||||
I1,
|
||||
I1,
|
||||
K1),
|
||||
make_tuple(Number<NRepeat>{} * Number<K0PerWmma>{} * K1,
|
||||
Number<K0PerWmma>{} * K1,
|
||||
Number<K0PerWmma>{} * K1,
|
||||
K1,
|
||||
K1,
|
||||
K1,
|
||||
I1));
|
||||
}
|
||||
}();
|
||||
|
||||
return b_block_desc_k0perblock_nperblock_k1;
|
||||
return b_block_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeABlockSliceCopyStep()
|
||||
{
|
||||
constexpr auto a_block_copy_step = [&]() {
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
constexpr auto K0PerBlock = KPerBlock / K1;
|
||||
|
||||
return make_multi_index(K0PerBlock, 0, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
|
||||
|
||||
return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0);
|
||||
}
|
||||
}();
|
||||
|
||||
return a_block_copy_step;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeBBlockSliceCopyStep()
|
||||
{
|
||||
constexpr auto b_block_copy_step = [&]() {
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
constexpr auto K0PerBlock = KPerBlock / K1;
|
||||
|
||||
return make_multi_index(K0PerBlock, 0, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
|
||||
|
||||
return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0);
|
||||
}
|
||||
}();
|
||||
|
||||
return b_block_copy_step;
|
||||
}
|
||||
|
||||
// Describe how data read from (LDS/VGPR) buffer
|
||||
template <typename ABlockDesc_>
|
||||
__host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&)
|
||||
{
|
||||
|
||||
constexpr auto a_wave_desc = [&]() {
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
|
||||
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
|
||||
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
|
||||
constexpr auto A_KRow = I1;
|
||||
return transform_tensor_descriptor(
|
||||
ABlockDesc_{},
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
|
||||
make_pass_through_transform(Number<A_K1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
|
||||
constexpr auto KWmma = ABlockDesc_{}.GetLength(I0);
|
||||
constexpr auto K0PerWmma = ABlockDesc_{}.GetLength(I3);
|
||||
constexpr auto A_KRow = ABlockDesc_{}.GetLength(I4);
|
||||
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I6);
|
||||
|
||||
// Err: merge transform cause non-constexpr issue
|
||||
|
||||
// return transform_tensor_descriptor(
|
||||
// ABlockDesc_{},
|
||||
// make_tuple(make_merge_transform(make_tuple(Number<KWmma>{}, I1)),
|
||||
// make_pass_through_transform(Number<MRepeat>{}),
|
||||
// make_pass_through_transform(I1),
|
||||
// make_pass_through_transform(I1),
|
||||
// make_pass_through_transform(Number<A_K1>{})),
|
||||
// make_tuple(Sequence<0, 3>{},
|
||||
// Sequence<1>{},
|
||||
// Sequence<2>{},
|
||||
// Sequence<4>{},
|
||||
// Sequence<5>{}),
|
||||
// make_tuple(
|
||||
// Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{},
|
||||
// Sequence<4>{}));
|
||||
|
||||
// Workaround, Freeze transform
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(Number<KWmma * K0PerWmma>{},
|
||||
Number<MRepeat>{},
|
||||
I1,
|
||||
Number<A_KRow>{},
|
||||
I1,
|
||||
Number<A_K1>{}));
|
||||
}
|
||||
}();
|
||||
|
||||
return a_wave_desc;
|
||||
}
|
||||
|
||||
template <typename BBlockDesc_>
|
||||
__host__ __device__ static constexpr auto MakeBWaveDescriptor(const BBlockDesc_&)
|
||||
{
|
||||
constexpr auto b_wave_desc = [&]() {
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
|
||||
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
|
||||
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
|
||||
constexpr auto B_KRow = I1;
|
||||
return transform_tensor_descriptor(
|
||||
BBlockDesc_{},
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
|
||||
make_pass_through_transform(Number<B_K1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
|
||||
constexpr auto KWmma = BBlockDesc_{}.GetLength(I0);
|
||||
constexpr auto K0PerWmma = BBlockDesc_{}.GetLength(I3);
|
||||
constexpr auto B_KRow = BBlockDesc_{}.GetLength(I4);
|
||||
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I6);
|
||||
|
||||
// Workaround, Freeze transform
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(Number<KWmma * K0PerWmma>{},
|
||||
Number<NRepeat>{},
|
||||
I1,
|
||||
Number<B_KRow>{},
|
||||
I1,
|
||||
Number<B_K1>{}));
|
||||
}
|
||||
}();
|
||||
|
||||
return b_wave_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
// *Caution Here repeat is shuffle repeat
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
|
||||
{
|
||||
constexpr index_t MWave = MPerBlock / (MRepeat * MPerWmma);
|
||||
constexpr index_t NWave = NPerBlock / (NRepeat * NPerWmma);
|
||||
|
||||
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1,
|
||||
Number<CShuffleMRepeatPerShuffle * MWave * MPerWmma>{},
|
||||
Number<CShuffleMRepeatPerShuffle * MWaves * MPerWmma>{},
|
||||
I1,
|
||||
Number<CShuffleNRepeatPerShuffle * NWave * NPerWmma>{}));
|
||||
Number<CShuffleNRepeatPerShuffle * NWaves * NPerWmma>{}));
|
||||
|
||||
return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_desc_k0perblock_mperblock_k1 =
|
||||
GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
|
||||
|
||||
constexpr auto b_block_desc_k0perblock_nperblock_k1 =
|
||||
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
|
||||
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
|
||||
a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
|
||||
b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
return (a_block_space_size_aligned * sizeof(FloatA) +
|
||||
b_block_space_size_aligned * sizeof(FloatB));
|
||||
}
|
||||
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
template <typename Block2CTileMap>
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
|
||||
const CGridDesc_M_N& c_grid_desc_m_n,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
__host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const CGridDesc_M_N& c_grid_desc_m_n,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
{
|
||||
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
|
||||
"wrong! K1 need to be known at compile-time");
|
||||
@@ -237,23 +409,66 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
(NPerBlock % (NRepeat * NPerWmma)) == 0,
|
||||
"Invalid tuning param!");
|
||||
|
||||
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
|
||||
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
|
||||
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
|
||||
const auto GetAProblemsizeMK = [&]() {
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
return make_tuple(a_grid_desc.GetLength(I1),
|
||||
a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) *
|
||||
a_grid_desc.GetLength(I5),
|
||||
a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) *
|
||||
a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6));
|
||||
}
|
||||
};
|
||||
|
||||
const auto GetBProblemsizeNK = [&]() {
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
return make_tuple(b_grid_desc.GetLength(I1),
|
||||
b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I2));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) *
|
||||
b_grid_desc.GetLength(I5),
|
||||
b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) *
|
||||
b_grid_desc.GetLength(I4) * b_grid_desc.GetLength(I6));
|
||||
}
|
||||
};
|
||||
|
||||
const auto M = GetAProblemsizeMK()[I0];
|
||||
const auto N = GetBProblemsizeNK()[I0];
|
||||
const auto K = GetAProblemsizeMK()[I1];
|
||||
|
||||
if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
|
||||
K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
|
||||
K1 == b_grid_desc_k0_n_k1.GetLength(I2)))
|
||||
K == GetBProblemsizeNK()[I1]))
|
||||
{
|
||||
printf("A: MxK = %d x %d, B: NxK = %d x %d, C: MxN = %d x %d\n",
|
||||
GetAProblemsizeMK()[I0],
|
||||
GetAProblemsizeMK()[I1],
|
||||
GetBProblemsizeNK()[I0],
|
||||
GetBProblemsizeNK()[I1],
|
||||
c_grid_desc_m_n.GetLength(I0),
|
||||
c_grid_desc_m_n.GetLength(I1));
|
||||
printf("GridwiseOp err: ProblemSize check");
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
|
||||
{
|
||||
printf("GridwiseOp err: ProblemSize division");
|
||||
return false;
|
||||
}
|
||||
|
||||
// check gridwise gemm pipeline
|
||||
const auto num_k_loop = K0 / K0PerBlock;
|
||||
const auto num_k_loop = K / KPerBlock;
|
||||
|
||||
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
|
||||
{
|
||||
printf("GridwiseOp err: Pipeline not support this k_loop");
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -265,8 +480,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
|
||||
if(!(a_grid_desc_k0_m_k1.GetElementSpaceSize() * sizeof(FloatA) <= TwoGB &&
|
||||
b_grid_desc_k0_n_k1.GetElementSpaceSize() * sizeof(FloatB) <= TwoGB))
|
||||
if(!(a_grid_desc.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
|
||||
b_grid_desc.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -275,7 +490,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
|
||||
{
|
||||
const index_t num_loop = K / (K0PerBlock * K1);
|
||||
const index_t num_loop = K / KPerBlock;
|
||||
|
||||
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
|
||||
}
|
||||
@@ -313,13 +528,44 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
using DefaultBlock2CTileMap =
|
||||
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
|
||||
|
||||
struct SharedMemTrait
|
||||
{
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
|
||||
static constexpr auto max_lds_align = K1;
|
||||
|
||||
static constexpr auto a_block_space_size_aligned =
|
||||
AEnableLds ? math::integer_least_multiple(MakeABlockDescriptor().GetElementSpaceSize(),
|
||||
max_lds_align)
|
||||
: 0;
|
||||
static constexpr auto b_block_space_size_aligned =
|
||||
BEnableLds ? math::integer_least_multiple(MakeBBlockDescriptor().GetElementSpaceSize(),
|
||||
max_lds_align)
|
||||
: 0;
|
||||
|
||||
static constexpr auto a_block_space_offset = 0;
|
||||
static constexpr auto b_block_space_offset = a_block_space_size_aligned;
|
||||
|
||||
// LDS allocation for C shuffle in LDS
|
||||
static constexpr auto c_shuffle_block_space_size =
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
|
||||
.GetElementSpaceSize();
|
||||
|
||||
static constexpr auto c_shuffle_block_space_offset = 0;
|
||||
|
||||
static constexpr auto lds_size =
|
||||
math::max(c_shuffle_block_space_size * sizeof(CShuffleDataType),
|
||||
a_block_space_size_aligned * sizeof(ADataType) +
|
||||
b_block_space_size_aligned * sizeof(BDataType));
|
||||
};
|
||||
|
||||
template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
|
||||
__device__ static void Run(const FloatA* __restrict__ p_a_grid,
|
||||
const FloatB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
__device__ static void Run(const ADataType* __restrict__ p_a_grid,
|
||||
const BDataType* __restrict__ p_b_grid,
|
||||
CDataType* __restrict__ p_c_grid,
|
||||
void* __restrict__ p_shared,
|
||||
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
|
||||
const AGridDesc& a_grid_desc,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
@@ -331,9 +577,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
/*******************************************************************************/
|
||||
// Memory buffer zone.
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
|
||||
p_a_grid, a_grid_desc.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize());
|
||||
p_b_grid, b_grid_desc.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
@@ -351,24 +597,41 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
|
||||
|
||||
/*******************************************************************************/
|
||||
// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy
|
||||
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
|
||||
constexpr auto max_lds_align = K1;
|
||||
constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
|
||||
constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock,
|
||||
// BlockLevel, A/B Matrix ThreadMapping in WMMA Source buffer, As Destinaion of BlockWise_Copy
|
||||
const auto K = [&](){
|
||||
if constexpr(AEnableLds){
|
||||
return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2);
|
||||
}
|
||||
else{
|
||||
return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3)
|
||||
* a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6);
|
||||
}
|
||||
}();
|
||||
|
||||
constexpr auto a_block_desc = MakeABlockDescriptor();
|
||||
constexpr auto b_block_desc = MakeBBlockDescriptor();
|
||||
|
||||
auto a_block_trait = [&](){
|
||||
// A matrix blockwise copy
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
constexpr auto K0PerBlock = KPerBlock/ K1;
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<ADataType*>(p_shared),
|
||||
SharedMemTrait::a_block_space_size_aligned);
|
||||
|
||||
auto a_blockwise_copy =
|
||||
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
|
||||
/* typename SrcElementwiseOperation, */ AElementwiseOperation,
|
||||
/* typename DstElementwiseOperation, */ ck::tensor_operation::element_wise::PassThrough,
|
||||
/* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set,
|
||||
/* typename BlockSliceLengths, */ Sequence<K0PerBlock, MPerBlock, K1>,
|
||||
/* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder,
|
||||
/* typename SrcData, */ FloatA,
|
||||
/* typename DstData, */ FloatA,
|
||||
/* typename SrcDesc, */ decltype(a_grid_desc_k0_m_k1),
|
||||
/* typename DstDesc, */ decltype(a_block_desc_k0perblock_mperblock_k1),
|
||||
/* typename SrcData, */ ADataType,
|
||||
/* typename DstData, */ ADataType,
|
||||
/* typename SrcDesc, */ decltype(a_grid_desc),
|
||||
/* typename DstDesc, */ decltype(a_block_desc),
|
||||
/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder,
|
||||
/* typename DstDimAccessOrder, */ Sequence<0, 1, 2>,
|
||||
/* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim,
|
||||
@@ -378,99 +641,197 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
/* index_t SrcScalarStrideInVector, */ 1,
|
||||
/* index_t DstScalarStrideInVector, */ 1,
|
||||
/* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun,
|
||||
/* bool ThreadTransferDstResetCoordinateAfterRun, */ true>(
|
||||
a_grid_desc_k0_m_k1,
|
||||
/* bool ThreadTransferDstResetCoordinateAfterRun, */ true,
|
||||
NumGemmKPrefetchStage>(
|
||||
a_grid_desc,
|
||||
make_multi_index(0, m_block_data_idx_on_grid, 0),
|
||||
a_element_op,
|
||||
a_block_desc_k0perblock_mperblock_k1,
|
||||
a_block_desc,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
|
||||
BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<K0PerBlock, NPerBlock, K1>,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatB,
|
||||
FloatB,
|
||||
decltype(b_grid_desc_k0_n_k1),
|
||||
decltype(b_block_desc_k0perblock_nperblock_k1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
b_grid_desc_k0_n_k1,
|
||||
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
b_block_desc_k0perblock_nperblock_k1,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
return make_tuple(a_block_buf, a_blockwise_copy);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Thread-wise copy
|
||||
// KPerBlock/WmmaK -> MRepeat -> MWaves -> K0PerWmma -> KRow -> MPerWmma -> K1
|
||||
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
|
||||
constexpr auto K0PerWmma = WmmaK/2/K1Value;
|
||||
auto a_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>(
|
||||
a_block_desc.GetElementSpaceSize());
|
||||
|
||||
// Limitation: NumDim of Src and Dst descriptor should be identical
|
||||
auto a_blockwise_copy =
|
||||
ThreadwiseTensorSliceTransfer_v2<ADataType,
|
||||
ADataType,
|
||||
decltype(a_grid_desc),
|
||||
decltype(a_block_desc),
|
||||
Sequence<Number<KWmmaPerBlock>{},
|
||||
Number<MRepeat>{},
|
||||
I1,
|
||||
Number<K0PerWmma>{},
|
||||
I1,
|
||||
I1,
|
||||
Number<K1Value>{}>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6>,
|
||||
6,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
a_grid_desc,
|
||||
make_multi_index(0,
|
||||
m_block_data_idx_on_grid/(MWaves * MPerWmma),
|
||||
get_thread_local_1d_id() / 32,
|
||||
0,
|
||||
(get_thread_local_1d_id() % 32 )/ 16,
|
||||
get_thread_local_1d_id() % 16,
|
||||
0));
|
||||
|
||||
return make_tuple(a_block_buf, a_blockwise_copy);
|
||||
}
|
||||
};
|
||||
|
||||
auto b_block_trait = [&](){
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
constexpr auto K0PerBlock = KPerBlock/ K1;
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<BDataType*>(p_shared) + SharedMemTrait::b_block_space_offset,
|
||||
SharedMemTrait::b_block_space_size_aligned);
|
||||
|
||||
auto b_blockwise_copy =
|
||||
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
|
||||
BElementwiseOperation,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<K0PerBlock, NPerBlock, K1>,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BDataType,
|
||||
BDataType,
|
||||
decltype(b_grid_desc),
|
||||
decltype(b_block_desc),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true,
|
||||
NumGemmKPrefetchStage>(
|
||||
b_grid_desc,
|
||||
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
||||
b_element_op,
|
||||
b_block_desc,
|
||||
make_multi_index(0, 0, 0),
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
|
||||
return make_tuple(b_block_buf, b_blockwise_copy);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Thread-wise copy
|
||||
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
|
||||
constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
|
||||
constexpr auto K0PerWmma = WmmaK/2/K1Value;
|
||||
auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
|
||||
b_block_desc.GetElementSpaceSize());
|
||||
|
||||
// Limitation: NumDim of Src and Dst descriptor should be identical
|
||||
auto b_blockwise_copy =
|
||||
ThreadwiseTensorSliceTransfer_v2<BDataType,
|
||||
BDataType,
|
||||
decltype(b_grid_desc),
|
||||
decltype(b_block_desc),
|
||||
Sequence<Number<KWmmaPerBlock>{},
|
||||
Number<NRepeat>{},
|
||||
I1,
|
||||
Number<K0PerWmma>{},
|
||||
I1,
|
||||
I1,
|
||||
Number<K1Value>{}>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6>,
|
||||
6,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
b_grid_desc,
|
||||
make_multi_index(0,
|
||||
n_block_data_idx_on_grid/(NWaves * NPerWmma),
|
||||
get_thread_local_1d_id() / 32,
|
||||
0,
|
||||
(get_thread_local_1d_id() % 32 )/ 16,
|
||||
get_thread_local_1d_id() % 16,
|
||||
0));
|
||||
|
||||
return make_tuple(b_block_buf, b_blockwise_copy);
|
||||
}
|
||||
};
|
||||
|
||||
auto a_block_buf = a_block_trait()[I0];
|
||||
auto a_blockwise_copy = a_block_trait()[I1];
|
||||
|
||||
auto b_block_buf = b_block_trait()[I0];
|
||||
auto b_blockwise_copy = b_block_trait()[I1];
|
||||
/*******************************************************************************/
|
||||
// GEMM
|
||||
constexpr auto WmmaK = 16;
|
||||
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
|
||||
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle<BlockSize,
|
||||
FloatA,
|
||||
FloatB,
|
||||
FloatAcc,
|
||||
decltype(a_block_desc_k0perblock_mperblock_k1),
|
||||
decltype(b_block_desc_k0perblock_nperblock_k1),
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
BlockwiseGemmWMMA<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
decltype(MakeAWaveDescriptor(a_block_desc)),
|
||||
decltype(MakeBWaveDescriptor(b_block_desc)),
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
AEnableLds,
|
||||
BEnableLds>{};
|
||||
|
||||
// Prepare Register for C matrix
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
/*******************************************************************************/
|
||||
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align);
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<FloatA*>(p_shared), a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize());
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(static_cast<FloatB*>(p_shared) + a_block_space_size_aligned, b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize());
|
||||
|
||||
/*******************************************************************************/
|
||||
// Shift Per SUB_K
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
|
||||
constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep();
|
||||
constexpr auto b_block_slice_copy_step = MakeBBlockSliceCopyStep();
|
||||
|
||||
// gridwise GEMM pipeline
|
||||
const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
|
||||
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_k0_m_k1,
|
||||
a_block_desc_k0perblock_mperblock_k1,
|
||||
const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock);
|
||||
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc,
|
||||
a_block_desc,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_k0_n_k1,
|
||||
b_block_desc_k0perblock_nperblock_k1,
|
||||
b_grid_desc,
|
||||
b_block_desc,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
blockwise_gemm,
|
||||
c_thread_buf,
|
||||
K0BlockMainLoop);
|
||||
KBlockMainLoop);
|
||||
/*******************************************************************************/
|
||||
// write out to C, implement shuffle
|
||||
{
|
||||
// C mapping in single thread.
|
||||
constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
|
||||
blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
|
||||
|
||||
// This API Provide All dimension (size) you need
|
||||
// C mapping in single block
|
||||
constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
|
||||
blockwise_gemm.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
|
||||
|
||||
@@ -485,8 +846,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
|
||||
|
||||
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<FloatCShuffle*>(p_shared),
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize());
|
||||
static_cast<CShuffleDataType*>(p_shared) + SharedMemTrait::c_shuffle_block_space_offset,
|
||||
SharedMemTrait::c_shuffle_block_space_size);
|
||||
|
||||
constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = transform_tensor_descriptor(
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
|
||||
@@ -532,8 +893,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
|
||||
// shuffle: threadwise copy C from VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
|
||||
FloatCShuffle,
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
|
||||
decltype(c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
@@ -571,8 +932,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
|
||||
FloatCShuffle, // typename SrcData,
|
||||
FloatC, // typename DstData,
|
||||
CShuffleDataType, // typename SrcData,
|
||||
CDataType, // typename DstData,
|
||||
decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
|
||||
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
|
||||
@@ -636,6 +997,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
|
||||
|
||||
// move on C
|
||||
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
|
||||
|
||||
@@ -39,7 +39,7 @@ __global__ void
|
||||
const CElementwiseOperation c_element_op)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx94__))
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
|
||||
|
||||
__shared__ uint8_t p_shared[shared_size];
|
||||
|
||||
@@ -9,7 +9,6 @@
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
|
||||
@@ -96,7 +95,10 @@ template <index_t BlockSize,
|
||||
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1,
|
||||
typename ComputeType = FloatC>
|
||||
typename ComputeTypeA = FloatC,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
typename LDSTypeA = ComputeTypeA,
|
||||
typename LDSTypeB = ComputeTypeB>
|
||||
struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -430,7 +432,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
constexpr auto c_block_size =
|
||||
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize();
|
||||
|
||||
return math::max((a_block_space_size + b_block_space_size) * sizeof(ComputeType),
|
||||
return math::max(a_block_space_size * sizeof(LDSTypeA) +
|
||||
b_block_space_size * sizeof(LDSTypeB),
|
||||
c_block_size * sizeof(FloatC));
|
||||
}
|
||||
|
||||
@@ -785,7 +788,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatA,
|
||||
ComputeType,
|
||||
LDSTypeA,
|
||||
decltype(a_b_k0_m_k1_grid_desc),
|
||||
decltype(a_b_k0_m_k1_block_desc),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
@@ -815,7 +818,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatB,
|
||||
ComputeType,
|
||||
LDSTypeB,
|
||||
decltype(b_b_k0_n_k1_grid_desc),
|
||||
decltype(b_b_k0_n_k1_block_desc),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
@@ -845,8 +848,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
ComputeType, // ComputeType A
|
||||
ComputeType, // ComputeType B
|
||||
LDSTypeA,
|
||||
LDSTypeB,
|
||||
FloatAcc,
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
decltype(b_k0_n_k1_block_desc),
|
||||
@@ -855,7 +858,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
K1,
|
||||
LoopSched>();
|
||||
LoopSched,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>();
|
||||
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
@@ -863,8 +868,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
constexpr auto a_block_space_size =
|
||||
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
ComputeType* p_a_block = static_cast<ComputeType*>(p_shared_block);
|
||||
ComputeType* p_b_block = static_cast<ComputeType*>(p_shared_block) + a_block_space_size;
|
||||
auto p_a_block = reinterpret_cast<LDSTypeA*>(p_shared_block);
|
||||
auto p_b_block = reinterpret_cast<LDSTypeB*>(p_a_block + a_block_space_size);
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
|
||||
|
||||
@@ -8,6 +8,8 @@
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
|
||||
@@ -1156,27 +1158,56 @@ struct ThreadwiseTensorSliceTransfer_v4
|
||||
src_ref_to_origin_disp_idx + data_to_origin_disp_idx +
|
||||
i * src_scalar_step_in_vector);
|
||||
|
||||
// apply type convert
|
||||
src_tmp_vector.template AsType<SrcData>()(i) = src_buf[Number<src_offset>{}];
|
||||
});
|
||||
}
|
||||
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
|
||||
// DstData)
|
||||
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
|
||||
|
||||
// TODO: if SrcData and DstData are vetor type, then static_cast may not compile
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
|
||||
dst_tmp_vector.template AsType<DstData>()(i) =
|
||||
type_convert<DstData>(src_tmp_vector.template AsType<SrcData>()[i]);
|
||||
});
|
||||
if constexpr(is_same<remove_cvref_t<SrcData>, f8_t>::value &&
|
||||
is_same<remove_cvref_t<DstData>, half_t>::value &&
|
||||
SrcScalarPerVector % 2 == 0)
|
||||
{
|
||||
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
|
||||
// DstData)
|
||||
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
|
||||
|
||||
// copy data from dst_tmp_vector into dst_buf
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
|
||||
constexpr index_t dst_offset = dst_desc.CalculateOffset(
|
||||
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
|
||||
constexpr index_t pack_size = 2;
|
||||
|
||||
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
|
||||
});
|
||||
using dst_v_t = typename vector_type_maker_t<DstData, pack_size>::type;
|
||||
using src_v_t = typename vector_type_maker_t<SrcData, pack_size>::type;
|
||||
static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) {
|
||||
ck::tensor_operation::element_wise::PassThroughPack2{}(
|
||||
dst_tmp_vector.template AsType<dst_v_t>()(i),
|
||||
src_tmp_vector.template AsType<src_v_t>()[i]);
|
||||
});
|
||||
|
||||
// copy data from dst_tmp_vector into dst_buf
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
|
||||
constexpr index_t dst_offset = dst_desc.CalculateOffset(
|
||||
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
|
||||
|
||||
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
|
||||
// DstData)
|
||||
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
|
||||
|
||||
// TODO: if SrcData and DstData are vetor type, then static_cast may not compile
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
|
||||
dst_tmp_vector.template AsType<DstData>()(i) =
|
||||
type_convert<DstData>(src_tmp_vector.template AsType<SrcData>()[i]);
|
||||
});
|
||||
|
||||
// copy data from dst_tmp_vector into dst_buf
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
|
||||
constexpr index_t dst_offset = dst_desc.CalculateOffset(
|
||||
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
|
||||
|
||||
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1302,4 +1333,139 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
|
||||
ElementwiseOperation element_op_;
|
||||
};
|
||||
|
||||
// Specilized for WMMA
|
||||
// A single Wave32 is composed by double row
|
||||
// Data exchange allowed between these two rows
|
||||
// This RowLane Dst buf will be filled from two Src buf
|
||||
// SrcA: From specific thread buffer hold by This RowLane on This Row
|
||||
// SrcB: From specific thread buffer hold by This RowLane on The other Row
|
||||
template <typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename ElementwiseOperation,
|
||||
typename SliceLengths,
|
||||
typename DimAccessOrder,
|
||||
index_t DstVectorDim,
|
||||
index_t DstScalarPerVector,
|
||||
uint32_t LowEightRowlaneIdx,
|
||||
uint32_t HighEightRowLaneIdx,
|
||||
bool IntraRowSwizzlePerm,
|
||||
typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow(const Index& src_idx)
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc need to known at compile-time");
|
||||
|
||||
static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
|
||||
"wrong! Not divisible");
|
||||
ignore = src_idx;
|
||||
}
|
||||
|
||||
template <typename SrcSliceOriginIdx,
|
||||
typename DstSliceOriginIdx,
|
||||
typename SrcBuffer,
|
||||
typename DstBuffer>
|
||||
__device__ void Run(const SrcDesc&,
|
||||
const SrcSliceOriginIdx&,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc&,
|
||||
const DstSliceOriginIdx&,
|
||||
DstBuffer& dst_buf) const
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc need to known at compile-time");
|
||||
|
||||
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value &&
|
||||
is_known_at_compile_time<remove_cvref_t<DstSliceOriginIdx>>::value,
|
||||
"wrong! SliceOrigin need to known at compile-time");
|
||||
|
||||
static_assert(SrcBuffer::IsStaticBuffer() && DstBuffer::IsStaticBuffer(),
|
||||
"wrong! Buffer need to be StaticBuffer");
|
||||
|
||||
// SrcDesc and src_slice_origin_idx are known at compile-time
|
||||
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
|
||||
constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
|
||||
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
|
||||
constexpr auto dst_slice_origin_idx = to_multi_index(DstSliceOriginIdx{});
|
||||
|
||||
// scalar per access on each dim
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_scalar_step_in_vector =
|
||||
generate_sequence(detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
|
||||
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(dst_scalar_per_access)>>;
|
||||
|
||||
static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector,
|
||||
"wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
|
||||
|
||||
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto idx_1d) {
|
||||
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
|
||||
|
||||
// copy data from src_buf into dst_vector
|
||||
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
|
||||
// src_desc error, non constexpr, caused by merge transform
|
||||
constexpr index_t src_offset = src_desc.CalculateOffset(
|
||||
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
||||
|
||||
constexpr index_t dst_offset = dst_desc.CalculateOffset(
|
||||
dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
||||
|
||||
SrcData v_this_row, v_theother_row;
|
||||
// int type temp value due to intrinsic requirement
|
||||
int temp = 0;
|
||||
|
||||
// apply element-wise operation
|
||||
element_op_(v_this_row, src_buf[Number<src_offset>{}]);
|
||||
|
||||
// apply intra-row permute.
|
||||
if constexpr(IntraRowSwizzlePerm)
|
||||
{
|
||||
temp = __builtin_amdgcn_permlane16(
|
||||
temp, type_convert_sp<int>(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0);
|
||||
v_this_row = type_convert_sp<SrcData>(temp);
|
||||
}
|
||||
|
||||
// apply inter-row permute.
|
||||
temp = __builtin_amdgcn_permlanex16(temp,
|
||||
type_convert_sp<int>(v_this_row),
|
||||
LowEightRowlaneIdx,
|
||||
HighEightRowLaneIdx,
|
||||
1,
|
||||
0);
|
||||
v_theother_row = type_convert_sp<SrcData>(temp);
|
||||
|
||||
if(get_thread_local_1d_id() % 32 < 16)
|
||||
{
|
||||
// apply type convert
|
||||
dst_buf(Number<dst_offset>{}) = type_convert_sp<DstData>(v_this_row);
|
||||
dst_buf(Number<dst_offset + DstScalarPerVector>{}) =
|
||||
type_convert_sp<DstData>(v_theother_row);
|
||||
}
|
||||
else
|
||||
{
|
||||
// apply type convert
|
||||
dst_buf(Number<dst_offset + DstScalarPerVector>{}) =
|
||||
type_convert_sp<DstData>(v_this_row);
|
||||
dst_buf(Number<dst_offset>{}) = type_convert_sp<DstData>(v_theother_row);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
ElementwiseOperation element_op_{};
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,804 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor/static_tensor.hpp"
|
||||
#include "ck/utility/is_detected.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Assume:
|
||||
// 1. src_desc and dst_desc are not known at compile-time
|
||||
// 2. SrcBuffer and DstBuffer are DynamicBuffer
|
||||
// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
|
||||
// 4. Use thread buffer
|
||||
template <typename SliceLengths,
|
||||
typename ElementwiseOperation,
|
||||
typename DstInMemOps, // Sequence
|
||||
typename SrcDatas,
|
||||
typename DstDatas,
|
||||
typename SrcDescs,
|
||||
typename DstDescs,
|
||||
typename SrcDimAccessOrder,
|
||||
typename DstDimAccessOrder,
|
||||
index_t SrcVectorDim,
|
||||
index_t DstVectorDim,
|
||||
typename SrcsScalarPerVector, // Sequence
|
||||
typename DstsScalarPerVector, // Sequence
|
||||
typename SrcsScalarStrideInVector, // Sequence
|
||||
typename DstsScalarStrideInVector, // Sequence
|
||||
typename SrcsResetCoordinateAfterRun, // control whether to move back src coordinate after
|
||||
// each RunRead(), will be fused with
|
||||
// MoveSrcSliceWindow to save addr computation
|
||||
typename DstsResetCoordinateAfterRun, // control whether to move back dst coordinate after
|
||||
// each RunWrite(), will be fused with
|
||||
// MoveDstSliceWindow to save addr computation
|
||||
index_t NumThreadScratch = 1>
|
||||
struct ThreadwiseTensorSliceTransfer_v3r2
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
static constexpr index_t nSrc = SrcDescs::Size();
|
||||
static constexpr index_t nDst = DstDescs::Size();
|
||||
|
||||
// return a tuple of coordiantes for a tuple of tensor
|
||||
template <typename Descs,
|
||||
typename Indices,
|
||||
enable_if_t<Descs::Size() == Indices::Size(), bool> = false>
|
||||
static constexpr auto MakeCoordinates(const Descs& descs, const Indices& indices)
|
||||
{
|
||||
return generate_tuple([&](auto i) { return make_tensor_coordinate(descs[i], indices[i]); },
|
||||
Number<Descs::Size()>{});
|
||||
}
|
||||
|
||||
using SrcCoords = decltype(MakeCoordinates(SrcDescs{}, StaticallyIndexedArray<Index, nSrc>{}));
|
||||
using DstCoords = decltype(MakeCoordinates(DstDescs{}, StaticallyIndexedArray<Index, nDst>{}));
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r2(
|
||||
const SrcDescs& src_descs,
|
||||
const StaticallyIndexedArray<Index, nSrc>& src_slice_origins,
|
||||
const DstDescs& dst_descs,
|
||||
const StaticallyIndexedArray<Index, nDst>& dst_slice_origins,
|
||||
const ElementwiseOperation& element_op)
|
||||
: src_coords_(MakeCoordinates(src_descs, src_slice_origins)),
|
||||
dst_coords_(MakeCoordinates(dst_descs, dst_slice_origins)),
|
||||
element_op_(element_op)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename Indices, enable_if_t<SrcDescs::Size() == Indices::Size(), bool> = false>
|
||||
__device__ void SetSrcSliceOrigins(const SrcDescs& src_descs,
|
||||
const Indices& src_slice_origin_idxs)
|
||||
{
|
||||
static_for<0, nSrc, 1>{}([&](auto src_i) {
|
||||
src_coords_(src_i) =
|
||||
make_tensor_coordinate(src_descs.At(src_i), src_slice_origin_idxs[src_i]);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Indices, enable_if_t<DstDescs::Size() == Indices::Size(), bool> = false>
|
||||
__device__ void SetDstSliceOrigins(const DstDescs& dst_descs,
|
||||
const Indices& dst_slice_origin_idxs)
|
||||
{
|
||||
static_for<0, nDst, 1>{}([&](auto dst_i) {
|
||||
dst_coords_(dst_i) =
|
||||
make_tensor_coordinate(dst_descs.At(dst_i), dst_slice_origin_idxs[dst_i]);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename SrcBuffers, index_t ThreadScratchId = 0>
|
||||
__device__ void RunRead(const SrcDescs& src_descs,
|
||||
const SrcBuffers& src_bufs,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto src_scalar_per_access_tuple = generate_tuple(
|
||||
[&](auto src_i) {
|
||||
return generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim,
|
||||
SrcsScalarPerVector::At(src_i)>{},
|
||||
Number<nDim>{});
|
||||
},
|
||||
Number<nSrc>{});
|
||||
|
||||
constexpr auto src_access_lengths_tuple = generate_tuple(
|
||||
[&](auto src_i) {
|
||||
return SliceLengths{} / src_scalar_per_access_tuple.At(src_i);
|
||||
static_assert(
|
||||
SliceLengths::At(SrcVectorDim) % SrcsScalarPerVector::At(src_i) == 0,
|
||||
"SliceLengths[SrcVectorDim] must be divisible by SrcsScalarPerVector");
|
||||
},
|
||||
Number<nSrc>{});
|
||||
|
||||
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_src_access_lengths_tuple = generate_tuple(
|
||||
[&](auto src_i) {
|
||||
return container_reorder_given_new2old(src_access_lengths_tuple.At(src_i),
|
||||
src_dim_access_order);
|
||||
},
|
||||
Number<nSrc>{});
|
||||
|
||||
// make forward steps
|
||||
const auto src_forward_steps_tuple = generate_tuple(
|
||||
[&](auto src_i) {
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step_idx(j) =
|
||||
(i.value == j.value) ? src_scalar_per_access_tuple.At(src_i)[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(src_descs.At(src_i), forward_step_idx);
|
||||
},
|
||||
Number<nDim>{});
|
||||
},
|
||||
Number<nSrc>{});
|
||||
|
||||
// make backward steps
|
||||
const auto src_backward_steps_tuple = generate_tuple(
|
||||
[&](auto src_i) {
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step_idx(j) = (i.value == j.value)
|
||||
? -src_scalar_per_access_tuple.At(src_i)[i]
|
||||
: 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(src_descs.At(src_i), backward_step_idx);
|
||||
},
|
||||
Number<nDim>{});
|
||||
},
|
||||
Number<nSrc>{});
|
||||
|
||||
// loop over tensor and copy
|
||||
static_for<0, nSrc, 1>{}([&](auto src_i) {
|
||||
static_ford<remove_cvref_t<decltype(ordered_src_access_lengths_tuple.At(src_i))>>{}(
|
||||
[&](auto ordered_src_access_idx) {
|
||||
// judge move forward or move backward
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_src_access_idx[I0];
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_src_access_lengths_tuple[j] +
|
||||
ordered_src_access_idx[j];
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate src data index
|
||||
constexpr auto src_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i]
|
||||
? ordered_src_access_idx[i]
|
||||
: ordered_src_access_lengths_tuple.At(src_i)[i] -
|
||||
1 - ordered_src_access_idx[i];
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
|
||||
src_scalar_per_access_tuple.At(src_i);
|
||||
}();
|
||||
|
||||
constexpr auto src_data_idx_seq =
|
||||
generate_sequence_v2([&](auto i) { return Number<src_data_idx[i]>{}; },
|
||||
Number<src_data_idx.Size()>{});
|
||||
|
||||
const bool is_src_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(
|
||||
src_descs.At(src_i), src_coords_.At(src_i));
|
||||
|
||||
using src_vector_type = vector_type_maker_t<tuple_element_t<src_i, SrcDatas>,
|
||||
SrcsScalarPerVector::At(src_i)>;
|
||||
using src_vector_t = typename src_vector_type::type;
|
||||
|
||||
// copy data from src_buf into src_vector_container
|
||||
auto src_vector_container =
|
||||
src_vector_type{src_bufs.At(src_i).template Get<src_vector_t>(
|
||||
src_coords_.At(src_i).GetOffset(), is_src_valid)};
|
||||
|
||||
// copy data from src_vector_container into src_thread_scratch_
|
||||
src_thread_scratch_tuple_(thread_scratch_id)
|
||||
.At(src_i)
|
||||
.template SetAsType<src_vector_t>(
|
||||
src_data_idx_seq,
|
||||
src_vector_container.template AsType<src_vector_t>()[I0]);
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
move_on_dim_(i) = ordered_src_access_idx[i] <
|
||||
ordered_src_access_lengths_tuple.At(src_i)[i] - 1;
|
||||
|
||||
static_for<i + 1, nDim, 1>{}([&](auto j) {
|
||||
move_on_dim_(i) &=
|
||||
ordered_src_access_idx[j] ==
|
||||
ordered_src_access_lengths_tuple.At(src_i)[j] - 1;
|
||||
});
|
||||
});
|
||||
|
||||
return move_on_dim_;
|
||||
}
|
||||
();
|
||||
|
||||
// move src coord
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
if constexpr(move_on_dim[i])
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
src_descs.At(src_i),
|
||||
src_coords_.At(src_i),
|
||||
src_forward_steps_tuple.At(src_i)[src_dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
src_descs.At(src_i),
|
||||
src_coords_.At(src_i),
|
||||
src_backward_steps_tuple.At(src_i)[src_dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, nSrc, 1>{}([&](auto src_i) {
|
||||
// move src coordinate back to slice origin (or not)
|
||||
if constexpr(SrcsResetCoordinateAfterRun::At(src_i))
|
||||
{
|
||||
const auto src_reset_step = make_tensor_coordinate_step(
|
||||
src_descs.At(src_i), GetSrcCoordinateResetStep<src_i>());
|
||||
|
||||
move_tensor_coordinate(src_descs.At(src_i), src_coords_.At(src_i), src_reset_step);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t ThreadScratchId>
|
||||
__device__ void
|
||||
TransferDataFromSrcThreadScratchToDstThreadScratch(Number<ThreadScratchId> thread_scratch_id)
|
||||
{
|
||||
// TODO: Add support for CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
|
||||
// (it requires to add Elementwise support in transpose_vectors)
|
||||
static_ford<SliceLengths>{}([&](auto idx) {
|
||||
const auto src_data_refs = generate_tie(
|
||||
[&](auto src_i) -> const auto& {
|
||||
return src_thread_scratch_tuple_[thread_scratch_id].At(src_i)[idx];
|
||||
},
|
||||
Number<nSrc>{});
|
||||
|
||||
auto dst_data_refs = generate_tie(
|
||||
[&](auto dst_i) -> auto& { return dst_thread_scratch_tuple_.At(dst_i)(idx); },
|
||||
Number<nDst>{});
|
||||
unpack2(element_op_, dst_data_refs, src_data_refs);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename DstBuffers, index_t ThreadScratchId = 0>
|
||||
__device__ void RunWrite(const DstDescs& dst_descs,
|
||||
DstBuffers& dst_bufs,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
// if there is transpose, it's done here
|
||||
// TODO move this elsewhere
|
||||
TransferDataFromSrcThreadScratchToDstThreadScratch(thread_scratch_id);
|
||||
|
||||
// src scalar per access on each dim
|
||||
// TODO: don't use this
|
||||
constexpr auto dst_scalar_per_access_tuple = generate_tuple(
|
||||
[&](auto dst_i) {
|
||||
return generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim,
|
||||
DstsScalarPerVector::At(dst_i)>{},
|
||||
Number<nDim>{});
|
||||
},
|
||||
Number<nDst>{});
|
||||
|
||||
constexpr auto dst_access_lengths_tuple = generate_tuple(
|
||||
[&](auto dst_i) { return SliceLengths{} / dst_scalar_per_access_tuple.At(dst_i); },
|
||||
Number<nDst>{});
|
||||
|
||||
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_dst_access_lengths_tuple = generate_tuple(
|
||||
[&](auto dst_i) {
|
||||
return container_reorder_given_new2old(dst_access_lengths_tuple.At(dst_i),
|
||||
dst_dim_access_order);
|
||||
},
|
||||
Number<nDst>{});
|
||||
|
||||
// make forward steps
|
||||
const auto dst_forward_steps_tuple = generate_tuple(
|
||||
[&](auto dst_i) {
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step_idx(j) =
|
||||
(i.value == j.value) ? dst_scalar_per_access_tuple.At(dst_i)[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(dst_descs.At(dst_i), forward_step_idx);
|
||||
},
|
||||
Number<nDim>{});
|
||||
},
|
||||
Number<nDst>{});
|
||||
|
||||
// make backward steps
|
||||
const auto dst_backward_steps_tuple = generate_tuple(
|
||||
[&](auto dst_i) {
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step_idx(j) = (i.value == j.value)
|
||||
? -dst_scalar_per_access_tuple.At(dst_i)[i]
|
||||
: 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(dst_descs.At(dst_i), backward_step_idx);
|
||||
},
|
||||
Number<nDim>{});
|
||||
},
|
||||
Number<nDst>{});
|
||||
|
||||
// loop over tensor and copy
|
||||
static_for<0, nDst, 1>{}([&](auto dst_i) {
|
||||
static_ford<remove_cvref_t<decltype(ordered_dst_access_lengths_tuple.At(dst_i))>>{}(
|
||||
[&](auto ordered_dst_access_idx) {
|
||||
// judge move forward or move backward
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_dst_access_idx[I0];
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_dst_access_lengths_tuple.At(dst_i)[j] +
|
||||
ordered_dst_access_idx[j];
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate dst data index
|
||||
constexpr auto dst_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i]
|
||||
? ordered_dst_access_idx[i]
|
||||
: ordered_dst_access_lengths_tuple.At(dst_i)[i] -
|
||||
1 - ordered_dst_access_idx[i];
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
|
||||
dst_scalar_per_access_tuple.At(dst_i);
|
||||
}();
|
||||
|
||||
constexpr auto dst_data_idx_seq =
|
||||
generate_sequence_v2([&](auto i) { return Number<dst_data_idx[i]>{}; },
|
||||
Number<dst_data_idx.Size()>{});
|
||||
|
||||
const bool is_dst_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(
|
||||
dst_descs.At(dst_i), dst_coords_.At(dst_i));
|
||||
|
||||
using dst_vector_type = vector_type_maker_t<tuple_element_t<dst_i, DstDatas>,
|
||||
DstsScalarPerVector::At(dst_i)>;
|
||||
using dst_vector_t = typename dst_vector_type::type;
|
||||
|
||||
// copy data from dst_thread_scratch_ into dst_vector_container
|
||||
auto dst_vector_container = dst_vector_type{
|
||||
dst_thread_scratch_tuple_.At(dst_i).template GetAsType<dst_vector_t>(
|
||||
dst_data_idx_seq)};
|
||||
|
||||
constexpr InMemoryDataOperationEnum DstInMemOp =
|
||||
static_cast<InMemoryDataOperationEnum>(DstInMemOps::At(dst_i.value));
|
||||
|
||||
// copy data from dst_vector_container to dst_buf
|
||||
dst_bufs.At(dst_i).template Update<DstInMemOp, dst_vector_t>(
|
||||
dst_coords_.At(dst_i).GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector_container.template AsType<dst_vector_t>()[I0]);
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
move_on_dim_(i) = ordered_dst_access_idx[i] <
|
||||
ordered_dst_access_lengths_tuple.At(dst_i)[i] - 1;
|
||||
|
||||
static_for<i + 1, nDim, 1>{}([&](auto j) {
|
||||
move_on_dim_(i) &=
|
||||
ordered_dst_access_idx[j] ==
|
||||
ordered_dst_access_lengths_tuple.At(dst_i)[j] - 1;
|
||||
});
|
||||
});
|
||||
|
||||
return move_on_dim_;
|
||||
}
|
||||
();
|
||||
|
||||
// move dst coord
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
if constexpr(move_on_dim[i])
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
dst_descs.At(dst_i),
|
||||
dst_coords_.At(dst_i),
|
||||
dst_forward_steps_tuple.At(dst_i)[dst_dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
dst_descs.At(dst_i),
|
||||
dst_coords_.At(dst_i),
|
||||
dst_backward_steps_tuple.At(dst_i)[dst_dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// move dst coordinate back to slice origin (or not)
|
||||
static_for<0, nDst, 1>{}([&](auto dst_i) {
|
||||
if constexpr(DstsResetCoordinateAfterRun::At(dst_i))
|
||||
{
|
||||
const auto dst_reset_step = make_tensor_coordinate_step(
|
||||
dst_descs.At(dst_i), GetDstCoordinateResetStep<dst_i>());
|
||||
|
||||
move_tensor_coordinate(dst_descs.At(dst_i), dst_coords_.At(dst_i), dst_reset_step);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t src_i>
|
||||
__device__ static constexpr auto GetSrcCoordinateResetStep()
|
||||
{
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcsScalarPerVector::At(src_i)>{},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_src_access_lengths =
|
||||
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
|
||||
|
||||
// judge move forward or move backward during the last iteration
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_src_access_lengths[I0] - 1;
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1;
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate src data index after last iteration in RunRead(), if it has not being reset by
|
||||
// RunRead()
|
||||
constexpr auto src_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0;
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
|
||||
src_scalar_per_access;
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr auto reset_src_data_step = [&]() {
|
||||
Index reset_src_data_step_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; });
|
||||
|
||||
return reset_src_data_step_;
|
||||
}();
|
||||
|
||||
return reset_src_data_step;
|
||||
}
|
||||
|
||||
template <index_t dst_i>
|
||||
__device__ static constexpr auto GetDstCoordinateResetStep()
|
||||
{
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstsScalarPerVector::At(dst_i)>{},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
|
||||
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_dst_access_lengths =
|
||||
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
|
||||
|
||||
// judge move forward or move backward during the last iteration
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_dst_access_lengths[I0] - 1;
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1;
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate dst data index after last iteration in RunWrite(), if it has not being reset by
|
||||
// RunWrite()
|
||||
constexpr auto dst_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0;
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
|
||||
dst_scalar_per_access.At(dst_i);
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr auto reset_dst_data_step = [&]() {
|
||||
Index reset_dst_data_step_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; });
|
||||
|
||||
return reset_dst_data_step_;
|
||||
}();
|
||||
|
||||
return reset_dst_data_step;
|
||||
}
|
||||
|
||||
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
__device__ void MoveSrcSliceWindow(const SrcDescs& src_descs,
|
||||
const Index& src_slice_origin_step_idx)
|
||||
{
|
||||
static_for<0, nSrc, 1>{}([&](auto src_i) {
|
||||
// if src coord was not reset by RunRead(), then need to adjust the step here
|
||||
const auto adjusted_step_idx =
|
||||
SrcsResetCoordinateAfterRun::At(src_i)
|
||||
? src_slice_origin_step_idx
|
||||
: src_slice_origin_step_idx + GetSrcCoordinateResetStep<src_i>();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step =
|
||||
make_tensor_coordinate_step(src_descs.At(src_i), adjusted_step_idx);
|
||||
|
||||
move_tensor_coordinate(src_descs.At(src_i), src_coords_.At(src_i), adjusted_step);
|
||||
});
|
||||
}
|
||||
|
||||
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
__device__ void MoveDstSliceWindow(const DstDescs& dst_descs,
|
||||
const Index& dst_slice_origin_step_idx)
|
||||
{
|
||||
static_for<0, nDst, 1>{}([&](auto dst_i) {
|
||||
// if dst coord was not reset by RunWrite(), then need to adjust the step here
|
||||
const auto adjusted_step_idx =
|
||||
DstsResetCoordinateAfterRun::At(dst_i)
|
||||
? dst_slice_origin_step_idx
|
||||
: dst_slice_origin_step_idx + GetDstCoordinateResetStep<dst_i>();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step =
|
||||
make_tensor_coordinate_step(dst_descs.At(dst_i), adjusted_step_idx);
|
||||
|
||||
move_tensor_coordinate(dst_descs.At(dst_i), dst_coords_.At(dst_i), adjusted_step);
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t src_i>
|
||||
__device__ static constexpr auto GetSrcThreadScratchDescriptor()
|
||||
{
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcsScalarPerVector::At(src_i)>{},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
constexpr auto src_access_lengths_and_vector_length =
|
||||
container_push_back(sequence_to_tuple_of_number(src_access_lengths),
|
||||
Number<SrcsScalarPerVector::At(src_i)>{});
|
||||
|
||||
// 1st stage of transforms
|
||||
constexpr auto desc0 =
|
||||
make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length);
|
||||
|
||||
// 2nd stage of transforms
|
||||
constexpr auto transforms = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == SrcVectorDim)
|
||||
{
|
||||
return make_merge_transform_v3_division_mod(
|
||||
make_tuple(src_access_lengths_and_vector_length[i],
|
||||
src_access_lengths_and_vector_length[Number<nDim>{}]));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_pass_through_transform(src_access_lengths_and_vector_length[i]);
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto low_dim_idss = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == SrcVectorDim)
|
||||
{
|
||||
return Sequence<i.value, nDim>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Sequence<i.value>{};
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto up_dim_idss =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
|
||||
|
||||
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
|
||||
}
|
||||
|
||||
template <index_t dst_i>
|
||||
__device__ static constexpr auto GetDstThreadScratchDescriptor()
|
||||
{
|
||||
// 1st stage of transforms
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstsScalarPerVector::At(dst_i)>{},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
|
||||
constexpr auto dst_access_lengths_and_vector_length =
|
||||
container_push_back(sequence_to_tuple_of_number(dst_access_lengths),
|
||||
Number<DstsScalarPerVector::At(dst_i)>{});
|
||||
|
||||
constexpr auto desc0 =
|
||||
make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length);
|
||||
|
||||
// 2nd stage of transforms
|
||||
constexpr auto transforms = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == DstVectorDim)
|
||||
{
|
||||
return make_merge_transform_v3_division_mod(
|
||||
make_tuple(dst_access_lengths_and_vector_length[i],
|
||||
dst_access_lengths_and_vector_length[Number<nDim>{}]));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_pass_through_transform(dst_access_lengths_and_vector_length[i]);
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto low_dim_idss = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == DstVectorDim)
|
||||
{
|
||||
return Sequence<i.value, nDim>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Sequence<i.value>{};
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto up_dim_idss =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
|
||||
|
||||
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto MakeSrcThreadScratchTuple()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto src_i) {
|
||||
constexpr auto src_thread_scratch_desc =
|
||||
decltype(GetSrcThreadScratchDescriptor<src_i>()){};
|
||||
using SrcThreadScratch =
|
||||
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
tuple_element_t<src_i, SrcDatas>,
|
||||
SrcsScalarPerVector::At(src_i),
|
||||
decltype(src_thread_scratch_desc),
|
||||
true>;
|
||||
return SrcThreadScratch{};
|
||||
},
|
||||
Number<nSrc>{});
|
||||
}
|
||||
|
||||
__device__ static constexpr auto MakeDstThreadScratchTuple()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto dst_i) {
|
||||
constexpr auto dst_thread_scratch_desc =
|
||||
decltype(GetDstThreadScratchDescriptor<dst_i>()){};
|
||||
using DstThreadScratch =
|
||||
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
tuple_element_t<dst_i, DstDatas>,
|
||||
DstsScalarPerVector::At(dst_i),
|
||||
decltype(dst_thread_scratch_desc),
|
||||
true>;
|
||||
return DstThreadScratch{};
|
||||
},
|
||||
Number<nDst>{});
|
||||
}
|
||||
|
||||
private:
|
||||
using SrcThreadScratchTuple = decltype(MakeSrcThreadScratchTuple());
|
||||
using DstThreadScratchTuple = decltype(MakeDstThreadScratchTuple());
|
||||
|
||||
StaticallyIndexedArray<SrcThreadScratchTuple, NumThreadScratch> src_thread_scratch_tuple_;
|
||||
|
||||
DstThreadScratchTuple dst_thread_scratch_tuple_;
|
||||
|
||||
SrcCoords src_coords_;
|
||||
DstCoords dst_coords_;
|
||||
const ElementwiseOperation element_op_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -89,6 +89,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
|
||||
static constexpr index_t src_a_data_size = 2;
|
||||
static constexpr index_t src_b_data_size = 2;
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
static constexpr index_t acc_pack_number = 1;
|
||||
// * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
|
||||
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
||||
|
||||
@@ -100,7 +101,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
|
||||
// * num_acc_vgprs_per_wave alone M direction
|
||||
// * num_subgroups alone M direction
|
||||
static constexpr index_t num_acc_vgprs_per_wave =
|
||||
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
|
||||
m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
|
||||
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
|
||||
|
||||
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
||||
@@ -129,6 +130,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
|
||||
static constexpr index_t src_a_data_size = 2;
|
||||
static constexpr index_t src_b_data_size = 2;
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
static constexpr index_t acc_pack_number = 1;
|
||||
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
||||
|
||||
// Wave mode dependent propety
|
||||
@@ -136,7 +138,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
|
||||
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
|
||||
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
|
||||
static constexpr index_t num_acc_vgprs_per_wave =
|
||||
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
|
||||
m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
|
||||
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
|
||||
|
||||
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
||||
@@ -153,7 +155,6 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
|
||||
}
|
||||
};
|
||||
|
||||
#ifdef CK_UNPACKED_ACC_DESC_LOGIC
|
||||
template <index_t WaveSize>
|
||||
struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
|
||||
WaveSize,
|
||||
@@ -166,6 +167,7 @@ struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
|
||||
static constexpr index_t src_a_data_size = 2;
|
||||
static constexpr index_t src_b_data_size = 2;
|
||||
static constexpr index_t acc_data_size = 2;
|
||||
static constexpr index_t acc_pack_number = 2;
|
||||
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
||||
|
||||
// Wave mode dependent propety
|
||||
@@ -173,28 +175,22 @@ struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
|
||||
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
|
||||
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
|
||||
static constexpr index_t num_acc_vgprs_per_wave =
|
||||
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
|
||||
m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
|
||||
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
|
||||
|
||||
template <index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
index_t Opsel,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
if constexpr(wave_size == 32)
|
||||
{
|
||||
intrin_wmma_f16_16x16x16_f16_w32<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
|
||||
intrin_wmma_f16_16x16x16_f16_w32<MPerWmma, NPerWmma, false>::Run(a, b, reg_c);
|
||||
}
|
||||
else if constexpr(wave_size == 64)
|
||||
{
|
||||
intrin_wmma_f16_16x16x16_f16_w64<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
|
||||
intrin_wmma_f16_16x16x16_f16_w64<MPerWmma, NPerWmma, false>::Run(a, b, reg_c);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t WaveSize>
|
||||
struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
|
||||
WaveSize,
|
||||
@@ -207,6 +203,7 @@ struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
|
||||
static constexpr index_t src_a_data_size = 2;
|
||||
static constexpr index_t src_b_data_size = 2;
|
||||
static constexpr index_t acc_data_size = 2;
|
||||
static constexpr index_t acc_pack_number = 2;
|
||||
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
||||
|
||||
// Wave mode dependent propety
|
||||
@@ -214,7 +211,7 @@ struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
|
||||
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
|
||||
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
|
||||
static constexpr index_t num_acc_vgprs_per_wave =
|
||||
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
|
||||
m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
|
||||
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
|
||||
|
||||
template <index_t MPerWmma,
|
||||
@@ -227,17 +224,15 @@ struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
|
||||
{
|
||||
if constexpr(wave_size == 32)
|
||||
{
|
||||
intrin_wmma_bf16_16x16x16_bf16_w32<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
|
||||
intrin_wmma_bf16_16x16x16_bf16_w32<MPerWmma, NPerWmma, false>::Run(a, b, reg_c);
|
||||
}
|
||||
else if constexpr(wave_size == 64)
|
||||
{
|
||||
intrin_wmma_bf16_16x16x16_bf16_w64<MPerWmma, NPerWmma, Opsel>::Run(a, b, reg_c);
|
||||
intrin_wmma_bf16_16x16x16_bf16_w64<MPerWmma, NPerWmma, false>::Run(a, b, reg_c);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
template <index_t WaveSize>
|
||||
struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
|
||||
WaveSize,
|
||||
@@ -250,6 +245,7 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
|
||||
static constexpr index_t src_a_data_size = 2;
|
||||
static constexpr index_t src_b_data_size = 2;
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
static constexpr index_t acc_pack_number = 1;
|
||||
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
||||
|
||||
// Wave mode dependent propety
|
||||
@@ -257,7 +253,7 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
|
||||
static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
|
||||
static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
|
||||
static constexpr index_t num_acc_vgprs_per_wave =
|
||||
m_per_wmma * n_per_wmma * acc_data_size / wave_size / 4;
|
||||
m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
|
||||
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
|
||||
|
||||
template <index_t MPerWmma,
|
||||
@@ -346,7 +342,7 @@ struct WmmaSelector
|
||||
static_assert(selected_wmma.k_per_wmma == 16, "WRONG! WMMA_M must equal to 16");
|
||||
|
||||
static_assert(selected_wmma.wave_size * selected_wmma.num_acc_vgprs_per_wave *
|
||||
selected_wmma.acc_data_size ==
|
||||
selected_wmma.acc_data_size * selected_wmma.acc_pack_number ==
|
||||
selected_wmma.m_per_wmma * selected_wmma.n_per_wmma * 4,
|
||||
"WRONG! Invalid Number of Accumulator Register");
|
||||
}
|
||||
@@ -358,7 +354,8 @@ template <typename src_type_a,
|
||||
index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
index_t KPack,
|
||||
bool TransposeC = false>
|
||||
bool TransposeC = false,
|
||||
bool AssemblyBackend = false>
|
||||
struct WmmaGemm
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -369,14 +366,14 @@ struct WmmaGemm
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
|
||||
using CIndex = MultiIndex<2>;
|
||||
using CIndex4D = MultiIndex<4>;
|
||||
using CIndex3D = MultiIndex<3>;
|
||||
|
||||
__host__ __device__ constexpr WmmaGemm()
|
||||
{
|
||||
static_assert(NPerWmma == 16 && MPerWmma == 16,
|
||||
"Only support GemmNPerWmma == 16 and GemmMPerWmma == 16 for wmma");
|
||||
|
||||
static_assert(KPack == wmma_instr.k_per_wmma, "KPack should be k_per_wmma");
|
||||
static_assert(KPack % wmma_instr.k_per_wmma == 0, "KPack should be multiple of k_per_wmma");
|
||||
}
|
||||
|
||||
// WMMA output supporting C = A * B
|
||||
@@ -421,9 +418,49 @@ struct WmmaGemm
|
||||
Sequence<5>{}));
|
||||
}
|
||||
|
||||
// Transposed WMMA Output C' = B' * A'
|
||||
template <typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs(
|
||||
const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA&
|
||||
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
|
||||
{
|
||||
const auto MBlockxRepeat =
|
||||
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0);
|
||||
const auto NBlockxRepeat =
|
||||
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3);
|
||||
const auto MWave =
|
||||
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1);
|
||||
const auto NWave =
|
||||
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4);
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma,
|
||||
make_tuple(
|
||||
make_pass_through_transform(MBlockxRepeat),
|
||||
make_pass_through_transform(MWave),
|
||||
make_pass_through_transform(Number<wmma_instr.num_thread_per_subgroups>{}),
|
||||
make_pass_through_transform(NBlockxRepeat),
|
||||
make_pass_through_transform(NWave),
|
||||
make_unmerge_transform(make_tuple(Number<wmma_instr.num_subgroups>{},
|
||||
Number<wmma_instr.num_acc_vgprs_per_wave>{}))),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5, 6>{}));
|
||||
}
|
||||
|
||||
__device__ static constexpr index_t GetRegSizePerWmma()
|
||||
{
|
||||
return wmma_instr.num_acc_vgprs_per_wave;
|
||||
return wmma_instr.num_acc_vgprs_per_wave * wmma_instr.acc_pack_number;
|
||||
}
|
||||
|
||||
__device__ static constexpr index_t GetWaveSize() { return wmma_instr.wave_size; }
|
||||
@@ -449,14 +486,16 @@ struct WmmaGemm
|
||||
,
|
||||
"base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), "
|
||||
"(int8, int32) or (int4, int32)!");
|
||||
if constexpr(!TransposeC)
|
||||
{
|
||||
wmma_instr.template run<MPerWmma, NPerWmma>(p_a_wave, p_b_wave, p_c_thread);
|
||||
}
|
||||
else
|
||||
{
|
||||
wmma_instr.template run<MPerWmma, NPerWmma>(p_b_wave, p_a_wave, p_c_thread);
|
||||
}
|
||||
static_for<0, KPack / wmma_instr.k_per_wmma, 1>{}([&](auto k) {
|
||||
if constexpr(!TransposeC)
|
||||
{
|
||||
wmma_instr.template run<MPerWmma, NPerWmma>(p_a_wave[k], p_b_wave[k], p_c_thread);
|
||||
}
|
||||
else
|
||||
{
|
||||
wmma_instr.template run<MPerWmma, NPerWmma>(p_b_wave[k], p_a_wave[k], p_c_thread);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
__device__ static auto GetLaneId() { return get_thread_local_1d_id() % wmma_instr.wave_size; }
|
||||
@@ -477,12 +516,12 @@ struct WmmaGemm
|
||||
|
||||
__host__ __device__ static auto CalculateAThreadOriginDataIndex()
|
||||
{
|
||||
return GetSwizzledLaneIdLow();
|
||||
return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow();
|
||||
}
|
||||
|
||||
__host__ __device__ static auto CalculateBThreadOriginDataIndex()
|
||||
{
|
||||
return GetLaneIdUnderSubGroup();
|
||||
return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup();
|
||||
}
|
||||
|
||||
__device__ static CIndex GetBeginOfThreadBlk()
|
||||
@@ -493,6 +532,14 @@ struct WmmaGemm
|
||||
return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset};
|
||||
}
|
||||
|
||||
__device__ static CIndex3D GetBeginOfThreadBlk3D()
|
||||
{
|
||||
index_t n_offset = GetLaneIdUnderSubGroup();
|
||||
index_t m_offset = GetSubGroupId();
|
||||
|
||||
return TransposeC ? CIndex3D{n_offset, m_offset, I0} : CIndex3D{m_offset, n_offset, I0};
|
||||
}
|
||||
|
||||
static constexpr auto wmma =
|
||||
WmmaSelector<src_type_a, src_type_b, dst_type, MPerWmma, NPerWmma>{};
|
||||
static constexpr auto wmma_instr = wmma.selected_wmma;
|
||||
@@ -500,7 +547,10 @@ struct WmmaGemm
|
||||
__host__ __device__ static constexpr auto
|
||||
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths()
|
||||
{
|
||||
return make_tuple(I1, I1, Number<wmma_instr.num_acc_vgprs_per_wave>{});
|
||||
return make_tuple(I1,
|
||||
I1,
|
||||
Number<wmma_instr.num_acc_vgprs_per_wave>{},
|
||||
Number<wmma_instr.acc_pack_number>{});
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -0,0 +1,391 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
|
||||
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
|
||||
template <index_t NumDimG,
|
||||
index_t NumDimM,
|
||||
index_t NumDimN,
|
||||
device::TensorSpecialization TensorSpec>
|
||||
__host__ __device__ static auto
|
||||
MakeGridDescriptorPair(const std::array<index_t, NumDimG + NumDimM + NumDimN>& gs_ms_ns_lengths_vec,
|
||||
const std::array<index_t, NumDimG + NumDimM + NumDimN>& gs_ms_ns_strides_vec)
|
||||
{
|
||||
// if(!(gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
|
||||
// gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN))
|
||||
// {
|
||||
// throw std::runtime_error("wrong! dimension must match input lengths");
|
||||
// }
|
||||
|
||||
const auto to_tuple = [&](auto& vec, auto start, auto end) {
|
||||
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
|
||||
};
|
||||
|
||||
const auto gs_ms_ns_lengths =
|
||||
to_tuple(gs_ms_ns_lengths_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
|
||||
const auto gs_ms_ns_strides =
|
||||
to_tuple(gs_ms_ns_strides_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
|
||||
|
||||
// dimension Ids for G0, G1, ...
|
||||
constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{};
|
||||
|
||||
// dimension Ids for M0, M1, ...
|
||||
constexpr auto mDimIds =
|
||||
typename arithmetic_sequence_gen<NumDimG, NumDimG + NumDimM, 1>::type{};
|
||||
|
||||
// dimension Ids for N0, N1, ...
|
||||
constexpr auto nDimIds =
|
||||
typename arithmetic_sequence_gen<NumDimG + NumDimM, NumDimG + NumDimM + NumDimN, 1>::type{};
|
||||
|
||||
// lengths for G0, G1, ...
|
||||
const auto gLengths = get_container_subset(gs_ms_ns_lengths, gDimIds);
|
||||
|
||||
// lengths for M0, M1, ...
|
||||
const auto mLengths = get_container_subset(gs_ms_ns_lengths, mDimIds);
|
||||
|
||||
// lengths for N0, N1, ...
|
||||
const auto nLengths = get_container_subset(gs_ms_ns_lengths, nDimIds);
|
||||
|
||||
if constexpr(TensorSpec == device::TensorSpecialization::Packed)
|
||||
{
|
||||
auto G = container_reduce(gLengths, math::multiplies{}, Number<1>{});
|
||||
auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{});
|
||||
auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{});
|
||||
const auto grid_desc_g_mraw_nraw = make_naive_tensor_descriptor(
|
||||
make_tuple(G, M, N),
|
||||
make_tuple(gs_ms_ns_strides[Number<NumDimG - 1>{}],
|
||||
gs_ms_ns_strides[Number<NumDimG + NumDimM - 1>{}],
|
||||
gs_ms_ns_strides[Number<NumDimG + NumDimM + NumDimN - 1>{}]));
|
||||
|
||||
const auto grid_desc_mraw_nraw = make_naive_tensor_descriptor(
|
||||
make_tuple(M, N),
|
||||
make_tuple(gs_ms_ns_strides[Number<NumDimG + NumDimM - 1>{}],
|
||||
gs_ms_ns_strides[Number<NumDimG + NumDimM + NumDimN - 1>{}]));
|
||||
|
||||
return std::make_pair(grid_desc_g_mraw_nraw, grid_desc_mraw_nraw);
|
||||
}
|
||||
else
|
||||
{
|
||||
// naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
|
||||
const auto grid_desc_gs_ms_ns =
|
||||
make_naive_tensor_descriptor(gs_ms_ns_lengths, gs_ms_ns_strides);
|
||||
|
||||
// transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
|
||||
// N2 * ...]
|
||||
// Note: This does not require padding as it only provides G offset calculation. Technically
|
||||
// descriptor for only G is needed. Here we opt for backward compatibility purpose to return
|
||||
// G_M_N
|
||||
const auto grid_desc_g_mraw_nraw =
|
||||
transform_tensor_descriptor(grid_desc_gs_ms_ns,
|
||||
make_tuple(make_merge_transform(gLengths),
|
||||
make_merge_transform(mLengths),
|
||||
make_merge_transform(nLengths)),
|
||||
make_tuple(gDimIds, mDimIds, nDimIds),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto c_ms_ns_lengths = to_tuple(
|
||||
gs_ms_ns_lengths_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
|
||||
const auto c_ms_ns_strides = to_tuple(
|
||||
gs_ms_ns_strides_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
|
||||
|
||||
// transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
|
||||
// N2 * ...]
|
||||
const auto grid_desc_ms_ns = make_naive_tensor_descriptor(c_ms_ns_lengths, c_ms_ns_strides);
|
||||
|
||||
const auto grid_desc_mraw_nraw = transform_tensor_descriptor(
|
||||
grid_desc_ms_ns,
|
||||
make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)),
|
||||
make_tuple(mDimIds - Number<NumDimG>{}, nDimIds - Number<NumDimG>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return std::make_pair(grid_desc_g_mraw_nraw, grid_desc_mraw_nraw);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename NumDims_G_M_N_K_O, // Sequence<>
|
||||
typename PerBlock_M_N_K_O, // Sequence<>
|
||||
device::GemmSpecialization GemmSpec,
|
||||
device::TensorSpecialization ASpec,
|
||||
device::TensorSpecialization B0Spec,
|
||||
device::TensorSpecialization B1Spec,
|
||||
device::TensorSpecialization CSpec>
|
||||
struct TransformBatchedContractionContractionToBatchedGemmGemm_Wmma
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
|
||||
static constexpr index_t NumDimG = NumDims_G_M_N_K_O::At(I0);
|
||||
static constexpr index_t NumDimM = NumDims_G_M_N_K_O::At(I1);
|
||||
static constexpr index_t NumDimN = NumDims_G_M_N_K_O::At(I2);
|
||||
static constexpr index_t NumDimK = NumDims_G_M_N_K_O::At(I3);
|
||||
static constexpr index_t NumDimO = NumDims_G_M_N_K_O::At(I4);
|
||||
|
||||
static constexpr index_t MPerBlock = PerBlock_M_N_K_O::At(I0);
|
||||
static constexpr index_t NPerBlock = PerBlock_M_N_K_O::At(I1);
|
||||
static constexpr index_t KPerBlock = PerBlock_M_N_K_O::At(I2);
|
||||
static constexpr index_t OPerBlock = PerBlock_M_N_K_O::At(I3);
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
device::GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{
|
||||
MPerBlock, NPerBlock, KPerBlock, OPerBlock};
|
||||
|
||||
//
|
||||
// A
|
||||
//
|
||||
__host__ __device__ static auto MakeAGridDescriptorPair(
|
||||
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths_vec,
|
||||
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides_vec)
|
||||
{
|
||||
return MakeGridDescriptorPair<NumDimG, NumDimM, NumDimK, ASpec>(a_gs_ms_ks_lengths_vec,
|
||||
a_gs_ms_ks_strides_vec);
|
||||
}
|
||||
|
||||
// TODO: rename to G_MRaw_KRaw
|
||||
__host__ __device__ static auto MakeAGridDescriptor_G_M_K(
|
||||
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths_vec,
|
||||
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides_vec)
|
||||
{
|
||||
return MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).first;
|
||||
}
|
||||
__host__ __device__ static auto MakeAGridDescriptor_M_K(
|
||||
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths_vec,
|
||||
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides_vec)
|
||||
{
|
||||
return matrix_padder.PadADescriptor_M_K(
|
||||
MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).second);
|
||||
}
|
||||
|
||||
template <typename AGridDesc_M_K, typename Number>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k, const Number& AK1)
|
||||
{
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
|
||||
const auto AK0 = K / AK1;
|
||||
|
||||
return transform_tensor_descriptor(a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
template <typename AGridDesc_M_K,
|
||||
typename WmmaK,
|
||||
typename MRepeat,
|
||||
typename MWaves,
|
||||
typename MPerWmma,
|
||||
typename AK1>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AK0PerWmma_AKRow_MPerWmma_AK1(
|
||||
const AGridDesc_M_K& a_grid_desc_m_k,
|
||||
const WmmaK&,
|
||||
const MRepeat&,
|
||||
const MWaves&,
|
||||
const MPerWmma&,
|
||||
const AK1&)
|
||||
{
|
||||
const auto M0 = a_grid_desc_m_k.GetLength(I0) / MPerBlock;
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
const auto AKWmma = K / WmmaK{};
|
||||
constexpr auto AKRow = 2;
|
||||
constexpr auto AK0PerWmma = WmmaK{} / AKRow / AK1{};
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(
|
||||
make_tuple(AKWmma, Number<AK0PerWmma>{}, Number<AKRow>{}, AK1{})),
|
||||
make_unmerge_transform(make_tuple(M0 * MRepeat{}, MWaves{}, MPerWmma{}))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
|
||||
}
|
||||
|
||||
//
|
||||
// B (alias of B0)
|
||||
//
|
||||
__host__ __device__ static auto MakeB0GridDescriptorPair(
|
||||
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_lengths_vec,
|
||||
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_strides_vec)
|
||||
{
|
||||
return MakeGridDescriptorPair<NumDimG, NumDimN, NumDimK, B0Spec>(b0_gs_ns_ks_lengths_vec,
|
||||
b0_gs_ns_ks_strides_vec);
|
||||
}
|
||||
|
||||
// TODO: rename to G_MRaw_NRaw
|
||||
__host__ __device__ static auto MakeB0GridDescriptor_G_N_K(
|
||||
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_lengths_vec,
|
||||
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_strides_vec)
|
||||
{
|
||||
return MakeB0GridDescriptorPair(b0_gs_ns_ks_lengths_vec, b0_gs_ns_ks_strides_vec).first;
|
||||
}
|
||||
__host__ __device__ static auto MakeB0GridDescriptor_N_K(
|
||||
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_lengths_vec,
|
||||
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ns_ks_strides_vec)
|
||||
{
|
||||
// alias of matrix_padder.PadB0Descriptor_N_K
|
||||
return matrix_padder.PadBDescriptor_N_K(
|
||||
MakeB0GridDescriptorPair(b0_gs_ns_ks_lengths_vec, b0_gs_ns_ks_strides_vec).second);
|
||||
}
|
||||
|
||||
template <typename BGridDesc_N_K, typename Number>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeB0GridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k, const Number& BK1)
|
||||
{
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
|
||||
const auto BK0 = K / BK1;
|
||||
|
||||
return transform_tensor_descriptor(b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
template <typename BGridDesc_L_K,
|
||||
typename WmmaK,
|
||||
typename LRepeat,
|
||||
typename LWaves,
|
||||
typename LPerWmma,
|
||||
typename BK1>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BK0PerWmma_BKRow_LPerWmma_BK1(
|
||||
const BGridDesc_L_K& b_grid_desc_l_k,
|
||||
const WmmaK&,
|
||||
const LRepeat&,
|
||||
const LWaves&,
|
||||
const LPerWmma&,
|
||||
const BK1&)
|
||||
{
|
||||
const auto L0 = b_grid_desc_l_k.GetLength(I0) / NPerBlock;
|
||||
const auto K = b_grid_desc_l_k.GetLength(I1);
|
||||
const auto BKWmma = K / WmmaK{};
|
||||
constexpr auto BKRow = 2;
|
||||
constexpr auto BK0PerWmma = WmmaK{} / BKRow / BK1{};
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_l_k,
|
||||
make_tuple(make_unmerge_transform(
|
||||
make_tuple(BKWmma, Number<BK0PerWmma>{}, Number<BKRow>{}, BK1{})),
|
||||
make_unmerge_transform(make_tuple(L0 * LRepeat{}, LWaves{}, LPerWmma{}))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
|
||||
}
|
||||
|
||||
//
|
||||
// B1
|
||||
//
|
||||
__host__ __device__ static auto MakeB1GridDescriptorPair(
|
||||
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_lengths_vec,
|
||||
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_strides_vec)
|
||||
{
|
||||
return MakeGridDescriptorPair<NumDimG, NumDimO, NumDimN, B1Spec>(b1_gs_os_ns_lengths_vec,
|
||||
b1_gs_os_ns_strides_vec);
|
||||
}
|
||||
|
||||
// TODO: rename to G_NRaw_KRaw
|
||||
__host__ __device__ static auto MakeB1GridDescriptor_G_N_K(
|
||||
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_lengths_vec,
|
||||
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_strides_vec)
|
||||
{
|
||||
return MakeB1GridDescriptorPair(b1_gs_os_ns_lengths_vec, b1_gs_os_ns_strides_vec).first;
|
||||
}
|
||||
__host__ __device__ static auto MakeB1GridDescriptor_N_K(
|
||||
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_lengths_vec,
|
||||
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_os_ns_strides_vec)
|
||||
{
|
||||
// alias of matrix_padder.PadB1Descriptor_O_N
|
||||
return matrix_padder.PadB1Descriptor_N_K(
|
||||
MakeB1GridDescriptorPair(b1_gs_os_ns_lengths_vec, b1_gs_os_ns_strides_vec).second);
|
||||
}
|
||||
|
||||
template <typename B1GridDesc_N_K, typename Number>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeB1GridDescriptor_BK0_N_BK1(const B1GridDesc_N_K& b1_grid_desc_n_k, const Number& B1K1)
|
||||
{
|
||||
const auto N = b1_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b1_grid_desc_n_k.GetLength(I1);
|
||||
|
||||
const auto B1K0 = K / B1K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b1_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
template <typename BGridDesc_N_L,
|
||||
typename WmmaL,
|
||||
typename NRepeat,
|
||||
typename NWaves,
|
||||
typename NPerWmma,
|
||||
typename BL1>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves__BL0PerWmma_BLRow_NPerWmma_BL1(
|
||||
const BGridDesc_N_L& b_grid_desc_n_l,
|
||||
const WmmaL&,
|
||||
const NRepeat&,
|
||||
const NWaves&,
|
||||
const NPerWmma&,
|
||||
const BL1&)
|
||||
{
|
||||
const auto N0 = b_grid_desc_n_l.GetLength(I0) / OPerBlock;
|
||||
const auto L = b_grid_desc_n_l.GetLength(I1);
|
||||
const auto BLWmma = L / WmmaL{};
|
||||
constexpr auto BLRow = 2;
|
||||
constexpr auto BL0PerWmma = WmmaL{} / BLRow / BL1{};
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_l,
|
||||
make_tuple(make_unmerge_transform(
|
||||
make_tuple(BLWmma, Number<BL0PerWmma>{}, Number<BLRow>{}, BL1{})),
|
||||
make_unmerge_transform(make_tuple(N0 * NRepeat{}, NWaves{}, NPerWmma{}))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
|
||||
}
|
||||
|
||||
//
|
||||
// C
|
||||
//
|
||||
__host__ __device__ static auto MakeCGridDescriptorPair(
|
||||
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_lengths_vec,
|
||||
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_strides_vec)
|
||||
{
|
||||
return MakeGridDescriptorPair<NumDimG, NumDimM, NumDimO, CSpec>(c_gs_ms_os_lengths_vec,
|
||||
c_gs_ms_os_strides_vec);
|
||||
}
|
||||
|
||||
// TODO: rename to G_MRaw_NRaw
|
||||
__host__ __device__ static auto MakeCGridDescriptor_G_M_N(
|
||||
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_lengths_vec,
|
||||
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_strides_vec)
|
||||
{
|
||||
return MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).first;
|
||||
}
|
||||
__host__ __device__ static auto MakeCGridDescriptor_M_N(
|
||||
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_lengths_vec,
|
||||
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_os_strides_vec)
|
||||
{
|
||||
return matrix_padder.PadCDescriptor_M_N(
|
||||
MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).second);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -417,7 +417,8 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
|
||||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, f8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
|
||||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, uint8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
|
||||
"wrong! not implemented");
|
||||
|
||||
using r_t = typename vector_type<T, N>::type;
|
||||
|
||||
@@ -220,8 +220,8 @@ amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0
|
||||
"0"(c0),
|
||||
"1"(c1));
|
||||
#else
|
||||
c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
|
||||
c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
|
||||
c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
|
||||
c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -257,10 +257,10 @@ __device__ void amd_assembly_outer_product_1x4(int8x4_t a,
|
||||
"2"(c2),
|
||||
"3"(c3));
|
||||
#else
|
||||
c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
|
||||
c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
|
||||
c2 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b2), c2, false);
|
||||
c3 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b3), c3, false);
|
||||
c0 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b0), c0, false);
|
||||
c1 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b1), c1, false);
|
||||
c2 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b2), c2, false);
|
||||
c3 = __builtin_amdgcn_sdot4(bit_cast<int32_t>(a), bit_cast<int32_t>(b3), c3, false);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -355,17 +355,5 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a,
|
||||
c3);
|
||||
}
|
||||
|
||||
// Ranged input operand
|
||||
__device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a, half16_t b, float8_t& c)
|
||||
{
|
||||
#if defined(__gfx11__)
|
||||
asm volatile("v_wmma_f32_16x16x16_f16 %0, %1, %2, %0" : "=v"(c) : "v"(a), "v"(b), "0"(c));
|
||||
#else
|
||||
ignore = a;
|
||||
ignore = b;
|
||||
ignore = c;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
namespace ck {
|
||||
// Define the common macro for MI300 models
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__)
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#define __gfx94__
|
||||
#endif
|
||||
|
||||
|
||||
@@ -133,6 +133,13 @@ struct scalar_type<int8_t>
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<uint8_t>
|
||||
{
|
||||
using type = uint8_t;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
template <>
|
||||
struct scalar_type<int4_t>
|
||||
@@ -189,7 +196,7 @@ struct vector_type<T, 1>
|
||||
}
|
||||
};
|
||||
|
||||
__device__ int static err = 0;
|
||||
int static err = 0;
|
||||
template <typename T>
|
||||
struct vector_type<T, 2>
|
||||
{
|
||||
@@ -1037,6 +1044,14 @@ using bf8x8_t = typename vector_type<bf8_t, 8>::type;
|
||||
using bf8x16_t = typename vector_type<bf8_t, 16>::type;
|
||||
using bf8x32_t = typename vector_type<bf8_t, 32>::type;
|
||||
using bf8x64_t = typename vector_type<bf8_t, 64>::type;
|
||||
// u8
|
||||
// i8
|
||||
using uint8x2_t = typename vector_type<uint8_t, 2>::type;
|
||||
using uint8x4_t = typename vector_type<uint8_t, 4>::type;
|
||||
using uint8x8_t = typename vector_type<uint8_t, 8>::type;
|
||||
using uint8x16_t = typename vector_type<uint8_t, 16>::type;
|
||||
using uint8x32_t = typename vector_type<uint8_t, 32>::type;
|
||||
using uint8x64_t = typename vector_type<uint8_t, 64>::type;
|
||||
|
||||
template <typename T>
|
||||
struct NumericLimits
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
|
||||
namespace ck {
|
||||
// Define the common macro for MI300 models
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__)
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#define __gfx94__
|
||||
#endif
|
||||
|
||||
@@ -99,6 +99,63 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
|
||||
return type_convert<bhalf_t>(x_fp32);
|
||||
}
|
||||
|
||||
// Convert X to Y
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr Y type_convert_sp(X x)
|
||||
{
|
||||
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
|
||||
|
||||
return static_cast<Y>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ constexpr int type_convert_sp<int, float>(float x)
|
||||
{
|
||||
union
|
||||
{
|
||||
float fp32;
|
||||
int int32;
|
||||
} u = {x};
|
||||
|
||||
return u.int32;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ constexpr float type_convert_sp<float, int>(int x)
|
||||
{
|
||||
union
|
||||
{
|
||||
int int32;
|
||||
float fp32;
|
||||
} u = {x};
|
||||
|
||||
return u.fp32;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ constexpr int type_convert_sp<int, half_t>(half_t x)
|
||||
{
|
||||
union
|
||||
{
|
||||
half_t fp16;
|
||||
int int32;
|
||||
} u = {x};
|
||||
|
||||
return u.int32;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ constexpr half_t type_convert_sp<half_t, int>(int x)
|
||||
{
|
||||
union
|
||||
{
|
||||
int int32;
|
||||
half_t fp16;
|
||||
} u = {x};
|
||||
|
||||
return u.fp16;
|
||||
}
|
||||
|
||||
// Declare a template function for fp8 conversion using SR
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ constexpr Y f8_convert_sr(X x);
|
||||
@@ -107,21 +164,24 @@ __host__ __device__ constexpr Y f8_convert_sr(X x);
|
||||
template <>
|
||||
inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
|
||||
{
|
||||
constexpr int seed = 42;
|
||||
constexpr int seed = 1254739;
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#if defined(__gfx94__)
|
||||
float max_fp8 = 240.0f;
|
||||
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
|
||||
union
|
||||
{
|
||||
float fval;
|
||||
uint32_t i32val;
|
||||
uint8_t i8val[4]; // not endian independent
|
||||
} val;
|
||||
val.fval = x;
|
||||
uint32_t ival = 0;
|
||||
ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos
|
||||
val.i32val = ival;
|
||||
val.fval = x;
|
||||
uint32_t ival = 0;
|
||||
const float max_fp8 = 240.0f;
|
||||
// if x is not +/- infinity or nan
|
||||
if((val.i32val & NumericUtils<float>::nan_mask) != NumericUtils<float>::Inf)
|
||||
// clip float value
|
||||
val.fval = __builtin_amdgcn_fmed3f(val.fval, max_fp8, -max_fp8);
|
||||
ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos
|
||||
val.i32val = ival;
|
||||
return val.i8val[0]; // little endian
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
@@ -144,7 +204,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
|
||||
constexpr int seed = 42;
|
||||
constexpr int seed = 1254739;
|
||||
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
return utils::
|
||||
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
@@ -156,7 +216,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
|
||||
template <>
|
||||
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
|
||||
{
|
||||
constexpr int seed = 42;
|
||||
constexpr int seed = 1254739;
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#if defined(__gfx94__)
|
||||
union
|
||||
@@ -165,10 +225,15 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
|
||||
uint32_t i32val;
|
||||
uint8_t i8val[4]; // not endian independent
|
||||
} val;
|
||||
val.fval = x;
|
||||
uint32_t ival = 0;
|
||||
ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
|
||||
val.i32val = ival;
|
||||
val.fval = x;
|
||||
uint32_t ival = 0;
|
||||
const float max_bf8 = 57344.0f;
|
||||
// if x is not +/- infinity or nan
|
||||
if((val.i32val & NumericUtils<float>::nan_mask) != NumericUtils<float>::Inf)
|
||||
// clip float value
|
||||
val.fval = __builtin_amdgcn_fmed3f(val.fval, max_bf8, -max_bf8);
|
||||
ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
|
||||
val.i32val = ival;
|
||||
return val.i8val[0]; // little endian
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
@@ -191,7 +256,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
|
||||
constexpr bool negative_zero_nan = true;
|
||||
constexpr bool clip = true;
|
||||
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
|
||||
constexpr int seed = 42;
|
||||
constexpr int seed = 1254739;
|
||||
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
return utils::
|
||||
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
|
||||
@@ -208,16 +273,19 @@ template <>
|
||||
inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
float max_fp8 = 240.0f;
|
||||
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
|
||||
union
|
||||
{
|
||||
float fval;
|
||||
uint32_t i32val;
|
||||
uint8_t i8val[4]; // not endian independent
|
||||
} val;
|
||||
val.fval = x;
|
||||
uint32_t ival = 0;
|
||||
val.fval = x;
|
||||
uint32_t ival = 0;
|
||||
const float max_fp8 = 240.0f;
|
||||
// if x is not +/- infinity or nan
|
||||
if((val.i32val & NumericUtils<float>::nan_mask) != NumericUtils<float>::Inf)
|
||||
// clip float value
|
||||
val.fval = __builtin_amdgcn_fmed3f(val.fval, max_fp8, -max_fp8);
|
||||
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false -> WORD0
|
||||
val.i32val = ival;
|
||||
return val.i8val[0];
|
||||
@@ -261,8 +329,13 @@ inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x)
|
||||
uint32_t i32val;
|
||||
uint8_t i8val[4]; // not endian independent
|
||||
} val;
|
||||
val.fval = x;
|
||||
uint32_t ival = 0;
|
||||
val.fval = x;
|
||||
uint32_t ival = 0;
|
||||
const float max_bf8 = 57344.0f;
|
||||
// if x is not +/- infinity or nan
|
||||
if((val.i32val & NumericUtils<float>::nan_mask) != NumericUtils<float>::Inf)
|
||||
// clip float value
|
||||
val.fval = __builtin_amdgcn_fmed3f(val.fval, max_bf8, -max_bf8);
|
||||
ival = __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0
|
||||
val.i32val = ival;
|
||||
return val.i8val[0];
|
||||
|
||||
@@ -5,8 +5,11 @@
|
||||
|
||||
#include "ck/wrapper/utils/layout_utils.hpp"
|
||||
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond INTERNAL
|
||||
namespace ck {
|
||||
namespace wrapper {
|
||||
/// @endcond
|
||||
|
||||
/**
|
||||
* \brief Layout wrapper that performs the tensor descriptor logic.
|
||||
@@ -19,6 +22,8 @@ namespace wrapper {
|
||||
template <typename Shape, typename UnrolledDescriptorType>
|
||||
struct Layout
|
||||
{
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond INTERNAL
|
||||
private:
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -246,6 +251,7 @@ struct Layout
|
||||
using Descriptor1dType =
|
||||
remove_cvref_t<decltype(MakeMerge1d(Shape{}, UnrolledDescriptorType{}))>;
|
||||
using DefaultIdxsTupleType = remove_cvref_t<decltype(GenerateDefaultIdxsTuple(Shape{}))>;
|
||||
/// @endcond
|
||||
|
||||
public:
|
||||
using LayoutShape = Shape;
|
||||
@@ -457,6 +463,8 @@ struct Layout
|
||||
return unrolled_descriptor_;
|
||||
}
|
||||
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond INTERNAL
|
||||
private:
|
||||
// All dimensions are unrolled
|
||||
UnrolledDescriptorType unrolled_descriptor_;
|
||||
@@ -469,6 +477,7 @@ struct Layout
|
||||
// Descriptor1dType lengths: (8)
|
||||
// MergedNestsDescriptorType lengths: (4, 2)
|
||||
const Shape shape_;
|
||||
/// @endcond
|
||||
};
|
||||
|
||||
} // namespace wrapper
|
||||
|
||||
@@ -12,8 +12,11 @@
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
|
||||
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond INTERNAL
|
||||
namespace ck {
|
||||
namespace wrapper {
|
||||
/// @endcond
|
||||
|
||||
/**
|
||||
* \brief Perform optimized copy between two tensors partitions (threadwise copy).
|
||||
@@ -61,12 +64,12 @@ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor)
|
||||
decltype(dim_access_order),
|
||||
VectorDim,
|
||||
ScalarPerVector,
|
||||
Sequence<false>,
|
||||
Sequence<false>>{in_grid_desc,
|
||||
make_tuple(src_tensor.GetMultiIdxOffsets()),
|
||||
out_grid_desc,
|
||||
make_tuple(dst_tensor.GetMultiIdxOffsets()),
|
||||
tensor_operation::element_wise::PassThrough{}};
|
||||
Sequence<true>,
|
||||
Sequence<true>>{in_grid_desc,
|
||||
make_tuple(src_tensor.GetMultiIdxOffsets()),
|
||||
out_grid_desc,
|
||||
make_tuple(dst_tensor.GetMultiIdxOffsets()),
|
||||
tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
transfer.Run(tie(in_grid_desc),
|
||||
tie(src_tensor.GetBuffer()),
|
||||
@@ -104,37 +107,25 @@ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor)
|
||||
else if constexpr(SrcTensorType::IsDynamicBuffer && !DstTensorType::IsDynamicBuffer)
|
||||
{
|
||||
// Perform copy from DynamicBuffer to StaticBuffer
|
||||
const auto src_dst_slice_origin =
|
||||
const auto dst_slice_origin_idxs =
|
||||
generate_tuple([&](auto) { return I0; }, Number<num_dims>{});
|
||||
constexpr auto src_vector_tensor_lengths = generate_sequence_v2(
|
||||
[&](auto I) {
|
||||
if constexpr(I == VectorDim)
|
||||
{
|
||||
return Number<ScalarPerVector>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return I1;
|
||||
}
|
||||
},
|
||||
Number<num_dims>{});
|
||||
|
||||
auto transfer =
|
||||
ThreadwiseTensorSliceTransfer_v4r1<typename SrcTensorType::TensorElementType,
|
||||
typename DstTensorType::TensorElementType,
|
||||
remove_cvref_t<decltype(in_grid_desc)>,
|
||||
remove_cvref_t<decltype(out_grid_desc)>,
|
||||
decltype(thread_slice_lengths),
|
||||
decltype(dim_access_order),
|
||||
decltype(src_vector_tensor_lengths),
|
||||
decltype(dim_access_order)>{
|
||||
src_tensor.GetMultiIdxOffsets()};
|
||||
auto transfer = ThreadwiseTensorSliceTransfer_v2<
|
||||
std::remove_const_t<typename SrcTensorType::TensorElementType>,
|
||||
std::remove_const_t<typename DstTensorType::TensorElementType>,
|
||||
remove_cvref_t<decltype(in_grid_desc)>,
|
||||
remove_cvref_t<decltype(out_grid_desc)>,
|
||||
decltype(thread_slice_lengths),
|
||||
decltype(dim_access_order),
|
||||
VectorDim,
|
||||
ScalarPerVector,
|
||||
I1,
|
||||
false,
|
||||
false>{in_grid_desc, src_tensor.GetMultiIdxOffsets()};
|
||||
|
||||
transfer.Run(in_grid_desc,
|
||||
src_dst_slice_origin,
|
||||
src_tensor.GetBuffer(),
|
||||
out_grid_desc,
|
||||
src_dst_slice_origin,
|
||||
dst_slice_origin_idxs,
|
||||
dst_tensor.GetBuffer());
|
||||
}
|
||||
else
|
||||
@@ -183,10 +174,12 @@ template <typename DimAccessOrderTuple,
|
||||
index_t ScalarPerVector,
|
||||
typename SrcTensorType,
|
||||
typename DstTensorType,
|
||||
typename ThreadLayoutTuple>
|
||||
__device__ void blockwise_copy(const SrcTensorType& src_tensor,
|
||||
DstTensorType& dst_tensor,
|
||||
[[maybe_unused]] ThreadLayoutTuple& thread_layout)
|
||||
typename ThreadShape,
|
||||
typename ThreadUnrolledDesc>
|
||||
__device__ void
|
||||
blockwise_copy(const SrcTensorType& src_tensor,
|
||||
DstTensorType& dst_tensor,
|
||||
[[maybe_unused]] const Layout<ThreadShape, ThreadUnrolledDesc>& thread_layout)
|
||||
{
|
||||
static_assert(SrcTensorType::IsDynamicBuffer && DstTensorType::IsDynamicBuffer);
|
||||
static_assert(is_detected<is_tuple, DimAccessOrderTuple>::value);
|
||||
@@ -199,12 +192,12 @@ __device__ void blockwise_copy(const SrcTensorType& src_tensor,
|
||||
|
||||
constexpr auto tile_lengths_seq =
|
||||
generate_sequence_v2([](auto I) { return size(SrcShapeType{}.At(I)); }, Number<num_dims>{});
|
||||
constexpr auto thread_layout_seq = generate_sequence_v2(
|
||||
[](auto I) { return size(ThreadLayoutTuple{}.At(I)); }, Number<num_dims>{});
|
||||
constexpr auto thread_layout_seq =
|
||||
generate_sequence_v2([](auto I) { return size<I>(ThreadShape{}); }, Number<num_dims>{});
|
||||
constexpr auto dim_access_order = generate_sequence_v2(
|
||||
[](auto I) { return DimAccessOrderTuple{}.At(I); }, Number<num_dims>{});
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<size(ThreadLayoutTuple{})>;
|
||||
using ThisThreadBlock = ThisThreadBlock<size(ThreadShape{})>;
|
||||
|
||||
// Perform copy between DynamicBuffers
|
||||
auto transfer = ThreadGroupTensorSliceTransfer_v7<
|
||||
|
||||
@@ -9,9 +9,14 @@
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
|
||||
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond INTERNAL
|
||||
namespace ck {
|
||||
namespace wrapper {
|
||||
/// @endcond
|
||||
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond INTERNAL
|
||||
namespace {
|
||||
namespace detail {
|
||||
/**
|
||||
@@ -45,11 +50,13 @@ __device__ constexpr auto GetBlockDescriptor()
|
||||
|
||||
} // namespace detail
|
||||
} // namespace
|
||||
/// @endcond
|
||||
|
||||
/**
|
||||
* \brief Perform blockwise gemm xdl on tensors stored in lds. Result will be
|
||||
* stored in Vgpr register. A data layout must be (MPerBlock, KPerBlock) and B
|
||||
* data layout must be (NPerBlock, KPerBlock).
|
||||
* stored in Vgpr register. A data layout must be (MPerBlock, KPerBlock) or
|
||||
* (K0PerBlock, MPerBlock, K1) and B data layout must be (NPerBlock, KPerBlock)
|
||||
* or (K0PerBlock, NPerBlock, K1).
|
||||
*
|
||||
* \note C output Vgpr register layout (8D):
|
||||
* - MXdlPerWave - The number of MFMA instructions run by single wave in M
|
||||
@@ -71,9 +78,9 @@ __device__ constexpr auto GetBlockDescriptor()
|
||||
* \tparam BlockSize Tensor to pad.
|
||||
* \tparam GemmTraits Traits of gemm xdl operation.
|
||||
* \param a_local_tile_tensor A tensor in LDS memory for blockwise gemm
|
||||
* (MPerBlock, KPerBlock) layout.
|
||||
* (MPerBlock, KPerBlock) or (K0PerBlock, MPerBlock, K1) layout.
|
||||
* \param b_local_tile_tensor B tensor in LDS memory for blockwise gemm
|
||||
* (NPerBlock, KPerBlock) layout.
|
||||
* (NPerBlock, KPerBlock) or (K0PerBlock, NPerBlock, K1) layout.
|
||||
* \param c_reg_tensor C tensor VGPR memory for blockwise gemm.
|
||||
*/
|
||||
template <typename DataType,
|
||||
@@ -86,6 +93,8 @@ __device__ void blockwise_gemm_xdl(const ATensorType& a_local_tile_tensor,
|
||||
const BTensorType& b_local_tile_tensor,
|
||||
CTensorType& c_reg_tensor)
|
||||
{
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
static_assert(ATensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds);
|
||||
static_assert(BTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds);
|
||||
static_assert(CTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Vgpr);
|
||||
@@ -99,10 +108,18 @@ __device__ void blockwise_gemm_xdl(const ATensorType& a_local_tile_tensor,
|
||||
using ATileLayout = remove_cvref_t<decltype(layout(a_local_tile_tensor))>;
|
||||
using BTileLayout = remove_cvref_t<decltype(layout(b_local_tile_tensor))>;
|
||||
|
||||
static_assert(typename ATileLayout::LayoutShape{}.Size() ==
|
||||
typename BTileLayout::LayoutShape{}.Size());
|
||||
constexpr bool is_3d_desc = typename ATileLayout::LayoutShape{}.Size() == I3;
|
||||
|
||||
using ABlockDesc_K0_M_K1_Type =
|
||||
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>());
|
||||
conditional_t<is_3d_desc,
|
||||
typename ATileLayout::LayoutUnrolledDescriptorType,
|
||||
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>())>;
|
||||
using BBlockDesc_K0_N_K1_Type =
|
||||
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>());
|
||||
conditional_t<is_3d_desc,
|
||||
typename BTileLayout::LayoutUnrolledDescriptorType,
|
||||
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>())>;
|
||||
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
DataType,
|
||||
@@ -168,14 +185,22 @@ make_blockwise_gemm_xdl_c_local_partition(CTensorType& c_local_tile_tensor)
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
|
||||
static_assert(typename ATileLayout::LayoutShape{}.Size() ==
|
||||
typename BTileLayout::LayoutShape{}.Size());
|
||||
|
||||
constexpr bool is_integer =
|
||||
is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>;
|
||||
using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>;
|
||||
|
||||
constexpr bool is_3d_desc = typename ATileLayout::LayoutShape{}.Size() == I3;
|
||||
using ABlockDesc_K0_M_K1_Type =
|
||||
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>());
|
||||
conditional_t<is_3d_desc,
|
||||
typename ATileLayout::LayoutUnrolledDescriptorType,
|
||||
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>())>;
|
||||
using BBlockDesc_K0_N_K1_Type =
|
||||
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>());
|
||||
conditional_t<is_3d_desc,
|
||||
typename BTileLayout::LayoutUnrolledDescriptorType,
|
||||
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>())>;
|
||||
|
||||
using BlockwiseGemmXdlops =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
@@ -233,19 +258,45 @@ make_blockwise_gemm_xdl_c_local_partition(CTensorType& c_local_tile_tensor)
|
||||
|
||||
const auto partition_desc = BlockwiseGemmXdlops::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(
|
||||
layout(c_local_tile_tensor).GetUnrolledDescriptor());
|
||||
|
||||
const auto lower_upper_dims =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<8>{});
|
||||
|
||||
auto sliced_desc = transform_tensor_descriptor(
|
||||
partition_desc,
|
||||
make_tuple(
|
||||
make_slice_transform(partition_shape.At(Number<0>{}),
|
||||
m_thread_data_on_grid_idx[I0],
|
||||
partition_shape.At(Number<0>{}) + m_thread_data_on_grid_idx[I0]),
|
||||
make_slice_transform(partition_shape.At(Number<1>{}),
|
||||
n_thread_data_on_grid_idx[I0],
|
||||
partition_shape.At(Number<1>{}) + n_thread_data_on_grid_idx[I0]),
|
||||
make_slice_transform(partition_shape.At(Number<2>{}),
|
||||
m_thread_data_on_grid_idx[I1],
|
||||
partition_shape.At(Number<2>{}) + m_thread_data_on_grid_idx[I1]),
|
||||
make_slice_transform(partition_shape.At(Number<3>{}),
|
||||
n_thread_data_on_grid_idx[I1],
|
||||
partition_shape.At(Number<3>{}) + n_thread_data_on_grid_idx[I1]),
|
||||
make_slice_transform(partition_shape.At(Number<4>{}),
|
||||
m_thread_data_on_grid_idx[I2],
|
||||
partition_shape.At(Number<4>{}) + m_thread_data_on_grid_idx[I2]),
|
||||
make_slice_transform(partition_shape.At(Number<5>{}),
|
||||
m_thread_data_on_grid_idx[I3],
|
||||
partition_shape.At(Number<5>{}) + m_thread_data_on_grid_idx[I3]),
|
||||
make_slice_transform(partition_shape.At(Number<6>{}),
|
||||
m_thread_data_on_grid_idx[I4],
|
||||
partition_shape.At(Number<6>{}) + m_thread_data_on_grid_idx[I4]),
|
||||
make_slice_transform(partition_shape.At(Number<7>{}),
|
||||
n_thread_data_on_grid_idx[I2],
|
||||
partition_shape.At(Number<7>{}) + n_thread_data_on_grid_idx[I2])),
|
||||
lower_upper_dims,
|
||||
lower_upper_dims);
|
||||
|
||||
const auto partition_layout =
|
||||
Layout<remove_reference_t<decltype(partition_shape)>, decltype(partition_desc)>(
|
||||
partition_shape, partition_desc);
|
||||
Layout<remove_reference_t<decltype(partition_shape)>, decltype(sliced_desc)>(
|
||||
partition_shape, sliced_desc);
|
||||
auto partition_tensor = make_tensor<CTensorType::TensorBufferAddressSpace>(
|
||||
c_local_tile_tensor.GetPointer(), partition_layout);
|
||||
partition_tensor.SetMultiIdxOffset(make_multi_index(m_thread_data_on_grid_idx[I0],
|
||||
n_thread_data_on_grid_idx[I0],
|
||||
m_thread_data_on_grid_idx[I1],
|
||||
n_thread_data_on_grid_idx[I1],
|
||||
m_thread_data_on_grid_idx[I2],
|
||||
m_thread_data_on_grid_idx[I3],
|
||||
m_thread_data_on_grid_idx[I4],
|
||||
n_thread_data_on_grid_idx[I2]));
|
||||
return partition_tensor;
|
||||
}
|
||||
|
||||
@@ -292,14 +343,22 @@ __host__ __device__ constexpr auto make_blockwise_gemm_xdl_c_vgpr()
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
|
||||
static_assert(typename ATileLayout::LayoutShape{}.Size() ==
|
||||
typename BTileLayout::LayoutShape{}.Size());
|
||||
|
||||
constexpr bool is_integer =
|
||||
is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>;
|
||||
using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>;
|
||||
|
||||
constexpr bool is_3d_desc = typename ATileLayout::LayoutShape{}.Size() == I3;
|
||||
using ABlockDesc_K0_M_K1_Type =
|
||||
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>());
|
||||
conditional_t<is_3d_desc,
|
||||
typename ATileLayout::LayoutUnrolledDescriptorType,
|
||||
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>())>;
|
||||
using BBlockDesc_K0_N_K1_Type =
|
||||
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>());
|
||||
conditional_t<is_3d_desc,
|
||||
typename BTileLayout::LayoutUnrolledDescriptorType,
|
||||
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>())>;
|
||||
|
||||
using BlockwiseGemmXdlops =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
@@ -326,9 +385,8 @@ __host__ __device__ constexpr auto make_blockwise_gemm_xdl_c_vgpr()
|
||||
const auto vgpr_layout = Layout<remove_reference_t<decltype(vgpr_shape)>, decltype(vgpr_desc)>(
|
||||
vgpr_shape, vgpr_desc);
|
||||
// Get vector type for Vgpr
|
||||
using BlockwiseGemmCThreadBufferType =
|
||||
remove_reference_t<decltype(BlockwiseGemmXdlops{}.GetCThreadBuffer())>;
|
||||
using VgprVectorType = typename BlockwiseGemmCThreadBufferType::V;
|
||||
constexpr index_t ScalarPerVector = BlockwiseGemmXdlops::xdlops_gemm.GetRegSizePerXdlops();
|
||||
using VgprVectorType = typename vector_type<GemmAccDataType, ScalarPerVector>::type;
|
||||
return ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, VgprVectorType>(
|
||||
vgpr_layout);
|
||||
}
|
||||
|
||||
@@ -7,9 +7,14 @@
|
||||
#include "utils/tensor_partition.hpp"
|
||||
#include "utils/layout_utils.hpp"
|
||||
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond INTERNAL
|
||||
namespace ck {
|
||||
namespace wrapper {
|
||||
/// @endcond
|
||||
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond INTERNAL
|
||||
namespace {
|
||||
namespace detail {
|
||||
/**
|
||||
@@ -172,10 +177,10 @@ __host__ __device__ constexpr auto GenerateUpperDims(const Tuple<Transforms...>&
|
||||
}
|
||||
}
|
||||
|
||||
template <typename... Ts, typename Shape, typename FlattenDescriptor>
|
||||
template <typename... Ts, typename Shape, typename UnrolledDescriptor>
|
||||
__host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple<Ts...>& idx,
|
||||
const Shape& shape,
|
||||
const FlattenDescriptor& flatten_desc)
|
||||
const UnrolledDescriptor& flatten_desc)
|
||||
{
|
||||
constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size();
|
||||
|
||||
@@ -189,6 +194,7 @@ __host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple<Ts...>&
|
||||
}
|
||||
} // namespace detail
|
||||
} // namespace
|
||||
/// @endcond
|
||||
|
||||
/**
|
||||
* \brief Tensor wrapper that performs static and dynamic buffer logic.
|
||||
@@ -394,6 +400,8 @@ struct Tensor
|
||||
}
|
||||
|
||||
private:
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond INTERNAL
|
||||
using DynamicBufferType = DynamicBuffer<BufferAddressSpace,
|
||||
ElementType,
|
||||
ElementSpaceSize,
|
||||
@@ -428,6 +436,7 @@ struct Tensor
|
||||
// tensor descriptor (thus all it's transforms) and is linear (1D).
|
||||
// We store base_offset_ to avoid multiple recalculations.
|
||||
index_t base_offset_;
|
||||
/// @endcond
|
||||
};
|
||||
|
||||
} // namespace wrapper
|
||||
|
||||
@@ -5,8 +5,11 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond INTERNAL
|
||||
namespace ck {
|
||||
namespace wrapper {
|
||||
/// @endcond
|
||||
|
||||
/**
|
||||
* \brief Traits for blockwise gemm xdl.
|
||||
@@ -20,48 +23,57 @@ namespace wrapper {
|
||||
* \tparam K1Value The number of K-dim elements that are packed together as
|
||||
* a separate logical dimension. Usually aligns with vector load size.
|
||||
*/
|
||||
template <index_t MPerXDLValue,
|
||||
index_t NPerXDLValue,
|
||||
index_t MXdlPerWaveValue,
|
||||
index_t NXdlPerWaveValue,
|
||||
index_t K1Value>
|
||||
template <typename MPerXDLValue,
|
||||
typename NPerXDLValue,
|
||||
typename MXdlPerWaveValue,
|
||||
typename NXdlPerWaveValue,
|
||||
typename K1Value>
|
||||
struct BlockwisGemmXdlTraits
|
||||
{
|
||||
static constexpr index_t MPerXDL = MPerXDLValue;
|
||||
static constexpr index_t NPerXDL = NPerXDLValue;
|
||||
static constexpr index_t MXdlPerWave = MXdlPerWaveValue;
|
||||
static constexpr index_t NXdlPerWave = NXdlPerWaveValue;
|
||||
static constexpr index_t K1 = K1Value;
|
||||
static constexpr auto MPerXDL = MPerXDLValue{};
|
||||
static constexpr auto NPerXDL = NPerXDLValue{};
|
||||
static constexpr auto MXdlPerWave = MXdlPerWaveValue{};
|
||||
static constexpr auto NXdlPerWave = NXdlPerWaveValue{};
|
||||
static constexpr auto K1 = K1Value{};
|
||||
};
|
||||
|
||||
// K1 = 4
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 4>
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1
|
||||
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<4>, Number<2>, Number<4>>
|
||||
{
|
||||
};
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 4>
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_4K1
|
||||
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<4>, Number<4>>
|
||||
{
|
||||
};
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 4>
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1
|
||||
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<2>, Number<4>>
|
||||
{
|
||||
};
|
||||
// K1 = 8
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 8>
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_8K1
|
||||
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<4>, Number<2>, Number<8>>
|
||||
{
|
||||
};
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 8>
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_8K1
|
||||
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<4>, Number<8>>
|
||||
{
|
||||
};
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 8>
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1
|
||||
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<2>, Number<8>>
|
||||
{
|
||||
};
|
||||
// K1 = 16
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 16>
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_16K1
|
||||
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<4>, Number<2>, Number<16>>
|
||||
{
|
||||
};
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 16>
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_16K1
|
||||
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<4>, Number<16>>
|
||||
{
|
||||
};
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 16>
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1
|
||||
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<2>, Number<16>>
|
||||
{
|
||||
};
|
||||
|
||||
|
||||
17
include/ck/wrapper/utils/kernel_utils.hpp
Normal file
17
include/ck/wrapper/utils/kernel_utils.hpp
Normal file
@@ -0,0 +1,17 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond INTERNAL
|
||||
namespace ck {
|
||||
namespace wrapper {
|
||||
/// @endcond
|
||||
|
||||
#define __CK_WRAPPER_LAUNCH_BOUNDS__ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
|
||||
} // namespace wrapper
|
||||
} // namespace ck
|
||||
@@ -15,12 +15,16 @@
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_description/multi_index_transform_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace wrapper {
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond
|
||||
/// @cond INTERNAL
|
||||
namespace ck {
|
||||
namespace wrapper {
|
||||
/// @endcond
|
||||
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond INTERNAL
|
||||
// forward declaration
|
||||
template <typename Shape, typename UnrolledDescriptorType>
|
||||
struct Layout;
|
||||
@@ -29,6 +33,7 @@ template <typename T>
|
||||
using is_tuple = decltype(std::declval<T&>().IsTuple());
|
||||
|
||||
namespace {
|
||||
namespace detail {
|
||||
/**
|
||||
* \brief Generate packed (column-major) strides if not passed
|
||||
*
|
||||
@@ -83,6 +88,7 @@ __host__ __device__ constexpr auto MakeUnrolledDescriptor(const LayoutShape& sha
|
||||
return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
|
||||
}
|
||||
}
|
||||
} // namespace detail
|
||||
} // namespace
|
||||
|
||||
/// @endcond
|
||||
@@ -98,8 +104,9 @@ __host__ __device__ constexpr auto MakeUnrolledDescriptor(const LayoutShape& sha
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides& strides)
|
||||
{
|
||||
using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Strides{}));
|
||||
return Layout<Shape, UnrolledDescriptorType>(shape, MakeUnrolledDescriptor(shape, strides));
|
||||
using UnrolledDescriptorType = decltype(detail::MakeUnrolledDescriptor(Shape{}, Strides{}));
|
||||
return Layout<Shape, UnrolledDescriptorType>(shape,
|
||||
detail::MakeUnrolledDescriptor(shape, strides));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -112,13 +119,12 @@ __host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides
|
||||
template <typename Shape>
|
||||
__host__ __device__ constexpr auto make_layout(const Shape& shape)
|
||||
{
|
||||
using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Tuple<>{}));
|
||||
return Layout<Shape, UnrolledDescriptorType>(shape, MakeUnrolledDescriptor(shape, Tuple<>{}));
|
||||
using UnrolledDescriptorType = decltype(detail::MakeUnrolledDescriptor(Shape{}, Tuple<>{}));
|
||||
return Layout<Shape, UnrolledDescriptorType>(shape,
|
||||
detail::MakeUnrolledDescriptor(shape, Tuple<>{}));
|
||||
}
|
||||
|
||||
// Layout helpers
|
||||
// get
|
||||
|
||||
/**
|
||||
* \private
|
||||
* \brief Get dim.
|
||||
@@ -152,8 +158,8 @@ __host__ __device__ constexpr auto get(const Tuple<Dims...>& tuple)
|
||||
* \param layout Layout to create sub layout.
|
||||
* \return Requsted sub layout.
|
||||
*/
|
||||
template <index_t idx, typename Shape, typename FlattenDesc>
|
||||
__host__ __device__ constexpr auto get(const Layout<Shape, FlattenDesc>& layout)
|
||||
template <index_t idx, typename Shape, typename UnrolledDesc>
|
||||
__host__ __device__ constexpr auto get(const Layout<Shape, UnrolledDesc>& layout)
|
||||
{
|
||||
const auto& shape = layout.GetShape();
|
||||
const auto new_shape = get<idx>(shape);
|
||||
@@ -427,5 +433,91 @@ __host__ __device__ constexpr const auto& shape(const LayoutType& layout)
|
||||
return layout.GetShape();
|
||||
}
|
||||
|
||||
// pad
|
||||
/**
|
||||
* \brief Pad layout shapes to be adjusted to tile lengths.
|
||||
*
|
||||
*
|
||||
* \param layout Layout to pad.
|
||||
* \param tile_lengths Tile lengths to align layout shape.
|
||||
* \return Padded layout.
|
||||
*/
|
||||
template <typename Shape, typename UnrolledDesc, typename TileLengths>
|
||||
__host__ __device__ constexpr auto pad(const Layout<Shape, UnrolledDesc>& layout,
|
||||
const TileLengths& tile_lengths)
|
||||
{
|
||||
auto& unrolled_desc = layout.GetUnrolledDescriptor();
|
||||
// Generate sequence with ones to mark that all dims will be padded
|
||||
constexpr auto do_pads_seq =
|
||||
generate_sequence_v2([](auto) { return Number<1>{}; }, Number<Shape::Size()>{});
|
||||
// Create descriptor with padding
|
||||
auto padded_desc =
|
||||
tensor_operation::device::PadTensorDescriptor(unrolled_desc, tile_lengths, do_pads_seq);
|
||||
// Generate padded shape
|
||||
const auto padded_shape = generate_tuple(
|
||||
[&](auto i) { return padded_desc.GetLength(Number<i>{}); }, Number<TileLengths::Size()>{});
|
||||
// Create layout
|
||||
return Layout<decltype(padded_shape), decltype(padded_desc)>(padded_shape, padded_desc);
|
||||
}
|
||||
|
||||
// unmerge
|
||||
/**
|
||||
* \brief Unmerge selected dim in layout.
|
||||
*
|
||||
* \tparam Idx Index to dimension being unmerged.
|
||||
* \param layout Layout to pad.
|
||||
* \param new_lengths Dimensions into which the indicated dimension will be divided.
|
||||
* \param new_indexes Indexes to shuffle dims. Dims for unmerged dim should be nested.
|
||||
* \return Unmerged layout.
|
||||
*/
|
||||
template <index_t Idx, typename Shape, typename UnrolledDesc, typename NewLengths, typename NewIdxs>
|
||||
__host__ __device__ constexpr auto unmerge(const Layout<Shape, UnrolledDesc>& layout,
|
||||
const NewLengths& new_lengths,
|
||||
[[maybe_unused]] const NewIdxs& new_indexes)
|
||||
{
|
||||
const auto& layout_shape = shape(layout);
|
||||
auto& unrolled_desc = layout.GetUnrolledDescriptor();
|
||||
constexpr auto dims = Shape::Size();
|
||||
// Generate transforms
|
||||
const auto transforms = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == Idx)
|
||||
{
|
||||
return make_unmerge_transform(new_lengths);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_pass_through_transform(layout_shape.At(i));
|
||||
}
|
||||
},
|
||||
Number<dims>{});
|
||||
|
||||
constexpr auto lower_dims =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<dims>{});
|
||||
constexpr auto upper_dims = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, NewIdxs>>::value)
|
||||
{
|
||||
constexpr auto idxs_tuple = tuple_element_t<i.value, NewIdxs>{};
|
||||
return to_sequence(idxs_tuple);
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t index = tuple_element_t<i.value, NewIdxs>{};
|
||||
return Sequence<index>{};
|
||||
}
|
||||
},
|
||||
Number<dims>{});
|
||||
|
||||
const auto unmerged_desc =
|
||||
transform_tensor_descriptor(unrolled_desc, transforms, lower_dims, upper_dims);
|
||||
const auto unmerged_shape =
|
||||
generate_tuple([&](auto i) { return unmerged_desc.GetLength(Number<i>{}); },
|
||||
Number<decltype(unmerged_desc)::GetNumOfVisibleDimension()>{});
|
||||
|
||||
// Create layout
|
||||
return Layout<decltype(unmerged_shape), decltype(unmerged_desc)>(unmerged_shape, unmerged_desc);
|
||||
}
|
||||
|
||||
} // namespace wrapper
|
||||
} // namespace ck
|
||||
|
||||
@@ -6,13 +6,17 @@
|
||||
#include "tensor_utils.hpp"
|
||||
#include "layout_utils.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_description/cluster_descriptor.hpp"
|
||||
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond INTERNAL
|
||||
namespace ck {
|
||||
namespace wrapper {
|
||||
/// @endcond
|
||||
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond INTERNAL
|
||||
namespace {
|
||||
|
||||
namespace detail {
|
||||
@@ -44,8 +48,9 @@ __host__ __device__ constexpr auto CalculateLocalPartitionShape(const Tuple<Ts..
|
||||
* \brief Apply projection.
|
||||
*
|
||||
* \param base_tuple Tuple to apply projection.
|
||||
* \param projection Projection to remove selected dim from partitioning.
|
||||
* slice(X) to remove, where X is dim size, Number<1>{} to keep.
|
||||
* \param projection Projection is used to remove selected dim from
|
||||
* partitioning. Use `slice(X)` to remove dimension, where X is dim
|
||||
* size. Use `Number<1>{}` to keep it.
|
||||
* \return Multi index after projection.
|
||||
*/
|
||||
template <typename MultiIndex, typename ProjectionTuple>
|
||||
@@ -73,7 +78,7 @@ ApplyProjection([[maybe_unused]] const MultiIndex& base_tuple,
|
||||
}
|
||||
else
|
||||
{
|
||||
return base_tuple.At(i_num);
|
||||
return make_tuple(base_tuple.At(i_num));
|
||||
}
|
||||
},
|
||||
Number<MultiIndex::Size()>{});
|
||||
@@ -86,8 +91,9 @@ ApplyProjection([[maybe_unused]] const MultiIndex& base_tuple,
|
||||
* \brief Calculate shape with dims from projection.
|
||||
*
|
||||
* \param shape Base tensor shape.
|
||||
* \param projection Projection to remove selected dim from partitioning.
|
||||
* slice(X) to remove, where X is dim size, Number<1>{} to keep.
|
||||
* \param projection Projection is used to remove selected dim from
|
||||
* partitioning. Use `slice(X)` to remove dimension, where X is dim
|
||||
* size. Use `Number<1>{}` to keep it.
|
||||
* \return Shape with dims from projection
|
||||
*/
|
||||
template <typename... Ts, typename... Ps>
|
||||
@@ -119,22 +125,14 @@ __host__ __device__ constexpr auto CalculateShapeWithProjection(const Tuple<Ts..
|
||||
*
|
||||
* \param shape Base tensor shape.
|
||||
* \param tile_shape Tile shape.
|
||||
* \param projection Projection is used to remove selected dim from
|
||||
* partitioning. Use `slice(X)` to remove dimension, where X is dim
|
||||
* size. Use `Number<1>{}` to keep it.
|
||||
* \return Tuple with blocks number.
|
||||
*/
|
||||
template <typename... Ts, typename... Ls, typename... Ps>
|
||||
__host__ __device__ constexpr auto CalculateGridSize(const Tuple<Ts...>& shape,
|
||||
const Tuple<Ls...>& tile_shape,
|
||||
const Tuple<Ps...>& projection)
|
||||
const Tuple<Ls...>& tile_shape)
|
||||
{
|
||||
auto shape_with_projection = CalculateShapeWithProjection(shape, projection);
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return ck::math::integer_divide_ceil(size<i>(shape_with_projection),
|
||||
size<i>(tile_shape));
|
||||
},
|
||||
[&](auto i) { return ck::math::integer_divide_ceil(size<i>(shape), size<i>(tile_shape)); },
|
||||
Number<Tuple<Ls...>::Size()>{});
|
||||
}
|
||||
|
||||
@@ -155,6 +153,54 @@ CalculateOffsetMultiIdxs(const ThreadIdxs& thread_idxs,
|
||||
return thread_idxs * partition_lengths_seq + old_offset_idxs;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Select dims to partition (skip if slice).
|
||||
*
|
||||
* \param block_idxs Input block indexes.
|
||||
* \return Partitioned dims.
|
||||
*/
|
||||
template <typename BlockIdxs>
|
||||
__host__ __device__ constexpr auto GetDimsToPartition([[maybe_unused]] const BlockIdxs& block_idxs)
|
||||
{
|
||||
const auto dims_to_partition = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(!is_detected<is_slice, tuple_element_t<i, BlockIdxs>>::value)
|
||||
{
|
||||
return Number<i>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Tuple<>{};
|
||||
}
|
||||
},
|
||||
Number<BlockIdxs::Size()>{});
|
||||
// Remove empty tuples
|
||||
return UnrollNestedTuple<0, 1>(dims_to_partition);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Replace slices with zeros (Slice dims are not partitioned).
|
||||
*
|
||||
* \param block_idxs Input block indexes.
|
||||
* \return Parsed dims.
|
||||
*/
|
||||
template <typename BlockIdxs>
|
||||
__host__ __device__ constexpr auto ReplaceSlicesWithZeros(const BlockIdxs& block_idxs)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(!is_detected<is_slice, tuple_element_t<i, BlockIdxs>>::value)
|
||||
{
|
||||
return block_idxs.At(i);
|
||||
}
|
||||
else
|
||||
{
|
||||
return Number<0>{};
|
||||
}
|
||||
},
|
||||
Number<BlockIdxs::Size()>{});
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Calculate default projection.
|
||||
*
|
||||
@@ -168,59 +214,96 @@ GenerateDefaultProjection([[maybe_unused]] const TileShape tile_shape)
|
||||
return generate_tuple([&](auto) { return Number<1>{}; }, Number<TileShape::Size()>{});
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Calculate thread multi index from 1d thread index.
|
||||
*
|
||||
* \param thread_layout Layout of threads (could not be nested).
|
||||
* \param thread_id Thread index represented as integer.
|
||||
* \return Multi index.
|
||||
*/
|
||||
template <typename ThreadShape, typename ThreadUnrolledDesc>
|
||||
__host__ __device__ constexpr auto CalculateThreadMultiIdx(
|
||||
[[maybe_unused]] const Layout<ThreadShape, ThreadUnrolledDesc>& thread_layout,
|
||||
const index_t thread_id)
|
||||
{
|
||||
static_assert(ThreadUnrolledDesc::GetNumOfTransform() == 1,
|
||||
"Thread layout should not be transformed.");
|
||||
constexpr auto embed_transform = ThreadUnrolledDesc{}.GetTransforms().At(Number<0>{});
|
||||
constexpr auto shape = ThreadShape{};
|
||||
constexpr auto strides = embed_transform.coefficients_;
|
||||
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr auto num_i = Number<i>{};
|
||||
return (thread_id / strides.At(num_i)) % shape.At(num_i);
|
||||
},
|
||||
Number<ThreadShape::Size()>{});
|
||||
}
|
||||
} // namespace detail
|
||||
} // namespace
|
||||
/// @endcond
|
||||
|
||||
/**
|
||||
* \brief Create local partition for thread (At now only packed partition
|
||||
* is supported).
|
||||
*
|
||||
* \param tensor Tensor for partition.
|
||||
* \param thread_lengths Layout of threads (could not be nested).
|
||||
* \param thread_layout Layout of threads (could not be transformed).
|
||||
* \param thread_id Thread index represented as integer.
|
||||
* \param projection Projection is used to remove selected dim from
|
||||
* partitioning. Use `slice(X)` to remove dimension, where X is dim
|
||||
* size. Use `Number<1>{}` to keep it.
|
||||
* \return Partition tensor.
|
||||
*/
|
||||
template <typename TensorType, typename ThreadLengthsTuple, typename ProjectionTuple>
|
||||
template <typename TensorType,
|
||||
typename ThreadShape,
|
||||
typename ThreadUnrolledDesc,
|
||||
typename ProjectionTuple>
|
||||
__host__ __device__ constexpr auto
|
||||
make_local_partition(TensorType& tensor,
|
||||
[[maybe_unused]] const ThreadLengthsTuple& thread_lengths,
|
||||
[[maybe_unused]] const Layout<ThreadShape, ThreadUnrolledDesc>& thread_layout,
|
||||
const index_t thread_id,
|
||||
const ProjectionTuple& projection)
|
||||
{
|
||||
static_assert(!IsNestedTuple(ThreadLengthsTuple{}));
|
||||
static_assert(!IsNestedTuple(ThreadShape{}));
|
||||
// Calculate new partition shape
|
||||
const auto& tensor_shape = shape(tensor);
|
||||
// Calculate projected thread lengths
|
||||
constexpr auto projected_thread_lengths =
|
||||
detail::ApplyProjection(ThreadLengthsTuple{}, ProjectionTuple{});
|
||||
detail::ApplyProjection(ThreadShape{}, ProjectionTuple{});
|
||||
constexpr auto partition_shape =
|
||||
detail::CalculateLocalPartitionShape(decltype(tensor_shape){}, projected_thread_lengths);
|
||||
// Create Thread Cluster Descriptor
|
||||
constexpr auto partition_shape_seq =
|
||||
generate_sequence_v2([&](auto I) { return size<I>(partition_shape); },
|
||||
Number<decltype(partition_shape)::Size()>{});
|
||||
constexpr auto thread_lengths_seq =
|
||||
generate_sequence_v2([&](auto I) { return size<I>(ThreadLengthsTuple{}); },
|
||||
Number<ThreadLengthsTuple::Size()>{});
|
||||
constexpr auto thread_cluster_desc_ = make_cluster_descriptor(thread_lengths_seq);
|
||||
// Calculate thread idxs and offsets
|
||||
const auto thread_idxs = thread_cluster_desc_.CalculateBottomIndex(make_multi_index(thread_id));
|
||||
const auto thread_idxs = detail::CalculateThreadMultiIdx(thread_layout, thread_id);
|
||||
// Apply projection on thread idxs to remove not needed idxs
|
||||
const auto projected_thread_idxs = detail::ApplyProjection(thread_idxs, projection);
|
||||
const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs(
|
||||
projected_thread_idxs, partition_shape_seq, tensor.GetMultiIdxOffsets());
|
||||
// Create new layout and tensor
|
||||
auto& unrolled_desc = layout(tensor).GetUnrolledDescriptor();
|
||||
// Slice descriptor
|
||||
const auto transforms = generate_tuple(
|
||||
[&](auto i) {
|
||||
return make_slice_transform(partition_shape.At(i),
|
||||
offset_multi_idxs.At(i),
|
||||
partition_shape.At(i) + offset_multi_idxs.At(i));
|
||||
},
|
||||
Number<remove_reference_t<decltype(tensor_shape)>::Size()>{});
|
||||
const auto lower_upper_dims =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; },
|
||||
Number<remove_reference_t<decltype(tensor_shape)>::Size()>{});
|
||||
auto sliced_desc =
|
||||
transform_tensor_descriptor(unrolled_desc, transforms, lower_upper_dims, lower_upper_dims);
|
||||
// Create layout
|
||||
const auto partition_layout =
|
||||
Layout<remove_reference_t<decltype(partition_shape)>, decltype(unrolled_desc)>(
|
||||
partition_shape, unrolled_desc);
|
||||
Layout<remove_reference_t<decltype(partition_shape)>, decltype(sliced_desc)>(
|
||||
partition_shape, sliced_desc);
|
||||
auto partition_tensor =
|
||||
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), partition_layout);
|
||||
// Apply offsets
|
||||
partition_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs));
|
||||
return partition_tensor;
|
||||
}
|
||||
|
||||
@@ -233,12 +316,13 @@ make_local_partition(TensorType& tensor,
|
||||
* \param thread_id Thread index represented as integer.
|
||||
* \return Partition tensor.
|
||||
*/
|
||||
template <typename TensorType, typename ThreadLengthsTuple>
|
||||
__host__ __device__ constexpr auto make_local_partition(TensorType& tensor,
|
||||
const ThreadLengthsTuple& thread_lengths,
|
||||
const index_t thread_id)
|
||||
template <typename TensorType, typename ThreadShape, typename ThreadUnrolledDesc>
|
||||
__host__ __device__ constexpr auto
|
||||
make_local_partition(TensorType& tensor,
|
||||
const Layout<ThreadShape, ThreadUnrolledDesc>& thread_lengths,
|
||||
const index_t thread_id)
|
||||
{
|
||||
const auto projection = detail::GenerateDefaultProjection(ThreadLengthsTuple{});
|
||||
const auto projection = detail::GenerateDefaultProjection(ThreadShape{});
|
||||
return make_local_partition(tensor, thread_lengths, thread_id, projection);
|
||||
}
|
||||
|
||||
@@ -252,21 +336,24 @@ __host__ __device__ constexpr auto make_local_partition(TensorType& tensor,
|
||||
*
|
||||
* \param tensor Tensor for partition.
|
||||
* \param tile_shape Shapes of requested tile.
|
||||
* \param block_id Block index represented as integer.
|
||||
* \param projection Projection to remove selected dim from partitioning.
|
||||
* slice(X) to remove, where X is dim size, Number<1>{} to keep.
|
||||
* \param block_idxs Tuple of block indexes represented as integer. If slice,
|
||||
* then get whole dim.
|
||||
* \param projection Projection is used to remove selected dim from
|
||||
* partitioning. Use `slice(X)` to remove dimension, where X is dim
|
||||
* size. Use `Number<1>{}` to keep it.
|
||||
* \return Tile tensor.
|
||||
*/
|
||||
template <typename TensorType, typename BlockShapeTuple, typename ProjectionTuple>
|
||||
template <typename TensorType,
|
||||
typename BlockShapeTuple,
|
||||
typename BlockIdxs,
|
||||
typename ProjectionTuple>
|
||||
__host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
|
||||
const BlockShapeTuple& tile_shape,
|
||||
const index_t block_id,
|
||||
const BlockIdxs& block_idxs,
|
||||
const ProjectionTuple& projection)
|
||||
{
|
||||
static_assert(!IsNestedTuple(BlockShapeTuple{}));
|
||||
|
||||
constexpr bool is_default_projection =
|
||||
is_same_v<ProjectionTuple, decltype(detail::GenerateDefaultProjection(BlockShapeTuple{}))>;
|
||||
static_assert(!IsNestedTuple(BlockIdxs{}));
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -274,49 +361,77 @@ __host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
|
||||
|
||||
auto& aligned_desc = layout(tensor).GetMergedNestingDescriptor();
|
||||
|
||||
// TODO: Enable block_2_tile_map partitioning for non-default projection.
|
||||
if constexpr(BlockShapeTuple::Size() == I2 && is_default_projection)
|
||||
constexpr auto projected_tile_shape =
|
||||
detail::ApplyProjection(BlockShapeTuple{}, ProjectionTuple{});
|
||||
// Number of dims which are partitioned
|
||||
constexpr auto dims_to_partition = detail::GetDimsToPartition(BlockIdxs{});
|
||||
const auto parsed_block_idxs = detail::ReplaceSlicesWithZeros(block_idxs);
|
||||
if constexpr(decltype(dims_to_partition)::Size() == I2)
|
||||
{
|
||||
// Optimized version for 2d tile shape [MxK]
|
||||
const auto shape_with_projection_dims =
|
||||
detail::CalculateShapeWithProjection(shape(tensor), projection);
|
||||
// Set Value for M, N partition
|
||||
const auto M = shape_with_projection_dims.At(dims_to_partition.At(I0));
|
||||
const auto N = shape_with_projection_dims.At(dims_to_partition.At(I1));
|
||||
constexpr auto MPerBlock = BlockShapeTuple{}.At(dims_to_partition.At(I0));
|
||||
constexpr auto NPerBlock = BlockShapeTuple{}.At(dims_to_partition.At(I1));
|
||||
auto m_n_desc = make_naive_tensor_descriptor_packed(make_tuple(M, N));
|
||||
// Get 1D block id
|
||||
const auto grid_size = detail::CalculateGridSize(shape_with_projection_dims, tile_shape);
|
||||
const auto block_lengths_desc = make_naive_tensor_descriptor_packed(grid_size);
|
||||
const index_t block_id_1d = block_lengths_desc.CalculateOffset(parsed_block_idxs);
|
||||
// Optimized version for 2d tile shape [MxN]
|
||||
const auto block_2_tile_map =
|
||||
BlockToCTileMap_M00_N0_M01Adapt<BlockShapeTuple{}.At(I0),
|
||||
BlockShapeTuple{}.At(I1),
|
||||
remove_cvref_t<decltype(aligned_desc)>>(aligned_desc);
|
||||
BlockToCTileMap_M00_N0_M01Adapt<MPerBlock,
|
||||
NPerBlock,
|
||||
remove_cvref_t<decltype(m_n_desc)>>(m_n_desc);
|
||||
const auto block_work_idx =
|
||||
block_2_tile_map.CalculateBottomIndex(make_multi_index(block_id));
|
||||
block_2_tile_map.CalculateBottomIndex(make_multi_index(block_id_1d));
|
||||
const index_t m_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * size<0>(tile_shape));
|
||||
const index_t k_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * size<1>(tile_shape));
|
||||
const auto offset_multi_idxs =
|
||||
make_tuple(m_block_data_idx_on_grid, k_block_data_idx_on_grid);
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
|
||||
// Apply 0 for non partitioned dims
|
||||
const auto offset_multi_idxs = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == dims_to_partition.At(I0))
|
||||
{
|
||||
return m_block_data_idx_on_grid;
|
||||
}
|
||||
else if constexpr(i == dims_to_partition.At(I1))
|
||||
{
|
||||
return n_block_data_idx_on_grid;
|
||||
}
|
||||
else
|
||||
{
|
||||
return Number<0>{};
|
||||
}
|
||||
},
|
||||
Number<BlockShapeTuple::Size()>{});
|
||||
const auto projected_offset_multi_idxs =
|
||||
detail::ApplyProjection(offset_multi_idxs, projection);
|
||||
// Create new layout and tensor
|
||||
const auto tile_layout =
|
||||
Layout<remove_reference_t<decltype(tile_shape)>, decltype(aligned_desc)>(tile_shape,
|
||||
aligned_desc);
|
||||
Layout<remove_reference_t<decltype(projected_tile_shape)>, decltype(aligned_desc)>(
|
||||
projected_tile_shape, aligned_desc);
|
||||
auto tile_tensor =
|
||||
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), tile_layout);
|
||||
// Apply offsets
|
||||
tile_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs));
|
||||
tile_tensor.SetMultiIdxOffset(to_multi_index(projected_offset_multi_idxs));
|
||||
return tile_tensor;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Calculate offsets
|
||||
// Sequence with data to process per block
|
||||
constexpr auto projected_tile_shape =
|
||||
detail::ApplyProjection(BlockShapeTuple{}, ProjectionTuple{});
|
||||
using ProjectedTileShapeTuple = decltype(projected_tile_shape);
|
||||
constexpr auto projected_tile_shape_seq =
|
||||
generate_sequence_v2([](auto I) { return ProjectedTileShapeTuple{}.At(I); },
|
||||
Number<ProjectedTileShapeTuple::Size()>{});
|
||||
// Tuple with number of blocks
|
||||
const auto block_lengths = detail::CalculateGridSize(shape(tensor), tile_shape, projection);
|
||||
const auto block_cluster_desc_ = make_cluster_descriptor(block_lengths);
|
||||
const auto block_idxs =
|
||||
block_cluster_desc_.CalculateBottomIndex(make_multi_index(block_id));
|
||||
const auto projected_block_idxs = detail::ApplyProjection(block_idxs, projection);
|
||||
const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs(
|
||||
const auto projected_block_idxs =
|
||||
to_multi_index(detail::ApplyProjection(parsed_block_idxs, projection));
|
||||
const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs(
|
||||
projected_block_idxs, projected_tile_shape_seq, tensor.GetMultiIdxOffsets());
|
||||
// Create new layout and tensor
|
||||
const auto tile_layout =
|
||||
@@ -338,52 +453,17 @@ __host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
|
||||
*
|
||||
* \param tensor Tensor for partition.
|
||||
* \param tile_shape Shapes of requested tile.
|
||||
* \param block_id Block index represented as integer.
|
||||
* \param block_idxs Tuple of block indexes represented as integer. If slice,
|
||||
* then get whole dim.
|
||||
* \return Tile tensor.
|
||||
*/
|
||||
template <typename TensorType, typename BlockShapeTuple>
|
||||
__host__ __device__ constexpr auto
|
||||
make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, const index_t block_id)
|
||||
template <typename TensorType, typename BlockShapeTuple, typename BlockIdxs>
|
||||
__host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
|
||||
const BlockShapeTuple& tile_shape,
|
||||
const BlockIdxs& block_idxs)
|
||||
{
|
||||
const auto projection = detail::GenerateDefaultProjection(BlockShapeTuple{});
|
||||
return make_local_tile(tensor, tile_shape, block_id, projection);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Pad tensor shapes to be adjusted to tile lengths.
|
||||
*
|
||||
*
|
||||
* \param tensor Tensor to pad.
|
||||
* \param tile_lengths Tile lengths to align tensor shape.
|
||||
* \return Padded tensor.
|
||||
*/
|
||||
template <typename TensorType, typename TileLengths>
|
||||
__host__ __device__ constexpr auto pad(const TensorType& tensor, const TileLengths& tile_lengths)
|
||||
{
|
||||
const auto& tensor_shape = shape(tensor);
|
||||
using TensorShapeType = remove_reference_t<decltype(tensor_shape)>;
|
||||
auto& unrolled_desc = layout(tensor).GetUnrolledDescriptor();
|
||||
// Generate sequence with ones to mark that all dims will be padded
|
||||
constexpr auto do_pads_seq =
|
||||
generate_sequence_v2([](auto) { return Number<1>{}; }, Number<TensorShapeType::Size()>{});
|
||||
// Create descriptor with padding
|
||||
auto padded_desc =
|
||||
tensor_operation::device::PadTensorDescriptor(unrolled_desc, tile_lengths, do_pads_seq);
|
||||
// Generate padded shape
|
||||
const auto padded_shape = generate_tuple(
|
||||
[&](auto i) {
|
||||
const auto& dim = size<i>(tensor_shape);
|
||||
const auto& tile_length = size<i>(tile_lengths);
|
||||
return ck::math::integer_divide_ceil(dim, tile_length) * tile_length;
|
||||
},
|
||||
Number<TileLengths::Size()>{});
|
||||
// Create layout and tensor
|
||||
const auto padded_layout =
|
||||
Layout<decltype(padded_shape), decltype(padded_desc)>(padded_shape, padded_desc);
|
||||
auto partition_tensor =
|
||||
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), padded_layout);
|
||||
partition_tensor.SetMultiIdxOffset(tensor.GetMultiIdxOffsets());
|
||||
return partition_tensor;
|
||||
return make_local_tile(tensor, tile_shape, block_idxs, projection);
|
||||
}
|
||||
|
||||
} // namespace wrapper
|
||||
|
||||
@@ -13,8 +13,11 @@
|
||||
#include "ck/utility/amd_address_space.hpp"
|
||||
#include "ck/utility/multi_index.hpp"
|
||||
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond INTERNAL
|
||||
namespace ck {
|
||||
namespace wrapper {
|
||||
/// @endcond
|
||||
|
||||
/**
|
||||
* \brief Memory type, allowed members:
|
||||
@@ -27,7 +30,7 @@ namespace wrapper {
|
||||
using MemoryTypeEnum = AddressSpaceEnum;
|
||||
|
||||
// Disable from doxygen docs generation
|
||||
/// @cond
|
||||
/// @cond INTERNAL
|
||||
// forward declarations
|
||||
template <typename Shape, typename UnrolledDescriptorType>
|
||||
struct Layout;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
@@ -6,4 +6,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
target_link_libraries(test_batched_gemm PRIVATE utility device_batched_gemm_instance)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
endforeach()
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
@@ -10,4 +10,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
set(target 1)
|
||||
endif()
|
||||
endif()
|
||||
endforeach()
|
||||
endforeach()
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
@@ -10,4 +10,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
set(target 1)
|
||||
endif()
|
||||
endif()
|
||||
endforeach()
|
||||
endforeach()
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
@@ -26,4 +26,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
endif()
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
endforeach()
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
@@ -6,4 +6,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
target_link_libraries(test_convnd_bwd_data PRIVATE utility device_conv1d_bwd_data_instance device_conv2d_bwd_data_instance device_conv3d_bwd_data_instance)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
endforeach()
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
list(APPEND gpu_list_xdl gfx908 gfx90a gfx940)
|
||||
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102)
|
||||
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102 gfx1103)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102)
|
||||
list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102 gfx1103)
|
||||
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
@@ -17,4 +17,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
target_link_libraries(test_grouped_convnd_bwd_weight_interface PRIVATE utility)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
endforeach()
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test_permute_scale_impl.hpp"
|
||||
#include "profiler/profile_permute_scale_impl.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
@@ -15,15 +15,32 @@ class TestPermute : public ::testing::Test
|
||||
using ADataType = std::tuple_element_t<0, Tuple>;
|
||||
using BDataType = std::tuple_element_t<1, Tuple>;
|
||||
|
||||
void Run()
|
||||
constexpr bool skip_case()
|
||||
{
|
||||
std::vector<std::vector<ck::index_t>> lengths = {
|
||||
{4, 2, 1, 8}, {1, 1, 1, 1}, {16, 8, 32, 64}, {32, 64, 128, 128}};
|
||||
|
||||
for(auto length : lengths)
|
||||
#ifndef CK_ENABLE_FP16
|
||||
if constexpr(ck::is_same_v<ADataType, F16> || ck::is_same_v<BDataType, F16>)
|
||||
{
|
||||
bool success =
|
||||
ck::test_permute_scale_impl<ADataType, BDataType, 4>(true, 2, false, false, length);
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
#ifndef CK_ENABLE_FP32
|
||||
if constexpr(ck::is_same_v<ADataType, F32> || ck::is_same_v<BDataType, F32>)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
|
||||
template <ck::index_t NDims>
|
||||
void Run(std::vector<ck::index_t> lengths,
|
||||
std::vector<ck::index_t> input_strides,
|
||||
std::vector<ck::index_t> output_strides)
|
||||
{
|
||||
if(!skip_case())
|
||||
{
|
||||
bool success = ck::profiler::profile_permute_scale_impl<ADataType, BDataType, NDims>(
|
||||
true, 2, false, false, lengths, input_strides, output_strides);
|
||||
EXPECT_TRUE(success);
|
||||
}
|
||||
}
|
||||
@@ -32,5 +49,52 @@ class TestPermute : public ::testing::Test
|
||||
using KernelTypes = ::testing::Types<std::tuple<F16, F16>, std::tuple<F32, F32>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestPermute, KernelTypes);
|
||||
TYPED_TEST(TestPermute, Test_FP16) { this->Run(); }
|
||||
TYPED_TEST(TestPermute, Test_FP32) { this->Run(); }
|
||||
TYPED_TEST(TestPermute, Test1D)
|
||||
{
|
||||
constexpr ck::index_t NumDims = 1;
|
||||
this->template Run<NumDims>({16}, {1}, {1});
|
||||
this->template Run<NumDims>({16}, {1}, {2});
|
||||
this->template Run<NumDims>({1}, {1}, {1});
|
||||
}
|
||||
|
||||
TYPED_TEST(TestPermute, Test2D)
|
||||
{
|
||||
constexpr ck::index_t NumDims = 2;
|
||||
this->template Run<NumDims>({8, 16}, {16, 1}, {1, 8});
|
||||
this->template Run<NumDims>({8, 16}, {1, 8}, {16, 1});
|
||||
this->template Run<NumDims>({1, 1}, {1, 1}, {1, 1});
|
||||
}
|
||||
|
||||
TYPED_TEST(TestPermute, Test3D)
|
||||
{
|
||||
constexpr ck::index_t NumDims = 3;
|
||||
this->template Run<NumDims>({8, 2, 8}, {16, 8, 1}, {1, 8, 16});
|
||||
this->template Run<NumDims>({8, 2, 8}, {1, 8, 16}, {16, 8, 1});
|
||||
this->template Run<NumDims>({1, 1, 1}, {1, 1, 1}, {1, 1, 1});
|
||||
}
|
||||
|
||||
TYPED_TEST(TestPermute, Test4D)
|
||||
{
|
||||
constexpr ck::index_t NumDims = 4;
|
||||
this->template Run<NumDims>({8, 2, 3, 8}, {48, 24, 8, 1}, {1, 8, 16, 48});
|
||||
this->template Run<NumDims>({8, 2, 3, 8}, {1, 8, 16, 48}, {48, 24, 8, 1});
|
||||
this->template Run<NumDims>({1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1});
|
||||
}
|
||||
|
||||
TYPED_TEST(TestPermute, Test5D)
|
||||
{
|
||||
constexpr ck::index_t NumDims = 5;
|
||||
this->template Run<NumDims>({8, 2, 3, 4, 8}, {192, 96, 32, 8, 1}, {1, 8, 16, 48, 192});
|
||||
this->template Run<NumDims>({8, 2, 3, 4, 8}, {1, 8, 16, 48, 192}, {192, 96, 32, 8, 1});
|
||||
this->template Run<NumDims>({1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1});
|
||||
}
|
||||
|
||||
TYPED_TEST(TestPermute, Test6D)
|
||||
{
|
||||
constexpr ck::index_t NumDims = 6;
|
||||
this->template Run<NumDims>(
|
||||
{8, 2, 3, 4, 5, 8}, {960, 480, 160, 40, 8, 1}, {1, 8, 16, 48, 192, 960});
|
||||
this->template Run<NumDims>(
|
||||
{8, 2, 3, 4, 5, 8}, {1, 8, 16, 48, 192, 960}, {960, 480, 160, 40, 8, 1});
|
||||
this->template Run<NumDims>({1, 1, 1, 1, 1, 1}, {1, 1, 1, 1, 1, 1}, {1, 1, 1, 1, 1, 1});
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -1,14 +1,21 @@
|
||||
add_gtest_executable(test_layout test_layout.cpp)
|
||||
target_link_libraries(test_layout PRIVATE utility)
|
||||
add_gtest_executable(test_tensor test_tensor.cpp)
|
||||
target_link_libraries(test_tensor PRIVATE utility)
|
||||
add_gtest_executable(test_copy test_copy.cpp)
|
||||
target_link_libraries(test_copy PRIVATE utility)
|
||||
add_gtest_executable(test_partition test_partition.cpp)
|
||||
target_link_libraries(test_partition PRIVATE utility)
|
||||
add_custom_target(test_wrapper)
|
||||
|
||||
add_gtest_executable(test_wrapper_layout test_wrapper_layout.cpp)
|
||||
target_link_libraries(test_wrapper_layout PRIVATE utility)
|
||||
add_dependencies(test_wrapper test_wrapper_layout)
|
||||
add_gtest_executable(test_wrapper_tensor test_wrapper_tensor.cpp)
|
||||
target_link_libraries(test_wrapper_tensor PRIVATE utility)
|
||||
add_dependencies(test_wrapper test_wrapper_tensor)
|
||||
add_gtest_executable(test_wrapper_copy test_wrapper_copy.cpp)
|
||||
target_link_libraries(test_wrapper_copy PRIVATE utility)
|
||||
add_dependencies(test_wrapper test_wrapper_copy)
|
||||
add_gtest_executable(test_wrapper_partition test_wrapper_partition.cpp)
|
||||
target_link_libraries(test_wrapper_partition PRIVATE utility)
|
||||
add_dependencies(test_wrapper test_wrapper_partition)
|
||||
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR
|
||||
GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR
|
||||
GPU_TARGETS MATCHES "gfx942" OR GPU_TARGETS MATCHES "gfx950")
|
||||
add_gtest_executable(test_gemm test_gemm.cpp)
|
||||
target_link_libraries(test_gemm PRIVATE utility)
|
||||
GPU_TARGETS MATCHES "gfx942")
|
||||
add_gtest_executable(test_wrapper_gemm test_wrapper_gemm.cpp)
|
||||
target_link_libraries(test_wrapper_gemm PRIVATE utility)
|
||||
add_dependencies(test_wrapper test_wrapper_gemm)
|
||||
endif()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user