mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
Merge branch 'develop' into zan_vllm_layout
This commit is contained in:
135
Jenkinsfile
vendored
135
Jenkinsfile
vendored
@@ -93,6 +93,30 @@ def build_compiler(){
|
||||
return compiler
|
||||
}
|
||||
|
||||
def check_arch(){
|
||||
def arch_type = 0
|
||||
sh 'rocminfo | tee rocminfo.log'
|
||||
if ( runShell('grep -n "gfx90a" rocminfo.log') ){
|
||||
arch_type = 1
|
||||
}
|
||||
else if ( runShell('grep -n "gfx942" rocminfo.log') ) {
|
||||
arch_type = 2
|
||||
}
|
||||
else if ( runShell('grep -n "gfx10" rocminfo.log') ) {
|
||||
arch_type = 3
|
||||
}
|
||||
else if ( runShell('grep -n "gfx11" rocminfo.log') ) {
|
||||
arch_type = 4
|
||||
}
|
||||
else if ( runShell('grep -n "gfx12" rocminfo.log') ) {
|
||||
arch_type = 5
|
||||
}
|
||||
else if ( runShell('grep -n "gfx908" rocminfo.log') ) {
|
||||
arch_type = 6
|
||||
}
|
||||
return arch_type
|
||||
}
|
||||
|
||||
def getDockerImage(Map conf=[:]){
|
||||
env.DOCKER_BUILDKIT=1
|
||||
def prefixpath = conf.get("prefixpath", "/opt/rocm")
|
||||
@@ -287,7 +311,7 @@ def cmake_build(Map conf=[:]){
|
||||
def build_cmd
|
||||
def execute_cmd = conf.get("execute_cmd", "")
|
||||
if(!setup_args.contains("NO_CK_BUILD")){
|
||||
if (setup_args.contains("gfx90a") && params.NINJA_BUILD_TRACE){
|
||||
if (setup_args.contains("gfx9") && params.NINJA_BUILD_TRACE){
|
||||
echo "running ninja build trace"
|
||||
setup_cmd = conf.get("setup_cmd", """${cmake_envs} cmake -G Ninja ${setup_args} -DCMAKE_CXX_FLAGS=" -O3 -ftime-trace " .. """)
|
||||
build_cmd = conf.get("build_cmd", "${build_envs} ninja -j${nt} ${config_targets}")
|
||||
@@ -315,7 +339,7 @@ def cmake_build(Map conf=[:]){
|
||||
sh cmd
|
||||
//run tests except when NO_CK_BUILD or BUILD_LEGACY_OS are set
|
||||
if(!setup_args.contains("NO_CK_BUILD") && !params.BUILD_LEGACY_OS){
|
||||
if (setup_args.contains("gfx90a") && params.NINJA_BUILD_TRACE){
|
||||
if ((setup_args.contains("gfx9") && params.NINJA_BUILD_TRACE) || params.BUILD_INSTANCES_ONLY){
|
||||
sh "/ninjatracing/ninjatracing .ninja_log > ck_build_trace.json"
|
||||
sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --all . clang_build.log"
|
||||
sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --analyze clang_build.log > clang_build_analysis.log"
|
||||
@@ -323,7 +347,15 @@ def cmake_build(Map conf=[:]){
|
||||
archiveArtifacts "clang_build_analysis.log"
|
||||
// do not run unit tests when building instances only
|
||||
if(!params.BUILD_INSTANCES_ONLY){
|
||||
sh "ninja test"
|
||||
sh "ninja check"
|
||||
}
|
||||
if(params.BUILD_INSTANCES_ONLY){
|
||||
// build deb packages
|
||||
echo "Build packages"
|
||||
sh 'ninja -j64 package'
|
||||
archiveArtifacts artifacts: 'composablekernel-dev*.deb'
|
||||
sh 'mv composablekernel-dev_*.deb composablekernel-dev_all_targets_1.1.0_amd64.deb'
|
||||
stash includes: "composablekernel-dev_all_targets_1.1.0_amd64.deb", name: "packages"
|
||||
}
|
||||
}
|
||||
else{
|
||||
@@ -340,21 +372,14 @@ def cmake_build(Map conf=[:]){
|
||||
archiveArtifacts artifacts: "build/*.deb", allowEmptyArchive: true, fingerprint: true
|
||||
}
|
||||
//check the node gpu architecture
|
||||
def arch_type = 0
|
||||
sh 'rocminfo | tee rocminfo.log'
|
||||
if ( runShell('grep -n "gfx90a" rocminfo.log') ){
|
||||
arch_type = 1
|
||||
}
|
||||
else if ( runShell('grep -n "gfx942" rocminfo.log') ) {
|
||||
arch_type = 2
|
||||
}
|
||||
def arch = check_arch()
|
||||
if (params.RUN_CK_TILE_FMHA_TESTS){
|
||||
try{
|
||||
archiveArtifacts "perf_fmha_*.log"
|
||||
if (arch_type == 1){
|
||||
if (arch == 1){
|
||||
stash includes: "perf_fmha_**_gfx90a.log", name: "perf_fmha_log_gfx90a"
|
||||
}
|
||||
else if (arch_type == 2){
|
||||
else if (arch == 2){
|
||||
stash includes: "perf_fmha_**_gfx942.log", name: "perf_fmha_log_gfx942"
|
||||
}
|
||||
}
|
||||
@@ -379,10 +404,10 @@ def cmake_build(Map conf=[:]){
|
||||
if (params.RUN_CK_TILE_GEMM_TESTS){
|
||||
try{
|
||||
archiveArtifacts "perf_tile_gemm_**.log"
|
||||
if (arch_type == 1){
|
||||
if (arch == 1){
|
||||
stash includes: "perf_tile_gemm_**_gfx90a.log", name: "perf_tile_gemm_log_gfx90a"
|
||||
}
|
||||
else if (arch_type == 2){
|
||||
else if (arch == 2){
|
||||
stash includes: "perf_tile_gemm_**_gfx942.log", name: "perf_tile_gemm_log_gfx942"
|
||||
}
|
||||
}
|
||||
@@ -410,7 +435,13 @@ def buildHipClangJob(Map conf=[:]){
|
||||
def prefixpath = conf.get("prefixpath", "/opt/rocm")
|
||||
|
||||
// Jenkins is complaining about the render group
|
||||
def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
|
||||
def dockerOpts
|
||||
if ( params.BUILD_INSTANCES_ONLY ){
|
||||
dockerOpts = "--group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
|
||||
}
|
||||
else{
|
||||
dockerOpts = "--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
|
||||
}
|
||||
if (conf.get("enforce_xnack_on", false)) {
|
||||
dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
|
||||
}
|
||||
@@ -521,28 +552,9 @@ def Build_CK(Map conf=[:]){
|
||||
timeout(time: 20, unit: 'HOURS')
|
||||
{
|
||||
//check whether to run performance tests on this node
|
||||
def arch_type = 0
|
||||
sh 'rocminfo | tee rocminfo.log'
|
||||
if ( runShell('grep -n "gfx90a" rocminfo.log') ){
|
||||
arch_type = 1
|
||||
}
|
||||
else if ( runShell('grep -n "gfx942" rocminfo.log') ) {
|
||||
arch_type = 2
|
||||
}
|
||||
else if ( runShell('grep -n "gfx10" rocminfo.log') ) {
|
||||
arch_type = 3
|
||||
}
|
||||
else if ( runShell('grep -n "gfx11" rocminfo.log') ) {
|
||||
arch_type = 4
|
||||
}
|
||||
else if ( runShell('grep -n "gfx12" rocminfo.log') ) {
|
||||
arch_type = 5
|
||||
}
|
||||
else if ( runShell('grep -n "gfx908" rocminfo.log') ) {
|
||||
arch_type = 6
|
||||
}
|
||||
def arch = check_arch()
|
||||
cmake_build(conf)
|
||||
if ( params.RUN_INDUCTOR_TESTS && !params.BUILD_LEGACY_OS && arch_type == 1 ){
|
||||
if ( params.RUN_INDUCTOR_TESTS && !params.BUILD_LEGACY_OS && arch == 1 ){
|
||||
echo "Run inductor codegen tests"
|
||||
sh """
|
||||
python3 -m venv ${env.WORKSPACE}
|
||||
@@ -553,9 +565,9 @@ def Build_CK(Map conf=[:]){
|
||||
"""
|
||||
}
|
||||
dir("build"){
|
||||
if (params.RUN_FULL_QA && arch_type == 2 ){
|
||||
// build deb packages for all gfx9 targets on gfx90a system and prepare to export
|
||||
echo "Build ckProfiler package"
|
||||
if (params.RUN_FULL_QA && arch == 2 ){
|
||||
// build deb packages
|
||||
echo "Build packages"
|
||||
sh 'make -j package'
|
||||
archiveArtifacts artifacts: 'composablekernel*.deb'
|
||||
sh 'mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.1.0_amd64.deb'
|
||||
@@ -568,7 +580,7 @@ def Build_CK(Map conf=[:]){
|
||||
// run performance tests, stash the logs, results will be processed on the master node
|
||||
dir("script"){
|
||||
if (params.RUN_PERFORMANCE_TESTS){
|
||||
if (params.RUN_FULL_QA && arch_type == 1){
|
||||
if (params.RUN_FULL_QA && arch == 1){
|
||||
// run full tests on gfx90a
|
||||
echo "Run full performance tests"
|
||||
sh "./run_full_performance_tests.sh 0 QA_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}"
|
||||
@@ -587,7 +599,7 @@ def Build_CK(Map conf=[:]){
|
||||
archiveArtifacts "perf_mixed_gemm.log"
|
||||
stash includes: "perf_**.log", name: "perf_log"
|
||||
}
|
||||
else if ( arch_type == 1 ){
|
||||
else if ( arch == 1 ){
|
||||
// run standard tests on gfx90a
|
||||
echo "Run performance tests"
|
||||
sh "./run_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}"
|
||||
@@ -598,28 +610,28 @@ def Build_CK(Map conf=[:]){
|
||||
stash includes: "perf_**.log", name: "perf_log"
|
||||
}
|
||||
// disable performance tests on gfx1030 for now.
|
||||
//else if ( arch_type == 3){
|
||||
//else if ( arch == 3){
|
||||
// run basic tests on gfx1030
|
||||
// echo "Run gemm performance tests"
|
||||
// sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx10"
|
||||
// archiveArtifacts "perf_onnx_gemm_gfx10.log"
|
||||
// stash includes: "perf_onnx_gemm_gfx10.log", name: "perf_log_gfx10"
|
||||
//}
|
||||
else if ( arch_type == 4){
|
||||
else if ( arch == 4){
|
||||
// run basic tests on gfx11
|
||||
echo "Run gemm performance tests"
|
||||
sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx11"
|
||||
archiveArtifacts "perf_onnx_gemm_gfx11.log"
|
||||
stash includes: "perf_onnx_gemm_gfx11.log", name: "perf_log_gfx11"
|
||||
}
|
||||
else if ( arch_type == 5 ){
|
||||
else if ( arch == 5 ){
|
||||
// run basic tests on gfx12
|
||||
echo "Run gemm performance tests"
|
||||
sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx12"
|
||||
archiveArtifacts "perf_onnx_gemm_gfx12.log"
|
||||
stash includes: "perf_onnx_gemm_gfx12.log", name: "perf_log_gfx12"
|
||||
}
|
||||
else if ( arch_type == 6 ){
|
||||
else if ( arch == 6 ){
|
||||
// run basic tests on gfx908
|
||||
echo "Run performance tests"
|
||||
sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx908"
|
||||
@@ -628,7 +640,7 @@ def Build_CK(Map conf=[:]){
|
||||
}
|
||||
}
|
||||
}
|
||||
if (params.hipTensor_test && arch_type == 1 ){
|
||||
if (params.hipTensor_test && arch == 1 ){
|
||||
// build and test hipTensor on gfx90a node
|
||||
sh """#!/bin/bash
|
||||
rm -rf "${params.hipTensor_branch}".zip
|
||||
@@ -730,24 +742,10 @@ def process_results(Map conf=[:]){
|
||||
echo "could not locate the GEMM performance logs: ${err.getMessage()}."
|
||||
}
|
||||
}
|
||||
if (params.RUN_FULL_QA){
|
||||
// unstash perf files to master
|
||||
if (params.RUN_FULL_QA || params.BUILD_INSTANCES_ONLY){
|
||||
// unstash deb packages
|
||||
unstash "packages"
|
||||
sh "sshpass -p ${env.ck_deb_pw} scp -o StrictHostKeyChecking=no composablekernel-*.deb ${env.ck_deb_user}@${env.ck_deb_ip}:/var/www/html/composable_kernel/"
|
||||
try{
|
||||
unstash "perf_log"
|
||||
}
|
||||
catch(Exception err){
|
||||
echo "could not locate perf_log: ${err.getMessage()}."
|
||||
}
|
||||
try{
|
||||
unstash "perf_log_gfx11"
|
||||
unstash "perf_log_gfx12"
|
||||
}
|
||||
catch(Exception err){
|
||||
echo "could not locate the GEMM gfx11/gfx12 performance logs: ${err.getMessage()}."
|
||||
}
|
||||
sh "./process_qa_data.sh"
|
||||
}
|
||||
else{
|
||||
// unstash perf files to master
|
||||
@@ -775,12 +773,12 @@ def process_results(Map conf=[:]){
|
||||
}
|
||||
}
|
||||
|
||||
//launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version
|
||||
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;ROCMVERSION=6.4;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true
|
||||
0 21 * * * % ROCMVERSION=6.4;hipTensor_test=true;RUN_CODEGEN_TESTS=true;BUILD_GFX908=true
|
||||
//launch develop branch daily jobs
|
||||
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true
|
||||
0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;RUN_CODEGEN_TESTS=true;BUILD_GFX908=true
|
||||
0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
|
||||
0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
|
||||
0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_PERFORMANCE_TESTS=false;USE_SCCACHE=false
|
||||
0 15 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
|
||||
0 13 * * * % BUILD_LEGACY_OS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false''' : ""
|
||||
|
||||
pipeline {
|
||||
@@ -1263,8 +1261,7 @@ pipeline {
|
||||
execute_args = """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-D CMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D GPU_ARCHS="gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1151;gfx1201" \
|
||||
-D CMAKE_CXX_FLAGS=" -O3 " .. && ninja -j64 """
|
||||
-D CMAKE_CXX_FLAGS=" -O3 -ftime-trace" .. && ninja -j64 """
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
|
||||
@@ -222,6 +222,9 @@
|
||||
// TODO: separate index calculation into "compile-time", "global", "block", "wave", "thread"
|
||||
#define CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 0
|
||||
|
||||
// workaround: conv crash when K, C is even
|
||||
#define CK_WORKAROUND_DISABLE_FILTER1x1STRIDE1PAD0_WHEN_K_C_IS_EVEN 1
|
||||
|
||||
// workaround: compiler crash when compiling recursive lambda
|
||||
#define CK_WORKAROUND_SWDEV_275126 1
|
||||
|
||||
|
||||
@@ -381,7 +381,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
|
||||
static_for<0, DS_READ_A_PREFETCH_STAGES, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
// K = k0 × KGroup × k1 = k0 × kg0 × A_K1
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
|
||||
a_block_buf.At(I0),
|
||||
|
||||
@@ -860,35 +860,37 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
#endif
|
||||
if(desc.has_main_k_block_loop)
|
||||
{
|
||||
GridwiseGemm::template Run<true>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared_block,
|
||||
desc.a_element_op,
|
||||
desc.b_element_op,
|
||||
desc.cde_element_op,
|
||||
desc.a_grid_desc_ak0_m_ak1,
|
||||
desc.b_grid_desc_bk0_n_bk1,
|
||||
desc.ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
desc.e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
desc.block_2_etile_map);
|
||||
GridwiseGemm::template Run<true, InMemoryDataOperationEnum::Set>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared_block,
|
||||
desc.a_element_op,
|
||||
desc.b_element_op,
|
||||
desc.cde_element_op,
|
||||
desc.a_grid_desc_ak0_m_ak1,
|
||||
desc.b_grid_desc_bk0_n_bk1,
|
||||
desc.ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
desc.e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
desc.block_2_etile_map);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseGemm::template Run<false>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared_block,
|
||||
desc.a_element_op,
|
||||
desc.b_element_op,
|
||||
desc.cde_element_op,
|
||||
desc.a_grid_desc_ak0_m_ak1,
|
||||
desc.b_grid_desc_bk0_n_bk1,
|
||||
desc.ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
desc.e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
desc.block_2_etile_map);
|
||||
GridwiseGemm::template Run<false, InMemoryDataOperationEnum::Set>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared_block,
|
||||
desc.a_element_op,
|
||||
desc.b_element_op,
|
||||
desc.cde_element_op,
|
||||
desc.a_grid_desc_ak0_m_ak1,
|
||||
desc.b_grid_desc_bk0_n_bk1,
|
||||
desc.ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
desc.e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
desc.block_2_etile_map);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1206,6 +1206,13 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
if constexpr(ConvBackwardWeightSpecialization ==
|
||||
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// workaround: disable when K, C is even
|
||||
#if CK_WORKAROUND_DISABLE_FILTER1x1STRIDE1PAD0_WHEN_K_C_IS_EVEN
|
||||
if(arg.Conv_C_ % 2 == 0 || arg.Conv_K_ % 2 == 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
// check if it's 1x1, stride=1 pad = 0 conv
|
||||
for(int i = 0; i < NDimSpatial; i++)
|
||||
{
|
||||
|
||||
@@ -841,14 +841,21 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
k_ptr,
|
||||
// make_tuple(kargs.num_total_pages * kargs.page_block_size, kargs.hdim_q),
|
||||
make_tuple(kargs.num_total_pages, kargs.hdim_q),
|
||||
make_tuple(kargs.stride_k, 1),
|
||||
make_tuple(kargs.num_total_pages / kargs.page_block_size, kargs.hdim_q / 8, kargs.page_block_size, 8),
|
||||
make_tuple(kargs.stride_k * kargs.page_block_size, kargs.page_block_size * 8, 8, 1),
|
||||
number<FmhaPipeline::kAlignmentK>{},
|
||||
number<1>{});
|
||||
|
||||
const auto k_dram_transposed = transform_tensor_view(
|
||||
k_dram_naive,
|
||||
make_tuple(make_merge_transform(ck_tile::make_tuple(kargs.seqlen_k /16, 16)),
|
||||
make_merge_transform(ck_tile::make_tuple(kargs.hdim_q / 8, 8))),
|
||||
ck_tile::make_tuple(ck_tile::sequence<0, 2>{}, ck_tile::sequence<1, 3>{}),
|
||||
ck_tile::make_tuple(ck_tile::sequence<0>{}, ck_tile::sequence<1>{}));
|
||||
|
||||
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
|
||||
return pad_tensor_view(
|
||||
k_dram_naive,
|
||||
k_dram_transposed,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
|
||||
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
|
||||
}();
|
||||
@@ -858,19 +865,19 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
v_ptr,
|
||||
// make_tuple(kargs.num_total_pages * kargs.page_block_size, kargs.hdim_v),
|
||||
make_tuple(kargs.num_total_pages, kargs.hdim_v),
|
||||
make_tuple(kargs.stride_v, 1),
|
||||
number<FmhaPipeline::kAlignmentV>{},
|
||||
make_tuple(kargs.num_total_pages / kargs.page_block_size, kargs.hdim_v, kargs.page_block_size),
|
||||
make_tuple(kargs.stride_k * kargs.page_block_size, kargs.page_block_size, 1),
|
||||
number<FmhaPipeline::kAlignmentK>{},
|
||||
number<1>{});
|
||||
|
||||
const auto v_dram_transposed = transform_tensor_view(
|
||||
v_dram_naive,
|
||||
make_tuple(
|
||||
make_pass_through_transform(kargs.hdim_v),
|
||||
// make_pass_through_transform(kargs.num_total_pages * kargs.page_block_size)),
|
||||
make_pass_through_transform(kargs.num_total_pages)),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
const auto v_dram_transposed =
|
||||
transform_tensor_view(v_dram_naive,
|
||||
make_tuple(make_pass_through_transform(kargs.hdim_v),
|
||||
make_merge_transform(
|
||||
ck_tile::make_tuple(kargs.num_total_pages / kargs.page_block_size,
|
||||
kargs.page_block_size))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
|
||||
return pad_tensor_view(
|
||||
|
||||
@@ -613,6 +613,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
@@ -623,6 +624,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -638,6 +640,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1p0_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
@@ -648,6 +651,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -665,6 +669,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
@@ -676,6 +681,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instances(op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -693,6 +699,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
@@ -704,6 +711,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instances(op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -51,6 +51,20 @@ void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
GNHWC,
|
||||
@@ -93,6 +107,20 @@ void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
GNDHWC,
|
||||
@@ -135,6 +163,20 @@ void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
@@ -177,6 +219,19 @@ void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
@@ -236,6 +291,20 @@ void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
GNDHWC,
|
||||
@@ -278,6 +347,20 @@ void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instanc
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
@@ -320,6 +403,20 @@ void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instanc
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
GNHWC,
|
||||
@@ -361,6 +458,20 @@ void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -103,7 +103,7 @@ function(add_instance_library INSTANCE_NAME)
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
message("remaining instances: ${ARGN}")
|
||||
#message("remaining instances: ${ARGN}")
|
||||
#only continue if there are some source files left on the list
|
||||
if(ARGN)
|
||||
set(INST_OBJ)
|
||||
|
||||
@@ -93,6 +93,8 @@ add_instance_library(device_grouped_conv2d_fwd_instance
|
||||
wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instance.cpp
|
||||
wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp
|
||||
wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instance.cpp
|
||||
wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instance.cpp
|
||||
wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instance.cpp
|
||||
## NHWGC, GKYXC, NHWGK
|
||||
wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
||||
wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instance.cpp
|
||||
@@ -100,4 +102,6 @@ add_instance_library(device_grouped_conv2d_fwd_instance
|
||||
wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instance.cpp
|
||||
wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp
|
||||
wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instance.cpp
|
||||
wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instance.cpp
|
||||
wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instance.cpp
|
||||
)
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k]
|
||||
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_wmma_f16_instances<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
Empty_Tuple,
|
||||
PassThrough,
|
||||
ConvFwdOddC>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k]
|
||||
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_wmma_i8_instances<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
Empty_Tuple,
|
||||
PassThrough,
|
||||
ConvFwdOddC>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_wmma_f16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
Empty_Tuple,
|
||||
PassThrough,
|
||||
ConvFwdOddC>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_wmma_i8_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
Empty_Tuple,
|
||||
PassThrough,
|
||||
ConvFwdOddC>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -52,6 +52,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances(
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwdOddC>{});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -49,6 +49,14 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances(
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwdOddC>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -52,6 +52,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instanc
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwdOddC>{});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -52,6 +52,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instances(
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f16_comp_instances_2x<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwdOddC>{});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -49,6 +49,14 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances(
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f16_comp_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwdOddC>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -52,6 +52,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instance
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_f16_comp_instances_part2<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwdOddC>{});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
|
||||
@@ -48,6 +48,14 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances(
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_comp_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwdOddC>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
|
||||
@@ -50,6 +50,14 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instances(
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_int8_comp_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwdOddC>{});
|
||||
|
||||
if(ck::get_device_name() != "gfx950")
|
||||
{
|
||||
add_device_operation_instances(
|
||||
@@ -78,6 +86,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instances(
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_int8_comp_instances_part2<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwdOddC>{});
|
||||
}
|
||||
|
||||
if(ck::get_device_name() == "gfx950")
|
||||
@@ -108,6 +125,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instances(
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_int8_comp_instances_2x<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwdOddC>{});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
@@ -46,6 +46,14 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
ConvFwd1x1S1P0>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_instances<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
ConvFwdOddC>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
@@ -46,6 +46,14 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
ConvFwd1x1S1P0>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f16_instances<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
ConvFwdOddC>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
@@ -46,6 +46,14 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
ConvFwd1x1S1P0>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_instances<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
ConvFwdOddC>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
@@ -46,6 +46,14 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwdOddC>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
@@ -46,6 +46,14 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwdOddC>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
@@ -46,6 +46,14 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwdOddC>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp"
|
||||
@@ -46,6 +46,14 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_instances(
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_int8_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwdOddC>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
@@ -49,6 +49,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Interwave>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwdOddC,
|
||||
Interwave>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
@@ -49,6 +49,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Intrawave>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwdOddC,
|
||||
Intrawave>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
@@ -49,6 +49,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instances
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Interwave>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f16_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwdOddC,
|
||||
Interwave>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
@@ -49,6 +49,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Intrawave>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f16_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwdOddC,
|
||||
Intrawave>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
@@ -49,6 +49,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Interwave>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwdOddC,
|
||||
Interwave>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
@@ -49,6 +49,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Intrawave>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_f32_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwdOddC,
|
||||
Intrawave>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
@@ -49,6 +49,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Interwave>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_int8_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwdOddC,
|
||||
Interwave>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp"
|
||||
@@ -49,6 +49,15 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Intrawave>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_int8_mem_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
ConvFwdOddC,
|
||||
Intrawave>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -66,6 +66,10 @@ set(GROUPED_CONV3D_FWD
|
||||
wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp
|
||||
wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp
|
||||
wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp
|
||||
wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instance.cpp
|
||||
wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instance.cpp
|
||||
wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp
|
||||
wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp
|
||||
)
|
||||
|
||||
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho,
|
||||
// wo, k]
|
||||
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_wmma_f16_instances<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWK,
|
||||
Empty_Tuple,
|
||||
PassThrough,
|
||||
ConvFwdOddC>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho,
|
||||
// wo, k]
|
||||
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_wmma_i8_instances<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWK,
|
||||
Empty_Tuple,
|
||||
PassThrough,
|
||||
ConvFwdOddC>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = out[n, do, ho, wo,
|
||||
// g, k]
|
||||
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_wmma_f16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
Empty_Tuple,
|
||||
PassThrough,
|
||||
ConvFwdOddC>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = out[n, do, ho, wo,
|
||||
// g, k]
|
||||
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_wmma_i8_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
Empty_Tuple,
|
||||
PassThrough,
|
||||
ConvFwdOddC>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -188,6 +188,7 @@ TYPED_TEST(TestGroupedConvndBwdWeight1d, Test1D)
|
||||
TYPED_TEST(TestGroupedConvndBwdWeight2d, Test2D)
|
||||
{
|
||||
this->conv_params.clear();
|
||||
this->conv_params.push_back({2, 2, 64, 4, 4, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->conv_params.push_back(
|
||||
{2, 2, 64, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->conv_params.push_back({2, 2, 64, 3, 3, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
|
||||
|
||||
Reference in New Issue
Block a user