mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 07:51:52 +00:00
Merge remote-tracking branch 'upstream/develop' into ck_migraphx_integration
This commit is contained in:
111
Jenkinsfile
vendored
111
Jenkinsfile
vendored
@@ -100,7 +100,15 @@ def getDockerImage(Map conf=[:]){
|
||||
dockerArgs = dockerArgs + " --no-cache "
|
||||
}
|
||||
echo "Docker Args: ${dockerArgs}"
|
||||
def image = getDockerImageName()
|
||||
def image
|
||||
if ( params.BUILD_LEGACY_OS && conf.get("docker_name", "") != "" ){
|
||||
image = conf.get("docker_name", "")
|
||||
echo "Using legacy docker: ${image}"
|
||||
}
|
||||
else{
|
||||
image = getDockerImageName()
|
||||
echo "Using default docker: ${image}"
|
||||
}
|
||||
//Check if image exists
|
||||
def retimage
|
||||
try
|
||||
@@ -125,7 +133,9 @@ def buildDocker(install_prefix){
|
||||
def image_name = getDockerImageName()
|
||||
echo "Building Docker for ${image_name}"
|
||||
def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${install_prefix} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' --build-arg DISABLE_CACHE='git rev-parse ${params.COMPILER_VERSION}' "
|
||||
|
||||
if(params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){
|
||||
dockerArgs = dockerArgs + " --no-cache "
|
||||
}
|
||||
echo "Build Args: ${dockerArgs}"
|
||||
try{
|
||||
if(params.BUILD_DOCKER){
|
||||
@@ -259,6 +269,7 @@ def cmake_build(Map conf=[:]){
|
||||
""")
|
||||
sh cmd3
|
||||
}
|
||||
|
||||
// reduce parallelism when compiling, clang uses too much memory
|
||||
def nt = nthreads()
|
||||
def cmd
|
||||
@@ -273,7 +284,7 @@ def cmake_build(Map conf=[:]){
|
||||
}
|
||||
else{
|
||||
setup_cmd = conf.get("setup_cmd", "${cmake_envs} cmake ${setup_args} .. ")
|
||||
build_cmd = conf.get("build_cmd", "${build_envs} dumb-init make -j${nt} ${config_targets}")
|
||||
build_cmd = conf.get("build_cmd", "${build_envs} make -j${nt} ${config_targets}")
|
||||
}
|
||||
cmd = conf.get("cmd", """
|
||||
${setup_cmd}
|
||||
@@ -292,8 +303,8 @@ def cmake_build(Map conf=[:]){
|
||||
dir("build"){
|
||||
//build CK
|
||||
sh cmd
|
||||
//run tests
|
||||
if(!setup_args.contains("NO_CK_BUILD")){
|
||||
//run tests except when NO_CK_BUILD or BUILD_LEGACY_OS are set
|
||||
if(!setup_args.contains("NO_CK_BUILD") && !params.BUILD_LEGACY_OS){
|
||||
if (setup_args.contains("gfx90a") && params.NINJA_BUILD_TRACE){
|
||||
sh "/ninjatracing/ninjatracing .ninja_log > ck_build_trace.json"
|
||||
archiveArtifacts "ck_build_trace.json"
|
||||
@@ -330,7 +341,15 @@ def buildHipClangJob(Map conf=[:]){
|
||||
env.HSA_ENABLE_SDMA=0
|
||||
checkout scm
|
||||
|
||||
def image = getDockerImageName()
|
||||
def image
|
||||
if ( params.BUILD_LEGACY_OS && conf.get("docker_name", "") != "" ){
|
||||
image = conf.get("docker_name", "")
|
||||
echo "Using legacy docker: ${image}"
|
||||
}
|
||||
else{
|
||||
image = getDockerImageName()
|
||||
echo "Using default docker: ${image}"
|
||||
}
|
||||
def prefixpath = conf.get("prefixpath", "/opt/rocm")
|
||||
|
||||
// Jenkins is complaining about the render group
|
||||
@@ -512,7 +531,16 @@ def Build_CK(Map conf=[:]){
|
||||
env.DOCKER_BUILDKIT=1
|
||||
checkout scm
|
||||
|
||||
def image = getDockerImageName()
|
||||
def image
|
||||
if ( params.BUILD_LEGACY_OS && conf.get("docker_name", "") != "" ){
|
||||
image = conf.get("docker_name", "")
|
||||
echo "Using legacy docker: ${image}"
|
||||
}
|
||||
else{
|
||||
image = getDockerImageName()
|
||||
echo "Using default docker: ${image}"
|
||||
}
|
||||
|
||||
def prefixpath = conf.get("prefixpath", "/opt/rocm")
|
||||
|
||||
// Jenkins is complaining about the render group
|
||||
@@ -524,6 +552,9 @@ def Build_CK(Map conf=[:]){
|
||||
if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){
|
||||
dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' "
|
||||
}
|
||||
if(params.BUILD_LEGACY_OS){
|
||||
dockerOpts = dockerOpts + " --env LD_LIBRARY_PATH='/opt/Python-3.8.13/lib' "
|
||||
}
|
||||
def video_id = sh(returnStdout: true, script: 'getent group video | cut -d: -f3')
|
||||
def render_id = sh(returnStdout: true, script: 'getent group render | cut -d: -f3')
|
||||
dockerOpts = dockerOpts + " --group-add=${video_id} --group-add=${render_id} "
|
||||
@@ -707,7 +738,8 @@ CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCM
|
||||
0 21 * * * % ROCMVERSION=6.2;hipTensor_test=true
|
||||
0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;BUILD_GFX12=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
|
||||
0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;BUILD_COMPILER=/llvm-project/build/bin/clang++;BUILD_GFX12=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
|
||||
0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_CODEGEN_TESTS=false;RUN_PERFORMANCE_TESTS=false;USE_SCCACHE=false''' : ""
|
||||
0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_CODEGEN_TESTS=false;RUN_PERFORMANCE_TESTS=false;USE_SCCACHE=false
|
||||
0 13 * * * % BUILD_LEGACY_OS=true ''' : ""
|
||||
|
||||
pipeline {
|
||||
agent none
|
||||
@@ -794,6 +826,10 @@ pipeline {
|
||||
name: "NINJA_BUILD_TRACE",
|
||||
defaultValue: false,
|
||||
description: "Generate a ninja build trace (default: OFF)")
|
||||
booleanParam(
|
||||
name: "BUILD_LEGACY_OS",
|
||||
defaultValue: false,
|
||||
description: "Try building CK with legacy OS dockers: RHEL8 and SLES15 (default: OFF)")
|
||||
}
|
||||
environment{
|
||||
dbuser = "${dbuser}"
|
||||
@@ -946,7 +982,6 @@ pipeline {
|
||||
{
|
||||
parallel
|
||||
{
|
||||
|
||||
stage("Run CK_TILE_GEMM Tests on gfx90a")
|
||||
{
|
||||
when {
|
||||
@@ -965,7 +1000,6 @@ pipeline {
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
cleanWs()
|
||||
}
|
||||
|
||||
}
|
||||
stage("Run CK_TILE_GEMM Tests on gfx942")
|
||||
{
|
||||
@@ -988,15 +1022,54 @@ pipeline {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stage("Build CK and run Tests")
|
||||
{
|
||||
parallel
|
||||
{
|
||||
stage("Build CK with RHEL8")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { params.BUILD_LEGACY_OS.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx90a") }
|
||||
environment{
|
||||
def docker_name = "${env.CK_DOCKERHUB_PRIVATE}:ck_rhel8_rocm6.3"
|
||||
setup_args = """ -DGPU_TARGETS="gfx942" \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " \
|
||||
-DCK_USE_ALTERNATIVE_PYTHON=/opt/Python-3.8.13/bin/python3.8 """
|
||||
execute_args = " "
|
||||
}
|
||||
steps{
|
||||
Build_CK_and_Reboot(setup_args: setup_args, config_targets: " ", no_reboot:true, build_type: 'Release', docker_name: docker_name)
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
stage("Build CK with SLES15")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { params.BUILD_LEGACY_OS.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx90a") }
|
||||
environment{
|
||||
def docker_name = "${env.CK_DOCKERHUB_PRIVATE}:ck_sles15_rocm6.3"
|
||||
setup_args = """ -DGPU_TARGETS="gfx942" \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " \
|
||||
-DCK_USE_ALTERNATIVE_PYTHON=/opt/Python-3.8.13/bin/python3.8 """
|
||||
execute_args = " "
|
||||
}
|
||||
steps{
|
||||
Build_CK_and_Reboot(setup_args: setup_args, config_targets: " ", no_reboot:true, build_type: 'Release', docker_name: docker_name)
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
stage("Build CK for all gfx9 targets")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { params.RUN_FULL_QA.toBoolean() }
|
||||
expression { params.RUN_FULL_QA.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx90a") }
|
||||
environment{
|
||||
@@ -1018,7 +1091,7 @@ pipeline {
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { params.RUN_FULL_QA.toBoolean() }
|
||||
expression { params.RUN_FULL_QA.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx942") }
|
||||
environment{
|
||||
@@ -1038,7 +1111,7 @@ pipeline {
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() }
|
||||
expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx90a") }
|
||||
environment{
|
||||
@@ -1058,7 +1131,7 @@ pipeline {
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { params.BUILD_INSTANCES_ONLY.toBoolean() && !params.RUN_FULL_QA.toBoolean() }
|
||||
expression { params.BUILD_INSTANCES_ONLY.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx90a") }
|
||||
environment{
|
||||
@@ -1077,7 +1150,7 @@ pipeline {
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() }
|
||||
expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx1030") }
|
||||
environment{
|
||||
@@ -1097,7 +1170,7 @@ pipeline {
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() }
|
||||
expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx1101") }
|
||||
environment{
|
||||
@@ -1117,7 +1190,7 @@ pipeline {
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { params.BUILD_GFX12.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() }
|
||||
expression { params.BUILD_GFX12.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx1201") }
|
||||
environment{
|
||||
@@ -1144,7 +1217,7 @@ pipeline {
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { params.RUN_PERFORMANCE_TESTS.toBoolean() }
|
||||
expression { params.RUN_PERFORMANCE_TESTS.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
|
||||
}
|
||||
options { retry(1) }
|
||||
agent{ label rocmnode("gfx90a")}
|
||||
@@ -1165,7 +1238,7 @@ pipeline {
|
||||
stage("Process results"){
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { params.RUN_PERFORMANCE_TESTS.toBoolean() }
|
||||
expression { params.RUN_PERFORMANCE_TESTS.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
|
||||
}
|
||||
agent { label 'mici' }
|
||||
steps{
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
rocm-docs-core==1.8.0
|
||||
rocm-docs-core==1.8.1
|
||||
sphinxcontrib-bibtex==2.6.3
|
||||
|
||||
@@ -103,7 +103,7 @@ requests==2.32.3
|
||||
# via
|
||||
# pygithub
|
||||
# sphinx
|
||||
rocm-docs-core==1.8.0
|
||||
rocm-docs-core==1.8.1
|
||||
# via -r requirements.in
|
||||
six==1.16.0
|
||||
# via pybtex
|
||||
|
||||
@@ -179,9 +179,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_buf,
|
||||
|
||||
std::cout << "The overall perfomance of the GEMM with "
|
||||
<< "[" << data_type << "]"
|
||||
<< "batch size: " << batch_size << ". m:" << M << ",n:" << N << ", k:" << K
|
||||
<< "is: \n";
|
||||
std::cout << "Running time :" << ave_time << "ms, Throughput" << gb_per_sec << "GB/s \n"
|
||||
<< "batch size: " << batch_size << ". m:" << M << ", n:" << N << ", k:" << K
|
||||
<< " is: \n";
|
||||
std::cout << "Running time: " << ave_time << "ms, Throughput " << gb_per_sec << "GB/s \n"
|
||||
<< std::flush;
|
||||
|
||||
return ave_time;
|
||||
@@ -235,7 +235,7 @@ int main(int argc, char* argv[])
|
||||
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
|
||||
constexpr bool kPadA = true;
|
||||
constexpr bool kPadB = true;
|
||||
constexpr bool kPadC = false;
|
||||
constexpr bool kPadC = true;
|
||||
|
||||
// This part comes from the Codegen
|
||||
constexpr ck_tile::index_t M_Tile = 128;
|
||||
@@ -348,7 +348,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
pass_gpu = ck_tile::check_err(c_host_dev, c_host_gpu_ref);
|
||||
|
||||
std::cout << "The GPU veification result is:" << (pass_gpu ? "correct" : "fail")
|
||||
std::cout << "The GPU veification result is: " << (pass_gpu ? "correct" : "fail")
|
||||
<< std::flush;
|
||||
}
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
@@ -22,7 +23,6 @@
|
||||
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
@@ -257,6 +257,19 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
KPerBlock / K1Number,
|
||||
ConvBackwardWeightSpecialization>{};
|
||||
|
||||
static constexpr index_t ClusterLengthMPerBlock =
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
|
||||
static constexpr index_t ClusterLengthNPerBlock =
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
|
||||
|
||||
static constexpr auto conv_ngchw_to_nhwgc_transformer =
|
||||
TransformConvNGCHWToNHWGC<InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
NDimSpatial,
|
||||
MPerBlock / ClusterLengthMPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock>{};
|
||||
|
||||
static constexpr GemmSpecialization GemmSpec = GemmSpecialization::Default;
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
|
||||
@@ -359,141 +372,12 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
batch)[I2];
|
||||
}
|
||||
|
||||
static constexpr index_t ClusterLengthMPerBlock =
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
|
||||
static constexpr index_t ClusterLengthNPerBlock =
|
||||
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
|
||||
static auto MakeInputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
|
||||
{
|
||||
const index_t& G = g_n_c_wis_lengths[0];
|
||||
const index_t& N = g_n_c_wis_lengths[1];
|
||||
const index_t& C = g_n_c_wis_lengths[2];
|
||||
const index_t& Hi = g_n_c_wis_lengths[3];
|
||||
const index_t& Wi = g_n_c_wis_lengths[4];
|
||||
|
||||
const index_t& GStride = g_n_c_wis_strides[0];
|
||||
const index_t& NStride = g_n_c_wis_strides[1];
|
||||
const index_t& CStride = g_n_c_wis_strides[2];
|
||||
const index_t& HiStride = g_n_c_wis_strides[3];
|
||||
const index_t& WiStride = g_n_c_wis_strides[4];
|
||||
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, G, C, Hi, Wi), make_tuple(NStride, GStride, CStride, HiStride, WiStride));
|
||||
const auto merged_desc =
|
||||
transform_tensor_descriptor(desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, G, C)),
|
||||
make_merge_transform(make_tuple(Hi, Wi))),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
return PadTensorDescriptor(
|
||||
merged_desc,
|
||||
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
|
||||
Sequence<true, true>{});
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
|
||||
static auto MakeOutputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
|
||||
{
|
||||
const index_t& G = g_n_c_wis_lengths[0];
|
||||
const index_t& N = g_n_c_wis_lengths[1];
|
||||
const index_t& C = g_n_c_wis_lengths[2];
|
||||
const index_t& Hi = g_n_c_wis_lengths[3];
|
||||
const index_t& Wi = g_n_c_wis_lengths[4];
|
||||
|
||||
const index_t& NStride = g_n_c_wis_strides[1];
|
||||
const index_t HiStride = Wi * G * C;
|
||||
const index_t WiStride = G * C;
|
||||
const index_t GStride = C;
|
||||
const index_t CStride = 1;
|
||||
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, G, C, Hi, Wi), make_tuple(NStride, GStride, CStride, HiStride, WiStride));
|
||||
const auto merged_desc =
|
||||
transform_tensor_descriptor(desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, G, C)),
|
||||
make_merge_transform(make_tuple(Hi, Wi))),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
return PadTensorDescriptor(
|
||||
merged_desc,
|
||||
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
|
||||
Sequence<true, true>{});
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
|
||||
static auto MakeInputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
|
||||
{
|
||||
const index_t& G = g_n_c_wis_lengths[0];
|
||||
const index_t& N = g_n_c_wis_lengths[1];
|
||||
const index_t& C = g_n_c_wis_lengths[2];
|
||||
const index_t& Di = g_n_c_wis_lengths[3];
|
||||
const index_t& Hi = g_n_c_wis_lengths[4];
|
||||
const index_t& Wi = g_n_c_wis_lengths[5];
|
||||
|
||||
const index_t& GStride = g_n_c_wis_strides[0];
|
||||
const index_t& NStride = g_n_c_wis_strides[1];
|
||||
const index_t& CStride = g_n_c_wis_strides[2];
|
||||
const index_t& DiStride = g_n_c_wis_strides[3];
|
||||
const index_t& HiStride = g_n_c_wis_strides[4];
|
||||
const index_t& WiStride = g_n_c_wis_strides[5];
|
||||
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, G, C, Di, Hi, Wi),
|
||||
make_tuple(NStride, GStride, CStride, DiStride, HiStride, WiStride));
|
||||
const auto merged_desc =
|
||||
transform_tensor_descriptor(desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, G, C)),
|
||||
make_merge_transform(make_tuple(Di, Hi, Wi))),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
return PadTensorDescriptor(
|
||||
merged_desc,
|
||||
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
|
||||
Sequence<true, true>{});
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
|
||||
static auto MakeOutputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
|
||||
{
|
||||
const index_t& G = g_n_c_wis_lengths[0];
|
||||
const index_t& N = g_n_c_wis_lengths[1];
|
||||
const index_t& C = g_n_c_wis_lengths[2];
|
||||
const index_t& Di = g_n_c_wis_lengths[3];
|
||||
const index_t& Hi = g_n_c_wis_lengths[4];
|
||||
const index_t& Wi = g_n_c_wis_lengths[5];
|
||||
|
||||
const index_t& NStride = g_n_c_wis_strides[1];
|
||||
const index_t DiStride = Hi * Wi * G * C;
|
||||
const index_t HiStride = Wi * G * C;
|
||||
const index_t WiStride = G * C;
|
||||
const index_t GStride = C;
|
||||
const index_t CStride = 1;
|
||||
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, G, C, Di, Hi, Wi),
|
||||
make_tuple(NStride, GStride, CStride, DiStride, HiStride, WiStride));
|
||||
const auto merged_desc =
|
||||
transform_tensor_descriptor(desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, G, C)),
|
||||
make_merge_transform(make_tuple(Di, Hi, Wi))),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
return PadTensorDescriptor(
|
||||
merged_desc,
|
||||
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
|
||||
Sequence<true, true>{});
|
||||
}
|
||||
|
||||
using InputTransposeDescType =
|
||||
remove_cvref_t<decltype(MakeInputTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
using OutputTransposeDescType =
|
||||
remove_cvref_t<decltype(MakeOutputTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
using NGCHWTransposeDescType =
|
||||
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
|
||||
.template MakeNGCHWTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
using NHWGCTransposeDescType =
|
||||
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
|
||||
.template MakeNHWGCTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
|
||||
using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>());
|
||||
|
||||
@@ -572,8 +456,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
I1>;
|
||||
|
||||
using GridwiseElementwiseTranspose =
|
||||
GridwiseElementwise<Tuple<InputTransposeDescType>,
|
||||
Tuple<OutputTransposeDescType>,
|
||||
GridwiseElementwise<Tuple<NGCHWTransposeDescType>,
|
||||
Tuple<NHWGCTransposeDescType>,
|
||||
Tuple<const ADataType*>,
|
||||
Tuple<ADataType*>,
|
||||
Block2TileMapElementwise,
|
||||
@@ -652,43 +536,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
begin(output_spatial_lengths_));
|
||||
|
||||
std::array<index_t, NDimSpatial + 3> b_g_n_c_wis_strides_transposed =
|
||||
b_g_n_c_wis_strides;
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeStrides(b_g_n_c_wis_lengths,
|
||||
b_g_n_c_wis_strides);
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
|
||||
a_g_n_k_wos_strides;
|
||||
|
||||
// NGKHW - transpose needed
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
b_g_n_c_wis_strides_transposed[I0] = Conv_C_;
|
||||
b_g_n_c_wis_strides_transposed[I2] = I1;
|
||||
a_g_n_k_wos_strides_transposed[I0] = Conv_K_;
|
||||
a_g_n_k_wos_strides_transposed[I2] = I1;
|
||||
|
||||
if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
b_g_n_c_wis_strides_transposed[I3] =
|
||||
input_spatial_lengths_[I1] * Conv_G_ * Conv_C_;
|
||||
b_g_n_c_wis_strides_transposed[I4] = Conv_G_ * Conv_C_;
|
||||
a_g_n_k_wos_strides_transposed[I3] =
|
||||
output_spatial_lengths_[I1] * Conv_G_ * Conv_K_;
|
||||
a_g_n_k_wos_strides_transposed[I4] = Conv_G_ * Conv_K_;
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
b_g_n_c_wis_strides_transposed[I3] =
|
||||
input_spatial_lengths_[I1] * input_spatial_lengths_[I2] * Conv_G_ * Conv_C_;
|
||||
b_g_n_c_wis_strides_transposed[I4] =
|
||||
input_spatial_lengths_[I2] * Conv_G_ * Conv_C_;
|
||||
b_g_n_c_wis_strides_transposed[I5] = Conv_G_ * Conv_C_;
|
||||
a_g_n_k_wos_strides_transposed[I3] = output_spatial_lengths_[I1] *
|
||||
input_spatial_lengths_[I2] * Conv_G_ *
|
||||
Conv_K_;
|
||||
a_g_n_k_wos_strides_transposed[I4] =
|
||||
input_spatial_lengths_[I2] * Conv_G_ * Conv_K_;
|
||||
a_g_n_k_wos_strides_transposed[I5] = Conv_G_ * Conv_K_;
|
||||
}
|
||||
}
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeStrides(a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides);
|
||||
|
||||
const auto descs =
|
||||
conv_to_gemm_transformer_v2
|
||||
@@ -755,14 +607,18 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
a_in_transpose_desc_ =
|
||||
MakeInputTransposeDesc<NDimSpatial>(a_g_n_k_wos_lengths, a_g_n_k_wos_strides);
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
|
||||
a_g_n_k_wos_lengths, a_g_n_k_wos_strides);
|
||||
a_out_transpose_desc_ =
|
||||
MakeOutputTransposeDesc<NDimSpatial>(a_g_n_k_wos_lengths, a_g_n_k_wos_strides);
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
|
||||
a_g_n_k_wos_lengths, a_g_n_k_wos_strides);
|
||||
|
||||
b_in_transpose_desc_ =
|
||||
MakeInputTransposeDesc<NDimSpatial>(b_g_n_c_wis_lengths, b_g_n_c_wis_strides);
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
|
||||
b_g_n_c_wis_lengths, b_g_n_c_wis_strides);
|
||||
b_out_transpose_desc_ =
|
||||
MakeOutputTransposeDesc<NDimSpatial>(b_g_n_c_wis_lengths, b_g_n_c_wis_strides);
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
|
||||
b_g_n_c_wis_lengths, b_g_n_c_wis_strides);
|
||||
|
||||
elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{
|
||||
a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)};
|
||||
@@ -816,8 +672,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_,
|
||||
elementwise_block_2_ctile_map_transpose_b_;
|
||||
|
||||
InputTransposeDescType a_in_transpose_desc_, b_in_transpose_desc_;
|
||||
OutputTransposeDescType a_out_transpose_desc_, b_out_transpose_desc_;
|
||||
NGCHWTransposeDescType a_in_transpose_desc_, b_in_transpose_desc_;
|
||||
NHWGCTransposeDescType a_out_transpose_desc_, b_out_transpose_desc_;
|
||||
|
||||
// for computing batch offset
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_batch_;
|
||||
@@ -1569,13 +1425,14 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
(arg.GetWorkspaceETensorSizeBytes() + arg.GetWorkspaceATensorSizeBytes()) /
|
||||
sizeof(BDataType);
|
||||
|
||||
// Different data type for A and B is not supported
|
||||
auto kernel_transpose = kernel_elementwise_dual<GridwiseElementwiseTranspose,
|
||||
ck::Tuple<InputTransposeDescType>,
|
||||
ck::Tuple<InputTransposeDescType>,
|
||||
ck::Tuple<OutputTransposeDescType>,
|
||||
ck::Tuple<OutputTransposeDescType>,
|
||||
ck::Tuple<NGCHWTransposeDescType>,
|
||||
ck::Tuple<NGCHWTransposeDescType>,
|
||||
ck::Tuple<NHWGCTransposeDescType>,
|
||||
ck::Tuple<NHWGCTransposeDescType>,
|
||||
ck::Tuple<const ADataType*>,
|
||||
ck::Tuple<BDataType*>,
|
||||
ck::Tuple<ADataType*>,
|
||||
Block2TileMapElementwise,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough>;
|
||||
|
||||
@@ -15,9 +15,11 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
@@ -307,6 +309,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
|
||||
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
|
||||
|
||||
// NGCHW is not supported for multiAB
|
||||
static_assert(!(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>()) ||
|
||||
!(isMultiA || isMultiB));
|
||||
|
||||
static constexpr index_t NumATensor = GetNumABTensors<isMultiA, ADataType>();
|
||||
static constexpr index_t NumBTensor = GetNumABTensors<isMultiB, BDataType>();
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
@@ -315,6 +322,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
|
||||
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
|
||||
ConvForwardSpecialization,
|
||||
@@ -323,14 +332,33 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
EDataType,
|
||||
NumGroupsToMerge>;
|
||||
|
||||
static constexpr index_t ClusterLengthNPerBlock =
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
|
||||
|
||||
static constexpr auto conv_ngchw_to_nhwgc_transformer =
|
||||
TransformConvNGCHWToNHWGC<ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
NDimSpatial,
|
||||
NPerBlock / ClusterLengthNPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock>{};
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
template <typename ALay>
|
||||
static auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
using Layout = std::conditional_t<
|
||||
is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>(),
|
||||
ctc::NHWGC,
|
||||
std::conditional_t<is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>(),
|
||||
ctc::NDHWGC,
|
||||
ALay>>;
|
||||
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<Layout>();
|
||||
|
||||
const auto in_gemmm_gemmk_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
|
||||
@@ -353,8 +381,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
template <typename ELay>
|
||||
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
{
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
using Layout = std::conditional_t<
|
||||
is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>(),
|
||||
ctc::NHWGK,
|
||||
std::conditional_t<is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>(),
|
||||
ctc::NDHWGK,
|
||||
ELay>>;
|
||||
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<Layout>();
|
||||
|
||||
const auto out_gemmm_gemmn_desc =
|
||||
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
|
||||
@@ -442,6 +478,52 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
// block-to-e-tile map
|
||||
using Block2ETileMap =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
|
||||
using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<NPerBlock, NPerBlock>;
|
||||
|
||||
using NGCHWTransposeDescType =
|
||||
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
|
||||
.template MakeNGCHWTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
using NHWGCTransposeDescType =
|
||||
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
|
||||
.template MakeNHWGCTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
|
||||
static constexpr index_t ElementwiseBlocksize = ClusterLengthNPerBlock * ClusterLengthNPerBlock;
|
||||
|
||||
using GridwiseElementwiseInputTranspose =
|
||||
GridwiseElementwise<Tuple<NGCHWTransposeDescType>,
|
||||
Tuple<NHWGCTransposeDescType>,
|
||||
Tuple<const ADataType*>,
|
||||
Tuple<ADataType*>,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough,
|
||||
ElementwiseBlocksize,
|
||||
NPerBlock,
|
||||
NPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock,
|
||||
Sequence<1, 0>,
|
||||
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
|
||||
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
|
||||
I1,
|
||||
I0>;
|
||||
|
||||
using GridwiseElementwiseOutputTranspose =
|
||||
GridwiseElementwise<Tuple<NHWGCTransposeDescType>,
|
||||
Tuple<NGCHWTransposeDescType>,
|
||||
Tuple<const EDataType*>,
|
||||
Tuple<EDataType*>,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough,
|
||||
ElementwiseBlocksize,
|
||||
NPerBlock,
|
||||
NPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock,
|
||||
Sequence<1, 0>,
|
||||
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
|
||||
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
|
||||
I0,
|
||||
I1>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
@@ -471,17 +553,31 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
p_bs_grid_{},
|
||||
p_ds_grid_{},
|
||||
p_e_grid_{static_cast<EDataType*>(p_e)},
|
||||
num_group_{a_g_n_c_wis_lengths[0]},
|
||||
conv_to_gemm_transformer_{a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads},
|
||||
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
|
||||
a_g_n_c_wis_strides_{conv_ngchw_to_nhwgc_transformer.TransposeStrides(
|
||||
a_g_n_c_wis_lengths, a_g_n_c_wis_strides)},
|
||||
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
|
||||
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
|
||||
ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths},
|
||||
ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides},
|
||||
e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
|
||||
e_g_n_k_wos_strides_{conv_ngchw_to_nhwgc_transformer.TransposeStrides(
|
||||
e_g_n_k_wos_lengths, e_g_n_k_wos_strides)},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
conv_filter_dilations_{conv_filter_dilations},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads},
|
||||
num_group_{a_g_n_c_wis_lengths_[0]},
|
||||
conv_to_gemm_transformer_{a_g_n_c_wis_lengths_,
|
||||
a_g_n_c_wis_strides_,
|
||||
b_g_k_c_xs_lengths_,
|
||||
b_g_k_c_xs_strides_,
|
||||
e_g_n_k_wos_lengths_,
|
||||
e_g_n_k_wos_strides_,
|
||||
conv_filter_strides_,
|
||||
conv_filter_dilations_,
|
||||
input_left_pads_,
|
||||
input_right_pads_},
|
||||
conv_N_per_block_{conv_to_gemm_transformer_.N_},
|
||||
a_grid_desc_m_k_{
|
||||
DeviceOp::MakeAGridDescriptor_M_K<ALayout>(conv_to_gemm_transformer_)},
|
||||
@@ -501,19 +597,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
compute_ptr_offset_of_n_{},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op},
|
||||
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
|
||||
a_g_n_c_wis_strides_{a_g_n_c_wis_strides},
|
||||
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
|
||||
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
|
||||
ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths},
|
||||
ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides},
|
||||
e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
|
||||
e_g_n_k_wos_strides_{e_g_n_k_wos_strides},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
conv_filter_dilations_{conv_filter_dilations},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads}
|
||||
cde_element_op_{cde_element_op}
|
||||
{
|
||||
// A/B/E Batch Stride
|
||||
if constexpr(isMultiA || isMultiB)
|
||||
@@ -521,7 +605,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
// Init compute_ptr_offset_of_groups_ for multiple AB
|
||||
compute_ptr_offset_of_groups_.BatchStrideA_(i) =
|
||||
a_g_n_c_wis_strides[0] * NumGroupsToMerge;
|
||||
a_g_n_c_wis_strides_[0] * NumGroupsToMerge;
|
||||
|
||||
// Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data
|
||||
// type is not tuple)
|
||||
@@ -537,20 +621,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
// in case of MultiA is false but isMultiB is true
|
||||
// BatchStrideA_ is not tuple.
|
||||
compute_ptr_offset_of_n_.BatchStrideA_(i) =
|
||||
a_g_n_c_wis_strides[1] * conv_N_per_block_;
|
||||
a_g_n_c_wis_strides_[1] * conv_N_per_block_;
|
||||
}
|
||||
else
|
||||
{
|
||||
// if MultiB and not MultiA then p_as is single pointer
|
||||
p_as_grid_(i) = static_cast<const DataType*>(p_as);
|
||||
compute_ptr_offset_of_n_.BatchStrideA_ =
|
||||
a_g_n_c_wis_strides[1] * conv_N_per_block_;
|
||||
a_g_n_c_wis_strides_[1] * conv_N_per_block_;
|
||||
}
|
||||
});
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
// Init compute_ptr_offset_of_groups_ for multiple AB
|
||||
compute_ptr_offset_of_groups_.BatchStrideB_(i) =
|
||||
b_g_k_c_xs_strides[0] * NumGroupsToMerge;
|
||||
b_g_k_c_xs_strides_[0] * NumGroupsToMerge;
|
||||
|
||||
using DataType = remove_cvref_t<tuple_element_t<i.value, GemmBDataType>>;
|
||||
// It is possible that one of the AB is a pointer and one is a tuple.
|
||||
@@ -571,10 +655,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
else
|
||||
{
|
||||
compute_ptr_offset_of_groups_.BatchStrideA_ =
|
||||
a_g_n_c_wis_strides[0] * NumGroupsToMerge;
|
||||
a_g_n_c_wis_strides_[0] * NumGroupsToMerge;
|
||||
compute_ptr_offset_of_groups_.BatchStrideB_ =
|
||||
b_g_k_c_xs_strides[0] * NumGroupsToMerge;
|
||||
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_;
|
||||
b_g_k_c_xs_strides_[0] * NumGroupsToMerge;
|
||||
compute_ptr_offset_of_n_.BatchStrideA_ =
|
||||
a_g_n_c_wis_strides_[1] * conv_N_per_block_;
|
||||
|
||||
// p_as and p_bs are pointers
|
||||
p_as_grid_(I0) = static_cast<const ADataType*>(p_as);
|
||||
@@ -591,27 +676,28 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
// D batch stride
|
||||
compute_ptr_offset_of_groups_.BatchStrideDs_(i) =
|
||||
ds_g_n_k_wos_strides[i][0] * NumGroupsToMerge;
|
||||
ds_g_n_k_wos_strides_[i][0] * NumGroupsToMerge;
|
||||
compute_ptr_offset_of_n_.BatchStrideDs_(i) =
|
||||
ds_g_n_k_wos_strides[i][1] * conv_N_per_block_;
|
||||
ds_g_n_k_wos_strides_[i][1] * conv_N_per_block_;
|
||||
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_k_wos_lengths,
|
||||
ds_g_n_k_wos_strides[i],
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths_,
|
||||
a_g_n_c_wis_strides_,
|
||||
b_g_k_c_xs_lengths_,
|
||||
b_g_k_c_xs_strides_,
|
||||
e_g_n_k_wos_lengths_,
|
||||
ds_g_n_k_wos_strides_[i],
|
||||
conv_filter_strides_,
|
||||
conv_filter_dilations_,
|
||||
input_left_pads_,
|
||||
input_right_pads_};
|
||||
|
||||
// D desc
|
||||
ds_grid_desc_m_n_(i) =
|
||||
DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer_d);
|
||||
});
|
||||
compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0] * NumGroupsToMerge;
|
||||
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_;
|
||||
compute_ptr_offset_of_groups_.BatchStrideE_ =
|
||||
e_g_n_k_wos_strides_[0] * NumGroupsToMerge;
|
||||
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides_[1] * conv_N_per_block_;
|
||||
|
||||
// populate desc for Ds/E
|
||||
if constexpr(isMultiA || isMultiB)
|
||||
@@ -653,6 +739,54 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
ds_grid_desc_m_n_);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
// Use not modified base strides
|
||||
a_in_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
|
||||
a_g_n_c_wis_lengths, a_g_n_c_wis_strides);
|
||||
a_out_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
|
||||
a_g_n_c_wis_lengths, a_g_n_c_wis_strides);
|
||||
|
||||
e_in_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
|
||||
e_g_n_k_wos_lengths, e_g_n_k_wos_strides);
|
||||
e_out_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
|
||||
e_g_n_k_wos_lengths, e_g_n_k_wos_strides);
|
||||
|
||||
elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{
|
||||
a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)};
|
||||
elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapElementwise{
|
||||
e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)};
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceATensorSizeBytes() const
|
||||
{
|
||||
return sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize();
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceETensorSizeBytes() const
|
||||
{
|
||||
return sizeof(EDataType) * e_out_transpose_desc_.GetElementSpaceSize();
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceSizeBytes() const
|
||||
{
|
||||
// Transpose require workspace for A and B
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
return GetWorkspaceATensorSizeBytes() + GetWorkspaceETensorSizeBytes();
|
||||
}
|
||||
else
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
void Print() const
|
||||
@@ -671,6 +805,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
typename GridwiseGemm::DsGridPointer p_ds_grid_;
|
||||
EDataType* p_e_grid_;
|
||||
|
||||
// for checking IsSupportedArgument()
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_;
|
||||
|
||||
// tensor descriptors for problem definiton
|
||||
index_t num_group_;
|
||||
|
||||
@@ -692,6 +840,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
// block-to-e-tile map
|
||||
Block2ETileMap block_2_etile_map_;
|
||||
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_,
|
||||
elementwise_block_2_ctile_map_transpose_e_;
|
||||
|
||||
NGCHWTransposeDescType a_in_transpose_desc_, e_out_transpose_desc_;
|
||||
NHWGCTransposeDescType a_out_transpose_desc_, e_in_transpose_desc_;
|
||||
|
||||
// for computing batch offset
|
||||
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>
|
||||
@@ -702,20 +855,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
|
||||
// for checking IsSupportedArgument()
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -723,7 +862,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
float RunGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
@@ -794,6 +933,17 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
else
|
||||
{
|
||||
const ADataType* p_a_grid = arg.p_as_grid_.At(I0);
|
||||
EDataType* p_e_grid = arg.p_e_grid_;
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
|
||||
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
|
||||
}
|
||||
|
||||
const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
|
||||
GridwiseGemm,
|
||||
const ADataType*,
|
||||
@@ -820,10 +970,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_as_grid_.At(I0), // Pass just A descriptor instead of tuple
|
||||
p_a_grid, // Pass just A descriptor instead of tuple
|
||||
arg.p_bs_grid_.At(I0), // Pass just B descriptor instead of tuple
|
||||
arg.p_ds_grid_,
|
||||
arg.p_e_grid_,
|
||||
p_e_grid,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
@@ -847,6 +997,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
}
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
float avg_time = 0.f;
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
const index_t grid_size =
|
||||
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
|
||||
arg.a_in_transpose_desc_);
|
||||
|
||||
ADataType* p_a_out_grid = type_convert<ADataType*>(arg.p_workspace_);
|
||||
|
||||
auto kernel_transpose = kernel_elementwise<GridwiseElementwiseInputTranspose,
|
||||
ck::Tuple<NGCHWTransposeDescType>,
|
||||
ck::Tuple<NHWGCTransposeDescType>,
|
||||
ck::Tuple<const ADataType*>,
|
||||
ck::Tuple<ADataType*>,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough>;
|
||||
|
||||
avg_time += launch_and_time_kernel(stream_config,
|
||||
kernel_transpose,
|
||||
dim3(grid_size),
|
||||
dim3(ElementwiseBlocksize),
|
||||
0,
|
||||
make_tuple(arg.a_in_transpose_desc_),
|
||||
make_tuple(arg.a_out_transpose_desc_),
|
||||
make_tuple(arg.p_as_grid_.At(I0)),
|
||||
make_tuple(p_a_out_grid),
|
||||
arg.elementwise_block_2_ctile_map_transpose_a_,
|
||||
element_wise::PassThrough{});
|
||||
}
|
||||
|
||||
avg_time += RunGemm(arg, stream_config);
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
const index_t grid_size =
|
||||
arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
|
||||
arg.e_in_transpose_desc_);
|
||||
|
||||
const EDataType* p_e_out_grid =
|
||||
type_convert<EDataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
|
||||
|
||||
EDataType* p_e_in_grid = arg.p_e_grid_;
|
||||
|
||||
auto kernel_transpose = kernel_elementwise<GridwiseElementwiseOutputTranspose,
|
||||
ck::Tuple<NHWGCTransposeDescType>,
|
||||
ck::Tuple<NGCHWTransposeDescType>,
|
||||
ck::Tuple<const EDataType*>,
|
||||
ck::Tuple<EDataType*>,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough>;
|
||||
|
||||
avg_time += launch_and_time_kernel(stream_config,
|
||||
kernel_transpose,
|
||||
dim3(grid_size),
|
||||
dim3(ElementwiseBlocksize),
|
||||
0,
|
||||
make_tuple(arg.e_in_transpose_desc_),
|
||||
make_tuple(arg.e_out_transpose_desc_),
|
||||
make_tuple(p_e_out_grid),
|
||||
make_tuple(p_e_in_grid),
|
||||
arg.elementwise_block_2_ctile_map_transpose_e_,
|
||||
element_wise::PassThrough{});
|
||||
}
|
||||
|
||||
return avg_time;
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
@@ -941,7 +1164,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if constexpr(!is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>())
|
||||
if constexpr(!(is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCSpatial_GKSpatial_NGKSpatial<ALayout, BLayout, ELayout>()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -953,14 +1177,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
is_same_v<ALayout, ctc::G_NDHW_C> || is_same_v<ALayout, ctc::GNWC> ||
|
||||
is_same_v<ALayout, ctc::GNHWC> || is_same_v<ALayout, ctc::GNDHWC> ||
|
||||
is_same_v<ALayout, ctc::NWGC> || is_same_v<ALayout, ctc::NHWGC> ||
|
||||
is_same_v<ALayout, ctc::NDHWGC>)
|
||||
is_same_v<ALayout, ctc::NDHWGC> || is_same_v<ALayout, ctc::NGCW> ||
|
||||
is_same_v<ALayout, ctc::NGCHW> || is_same_v<ALayout, ctc::NGCDHW>)
|
||||
{
|
||||
// Check access per C
|
||||
if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
// If not possible, check access per G
|
||||
if(!(ABlockTransferSrcVectorDim == 1 && C == 1 &&
|
||||
is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>() &&
|
||||
if(!(ABlockTransferSrcVectorDim == 1 && (C == 1 || NumGroupsToMerge == 1) &&
|
||||
(is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCSpatial_GKSpatial_NGKSpatial<ALayout, BLayout, ELayout>()) &&
|
||||
G % ABlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
@@ -1036,6 +1262,35 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
}
|
||||
});
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
if((G * C) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if((G * K) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
const index_t input_spatial_acum = ck::accumulate_n<index_t>(
|
||||
arg.a_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
|
||||
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
|
||||
arg.e_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
|
||||
|
||||
if(input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(output_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if(!valid)
|
||||
{
|
||||
return false;
|
||||
@@ -1046,7 +1301,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
is_same_v<ELayout, ctc::G_NDHW_K> || is_same_v<ELayout, ctc::GNWK> ||
|
||||
is_same_v<ELayout, ctc::GNHWK> || is_same_v<ELayout, ctc::GNDHWK> ||
|
||||
is_same_v<ELayout, ctc::NWGK> || is_same_v<ELayout, ctc::NHWGK> ||
|
||||
is_same_v<ELayout, ctc::NDHWGK>)
|
||||
is_same_v<ELayout, ctc::NDHWGK> || is_same_v<ELayout, ctc::NGKW> ||
|
||||
is_same_v<ELayout, ctc::NGKHW> || is_same_v<ELayout, ctc::NGKDHW>)
|
||||
{
|
||||
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
{
|
||||
@@ -1352,6 +1608,34 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
return str.str();
|
||||
}
|
||||
|
||||
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
|
||||
{
|
||||
auto arg = dynamic_cast<const Argument*>(p_arg);
|
||||
if(arg)
|
||||
{
|
||||
return arg->GetWorkspaceSizeBytes();
|
||||
}
|
||||
else
|
||||
throw std::runtime_error(
|
||||
"The argument pointer is not an object of "
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!");
|
||||
}
|
||||
|
||||
void SetWorkSpacePointer(BaseArgument* p_arg,
|
||||
void* p_workspace,
|
||||
const StreamConfig& = StreamConfig{}) const override
|
||||
{
|
||||
auto p_arg_ = dynamic_cast<Argument*>(p_arg);
|
||||
if(p_arg_)
|
||||
{
|
||||
p_arg_->p_workspace_ = p_workspace;
|
||||
}
|
||||
else
|
||||
throw std::runtime_error(
|
||||
"The argument pointer is not an object of "
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -15,10 +15,12 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
@@ -292,6 +294,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
|
||||
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
|
||||
ConvForwardSpecialization,
|
||||
@@ -302,13 +306,32 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
static constexpr index_t ClusterLengthNPerBlock =
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
|
||||
|
||||
static constexpr auto conv_ngchw_to_nhwgc_transformer =
|
||||
TransformConvNGCHWToNHWGC<ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
NDimSpatial,
|
||||
MPerBlock / ClusterLengthNPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock>{};
|
||||
|
||||
template <typename ALay>
|
||||
static auto
|
||||
MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
|
||||
{
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
using Layout = std::conditional_t<
|
||||
is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>(),
|
||||
ctc::NHWGC,
|
||||
std::conditional_t<is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>(),
|
||||
ctc::NDHWGC,
|
||||
ALay>>;
|
||||
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<Layout>();
|
||||
|
||||
const auto in_gemmm_gemmk_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
|
||||
@@ -351,8 +374,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
|
||||
|
||||
{
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
using Layout = std::conditional_t<
|
||||
is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>(),
|
||||
ctc::NHWGK,
|
||||
std::conditional_t<is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>(),
|
||||
ctc::NDHWGK,
|
||||
ELay>>;
|
||||
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<Layout>();
|
||||
|
||||
const auto out_gemmm_gemmn_desc =
|
||||
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
|
||||
@@ -385,6 +416,53 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
// Use appropriate gridwise gemm
|
||||
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3<GridwiseGemmV3TemplateParams>;
|
||||
|
||||
using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<NPerBlock, NPerBlock>;
|
||||
|
||||
using NGCHWTransposeDescType =
|
||||
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
|
||||
.template MakeNGCHWTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
using NHWGCTransposeDescType =
|
||||
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
|
||||
.template MakeNHWGCTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
|
||||
static constexpr index_t ElementwiseBlocksize = ClusterLengthNPerBlock * ClusterLengthNPerBlock;
|
||||
|
||||
using GridwiseElementwiseInputTranspose =
|
||||
GridwiseElementwise<Tuple<NGCHWTransposeDescType>,
|
||||
Tuple<NHWGCTransposeDescType>,
|
||||
Tuple<const ADataType*>,
|
||||
Tuple<ADataType*>,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough,
|
||||
ElementwiseBlocksize,
|
||||
NPerBlock,
|
||||
NPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock,
|
||||
Sequence<1, 0>,
|
||||
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
|
||||
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
|
||||
I1,
|
||||
I0>;
|
||||
|
||||
using GridwiseElementwiseOutputTranspose =
|
||||
GridwiseElementwise<Tuple<NHWGCTransposeDescType>,
|
||||
Tuple<NGCHWTransposeDescType>,
|
||||
Tuple<const EDataType*>,
|
||||
Tuple<EDataType*>,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough,
|
||||
ElementwiseBlocksize,
|
||||
NPerBlock,
|
||||
NPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock,
|
||||
Sequence<1, 0>,
|
||||
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
|
||||
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
|
||||
I0,
|
||||
I1>;
|
||||
|
||||
static auto
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n)
|
||||
{
|
||||
@@ -428,17 +506,29 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
: p_a_grid_{},
|
||||
p_b_grid_{},
|
||||
p_e_grid_{static_cast<EDataType*>(p_e)},
|
||||
num_group_{a_g_n_c_wis_lengths[0]},
|
||||
conv_to_gemm_transformer_{a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads},
|
||||
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
|
||||
a_g_n_c_wis_strides_{conv_ngchw_to_nhwgc_transformer.TransposeStrides(
|
||||
a_g_n_c_wis_lengths, a_g_n_c_wis_strides)},
|
||||
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
|
||||
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
|
||||
e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
|
||||
e_g_n_k_wos_strides_{conv_ngchw_to_nhwgc_transformer.TransposeStrides(
|
||||
e_g_n_k_wos_lengths, e_g_n_k_wos_strides)},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
conv_filter_dilations_{conv_filter_dilations},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads},
|
||||
num_group_{a_g_n_c_wis_lengths_[0]},
|
||||
conv_to_gemm_transformer_{a_g_n_c_wis_lengths_,
|
||||
a_g_n_c_wis_strides_,
|
||||
b_g_k_c_xs_lengths_,
|
||||
b_g_k_c_xs_strides_,
|
||||
e_g_n_k_wos_lengths_,
|
||||
e_g_n_k_wos_strides_,
|
||||
conv_filter_strides_,
|
||||
conv_filter_dilations_,
|
||||
input_left_pads_,
|
||||
input_right_pads_},
|
||||
conv_N_per_block_{conv_to_gemm_transformer_.N_},
|
||||
a_grid_desc_ak0_m_ak1_{
|
||||
MakeAGridDescriptor_AK0_M_AK1<ALayout>(conv_to_gemm_transformer_)},
|
||||
@@ -451,32 +541,70 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
compute_ptr_offset_of_n_{},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op},
|
||||
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
|
||||
a_g_n_c_wis_strides_{a_g_n_c_wis_strides},
|
||||
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
|
||||
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
|
||||
e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
|
||||
e_g_n_k_wos_strides_{e_g_n_k_wos_strides},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
conv_filter_dilations_{conv_filter_dilations},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads}
|
||||
cde_element_op_{cde_element_op}
|
||||
{
|
||||
// A/B/E Batch/N Stride
|
||||
compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0];
|
||||
compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0];
|
||||
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_;
|
||||
compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides_[0];
|
||||
compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides_[0];
|
||||
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides_[1] * conv_N_per_block_;
|
||||
|
||||
// p_as and p_bs are pointers
|
||||
p_a_grid_ = static_cast<const ADataType*>(p_as);
|
||||
p_b_grid_ = static_cast<const BDataType*>(p_bs);
|
||||
|
||||
compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0];
|
||||
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_;
|
||||
compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides_[0];
|
||||
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides_[1] * conv_N_per_block_;
|
||||
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_);
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
// Use not modified base strides
|
||||
a_in_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
|
||||
a_g_n_c_wis_lengths, a_g_n_c_wis_strides);
|
||||
a_out_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
|
||||
a_g_n_c_wis_lengths, a_g_n_c_wis_strides);
|
||||
|
||||
e_in_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
|
||||
e_g_n_k_wos_lengths, e_g_n_k_wos_strides);
|
||||
e_out_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
|
||||
e_g_n_k_wos_lengths, e_g_n_k_wos_strides);
|
||||
|
||||
elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{
|
||||
a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)};
|
||||
elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapElementwise{
|
||||
e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)};
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceATensorSizeBytes() const
|
||||
{
|
||||
return sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize();
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceETensorSizeBytes() const
|
||||
{
|
||||
return sizeof(EDataType) * e_out_transpose_desc_.GetElementSpaceSize();
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceSizeBytes() const
|
||||
{
|
||||
// Transpose require workspace for A and B
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
return GetWorkspaceATensorSizeBytes() + GetWorkspaceETensorSizeBytes();
|
||||
}
|
||||
else
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
void Print() const
|
||||
@@ -492,6 +620,18 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
const BDataType* p_b_grid_;
|
||||
EDataType* p_e_grid_;
|
||||
|
||||
// for checking IsSupportedArgument()
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_;
|
||||
|
||||
// tensor descriptors for problem definiton
|
||||
index_t num_group_;
|
||||
|
||||
@@ -514,17 +654,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
|
||||
// for checking IsSupportedArgument()
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_;
|
||||
// block-to-e-tile map
|
||||
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_,
|
||||
elementwise_block_2_ctile_map_transpose_e_;
|
||||
|
||||
NGCHWTransposeDescType a_in_transpose_desc_, e_out_transpose_desc_;
|
||||
NHWGCTransposeDescType a_out_transpose_desc_, e_in_transpose_desc_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -532,7 +667,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
float RunGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
@@ -561,8 +696,19 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
index_t K_split = (GemmK + KPerBlock - 1) / KPerBlock * KPerBlock;
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
const ADataType* p_a_grid = arg.p_a_grid_;
|
||||
EDataType* p_e_grid = arg.p_e_grid_;
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
|
||||
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
|
||||
}
|
||||
|
||||
typename GridwiseGemm::Argument gemm_arg{
|
||||
arg.p_a_grid_, arg.p_b_grid_, arg.p_e_grid_, GemmM, GemmN, GemmK, I0, I0, I0, I1};
|
||||
p_a_grid, arg.p_b_grid_, p_e_grid, GemmM, GemmN, GemmK, I0, I0, I0, I1};
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
if(stream_config.flush_cache)
|
||||
@@ -857,6 +1003,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
float avg_time = 0.f;
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
const index_t grid_size =
|
||||
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
|
||||
arg.a_in_transpose_desc_);
|
||||
|
||||
ADataType* p_a_out_grid = type_convert<ADataType*>(arg.p_workspace_);
|
||||
|
||||
auto kernel_transpose = kernel_elementwise<GridwiseElementwiseInputTranspose,
|
||||
ck::Tuple<NGCHWTransposeDescType>,
|
||||
ck::Tuple<NHWGCTransposeDescType>,
|
||||
ck::Tuple<const ADataType*>,
|
||||
ck::Tuple<ADataType*>,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough>;
|
||||
|
||||
avg_time += launch_and_time_kernel(stream_config,
|
||||
kernel_transpose,
|
||||
dim3(grid_size),
|
||||
dim3(ElementwiseBlocksize),
|
||||
0,
|
||||
make_tuple(arg.a_in_transpose_desc_),
|
||||
make_tuple(arg.a_out_transpose_desc_),
|
||||
make_tuple(arg.p_a_grid_),
|
||||
make_tuple(p_a_out_grid),
|
||||
arg.elementwise_block_2_ctile_map_transpose_a_,
|
||||
element_wise::PassThrough{});
|
||||
}
|
||||
|
||||
avg_time += RunGemm(arg, stream_config);
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
const index_t grid_size =
|
||||
arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
|
||||
arg.e_in_transpose_desc_);
|
||||
|
||||
const EDataType* p_e_out_grid =
|
||||
type_convert<EDataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
|
||||
|
||||
EDataType* p_e_in_grid = arg.p_e_grid_;
|
||||
|
||||
auto kernel_transpose = kernel_elementwise<GridwiseElementwiseOutputTranspose,
|
||||
ck::Tuple<NHWGCTransposeDescType>,
|
||||
ck::Tuple<NGCHWTransposeDescType>,
|
||||
ck::Tuple<const EDataType*>,
|
||||
ck::Tuple<EDataType*>,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough>;
|
||||
|
||||
avg_time += launch_and_time_kernel(stream_config,
|
||||
kernel_transpose,
|
||||
dim3(grid_size),
|
||||
dim3(ElementwiseBlocksize),
|
||||
0,
|
||||
make_tuple(arg.e_in_transpose_desc_),
|
||||
make_tuple(arg.e_out_transpose_desc_),
|
||||
make_tuple(p_e_out_grid),
|
||||
make_tuple(p_e_in_grid),
|
||||
arg.elementwise_block_2_ctile_map_transpose_e_,
|
||||
element_wise::PassThrough{});
|
||||
}
|
||||
|
||||
return avg_time;
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
@@ -868,6 +1087,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
{
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
|
||||
const index_t G = arg.b_g_k_c_xs_lengths_[I0];
|
||||
const index_t K = arg.b_g_k_c_xs_lengths_[I1];
|
||||
const index_t C = arg.b_g_k_c_xs_lengths_[I2];
|
||||
|
||||
// check device
|
||||
if(get_device_name() == "gfx908")
|
||||
{
|
||||
@@ -924,10 +1147,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
is_same_v<ALayout, ctc::G_NDHW_C> || is_same_v<ALayout, ctc::GNWC> ||
|
||||
is_same_v<ALayout, ctc::GNHWC> || is_same_v<ALayout, ctc::GNDHWC> ||
|
||||
is_same_v<ALayout, ctc::NWGC> || is_same_v<ALayout, ctc::NHWGC> ||
|
||||
is_same_v<ALayout, ctc::NDHWGC>)
|
||||
is_same_v<ALayout, ctc::NDHWGC> || is_same_v<ALayout, ctc::NGCW> ||
|
||||
is_same_v<ALayout, ctc::NGCHW> || is_same_v<ALayout, ctc::NGCDHW>)
|
||||
{
|
||||
const index_t C = arg.a_g_n_c_wis_lengths_[2];
|
||||
|
||||
if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
@@ -947,8 +1169,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
is_same_v<BLayout, ctc::KZYXGC>)
|
||||
|
||||
{
|
||||
const index_t C = arg.b_g_k_c_xs_lengths_[2];
|
||||
|
||||
if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
@@ -959,15 +1179,43 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
if((G * C) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if((G * K) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
const index_t input_spatial_acum = ck::accumulate_n<index_t>(
|
||||
arg.a_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
|
||||
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
|
||||
arg.e_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
|
||||
|
||||
if(input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(output_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// check vector access of E
|
||||
if constexpr(is_same_v<ELayout, ctc::G_NW_K> || is_same_v<ELayout, ctc::G_NHW_K> ||
|
||||
is_same_v<ELayout, ctc::G_NDHW_K> || is_same_v<ELayout, ctc::GNWK> ||
|
||||
is_same_v<ELayout, ctc::GNHWK> || is_same_v<ELayout, ctc::GNDHWK> ||
|
||||
is_same_v<ELayout, ctc::NWGK> || is_same_v<ELayout, ctc::NHWGK> ||
|
||||
is_same_v<ELayout, ctc::NDHWGK>)
|
||||
is_same_v<ELayout, ctc::NDHWGK> || is_same_v<ELayout, ctc::NGKW> ||
|
||||
is_same_v<ELayout, ctc::NGKHW> || is_same_v<ELayout, ctc::NGKDHW>)
|
||||
{
|
||||
const index_t K = arg.e_g_n_k_wos_lengths_[2];
|
||||
|
||||
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
{
|
||||
return false;
|
||||
@@ -1279,6 +1527,34 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
|
||||
return str.str();
|
||||
}
|
||||
|
||||
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
|
||||
{
|
||||
auto arg = dynamic_cast<const Argument*>(p_arg);
|
||||
if(arg)
|
||||
{
|
||||
return arg->GetWorkspaceSizeBytes();
|
||||
}
|
||||
else
|
||||
throw std::runtime_error(
|
||||
"The argument pointer is not an object of "
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!");
|
||||
}
|
||||
|
||||
void SetWorkSpacePointer(BaseArgument* p_arg,
|
||||
void* p_workspace,
|
||||
const StreamConfig& = StreamConfig{}) const override
|
||||
{
|
||||
auto p_arg_ = dynamic_cast<Argument*>(p_arg);
|
||||
if(p_arg_)
|
||||
{
|
||||
p_arg_->p_workspace_ = p_workspace;
|
||||
}
|
||||
else
|
||||
throw std::runtime_error(
|
||||
"The argument pointer is not an object of "
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -26,6 +26,15 @@ constexpr bool is_GNWC_GKXC_GNWK()
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::GNWK>;
|
||||
}
|
||||
|
||||
template <typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
constexpr bool is_NGCW_GKXC_NGKW()
|
||||
{
|
||||
return is_same_v<InLayout, tensor_layout::convolution::NGCW> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::NGKW>;
|
||||
}
|
||||
|
||||
// 2d
|
||||
template <typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
constexpr bool is_NHWGC_GKYXC_NHWGK()
|
||||
@@ -91,6 +100,14 @@ constexpr bool is_GNSpatialC_GKSpatial_GNSpatialK()
|
||||
is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>();
|
||||
}
|
||||
|
||||
template <typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
constexpr bool is_NGCSpatial_GKSpatial_NGKSpatial()
|
||||
{
|
||||
return is_NGCW_GKXC_NGKW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>();
|
||||
}
|
||||
|
||||
template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0, typename = void>
|
||||
struct ComputePtrOffsetOfStridedBatch
|
||||
{
|
||||
|
||||
@@ -0,0 +1,236 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
index_t NDimSpatial,
|
||||
index_t MPerThread,
|
||||
index_t NPerThread>
|
||||
struct TransformConvNGCHWToNHWGC
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
|
||||
static auto MakeNGCHWTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
|
||||
{
|
||||
const index_t& G = g_n_c_wis_lengths[I0];
|
||||
const index_t& N = g_n_c_wis_lengths[I1];
|
||||
const index_t& C = g_n_c_wis_lengths[I2];
|
||||
const index_t& Wi = g_n_c_wis_lengths[I3];
|
||||
|
||||
const index_t& GStride = g_n_c_wis_strides[I0];
|
||||
const index_t& NStride = g_n_c_wis_strides[I1];
|
||||
const index_t& CStride = g_n_c_wis_strides[I2];
|
||||
const index_t& WiStride = g_n_c_wis_strides[I3];
|
||||
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, G, C, Wi), make_tuple(NStride, GStride, CStride, WiStride));
|
||||
const auto merged_desc =
|
||||
transform_tensor_descriptor(desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, G, C)),
|
||||
make_merge_transform(make_tuple(Wi))),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
return device::PadTensorDescriptor(
|
||||
merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
|
||||
static auto MakeNHWGCTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
|
||||
{
|
||||
const index_t& G = g_n_c_wis_lengths[I0];
|
||||
const index_t& N = g_n_c_wis_lengths[I1];
|
||||
const index_t& C = g_n_c_wis_lengths[I2];
|
||||
const index_t& Wi = g_n_c_wis_lengths[I3];
|
||||
|
||||
const index_t& NStride = g_n_c_wis_strides[I1];
|
||||
const index_t WiStride = G * C;
|
||||
const index_t GStride = C;
|
||||
const index_t CStride = 1;
|
||||
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, G, C, Wi), make_tuple(NStride, GStride, CStride, WiStride));
|
||||
const auto merged_desc =
|
||||
transform_tensor_descriptor(desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, G, C)),
|
||||
make_merge_transform(make_tuple(Wi))),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
return device::PadTensorDescriptor(
|
||||
merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
|
||||
static auto MakeNGCHWTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
|
||||
{
|
||||
const index_t& G = g_n_c_wis_lengths[I0];
|
||||
const index_t& N = g_n_c_wis_lengths[I1];
|
||||
const index_t& C = g_n_c_wis_lengths[I2];
|
||||
const index_t& Hi = g_n_c_wis_lengths[I3];
|
||||
const index_t& Wi = g_n_c_wis_lengths[I4];
|
||||
|
||||
const index_t& GStride = g_n_c_wis_strides[I0];
|
||||
const index_t& NStride = g_n_c_wis_strides[I1];
|
||||
const index_t& CStride = g_n_c_wis_strides[I2];
|
||||
const index_t& HiStride = g_n_c_wis_strides[I3];
|
||||
const index_t& WiStride = g_n_c_wis_strides[I4];
|
||||
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, G, C, Hi, Wi), make_tuple(NStride, GStride, CStride, HiStride, WiStride));
|
||||
const auto merged_desc =
|
||||
transform_tensor_descriptor(desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, G, C)),
|
||||
make_merge_transform(make_tuple(Hi, Wi))),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
return device::PadTensorDescriptor(
|
||||
merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
|
||||
static auto MakeNHWGCTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
|
||||
{
|
||||
const index_t& G = g_n_c_wis_lengths[I0];
|
||||
const index_t& N = g_n_c_wis_lengths[I1];
|
||||
const index_t& C = g_n_c_wis_lengths[I2];
|
||||
const index_t& Hi = g_n_c_wis_lengths[I3];
|
||||
const index_t& Wi = g_n_c_wis_lengths[I4];
|
||||
|
||||
const index_t& NStride = g_n_c_wis_strides[I1];
|
||||
const index_t HiStride = Wi * G * C;
|
||||
const index_t WiStride = G * C;
|
||||
const index_t GStride = C;
|
||||
const index_t CStride = 1;
|
||||
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, G, C, Hi, Wi), make_tuple(NStride, GStride, CStride, HiStride, WiStride));
|
||||
const auto merged_desc =
|
||||
transform_tensor_descriptor(desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, G, C)),
|
||||
make_merge_transform(make_tuple(Hi, Wi))),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
return device::PadTensorDescriptor(
|
||||
merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
|
||||
static auto MakeNGCHWTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
|
||||
{
|
||||
const index_t& G = g_n_c_wis_lengths[I0];
|
||||
const index_t& N = g_n_c_wis_lengths[I1];
|
||||
const index_t& C = g_n_c_wis_lengths[I2];
|
||||
const index_t& Di = g_n_c_wis_lengths[I3];
|
||||
const index_t& Hi = g_n_c_wis_lengths[I4];
|
||||
const index_t& Wi = g_n_c_wis_lengths[I5];
|
||||
|
||||
const index_t& GStride = g_n_c_wis_strides[I0];
|
||||
const index_t& NStride = g_n_c_wis_strides[I1];
|
||||
const index_t& CStride = g_n_c_wis_strides[I2];
|
||||
const index_t& DiStride = g_n_c_wis_strides[I3];
|
||||
const index_t& HiStride = g_n_c_wis_strides[I4];
|
||||
const index_t& WiStride = g_n_c_wis_strides[I5];
|
||||
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, G, C, Di, Hi, Wi),
|
||||
make_tuple(NStride, GStride, CStride, DiStride, HiStride, WiStride));
|
||||
const auto merged_desc =
|
||||
transform_tensor_descriptor(desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, G, C)),
|
||||
make_merge_transform(make_tuple(Di, Hi, Wi))),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
return device::PadTensorDescriptor(
|
||||
merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
|
||||
static auto MakeNHWGCTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
|
||||
{
|
||||
const index_t& G = g_n_c_wis_lengths[I0];
|
||||
const index_t& N = g_n_c_wis_lengths[I1];
|
||||
const index_t& C = g_n_c_wis_lengths[I2];
|
||||
const index_t& Di = g_n_c_wis_lengths[I3];
|
||||
const index_t& Hi = g_n_c_wis_lengths[I4];
|
||||
const index_t& Wi = g_n_c_wis_lengths[I5];
|
||||
|
||||
const index_t& NStride = g_n_c_wis_strides[I1];
|
||||
const index_t DiStride = Hi * Wi * G * C;
|
||||
const index_t HiStride = Wi * G * C;
|
||||
const index_t WiStride = G * C;
|
||||
const index_t GStride = C;
|
||||
const index_t CStride = 1;
|
||||
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, G, C, Di, Hi, Wi),
|
||||
make_tuple(NStride, GStride, CStride, DiStride, HiStride, WiStride));
|
||||
const auto merged_desc =
|
||||
transform_tensor_descriptor(desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, G, C)),
|
||||
make_merge_transform(make_tuple(Di, Hi, Wi))),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
return device::PadTensorDescriptor(
|
||||
merged_desc, make_tuple(MPerThread, NPerThread), Sequence<true, true>{});
|
||||
}
|
||||
|
||||
static auto TransposeStrides(const std::array<index_t, NDimSpatial + 3>& g_n_c_wis_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& g_n_c_wis_strides)
|
||||
{
|
||||
if constexpr(device::is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
|
||||
device::is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
|
||||
{
|
||||
std::array<index_t, NDimSpatial + 3> g_n_c_wis_strides_transposed;
|
||||
const auto G = g_n_c_wis_lengths[I0];
|
||||
const auto C = g_n_c_wis_lengths[I2];
|
||||
|
||||
g_n_c_wis_strides_transposed[I0] = C;
|
||||
g_n_c_wis_strides_transposed[I1] = g_n_c_wis_strides[I1];
|
||||
g_n_c_wis_strides_transposed[I2] = I1;
|
||||
if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
g_n_c_wis_strides_transposed[I3] = g_n_c_wis_lengths[I4] * G * C;
|
||||
g_n_c_wis_strides_transposed[I4] = G * C;
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
g_n_c_wis_strides_transposed[I3] =
|
||||
g_n_c_wis_lengths[I4] * g_n_c_wis_lengths[I5] * G * C;
|
||||
g_n_c_wis_strides_transposed[I4] = g_n_c_wis_lengths[I5] * G * C;
|
||||
g_n_c_wis_strides_transposed[I5] = G * C;
|
||||
}
|
||||
return g_n_c_wis_strides_transposed;
|
||||
}
|
||||
else
|
||||
{
|
||||
// transpose not needed
|
||||
return g_n_c_wis_strides;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -516,7 +516,7 @@ struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::Add,
|
||||
static constexpr bool value =
|
||||
is_same<DataType, float>::value || is_same<DataType, double>::value ||
|
||||
is_same<DataType, half_t>::value || is_same<DataType, int8_t>::value ||
|
||||
is_same<DataType, int32_t>::value || is_same<DataType, f8_t>::value;
|
||||
is_same<DataType, int32_t>::value;
|
||||
};
|
||||
|
||||
} // namespace reduce
|
||||
|
||||
@@ -215,8 +215,8 @@ struct BlockFmhaPipelineQRKSVS
|
||||
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
|
||||
|
||||
// check early exit if masked and no work to do.
|
||||
if constexpr(FmhaMask::IsMasking)
|
||||
// check early exit if no work to do
|
||||
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
|
||||
{
|
||||
if(num_total_loop <= 0)
|
||||
{
|
||||
|
||||
@@ -268,7 +268,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
|
||||
|
||||
// check early exit
|
||||
// check early exit if no work to do
|
||||
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
|
||||
{
|
||||
if(num_total_loop <= 0)
|
||||
|
||||
@@ -123,14 +123,26 @@ struct GemmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
auto ABlockWindow = make_tile_window(
|
||||
auto a_pad_view = pad_tensor_view(
|
||||
a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
|
||||
sequence < 0,
|
||||
GemmPipeline::kPadA ? 1 : 0 > {});
|
||||
|
||||
auto ABlockWindow = make_tile_window(
|
||||
a_pad_view,
|
||||
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
|
||||
{i_m, 0});
|
||||
|
||||
auto BBlockWindow = make_tile_window(
|
||||
auto b_pad_view = pad_tensor_view(
|
||||
b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
|
||||
sequence < 0,
|
||||
GemmPipeline::kPadB ? 1 : 0 > {});
|
||||
|
||||
auto BBlockWindow = make_tile_window(
|
||||
b_pad_view,
|
||||
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
|
||||
{i_n, 0});
|
||||
|
||||
// allocate LDS
|
||||
@@ -163,12 +175,16 @@ struct GemmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
auto CBlockWindow = make_tile_window(
|
||||
auto c_pad_view = pad_tensor_view(
|
||||
c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
|
||||
sequence < 0,
|
||||
GemmPipeline::kPadC ? 1 : 0 > {});
|
||||
auto CBlockWindow_pad = make_tile_window(
|
||||
c_pad_view,
|
||||
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
|
||||
{i_m, i_n});
|
||||
// epilogue.
|
||||
EpiloguePipeline{}(CBlockWindow, acc);
|
||||
EpiloguePipeline{}(CBlockWindow_pad, acc);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -29,6 +29,10 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
|
||||
static constexpr index_t AlignmentB = Problem::AlignmentB;
|
||||
static constexpr index_t AlignmentC = Problem::AlignmentC;
|
||||
|
||||
static constexpr bool kPadA = Problem::kPadA;
|
||||
static constexpr bool kPadB = Problem::kPadB;
|
||||
static constexpr bool kPadC = Problem::kPadC;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize()
|
||||
{
|
||||
return ck_tile::integer_divide_ceil(
|
||||
|
||||
@@ -28,9 +28,9 @@ struct BlockGemmPipelineProblem
|
||||
static constexpr bool kPadB = kPadB_;
|
||||
static constexpr bool kPadC = kPadC_;
|
||||
|
||||
static constexpr index_t AlignmentA = kPadA ? VectorLoadSize / sizeof(ADataType) : 1;
|
||||
static constexpr index_t AlignmentB = kPadB ? VectorLoadSize / sizeof(BDataType) : 1;
|
||||
static constexpr index_t AlignmentC = kPadC ? VectorLoadSize / sizeof(CDataType) : 1;
|
||||
static constexpr index_t AlignmentA = kPadA ? 1 : VectorLoadSize / sizeof(ADataType);
|
||||
static constexpr index_t AlignmentB = kPadB ? 1 : VectorLoadSize / sizeof(BDataType);
|
||||
static constexpr index_t AlignmentC = kPadC ? 1 : VectorLoadSize / sizeof(CDataType);
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -249,6 +249,40 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
}
|
||||
#endif
|
||||
}
|
||||
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NGCHW> &&
|
||||
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NGKHW>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float> && is_same_v<AComputeType, float> &&
|
||||
is_same_v<BComputeType, float>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
|
||||
is_same_v<BComputeType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, GNDHWC> &&
|
||||
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, GNDHWK>)
|
||||
|
||||
@@ -57,6 +57,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_BF16
|
||||
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(
|
||||
|
||||
@@ -57,6 +57,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_BF16
|
||||
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances(
|
||||
|
||||
@@ -57,6 +57,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_BF16
|
||||
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances(
|
||||
|
||||
@@ -171,6 +171,39 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_BF16
|
||||
// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK
|
||||
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances(
|
||||
|
||||
@@ -39,6 +39,20 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_insta
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP32
|
||||
@@ -55,6 +69,20 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_insta
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_BF16
|
||||
|
||||
@@ -102,12 +102,14 @@ function(add_instance_library INSTANCE_NAME)
|
||||
set(FMHA_FWD_FAST_EXP2 true)
|
||||
endif()
|
||||
if(FMHA_FWD_FAST_EXP2)
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
|
||||
list(APPEND FMHA_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
|
||||
else()
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
|
||||
list(APPEND FMHA_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
|
||||
endif()
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal)
|
||||
target_compile_options(device_mha_instance PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS})
|
||||
list(APPEND FMHA_COMPILE_OPTIONS -Wno-float-equal)
|
||||
list(APPEND FMHA_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=1)
|
||||
list(APPEND FMHA_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=1)
|
||||
target_compile_options(device_mha_instance PRIVATE ${FMHA_COMPILE_OPTIONS})
|
||||
endif()
|
||||
|
||||
target_compile_features(${INSTANCE_NAME} PUBLIC)
|
||||
|
||||
@@ -9,6 +9,9 @@ add_instance_library(device_grouped_conv2d_fwd_instance
|
||||
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
||||
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp
|
||||
# NGCHW, GKYXC, NGKHW
|
||||
xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp
|
||||
xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp
|
||||
# large tensor
|
||||
# NHWGC, GKYXC, NHWGK
|
||||
xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||
@@ -19,6 +22,9 @@ add_instance_library(device_grouped_conv2d_fwd_instance
|
||||
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
||||
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instance.cpp
|
||||
# NGCHW, GKYXC, NGKHW
|
||||
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_instance.cpp
|
||||
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_instance.cpp
|
||||
#mem
|
||||
# NHWGC, GKYXC, NHWGK
|
||||
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp
|
||||
@@ -28,11 +34,20 @@ add_instance_library(device_grouped_conv2d_fwd_instance
|
||||
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp
|
||||
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp
|
||||
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp
|
||||
# NGCHW, GKYXC, NGKHW
|
||||
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instance.cpp
|
||||
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instance.cpp
|
||||
# NGCHW, GKYXC, NGKHW
|
||||
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instance.cpp
|
||||
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instance.cpp
|
||||
#comp
|
||||
# NHWGC, GKYXC, NHWGK
|
||||
xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp
|
||||
xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp
|
||||
xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp
|
||||
# NGCHW, GKYXC, NGKHW
|
||||
xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instance.cpp
|
||||
xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instance.cpp
|
||||
#dl
|
||||
# GNHWC, GKYXC, GNHWK
|
||||
dl/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f16_comp_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwdDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,39 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f32_comp_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwdDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,38 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f16_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwdDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,38 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwdDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,39 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f16_mem_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwdDefault,
|
||||
Interwave>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,39 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f16_mem_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwdDefault,
|
||||
Intrawave>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,39 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_mem_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwdDefault,
|
||||
Interwave>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,39 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_mem_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwdDefault,
|
||||
Intrawave>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,48 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_merged_groups_f16_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwdDefault>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_merged_groups_f16_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwd3x3>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,48 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_merged_groups_f32_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwdDefault>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_merged_groups_f32_instances<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGKHW,
|
||||
ConvFwd3x3>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -32,23 +32,33 @@ if(EXISTS ${FMHA_CPP_FOLDER}/blob_list.txt)
|
||||
file(REMOVE ${FMHA_CPP_FOLDER}/blob_list.txt)
|
||||
endif()
|
||||
|
||||
set(FMHA_KNOWN_APIS "fwd,fwd_splitkv,fwd_appendkv,bwd")
|
||||
|
||||
# generate a list of kernels, but not actually emit files at config stage
|
||||
# Note: The receipt 3 arg filters the generated backwards instances to reduce compilation time.
|
||||
# With receipt 3 set, we are generating instances for datatype == {fp16 || bfp16}, bias == {no || alibi}, deterministic == off, and dpad == dvpad.
|
||||
execute_process(
|
||||
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py
|
||||
COMMAND ${PYTHON_EXECUTABLE} ${FMHA_SRC_FOLDER}/generate.py
|
||||
--list_blobs ${FMHA_CPP_FOLDER}/blob_list.txt
|
||||
--api ${FMHA_KNOWN_APIS}
|
||||
--receipt 3
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile MHA FAILED to genrate a list of kernels via Python.")
|
||||
else()
|
||||
file(STRINGS ${FMHA_CPP_FOLDER}/blob_list.txt FMHA_FWD_GEN_BLOBS)
|
||||
file(STRINGS ${FMHA_CPP_FOLDER}/blob_list.txt FMHA_GEN_BLOBS)
|
||||
endif()
|
||||
|
||||
# actually generate the kernel content now
|
||||
# Note: The receipt 3 arg filters the generated backwards instances to reduce compilation time.
|
||||
# With receipt 3 set, we are generating instances for datatype == {fp16 || bfp16}, bias == {no || alibi}, deterministic == off, and dpad == dvpad.
|
||||
add_custom_command(
|
||||
OUTPUT ${FMHA_FWD_GEN_BLOBS}
|
||||
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_SOURCE_DIR}/example/ck_tile/01_fmha/generate.py
|
||||
OUTPUT ${FMHA_GEN_BLOBS}
|
||||
COMMAND ${PYTHON_EXECUTABLE} ${FMHA_SRC_FOLDER}/generate.py
|
||||
--output_dir ${FMHA_CPP_FOLDER}
|
||||
--api ${FMHA_KNOWN_APIS}
|
||||
--receipt 3
|
||||
COMMENT "Generating mha kernel (cpp) files now ..."
|
||||
VERBATIM
|
||||
)
|
||||
@@ -57,12 +67,12 @@ add_custom_command(
|
||||
# have filename. Since, it was cauing the cmake
|
||||
# to throw "File name too long"
|
||||
set(device_files)
|
||||
foreach(filepath IN LISTS FMHA_FWD_GEN_BLOBS)
|
||||
foreach(filepath IN LISTS FMHA_GEN_BLOBS)
|
||||
get_filename_component(filename ${filepath} NAME)
|
||||
# Append the filename to the device_files list
|
||||
list(APPEND device_files ${filename})
|
||||
endforeach()
|
||||
add_custom_target(generate_cpp_files DEPENDS ${FMHA_FWD_GEN_BLOBS})
|
||||
add_custom_target(generate_cpp_files DEPENDS ${FMHA_GEN_BLOBS})
|
||||
|
||||
add_instance_library(device_mha_instance ${device_files})
|
||||
|
||||
|
||||
@@ -148,6 +148,11 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
|
||||
bool pass = true;
|
||||
|
||||
auto run_impl = [&](auto& op_ptr, auto& argument_ptr) {
|
||||
// workspace_sz will be equal to 0 for other layout than NGCHW
|
||||
const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
|
||||
DeviceMem workspace_dev(workspace_sz);
|
||||
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
// re-init output to zero before profiling next kernel
|
||||
|
||||
@@ -45,6 +45,8 @@ static void print_helper_msg()
|
||||
"N, Ho, Wo, K]\n"
|
||||
<< " 2: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, "
|
||||
"Ho, Wo, G, K]\n"
|
||||
<< " 3: Input[N, G, C, Hi, Wi], Weight[G, K, Y, X, C], Output[N, "
|
||||
"G, K, Ho, Wo]\n"
|
||||
<< "arg4: verification (0: no, 1: yes)\n"
|
||||
<< "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n"
|
||||
<< "arg6: print tensor value (0: no; 1: yes)\n"
|
||||
|
||||
@@ -15,6 +15,7 @@ enum struct ConvLayout
|
||||
{
|
||||
GNHWC_GKYXC_GNHWK, // 0
|
||||
NHWGC_GKYXC_NHWGK, // 1
|
||||
NGCHW_GKYXC_NGKHW, // 2
|
||||
};
|
||||
|
||||
enum struct ConvDataType
|
||||
@@ -54,6 +55,8 @@ static void print_helper_msg()
|
||||
<< "arg3: indexing data type (0: 32-bit, 1: 64-bit)\n"
|
||||
<< "arg4: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n"
|
||||
<< " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])\n"
|
||||
<< " 2: Input[N, G, C, Hi, Wi], Weight[G, K, Y, X, C], Output[N, "
|
||||
"G, K, Ho, Wo]\n"
|
||||
<< "arg5: verification (0: no, 1: yes)\n"
|
||||
<< "arg6: initialization (0: no init, 1: integer value, 2: decimal value)\n"
|
||||
<< "arg7: print tensor value (0: no; 1: yes)\n"
|
||||
@@ -111,6 +114,11 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
using GNHWK = ck::tensor_layout::convolution::GNHWK;
|
||||
using GNDHWK = ck::tensor_layout::convolution::GNDHWK;
|
||||
|
||||
//
|
||||
using NGCHW = ck::tensor_layout::convolution::NGCHW;
|
||||
|
||||
using NGKHW = ck::tensor_layout::convolution::NGKHW;
|
||||
|
||||
//
|
||||
using NWGC = ck::tensor_layout::convolution::NWGC;
|
||||
using NHWGC = ck::tensor_layout::convolution::NHWGC;
|
||||
@@ -284,6 +292,17 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKYXC_NGKHW)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, F32{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, F16{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
|
||||
@@ -28,6 +28,8 @@ def parse_layouts(args):
|
||||
args.in_layout == "NCDHW":
|
||||
if args.ck_profier_op == "grouped_conv_bwd_weight":
|
||||
args.layout = 3
|
||||
elif args.ck_profier_op == "grouped_conv_fwd":
|
||||
args.layout = 2
|
||||
else:
|
||||
print('Not supported layout for this op')
|
||||
exit(1)
|
||||
|
||||
@@ -62,7 +62,9 @@ using KernelTypes2d = ::testing::Types<std::tuple<float, GNHWC, GKYXC, GNHWK>,
|
||||
std::tuple<float, NHWGC, GKYXC, NHWGK>,
|
||||
std::tuple<ck::half_t, NHWGC, GKYXC, NHWGK>,
|
||||
std::tuple<ck::bhalf_t, NHWGC, GKYXC, NHWGK>,
|
||||
std::tuple<int8_t, NHWGC, GKYXC, NHWGK>>;
|
||||
std::tuple<int8_t, NHWGC, GKYXC, NHWGK>,
|
||||
std::tuple<float, NGCHW, GKYXC, NGKHW>,
|
||||
std::tuple<ck::half_t, NGCHW, GKYXC, NGKHW>>;
|
||||
|
||||
using KernelTypes3d = ::testing::Types<std::tuple<float, GNDHWC, GKZYXC, GNDHWK>,
|
||||
std::tuple<ck::half_t, GNDHWC, GKZYXC, GNDHWK>,
|
||||
|
||||
Reference in New Issue
Block a user