mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
merge develop
This commit is contained in:
@@ -4,11 +4,11 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
|
||||
## Composable Kernel 1.1.0 for ROCm 6.5.0
|
||||
|
||||
### Additions
|
||||
### Added
|
||||
|
||||
None
|
||||
* Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data
|
||||
|
||||
### Optimizations
|
||||
### Optimized
|
||||
|
||||
None
|
||||
|
||||
|
||||
59
Jenkinsfile
vendored
59
Jenkinsfile
vendored
@@ -229,8 +229,11 @@ def cmake_build(Map conf=[:]){
|
||||
if (setup_args.contains("gfx10")){
|
||||
invocation_tag="gfx10"
|
||||
}
|
||||
if (setup_args.contains("gfx90")){
|
||||
invocation_tag="gfx90"
|
||||
if (setup_args.contains("gfx908")){
|
||||
invocation_tag="gfx908"
|
||||
}
|
||||
if (setup_args.contains("gfx90a")){
|
||||
invocation_tag="gfx90a"
|
||||
}
|
||||
if (setup_args.contains("gfx94")){
|
||||
invocation_tag="gfx94"
|
||||
@@ -314,9 +317,13 @@ def cmake_build(Map conf=[:]){
|
||||
if (setup_args.contains("gfx90a") && params.NINJA_BUILD_TRACE){
|
||||
sh "/ninjatracing/ninjatracing .ninja_log > ck_build_trace.json"
|
||||
archiveArtifacts "ck_build_trace.json"
|
||||
sh "ninja test"
|
||||
// do not run unit tests when building instances only
|
||||
if(!params.BUILD_INSTANCES_ONLY){
|
||||
sh "ninja test"
|
||||
}
|
||||
}
|
||||
else{
|
||||
// run unit tests
|
||||
sh "make check"
|
||||
}
|
||||
}
|
||||
@@ -511,6 +518,9 @@ def Build_CK(Map conf=[:]){
|
||||
else if ( runShell('grep -n "gfx1201" rocminfo.log') ) {
|
||||
arch_type = 5
|
||||
}
|
||||
else if ( runShell('grep -n "gfx908" rocminfo.log') ) {
|
||||
arch_type = 6
|
||||
}
|
||||
cmake_build(conf)
|
||||
if ( !params.BUILD_LEGACY_OS && arch_type == 1 ){
|
||||
echo "Run inductor codegen tests"
|
||||
@@ -582,7 +592,17 @@ def Build_CK(Map conf=[:]){
|
||||
sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx12"
|
||||
archiveArtifacts "perf_onnx_gemm_gfx12.log"
|
||||
stash includes: "perf_onnx_gemm_gfx12.log", name: "perf_log_gfx12"
|
||||
}
|
||||
}
|
||||
else if ( arch_type == 6 ){
|
||||
// run standard tests on gfx908
|
||||
echo "Run performance tests"
|
||||
sh "./run_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}"
|
||||
archiveArtifacts "perf_gemm_gfx908.log"
|
||||
archiveArtifacts "perf_onnx_gemm_gfx908.log"
|
||||
archiveArtifacts "perf_resnet50_N256_gfx908.log"
|
||||
archiveArtifacts "perf_resnet50_N4_gfx908.log"
|
||||
stash includes: "perf_**.log", name: "perf_log_gfx908"
|
||||
}
|
||||
}
|
||||
}
|
||||
if (params.hipTensor_test && arch_type == 1 ){
|
||||
@@ -718,11 +738,12 @@ def process_results(Map conf=[:]){
|
||||
|
||||
//launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version
|
||||
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;ROCMVERSION=6.3;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true
|
||||
0 22 * * * % ROCMVERSION=6.3;BUILD_GFX908=true;BUILD_GFX12=false;RUN_PERFORMANCE_TESTS=false
|
||||
0 21 * * * % ROCMVERSION=6.3;hipTensor_test=true;RUN_CODEGEN_TESTS=true
|
||||
0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
|
||||
0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
|
||||
0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_PERFORMANCE_TESTS=false;USE_SCCACHE=false
|
||||
0 13 * * * % BUILD_LEGACY_OS=true''' : ""
|
||||
0 13 * * * % BUILD_LEGACY_OS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false''' : ""
|
||||
|
||||
pipeline {
|
||||
agent none
|
||||
@@ -805,6 +826,10 @@ pipeline {
|
||||
name: "BUILD_INSTANCES_ONLY",
|
||||
defaultValue: false,
|
||||
description: "Test building instances for various architectures simultaneously (default: OFF)")
|
||||
booleanParam(
|
||||
name: "BUILD_GFX908",
|
||||
defaultValue: false,
|
||||
description: "Build CK and run tests on gfx908 (default: OFF)")
|
||||
booleanParam(
|
||||
name: "BUILD_GFX12",
|
||||
defaultValue: true,
|
||||
@@ -1002,7 +1027,7 @@ pipeline {
|
||||
environment{
|
||||
setup_args = "NO_CK_BUILD"
|
||||
execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \
|
||||
make -j64 tile_example_gemm_basic tile_example_gemm_universal && \
|
||||
make -j64 tile_example_gemm_universal && \
|
||||
cd ../ &&
|
||||
example/ck_tile/03_gemm/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx90a """
|
||||
}
|
||||
@@ -1021,7 +1046,7 @@ pipeline {
|
||||
environment{
|
||||
setup_args = "NO_CK_BUILD"
|
||||
execute_args = """ ../script/cmake-ck-dev.sh ../ gfx942 && \
|
||||
make -j64 tile_example_gemm_basic tile_example_gemm_universal && \
|
||||
make -j64 tile_example_gemm_universal && \
|
||||
cd ../ &&
|
||||
example/ck_tile/03_gemm/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx942 """
|
||||
}
|
||||
@@ -1117,6 +1142,26 @@ pipeline {
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
stage("Build CK and run Tests on gfx908")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { params.BUILD_GFX908.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx908") }
|
||||
environment{
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908" -DCMAKE_CXX_FLAGS=" -O3 " """
|
||||
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
|
||||
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
|
||||
-DGPU_TARGETS="gfx908" \
|
||||
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
|
||||
}
|
||||
steps{
|
||||
Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
stage("Build CK and run Tests on gfx90a")
|
||||
{
|
||||
when {
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
add_executable(client_grouped_conv2d_bwd_data grouped_conv2d_bwd_data.cpp)
|
||||
target_link_libraries(client_grouped_conv2d_bwd_data PRIVATE composable_kernel::device_conv_operations)
|
||||
|
||||
add_executable(client_grouped_conv2d_bwd_data_ngchw grouped_conv2d_bwd_data_ngchw.cpp)
|
||||
target_link_libraries(client_grouped_conv2d_bwd_data_ngchw PRIVATE composable_kernel::device_conv_operations)
|
||||
|
||||
add_executable(client_grouped_conv3d_bwd_data grouped_conv3d_bwd_data.cpp)
|
||||
target_link_libraries(client_grouped_conv3d_bwd_data PRIVATE composable_kernel::device_conv_operations)
|
||||
|
||||
|
||||
@@ -31,9 +31,9 @@ Table of supported cases by instance factory with XDL instruction:
|
||||
|
||||
| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK|
|
||||
|-------|---|---|---|
|
||||
|bf16|2D, 3D|✗|2D, 3D|
|
||||
|fp16 |2D, 3D|✗|2D, 3D|
|
||||
|fp32 |2D, 3D|✗|2D, 3D|
|
||||
|bf16|2D, 3D|2D, 3D|2D, 3D|
|
||||
|fp16 |2D, 3D|2D, 3D|2D, 3D|
|
||||
|fp32 |2D, 3D|2D, 3D|2D, 3D|
|
||||
|
||||
Table of supported cases by instance factory with WMMA instruction:
|
||||
|
||||
|
||||
@@ -0,0 +1,205 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
using InDataType = ck::half_t;
|
||||
using WeiDataType = ck::half_t;
|
||||
using OutDataType = ck::half_t;
|
||||
|
||||
using InLayout = ck::tensor_layout::convolution::NGCHW;
|
||||
using WeiLayout = ck::tensor_layout::convolution::GKYXC;
|
||||
using OutLayout = ck::tensor_layout::convolution::NGKHW;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr ck::index_t NumDimSpatial = 2;
|
||||
static constexpr ck::index_t G = 32;
|
||||
static constexpr ck::index_t N = 256;
|
||||
static constexpr ck::index_t K = 192;
|
||||
static constexpr ck::index_t C = 192;
|
||||
static constexpr ck::index_t Y = 3;
|
||||
static constexpr ck::index_t X = 3;
|
||||
static constexpr ck::index_t Hi = 28;
|
||||
static constexpr ck::index_t Wi = 28;
|
||||
static constexpr ck::index_t Ho = 28;
|
||||
static constexpr ck::index_t Wo = 28;
|
||||
|
||||
struct SimpleDeviceMem
|
||||
{
|
||||
SimpleDeviceMem() = delete;
|
||||
|
||||
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
|
||||
{
|
||||
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
|
||||
}
|
||||
|
||||
void* GetDeviceBuffer() { return p_mem_; }
|
||||
|
||||
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
|
||||
|
||||
void* p_mem_;
|
||||
};
|
||||
|
||||
int main()
|
||||
{
|
||||
std::array<ck::index_t, NumDimSpatial + 3> in_lengths{G, N, Hi, Wi, C};
|
||||
std::array<ck::index_t, NumDimSpatial + 3> in_strides{
|
||||
C * Hi * Wi, G * C * Hi * Wi, Wi, 1, Hi * Wi};
|
||||
|
||||
std::array<ck::index_t, NumDimSpatial + 3> wei_lengths{G, K, Y, X, C};
|
||||
std::array<ck::index_t, NumDimSpatial + 3> wei_strides{K * Y * X * C, Y * X * C, X * C, C, 1};
|
||||
|
||||
std::array<ck::index_t, NumDimSpatial + 3> out_lengths{G, N, Ho, Wo, K};
|
||||
std::array<ck::index_t, NumDimSpatial + 3> out_strides{
|
||||
K * Ho * Wo, G * K * Ho * Wo, Wo, 1, Ho * Wo};
|
||||
|
||||
std::array<ck::index_t, NumDimSpatial> filter_strides{1, 1};
|
||||
std::array<ck::index_t, NumDimSpatial> filter_dilations{1, 1};
|
||||
std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1};
|
||||
std::array<ck::index_t, NumDimSpatial> input_right_pads{1, 1};
|
||||
|
||||
SimpleDeviceMem in(sizeof(InDataType) * G * N * Hi * Wi * C);
|
||||
SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C);
|
||||
SimpleDeviceMem out(sizeof(OutDataType) * G * N * Ho * Wo * K);
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD<NumDimSpatial,
|
||||
OutLayout,
|
||||
WeiLayout,
|
||||
ck::Tuple<>,
|
||||
InLayout,
|
||||
OutDataType,
|
||||
WeiDataType,
|
||||
ck::Tuple<>,
|
||||
InDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
std::string best_op_name;
|
||||
int best_op_id = -1;
|
||||
float best_avg_time = std::numeric_limits<float>::max();
|
||||
float best_gb_per_sec = 0;
|
||||
float best_tflops = 0;
|
||||
|
||||
// profile device operation instances
|
||||
std::cout << "Run all instances and do timing" << std::endl;
|
||||
|
||||
for(int i = 0; i < op_ptrs.size(); ++i)
|
||||
{
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
{},
|
||||
in.GetDeviceBuffer(),
|
||||
out_lengths,
|
||||
out_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{},
|
||||
{},
|
||||
in_lengths,
|
||||
in_strides,
|
||||
filter_strides,
|
||||
filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{});
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
|
||||
SimpleDeviceMem workspace_dev(workspace_sz);
|
||||
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
|
||||
|
||||
std::size_t flop = std::size_t(2) * G * N * K * C * Ho * Wo * Y * X;
|
||||
std::size_t num_bytes = sizeof(InDataType) * G * N * Hi * Wi * C +
|
||||
sizeof(WeiDataType) * G * K * Y * X * C +
|
||||
sizeof(OutDataType) * G * N * Ho * Wo * K;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
float gb_per_sec = num_bytes / 1.E6 / avg_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
best_op_id = i;
|
||||
best_op_name = op_name;
|
||||
best_avg_time = avg_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
best_tflops = tflops;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << op_name << " does not support this problem" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
if(best_op_id < 0)
|
||||
{
|
||||
std::cerr << "no suitable instance" << std::endl;
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops
|
||||
<< " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
|
||||
|
||||
// run the best intance
|
||||
{
|
||||
auto& op_ptr = op_ptrs[best_op_id];
|
||||
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
|
||||
<< std::endl;
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
{},
|
||||
in.GetDeviceBuffer(),
|
||||
out_lengths,
|
||||
out_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{},
|
||||
{},
|
||||
in_lengths,
|
||||
in_strides,
|
||||
filter_strides,
|
||||
filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{});
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
|
||||
}
|
||||
|
||||
std::cout << "Done" << std::endl;
|
||||
}
|
||||
}
|
||||
@@ -1,2 +1,2 @@
|
||||
rocm-docs-core==1.17.1
|
||||
rocm-docs-core==1.18.1
|
||||
sphinxcontrib-bibtex==2.6.3
|
||||
|
||||
@@ -199,7 +199,7 @@ requests==2.32.3
|
||||
# via
|
||||
# pygithub
|
||||
# sphinx
|
||||
rocm-docs-core==1.17.1
|
||||
rocm-docs-core==1.18.1
|
||||
# via -r requirements.in
|
||||
rpds-py==0.22.3
|
||||
# via
|
||||
|
||||
@@ -46,9 +46,6 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
add_example_executable(example_gemm_xdl_bf16_streamk_v3 gemm_xdl_bf16_streamk_v3.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_streamk_v3)
|
||||
|
||||
add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)
|
||||
|
||||
@@ -80,6 +77,12 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
|
||||
add_example_executable(example_gemm_xdl_lds_direct_load_fp16 gemm_xdl_lds_direct_load_fp16.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp16)
|
||||
|
||||
add_example_executable(example_gemm_xdl_bf16_streamk_v3 gemm_xdl_bf16_streamk_v3.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_streamk_v3)
|
||||
|
||||
add_example_executable(example_gemm_xdl_fp8_streamk_v3 gemm_xdl_fp8_streamk_v3.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_streamk_v3)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
@@ -90,9 +93,6 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8)
|
||||
add_example_executable(example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8)
|
||||
|
||||
add_example_executable(example_gemm_xdl_fp8_streamk_v3 gemm_xdl_fp8_streamk_v3.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_streamk_v3)
|
||||
|
||||
add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8)
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp8 convnd_fwd_xdl_fp8.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_bf8 convnd_fwd_xdl_bf8.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp16_comp_fp8 convnd_fwd_xdl_fp16_comp_fp8.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp8_bf8 convnd_fwd_xdl_fp8_bf8.cpp)
|
||||
@@ -11,3 +10,13 @@ add_example_executable(example_convnd_fwd_xdl_bf8_fp8 convnd_fwd_xdl_bf8_fp8.cpp
|
||||
add_example_executable(example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp)
|
||||
add_example_executable(example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp)
|
||||
add_example_executable(example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp)
|
||||
|
||||
# only build fp64 example for the following targets
|
||||
list(APPEND gpu_list gfx90a gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -173,8 +173,10 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
|
||||
std::size_t workspace_size = gemm.GetWorkSpaceSize(&argument);
|
||||
std::size_t kargs_size = gemm.GetDeviceKernelArgSize(&argument);
|
||||
std::size_t hargs_size = gemm.GetHostKernelArgSize(&argument);
|
||||
|
||||
DeviceMem gemm_workspace, gemm_kargs;
|
||||
void* gemm_hargs;
|
||||
|
||||
// The following is necessary since TwoStage kernel is using additional memory both
|
||||
// for Workspace and kernel arguments.
|
||||
@@ -188,6 +190,11 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
gemm_workspace.Realloc(workspace_size);
|
||||
gemm.SetWorkSpacePointer(&argument, gemm_workspace.GetDeviceBuffer());
|
||||
}
|
||||
if(hargs_size > 0)
|
||||
{
|
||||
hip_check_error(hipHostMalloc(&gemm_hargs, hargs_size));
|
||||
gemm.SetHostKernelArgs(&argument, gemm_hargs);
|
||||
}
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
@@ -196,7 +203,16 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
invoker.Run(argument, StreamConfig{nullptr, false});
|
||||
hipStream_t stream0 = nullptr;
|
||||
hip_check_error(hipStreamCreate(&stream0));
|
||||
|
||||
hipEvent_t event0 = nullptr;
|
||||
hip_check_error(hipEventCreate(&event0));
|
||||
|
||||
invoker.Run(argument, StreamConfig{nullptr, false}, stream0, event0);
|
||||
|
||||
hip_check_error(hipEventSynchronize(event0));
|
||||
hip_check_error(hipStreamSynchronize(stream0));
|
||||
|
||||
bool pass = true;
|
||||
if(config.do_verification)
|
||||
|
||||
@@ -104,6 +104,13 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
# Do not build gemm_universal_f8 or gemm_multiply_multiply_f8 for any targets except gfx94
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT EX_TARGETS MATCHES "gfx94" AND NOT EX_TARGETS MATCHES "gfx95" AND source MATCHES "gemm_multiply_multiply_xdl_fp8_bpreshuffle")
|
||||
message("Skipping ${source} example for current target")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#only continue if there are some source files left on the list
|
||||
if(FILE_NAME)
|
||||
if(FILE_NAME MATCHES "_xdl")
|
||||
|
||||
@@ -118,7 +118,7 @@ FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
|
||||
{F_hdim_case}
|
||||
}}
|
||||
"""
|
||||
FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{
|
||||
FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
|
||||
{F_inner_dispatch}
|
||||
}}
|
||||
"""
|
||||
@@ -288,7 +288,7 @@ class FmhaFwdApiPool:
|
||||
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
if not per_dtypes:
|
||||
@@ -417,6 +417,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
@@ -489,6 +490,10 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
|
||||
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
continue
|
||||
if hdim == 192 and tile.F_bn1 == 128:
|
||||
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
|
||||
if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't' or (pipeline.F_mask not in ['no', 's_no']):
|
||||
continue
|
||||
k = FmhaFwdKernel(F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
|
||||
@@ -181,7 +181,7 @@ class FmhaFwdAppendKVApiPool:
|
||||
F_pagedkv=BOOL_MAP[trait.pagedkv], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_API.format(F_dispatch = per_dtypes)
|
||||
|
||||
@@ -476,7 +476,7 @@ class FmhaFwdSplitKVApiPool:
|
||||
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
if not per_dtypes:
|
||||
|
||||
@@ -72,8 +72,14 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
|
||||
|
||||
float r = ck_tile::launch_kernel(
|
||||
s,
|
||||
[=, &r0](const ck_tile::stream_config&) { r0 = fused_moesorting(t0, a0, s_sub); },
|
||||
[=, &r1](const ck_tile::stream_config&) { r1 = fused_moegemm(t1, a1, s_sub); });
|
||||
[=, &r0](const ck_tile::stream_config&) {
|
||||
r0 = fused_moesorting(t0, a0, s_sub);
|
||||
return hipPeekAtLastError() == hipSuccess;
|
||||
},
|
||||
[=, &r1](const ck_tile::stream_config&) {
|
||||
r1 = fused_moegemm(t1, a1, s_sub);
|
||||
return hipPeekAtLastError() == hipSuccess;
|
||||
});
|
||||
|
||||
// keep unsupported case return negative
|
||||
if(r0 < 0 || r1 < 0)
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/library/utility/numeric.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
@@ -13,7 +14,9 @@
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
@@ -202,9 +205,11 @@ template <index_t NDimSpatial,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
typename AComputeType = ADataType,
|
||||
typename BComputeType = AComputeType>
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
typename AComputeType = ADataType,
|
||||
typename BComputeType = AComputeType,
|
||||
index_t MaxTransposeTransferInScalarPerVector = 1,
|
||||
index_t MaxTransposeTransferOutScalarPerVector = 1>
|
||||
struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
: public DeviceGroupedConvBwdDataMultipleD<NDimSpatial,
|
||||
ALayout, // output image
|
||||
@@ -237,6 +242,19 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
using ALayoutAfterTranspose =
|
||||
std::conditional_t<is_NGCHW_GKYXC_NGKHW<ELayout, BLayout, ALayout>(),
|
||||
tensor_layout::convolution::NHWGK,
|
||||
std::conditional_t<is_NGCDHW_GKZYXC_NGKDHW<ELayout, BLayout, ALayout>(),
|
||||
tensor_layout::convolution::NDHWGK,
|
||||
ALayout>>;
|
||||
using ELayoutAfterTranspose =
|
||||
std::conditional_t<is_NGCHW_GKYXC_NGKHW<ELayout, BLayout, ALayout>(),
|
||||
tensor_layout::convolution::NHWGC,
|
||||
std::conditional_t<is_NGCDHW_GKZYXC_NGKDHW<ELayout, BLayout, ALayout>(),
|
||||
tensor_layout::convolution::NDHWGC,
|
||||
ELayout>>;
|
||||
|
||||
using ConvToGemmBwdDataTransform = TransformConvBwdDataToGemm_v1<NDimSpatial,
|
||||
ConvBackwardDataSpecialization,
|
||||
AK1,
|
||||
@@ -246,9 +264,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
KPerBlock,
|
||||
DoPadGemmM,
|
||||
DoPadGemmN,
|
||||
ALayout,
|
||||
ALayoutAfterTranspose,
|
||||
BLayout,
|
||||
ELayout,
|
||||
ELayoutAfterTranspose,
|
||||
true, /*SplitConvN*/
|
||||
ABDataType,
|
||||
EDataType>;
|
||||
@@ -274,7 +292,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
KPerBlock,
|
||||
DoPadGemmM,
|
||||
DoPadGemmN,
|
||||
ALayout,
|
||||
ALayoutAfterTranspose,
|
||||
BLayout,
|
||||
DLayout,
|
||||
true, /*SplitConvN*/
|
||||
@@ -374,7 +392,70 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
// block-to-e-tile map
|
||||
using Block2ETileMap =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
|
||||
using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<NPerBlock, MPerBlock>;
|
||||
|
||||
static constexpr index_t ClusterLengthMPerBlock =
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
|
||||
static constexpr index_t ClusterLengthNPerBlock =
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
|
||||
|
||||
static constexpr auto conv_ngchw_to_nhwgc_transformer =
|
||||
TransformConvNGCHWToNHWGC<ELayout,
|
||||
BLayout,
|
||||
ALayout,
|
||||
NDimSpatial,
|
||||
NPerBlock / ClusterLengthNPerBlock,
|
||||
MPerBlock / ClusterLengthMPerBlock>{};
|
||||
|
||||
static constexpr index_t TransposeTransferInScalarPerVectorAligned =
|
||||
std::min(MPerBlock / ClusterLengthMPerBlock, MaxTransposeTransferInScalarPerVector);
|
||||
static constexpr index_t TransposeTransferOutScalarPerVectorAligned =
|
||||
std::min(MPerBlock / ClusterLengthMPerBlock, MaxTransposeTransferOutScalarPerVector);
|
||||
|
||||
using NGCHWTransposeDescType =
|
||||
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
|
||||
.template MakeNGCHWTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
using NHWGCTransposeDescType =
|
||||
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
|
||||
.template MakeNHWGCTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
|
||||
static constexpr index_t ElementwiseBlocksize = ClusterLengthMPerBlock * ClusterLengthNPerBlock;
|
||||
|
||||
using GridwiseElementwiseInputTranspose =
|
||||
GridwiseElementwise<Tuple<NGCHWTransposeDescType>,
|
||||
Tuple<NHWGCTransposeDescType>,
|
||||
Tuple<const ADataType*>,
|
||||
Tuple<ADataType*>,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough,
|
||||
ElementwiseBlocksize,
|
||||
NPerBlock,
|
||||
MPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock,
|
||||
MPerBlock / ClusterLengthMPerBlock,
|
||||
Sequence<1, 0>,
|
||||
Sequence<TransposeTransferInScalarPerVectorAligned>,
|
||||
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
|
||||
I1,
|
||||
I0>;
|
||||
|
||||
using GridwiseElementwiseOutputTranspose =
|
||||
GridwiseElementwise<Tuple<NHWGCTransposeDescType>,
|
||||
Tuple<NGCHWTransposeDescType>,
|
||||
Tuple<const EDataType*>,
|
||||
Tuple<EDataType*>,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough,
|
||||
ElementwiseBlocksize,
|
||||
NPerBlock,
|
||||
MPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock,
|
||||
MPerBlock / ClusterLengthMPerBlock,
|
||||
Sequence<1, 0>,
|
||||
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
|
||||
Sequence<TransposeTransferOutScalarPerVectorAligned>,
|
||||
I0,
|
||||
I1>;
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
@@ -409,10 +490,18 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
cde_element_op_{cde_element_op},
|
||||
a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths},
|
||||
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
|
||||
e_g_n_c_wis_lengths_{e_g_n_c_wis_lengths},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeStrides(a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides);
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_strides_transposed =
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeStrides(e_g_n_c_wis_lengths,
|
||||
e_g_n_c_wis_strides);
|
||||
|
||||
// populate Ds pointer
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
@@ -491,17 +580,18 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
throw std::runtime_error("wrong! only implemented for 2D and 3D now");
|
||||
}
|
||||
|
||||
ConvToGemmBwdDataTransform conv_to_gemm_transform_{a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_c_wis_lengths,
|
||||
e_g_n_c_wis_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
tildes};
|
||||
ConvToGemmBwdDataTransform conv_to_gemm_transform_{
|
||||
a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides_transposed,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_c_wis_lengths,
|
||||
e_g_n_c_wis_strides_transposed,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
tildes};
|
||||
|
||||
conv_N_per_block_ = conv_to_gemm_transform_.N_;
|
||||
|
||||
@@ -527,7 +617,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
KPerBlock,
|
||||
DoPadGemmM,
|
||||
DoPadGemmN,
|
||||
ALayout,
|
||||
ALayoutAfterTranspose,
|
||||
BLayout,
|
||||
DLayout,
|
||||
true, /*SplitConvN*/
|
||||
@@ -535,7 +625,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
DDataType>;
|
||||
ConvToGemmBwdDataTransformD conv_to_gemm_transform_d{
|
||||
a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
a_g_n_k_wos_strides_transposed,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
ds_g_n_c_wis_lengths[i],
|
||||
@@ -591,12 +681,73 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
}
|
||||
}
|
||||
// A/B/Ds/E Batch Stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides_transposed[0];
|
||||
|
||||
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_k_wos_strides[1] * conv_N_per_block_;
|
||||
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_c_wis_strides[1] * conv_N_per_block_;
|
||||
compute_ptr_offset_of_n_.BatchStrideA_ =
|
||||
a_g_n_k_wos_strides_transposed[1] * conv_N_per_block_;
|
||||
compute_ptr_offset_of_n_.BatchStrideE_ =
|
||||
e_g_n_c_wis_strides_transposed[1] * conv_N_per_block_;
|
||||
|
||||
num_workgroups_per_Conv_N_ = a_g_n_k_wos_lengths_[I1] / conv_N_per_block_;
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ELayout, BLayout, ALayout>())
|
||||
{
|
||||
// Use not modified base strides
|
||||
a_in_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
|
||||
a_g_n_k_wos_lengths, a_g_n_k_wos_strides, num_workgroups_per_Conv_N_);
|
||||
a_out_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
|
||||
a_g_n_k_wos_lengths, a_g_n_k_wos_strides, num_workgroups_per_Conv_N_);
|
||||
|
||||
e_in_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
|
||||
e_g_n_c_wis_lengths, e_g_n_c_wis_strides, num_workgroups_per_Conv_N_);
|
||||
e_out_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
|
||||
e_g_n_c_wis_lengths, e_g_n_c_wis_strides, num_workgroups_per_Conv_N_);
|
||||
|
||||
elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{
|
||||
a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)};
|
||||
elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapElementwise{
|
||||
e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)};
|
||||
|
||||
compute_ptr_offset_of_workspace_n_.BatchStrideA_ =
|
||||
a_g_n_k_wos_strides[1] * conv_N_per_block_;
|
||||
compute_ptr_offset_of_workspace_n_.BatchStrideE_ =
|
||||
e_g_n_c_wis_strides[1] * conv_N_per_block_;
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceATensorSizeBytes() const
|
||||
{
|
||||
const long_index_t a_acum = ck::accumulate_n<long_index_t>(
|
||||
a_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
|
||||
return sizeof(ADataType) * a_acum;
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceETensorSizeBytes() const
|
||||
{
|
||||
const long_index_t e_accum = ck::accumulate_n<long_index_t>(
|
||||
e_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
|
||||
return sizeof(EDataType) * e_accum;
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceSizeBytes() const
|
||||
{
|
||||
// Transpose require workspace for A and B
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ELayout, BLayout, ALayout>())
|
||||
{
|
||||
return GetWorkspaceATensorSizeBytes() + GetWorkspaceETensorSizeBytes();
|
||||
}
|
||||
else
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
void Print() const
|
||||
@@ -645,10 +796,16 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
|
||||
// block-to-e-tile map
|
||||
std::vector<Block2ETileMap> block_2_etile_map_container_;
|
||||
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_,
|
||||
elementwise_block_2_ctile_map_transpose_e_;
|
||||
|
||||
NGCHWTransposeDescType a_in_transpose_desc_, e_out_transpose_desc_;
|
||||
NHWGCTransposeDescType a_out_transpose_desc_, e_in_transpose_desc_;
|
||||
|
||||
// for computing batch offset
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor> compute_ptr_offset_of_batch_;
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_n_;
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_workspace_n_;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOp a_element_op_;
|
||||
@@ -657,9 +814,12 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_lengths_;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_;
|
||||
|
||||
index_t num_workgroups_per_Conv_N_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -667,19 +827,24 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
float RunGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
arg.Print();
|
||||
}
|
||||
float ave_time = 0;
|
||||
|
||||
const index_t gdy = arg.num_group_;
|
||||
const index_t num_workgroups_per_Conv_N =
|
||||
arg.a_g_n_k_wos_lengths_[I1] / arg.conv_N_per_block_;
|
||||
const index_t gdz = num_workgroups_per_Conv_N;
|
||||
const index_t gdz = arg.num_workgroups_per_Conv_N_;
|
||||
|
||||
const ADataType* p_a_grid = arg.p_a_grid_;
|
||||
EDataType* p_e_grid = arg.p_e_grid_;
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ELayout, BLayout, ALayout>())
|
||||
{
|
||||
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
|
||||
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
|
||||
}
|
||||
|
||||
float ave_time = 0;
|
||||
for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++)
|
||||
{
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i],
|
||||
@@ -722,10 +887,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
p_a_grid,
|
||||
arg.p_b_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.p_e_grid_,
|
||||
p_e_grid,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
@@ -751,6 +916,114 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
float ave_time = 0;
|
||||
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
arg.Print();
|
||||
}
|
||||
// Transpose from NGKHW to NHWGK
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ELayout, BLayout, ALayout>())
|
||||
{
|
||||
EDataType* p_e_in_grid = type_convert<EDataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
|
||||
|
||||
const auto clear_workspace = [&]() {
|
||||
hip_check_error(hipMemsetAsync(p_e_in_grid,
|
||||
0,
|
||||
arg.GetWorkspaceETensorSizeBytes(),
|
||||
stream_config.stream_id_));
|
||||
};
|
||||
|
||||
const index_t grid_size =
|
||||
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
|
||||
arg.a_in_transpose_desc_) *
|
||||
arg.num_workgroups_per_Conv_N_;
|
||||
|
||||
ADataType* p_a_out_grid = type_convert<ADataType*>(arg.p_workspace_);
|
||||
|
||||
auto kernel_transpose =
|
||||
kernel_batched_elementwise<GridwiseElementwiseInputTranspose,
|
||||
ck::Tuple<NGCHWTransposeDescType>,
|
||||
ck::Tuple<NHWGCTransposeDescType>,
|
||||
ck::Tuple<const ADataType*>,
|
||||
ck::Tuple<ADataType*>,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough,
|
||||
I1,
|
||||
I1>;
|
||||
|
||||
ave_time += launch_and_time_kernel_with_preprocess(
|
||||
stream_config,
|
||||
clear_workspace,
|
||||
kernel_transpose,
|
||||
dim3(grid_size),
|
||||
dim3(ElementwiseBlocksize),
|
||||
0,
|
||||
make_tuple(arg.a_in_transpose_desc_),
|
||||
make_tuple(arg.a_out_transpose_desc_),
|
||||
make_tuple(arg.p_a_grid_),
|
||||
make_tuple(p_a_out_grid),
|
||||
arg.elementwise_block_2_ctile_map_transpose_a_,
|
||||
element_wise::PassThrough{},
|
||||
arg.num_workgroups_per_Conv_N_,
|
||||
std::array<index_t, I1>{
|
||||
static_cast<index_t>(arg.compute_ptr_offset_of_workspace_n_.BatchStrideA_)},
|
||||
std::array<index_t, I1>{
|
||||
static_cast<index_t>(arg.compute_ptr_offset_of_n_.BatchStrideA_)});
|
||||
}
|
||||
ave_time += RunGemm(arg, stream_config);
|
||||
// Transpose from NHWGC to NGCHW
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ELayout, BLayout, ALayout>())
|
||||
{
|
||||
const index_t grid_size =
|
||||
arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
|
||||
arg.e_in_transpose_desc_) *
|
||||
arg.num_workgroups_per_Conv_N_;
|
||||
|
||||
const EDataType* p_e_in_grid =
|
||||
type_convert<EDataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
|
||||
|
||||
EDataType* p_e_out_grid = arg.p_e_grid_;
|
||||
|
||||
auto kernel_transpose =
|
||||
kernel_batched_elementwise<GridwiseElementwiseOutputTranspose,
|
||||
ck::Tuple<NHWGCTransposeDescType>,
|
||||
ck::Tuple<NGCHWTransposeDescType>,
|
||||
ck::Tuple<const EDataType*>,
|
||||
ck::Tuple<EDataType*>,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough,
|
||||
I1,
|
||||
I1>;
|
||||
|
||||
ave_time += launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel_transpose,
|
||||
dim3(grid_size),
|
||||
dim3(ElementwiseBlocksize),
|
||||
0,
|
||||
make_tuple(arg.e_in_transpose_desc_),
|
||||
make_tuple(arg.e_out_transpose_desc_),
|
||||
make_tuple(p_e_in_grid),
|
||||
make_tuple(p_e_out_grid),
|
||||
arg.elementwise_block_2_ctile_map_transpose_e_,
|
||||
element_wise::PassThrough{},
|
||||
arg.num_workgroups_per_Conv_N_,
|
||||
std::array<index_t, I1>{
|
||||
static_cast<index_t>(arg.compute_ptr_offset_of_n_.BatchStrideE_)},
|
||||
std::array<index_t, I1>{static_cast<index_t>(
|
||||
arg.compute_ptr_offset_of_workspace_n_.BatchStrideE_)});
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
@@ -765,6 +1038,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
return false;
|
||||
}
|
||||
|
||||
const index_t ConvG = arg.b_g_k_c_xs_lengths_[0];
|
||||
const index_t ConvK = arg.b_g_k_c_xs_lengths_[1];
|
||||
const index_t ConvC = arg.b_g_k_c_xs_lengths_[2];
|
||||
|
||||
@@ -787,7 +1061,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK> ||
|
||||
is_same_v<ALayout, tensor_layout::convolution::GNDHWK> ||
|
||||
is_same_v<ALayout, tensor_layout::convolution::NHWGK> ||
|
||||
is_same_v<ALayout, tensor_layout::convolution::NDHWGK>)
|
||||
is_same_v<ALayout, tensor_layout::convolution::NDHWGK> ||
|
||||
is_same_v<ALayout, tensor_layout::convolution::NGKHW> ||
|
||||
is_same_v<ALayout, tensor_layout::convolution::NGKDHW>)
|
||||
{
|
||||
if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
@@ -848,7 +1124,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
if constexpr(is_same_v<ELayout, tensor_layout::convolution::GNHWC> ||
|
||||
is_same_v<ELayout, tensor_layout::convolution::GNDHWC> ||
|
||||
is_same_v<ELayout, tensor_layout::convolution::NHWGC> ||
|
||||
is_same_v<ELayout, tensor_layout::convolution::NDHWGC>)
|
||||
is_same_v<ELayout, tensor_layout::convolution::NDHWGC> ||
|
||||
is_same_v<ELayout, tensor_layout::convolution::NGCHW> ||
|
||||
is_same_v<ELayout, tensor_layout::convolution::NGCDHW>)
|
||||
{
|
||||
// vector store C matrix into global memory
|
||||
if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
@@ -874,6 +1152,48 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ELayout, BLayout, ALayout>())
|
||||
{
|
||||
if((ConvG * ConvC) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if((ConvG * ConvK) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
const index_t a_spatial_acum = ck::accumulate_n<index_t>(
|
||||
arg.a_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
|
||||
const index_t e_spatial_acum = ck::accumulate_n<index_t>(
|
||||
arg.e_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
|
||||
|
||||
if(a_spatial_acum % TransposeTransferInScalarPerVectorAligned != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(e_spatial_acum % TransposeTransferOutScalarPerVectorAligned != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!arg.p_workspace_)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout
|
||||
<< "Warning: Workspace for "
|
||||
"DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1::Argument is not "
|
||||
"allocated, use SetWorkSpacePointer."
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -998,11 +1318,48 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
<< ABlockTransferSrcScalarPerVector << ", "
|
||||
<< BBlockTransferSrcScalarPerVector << ", "
|
||||
<< CShuffleMXdlPerWavePerShuffle << ", "
|
||||
<< CShuffleNXdlPerWavePerShuffle
|
||||
<< ">";
|
||||
<< CShuffleNXdlPerWavePerShuffle;
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ELayout, BLayout, ALayout>()) {
|
||||
str << ", TransposeTransferInScalarPerVectorAligned: "
|
||||
<< TransposeTransferInScalarPerVectorAligned <<", "
|
||||
<< "TransposeTransferOutScalarPerVectorAligned: " << TransposeTransferOutScalarPerVectorAligned;
|
||||
}
|
||||
|
||||
|
||||
str << ">";
|
||||
|
||||
return str.str();
|
||||
}
|
||||
|
||||
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
|
||||
{
|
||||
auto arg = dynamic_cast<const Argument*>(p_arg);
|
||||
if(arg)
|
||||
{
|
||||
return arg->GetWorkspaceSizeBytes();
|
||||
}
|
||||
else
|
||||
throw std::runtime_error(
|
||||
"The argument pointer is not an object of "
|
||||
"DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1::Argument structure!");
|
||||
}
|
||||
|
||||
void SetWorkSpacePointer(BaseArgument* p_arg,
|
||||
void* p_workspace,
|
||||
const StreamConfig& = StreamConfig{}) const override
|
||||
{
|
||||
auto p_arg_ = dynamic_cast<Argument*>(p_arg);
|
||||
if(p_arg_)
|
||||
{
|
||||
p_arg_->p_workspace_ = p_workspace;
|
||||
}
|
||||
else
|
||||
throw std::runtime_error(
|
||||
"The argument pointer is not an object of "
|
||||
"DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1::Argument structure!");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -1621,6 +1621,13 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
if(!(arg.a_out_transpose_desc_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
|
||||
arg.b_out_transpose_desc_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
|
||||
@@ -834,6 +834,25 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!arg.p_workspace_)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Warning: Workspace for "
|
||||
"DeviceGroupedConvBwdWeight_Xdl_CShuffle::Argument is not "
|
||||
"allocated, use SetWorkSpacePointer."
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
if(!(arg.a_out_transpose_desc_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
|
||||
arg.b_out_transpose_desc_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Gridwise GEMM size
|
||||
|
||||
@@ -771,12 +771,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
std::size_t GetWorkspaceATensorSizeBytes() const
|
||||
{
|
||||
return sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize();
|
||||
const long_index_t a_acum = ck::accumulate_n<long_index_t>(
|
||||
a_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
|
||||
return sizeof(ADataType) * a_acum;
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceETensorSizeBytes() const
|
||||
{
|
||||
return sizeof(EDataType) * e_out_transpose_desc_.GetElementSpaceSize();
|
||||
const long_index_t e_accum = ck::accumulate_n<long_index_t>(
|
||||
e_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
|
||||
return sizeof(EDataType) * e_accum;
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceSizeBytes() const
|
||||
@@ -1293,6 +1297,25 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!arg.p_workspace_)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Warning: Workspace for "
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument is not "
|
||||
"allocated, use SetWorkSpacePointer."
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
if(!(arg.a_out_transpose_desc_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
|
||||
arg.e_in_transpose_desc_.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if(!valid)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -586,12 +586,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
|
||||
std::size_t GetWorkspaceATensorSizeBytes() const
|
||||
{
|
||||
return sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize();
|
||||
const long_index_t a_acum = ck::accumulate_n<long_index_t>(
|
||||
a_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
|
||||
return sizeof(ADataType) * a_acum;
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceETensorSizeBytes() const
|
||||
{
|
||||
return sizeof(EDataType) * e_out_transpose_desc_.GetElementSpaceSize();
|
||||
const long_index_t e_accum = ck::accumulate_n<long_index_t>(
|
||||
e_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
|
||||
return sizeof(EDataType) * e_accum;
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceSizeBytes() const
|
||||
@@ -1207,6 +1211,25 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!arg.p_workspace_)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Warning: Workspace for "
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3::Argument is not "
|
||||
"allocated, use SetWorkSpacePointer."
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
if(!(arg.a_out_transpose_desc_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
|
||||
arg.e_in_transpose_desc_.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// check vector access of E
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -420,7 +420,8 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
: a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op}
|
||||
cde_element_op_{cde_element_op},
|
||||
gemm_kernel_host_args_{nullptr}
|
||||
{
|
||||
grid_size_ = 0;
|
||||
|
||||
@@ -538,6 +539,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
|
||||
std::vector<Tuple<index_t, index_t>> b_mtx_nraw_kraw_;
|
||||
|
||||
index_t grid_size_;
|
||||
void* gemm_kernel_host_args_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -545,7 +547,10 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
float Run(const Argument& arg,
|
||||
const StreamConfig& stream_config = StreamConfig{},
|
||||
hipStream_t cpy_stream = nullptr,
|
||||
hipEvent_t cpy_event = nullptr)
|
||||
{
|
||||
auto K0 = arg.gemm_desc_kernel_arg_[0].a_grid_desc_k0_m_k1_.GetLength(I0);
|
||||
bool all_has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
|
||||
@@ -602,12 +607,33 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
|
||||
}
|
||||
}
|
||||
|
||||
hipGetErrorString(
|
||||
hipMemcpyAsync(arg.p_workspace_,
|
||||
arg.gemm_desc_kernel_arg_.data(),
|
||||
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream_config.stream_id_));
|
||||
if(cpy_stream && cpy_event)
|
||||
{
|
||||
if(arg.gemm_kernel_host_args_ == nullptr)
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "No memory has been allocated for gemm kernel host args "
|
||||
<< "when providing the copy stream and copy event! In " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
hipGetErrorString(hipMemcpyAsync(arg.p_workspace_,
|
||||
arg.gemm_kernel_host_args_,
|
||||
arg.group_count_ * sizeof(GemmKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
cpy_stream));
|
||||
hipGetErrorString(hipEventRecord(cpy_event, cpy_stream));
|
||||
hipGetErrorString(hipEventSynchronize(cpy_event));
|
||||
}
|
||||
else
|
||||
{
|
||||
hipGetErrorString(
|
||||
hipMemcpyAsync(arg.p_workspace_,
|
||||
arg.gemm_desc_kernel_arg_.data(),
|
||||
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream_config.stream_id_));
|
||||
}
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop,
|
||||
auto has_double_tail_k_block_loop) {
|
||||
@@ -762,6 +788,32 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
|
||||
{
|
||||
return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(GemmKernelArg);
|
||||
}
|
||||
|
||||
size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
|
||||
{
|
||||
return GetWorkSpaceSize(p_arg);
|
||||
}
|
||||
|
||||
size_t GetHostKernelArgSize(const BaseArgument* p_arg) const { return GetWorkSpaceSize(p_arg); }
|
||||
|
||||
void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override
|
||||
{
|
||||
return this->SetWorkSpacePointer(p_arg, p_dev_kernel_args);
|
||||
}
|
||||
|
||||
void SetHostKernelArgs(BaseArgument* p_arg, void* p_host_kernel_args) const
|
||||
{
|
||||
Argument* pArg_ = dynamic_cast<Argument*>(p_arg);
|
||||
if(!pArg_)
|
||||
{
|
||||
throw std::runtime_error("Failed to cast argument pointer!");
|
||||
}
|
||||
|
||||
pArg_->gemm_kernel_host_args_ = p_host_kernel_args;
|
||||
std::copy(pArg_->gemm_desc_kernel_arg_.begin(),
|
||||
pArg_->gemm_desc_kernel_arg_.end(),
|
||||
static_cast<GemmKernelArg*>(pArg_->gemm_kernel_host_args_));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -500,6 +500,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
|
||||
std::vector<Tuple<index_t, index_t>> b_mtx_nraw_kraw_;
|
||||
|
||||
index_t grid_size_;
|
||||
void* gemm_kernel_host_args_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -507,7 +508,10 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
float Run(const Argument& arg,
|
||||
const StreamConfig& stream_config = StreamConfig{},
|
||||
hipStream_t cpy_stream = nullptr,
|
||||
hipEvent_t cpy_event = nullptr)
|
||||
{
|
||||
bool has_main_k_block_loop = true;
|
||||
|
||||
@@ -556,12 +560,33 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
|
||||
}
|
||||
}
|
||||
|
||||
hipGetErrorString(
|
||||
hipMemcpyAsync(arg.p_workspace_,
|
||||
arg.gemm_desc_kernel_arg_.data(),
|
||||
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmBiasTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream_config.stream_id_));
|
||||
if(cpy_stream && cpy_event)
|
||||
{
|
||||
if(arg.gemm_kernel_host_args_ == nullptr)
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "No memory has been allocated for gemm kernel host args "
|
||||
<< "when providing the copy stream and copy event! In " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
hipGetErrorString(hipMemcpyAsync(arg.p_workspace_,
|
||||
arg.gemm_kernel_host_args_,
|
||||
arg.group_count_ * sizeof(GemmBiasTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
cpy_stream));
|
||||
hipGetErrorString(hipEventRecord(cpy_event, cpy_stream));
|
||||
hipGetErrorString(hipEventSynchronize(cpy_event));
|
||||
}
|
||||
else
|
||||
{
|
||||
hipGetErrorString(hipMemcpyAsync(arg.p_workspace_,
|
||||
arg.gemm_desc_kernel_arg_.data(),
|
||||
arg.gemm_desc_kernel_arg_.size() *
|
||||
sizeof(GemmBiasTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream_config.stream_id_));
|
||||
}
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
@@ -735,6 +760,22 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
|
||||
{
|
||||
return this->SetWorkSpacePointer(p_arg, p_dev_kernel_args);
|
||||
}
|
||||
|
||||
size_t GetHostKernelArgSize(const BaseArgument* p_arg) const { return GetWorkSpaceSize(p_arg); }
|
||||
|
||||
void SetHostKernelArgs(BaseArgument* p_arg, void* p_host_kernel_args) const
|
||||
{
|
||||
Argument* pArg_ = dynamic_cast<Argument*>(p_arg);
|
||||
if(!pArg_)
|
||||
{
|
||||
throw std::runtime_error("Failed to cast argument pointer!");
|
||||
}
|
||||
|
||||
pArg_->gemm_kernel_host_args_ = p_host_kernel_args;
|
||||
std::copy(pArg_->gemm_desc_kernel_arg_.begin(),
|
||||
pArg_->gemm_desc_kernel_arg_.end(),
|
||||
static_cast<GemmBiasTransKernelArg*>(pArg_->gemm_kernel_host_args_));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -244,7 +244,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
std::vector<void*>& p_Es,
|
||||
std::vector<GemmDesc>& gemm_descs,
|
||||
index_t kbatch)
|
||||
: K_BATCH{kbatch}
|
||||
: K_BATCH{kbatch}, gemm_kernel_host_args_{nullptr}
|
||||
{
|
||||
grid_size_ = 0;
|
||||
group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size());
|
||||
@@ -365,13 +365,17 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
index_t skipped_group_count_;
|
||||
|
||||
std::vector<GemmTransKernelArg> gemm_kernel_args_;
|
||||
void* gemm_kernel_host_args_;
|
||||
index_t grid_size_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
float Run(const Argument& arg,
|
||||
const StreamConfig& stream_config = StreamConfig{},
|
||||
hipStream_t cpy_stream = nullptr,
|
||||
hipEvent_t cpy_event = nullptr)
|
||||
{
|
||||
index_t K0 = arg.gemm_kernel_args_[0].karg_.K0Padded;
|
||||
bool all_have_kbatch_gt_one = arg.gemm_kernel_args_[0].karg_.k_batch > 1;
|
||||
@@ -419,12 +423,34 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
}
|
||||
}
|
||||
|
||||
hip_check_error(
|
||||
hipMemcpyAsync(arg.p_workspace_,
|
||||
arg.gemm_kernel_args_.data(),
|
||||
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream_config.stream_id_));
|
||||
if(cpy_stream && cpy_event)
|
||||
{
|
||||
if(arg.gemm_kernel_host_args_ == nullptr)
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "No memory has been allocated for gemm kernel host args "
|
||||
<< "when providing the copy stream and copy event! In " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
hip_check_error(hipMemcpyAsync(arg.p_workspace_,
|
||||
arg.gemm_kernel_host_args_,
|
||||
arg.group_count_ * sizeof(GemmTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
cpy_stream));
|
||||
hip_check_error(hipEventRecord(cpy_event, cpy_stream));
|
||||
hip_check_error(hipEventSynchronize(cpy_event));
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
hip_check_error(
|
||||
hipMemcpyAsync(arg.p_workspace_,
|
||||
arg.gemm_kernel_args_.data(),
|
||||
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream_config.stream_id_));
|
||||
}
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
@@ -652,6 +678,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
return GetWorkSpaceSize(p_arg);
|
||||
}
|
||||
|
||||
size_t GetHostKernelArgSize(const BaseArgument* p_arg) const { return GetWorkSpaceSize(p_arg); }
|
||||
|
||||
// TODO: deperecation notice.
|
||||
static void SetKBatchSize(Argument& arg, index_t kbatch) { arg.UpdateKBatch(kbatch); }
|
||||
|
||||
@@ -673,6 +701,20 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
{
|
||||
return this->SetWorkSpacePointer(p_arg, p_dev_kernel_args);
|
||||
}
|
||||
|
||||
void SetHostKernelArgs(BaseArgument* p_arg, void* p_host_kernel_args) const
|
||||
{
|
||||
Argument* pArg_ = dynamic_cast<Argument*>(p_arg);
|
||||
if(!pArg_)
|
||||
{
|
||||
throw std::runtime_error("Failed to cast argument pointer!");
|
||||
}
|
||||
|
||||
pArg_->gemm_kernel_host_args_ = p_host_kernel_args;
|
||||
std::copy(pArg_->gemm_kernel_args_.begin(),
|
||||
pArg_->gemm_kernel_args_.end(),
|
||||
static_cast<GemmTransKernelArg*>(pArg_->gemm_kernel_host_args_));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#if __clang_major__ >= 20
|
||||
#if __clang_major__ == 20
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm_builtins.hpp"
|
||||
#else
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -126,12 +126,13 @@ __global__ void
|
||||
OutDataTypePointerTuple p_out_global_with_offset_tuple;
|
||||
|
||||
static_for<0, InDataTypePointerTuple::Size(), 1>{}([&](auto i) {
|
||||
p_in_global_with_offset_tuple(i) = p_in_global_tuple.At(i) + input_batch_strides[i] * g_idx;
|
||||
p_in_global_with_offset_tuple(i) =
|
||||
p_in_global_tuple.At(i) + type_convert<long_index_t>(input_batch_strides[i]) * g_idx;
|
||||
});
|
||||
|
||||
static_for<0, OutDataTypePointerTuple::Size(), 1>{}([&](auto i) {
|
||||
p_out_global_with_offset_tuple(i) =
|
||||
p_out_global_tuple.At(i) + output_batch_strides[i] * g_idx;
|
||||
p_out_global_tuple.At(i) + type_convert<long_index_t>(output_batch_strides[i]) * g_idx;
|
||||
});
|
||||
|
||||
GridwiseElementwiseFunctor::Run(in_grid_desc_tuple,
|
||||
|
||||
@@ -431,6 +431,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
}
|
||||
// copy data from buf_vectors into dst_bufs
|
||||
static_for<0, nDst, 1>{}([&](auto i) {
|
||||
<<<<<<< HEAD
|
||||
using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
|
||||
IndexType dst_offset = scatter_offset + (dst_coords_[i].GetOffset());
|
||||
const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize();
|
||||
@@ -440,6 +441,15 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
static_cast<InMemoryDataOperationEnum>(DstInMemOps::At(i.value));
|
||||
dst_bufs(i).template Update<DstInMemOp, dst_vector_t>(
|
||||
dst_offset, is_dst_valid, dst_vectors[i].template AsType<dst_vector_t>()[I0]);
|
||||
=======
|
||||
using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
|
||||
auto dst_offset = scatter_offset + dst_coords_[i].GetOffset();
|
||||
constexpr InMemoryDataOperationEnum DstInMemOp =
|
||||
static_cast<InMemoryDataOperationEnum>(DstInMemOps::At(i.value));
|
||||
|
||||
dst_bufs(i).template Update<DstInMemOp, dst_vector_t>(
|
||||
dst_offset, true, dst_vectors[i].template AsType<dst_vector_t>()[I0]);
|
||||
>>>>>>> origin/develop
|
||||
});
|
||||
|
||||
// move coordinate
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -29,10 +29,11 @@ struct TransformConvNGCHWToNHWGC
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
|
||||
static auto MakeNGCHWTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides,
|
||||
const index_t split_n_size = 1)
|
||||
{
|
||||
const index_t& G = g_n_c_wis_lengths[I0];
|
||||
const index_t& N = g_n_c_wis_lengths[I1];
|
||||
const index_t N = g_n_c_wis_lengths[I1] / split_n_size;
|
||||
const index_t& C = g_n_c_wis_lengths[I2];
|
||||
const index_t& Wi = g_n_c_wis_lengths[I3];
|
||||
|
||||
@@ -55,10 +56,11 @@ struct TransformConvNGCHWToNHWGC
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
|
||||
static auto MakeNHWGCTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides,
|
||||
const index_t split_n_size = 1)
|
||||
{
|
||||
const index_t& G = g_n_c_wis_lengths[I0];
|
||||
const index_t& N = g_n_c_wis_lengths[I1];
|
||||
const index_t N = g_n_c_wis_lengths[I1] / split_n_size;
|
||||
const index_t& C = g_n_c_wis_lengths[I2];
|
||||
const index_t& Wi = g_n_c_wis_lengths[I3];
|
||||
|
||||
@@ -81,10 +83,11 @@ struct TransformConvNGCHWToNHWGC
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
|
||||
static auto MakeNGCHWTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides,
|
||||
const index_t split_n_size = 1)
|
||||
{
|
||||
const index_t& G = g_n_c_wis_lengths[I0];
|
||||
const index_t& N = g_n_c_wis_lengths[I1];
|
||||
const index_t N = g_n_c_wis_lengths[I1] / split_n_size;
|
||||
const index_t& C = g_n_c_wis_lengths[I2];
|
||||
const index_t& Hi = g_n_c_wis_lengths[I3];
|
||||
const index_t& Wi = g_n_c_wis_lengths[I4];
|
||||
@@ -109,10 +112,11 @@ struct TransformConvNGCHWToNHWGC
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
|
||||
static auto MakeNHWGCTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides,
|
||||
const index_t split_n_size = 1)
|
||||
{
|
||||
const index_t& G = g_n_c_wis_lengths[I0];
|
||||
const index_t& N = g_n_c_wis_lengths[I1];
|
||||
const index_t N = g_n_c_wis_lengths[I1] / split_n_size;
|
||||
const index_t& C = g_n_c_wis_lengths[I2];
|
||||
const index_t& Hi = g_n_c_wis_lengths[I3];
|
||||
const index_t& Wi = g_n_c_wis_lengths[I4];
|
||||
@@ -137,10 +141,11 @@ struct TransformConvNGCHWToNHWGC
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
|
||||
static auto MakeNGCHWTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides,
|
||||
const index_t split_n_size = 1)
|
||||
{
|
||||
const index_t& G = g_n_c_wis_lengths[I0];
|
||||
const index_t& N = g_n_c_wis_lengths[I1];
|
||||
const index_t N = g_n_c_wis_lengths[I1] / split_n_size;
|
||||
const index_t& C = g_n_c_wis_lengths[I2];
|
||||
const index_t& Di = g_n_c_wis_lengths[I3];
|
||||
const index_t& Hi = g_n_c_wis_lengths[I4];
|
||||
@@ -168,10 +173,11 @@ struct TransformConvNGCHWToNHWGC
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
|
||||
static auto MakeNHWGCTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
|
||||
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides,
|
||||
const index_t split_n_size = 1)
|
||||
{
|
||||
const index_t& G = g_n_c_wis_lengths[I0];
|
||||
const index_t& N = g_n_c_wis_lengths[I1];
|
||||
const index_t N = g_n_c_wis_lengths[I1] / split_n_size;
|
||||
const index_t& C = g_n_c_wis_lengths[I2];
|
||||
const index_t& Di = g_n_c_wis_lengths[I3];
|
||||
const index_t& Hi = g_n_c_wis_lengths[I4];
|
||||
|
||||
@@ -33,7 +33,7 @@
|
||||
#include "ck/utility/thread_group.hpp"
|
||||
#include "ck/utility/debug.hpp"
|
||||
|
||||
#if __clang_major__ >= 20
|
||||
#if __clang_major__ == 20
|
||||
#include "amd_buffer_addressing_builtins.hpp"
|
||||
#else
|
||||
#include "amd_buffer_addressing.hpp"
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "enable_if.hpp"
|
||||
#include "c_style_pointer_cast.hpp"
|
||||
#if __clang_major__ >= 20
|
||||
#if __clang_major__ == 20
|
||||
#include "amd_buffer_addressing_builtins.hpp"
|
||||
#else
|
||||
#include "amd_buffer_addressing.hpp"
|
||||
|
||||
@@ -8,11 +8,8 @@
|
||||
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
|
||||
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
|
||||
#include "ck_tile/core/algorithm/static_encoding_pattern.hpp"
|
||||
#if __clang_major__ >= 20
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp"
|
||||
#else
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
|
||||
#endif
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
|
||||
#include "ck_tile/core/arch/utility.hpp"
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#if !CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
|
||||
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
@@ -2553,3 +2555,5 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
#endif // !CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#if CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
|
||||
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
@@ -2553,3 +2555,5 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
#endif // CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
|
||||
|
||||
@@ -252,3 +252,11 @@ CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_LOGGING)
|
||||
#else // for GPU code
|
||||
#define CK_TILE_USE_OCP_FP8 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
|
||||
#if __clang_major__ == 20
|
||||
#define CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN 1
|
||||
#else
|
||||
#define CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN 0
|
||||
#endif
|
||||
#endif
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#if __clang_major__ >= 20
|
||||
#if __clang_major__ == 20
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp"
|
||||
#else
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
|
||||
|
||||
@@ -33,12 +33,12 @@
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
|
||||
|
||||
@@ -112,6 +112,13 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
else
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim <= 192)
|
||||
{
|
||||
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
return 1;
|
||||
else
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim <= 256)
|
||||
{
|
||||
return 1;
|
||||
|
||||
@@ -13,6 +13,8 @@ static CK_TILE_HOST_DEVICE constexpr index_t ceil_to_qualified_tile_length(index
|
||||
return 128;
|
||||
if(len == 160)
|
||||
return 256;
|
||||
if(len == 192)
|
||||
return 192;
|
||||
|
||||
// only length of 96, 160 and power-of-two is supported
|
||||
if(!(len & (len - 1)))
|
||||
|
||||
@@ -0,0 +1,144 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using BF8 = ck::bf8_t;
|
||||
using F8 = ck::f8_t;
|
||||
|
||||
using Empty_Tuple = ck::Tuple<>;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
// f16_f16_f32_f16
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardDataSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_data_transpose_xdl_f16_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| LoopSched| AComputeType| BComputeType| MaxTranspose| MaxTranspose|
|
||||
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| | | | TransferIn| TransferOut|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| | | | ScalarPer| ScalarPer|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Vector| Vector|
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, make_default_loop_scheduler(), F16, F16, 2, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, make_default_loop_scheduler(), F16, F16, 2, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, make_default_loop_scheduler(), F16, F16, 2, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, make_default_loop_scheduler(), F16, F16, 2, 2>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, make_default_loop_scheduler(), F16, F16, 4, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, make_default_loop_scheduler(), F16, F16, 4, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, make_default_loop_scheduler(), F16, F16, 4, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, make_default_loop_scheduler(), F16, F16, 4, 4>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, make_default_loop_scheduler(), F16, F16, 1, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, make_default_loop_scheduler(), F16, F16, 1, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, make_default_loop_scheduler(), F16, F16, 1, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, make_default_loop_scheduler(), F16, F16, 1, 2>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, make_default_loop_scheduler(), F16, F16, 2, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, make_default_loop_scheduler(), F16, F16, 2, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, make_default_loop_scheduler(), F16, F16, 2, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, make_default_loop_scheduler(), F16, F16, 2, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
// bf16_bf16_f32_bf16
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardDataSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_data_transpose_xdl_bf16_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| LoopSched| AComputeType| BComputeType| MaxTranspose| MaxTranspose|
|
||||
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| | | | TransferIn| TransferOut|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| | | | ScalarPer| ScalarPer|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Vector| Vector|
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, make_default_loop_scheduler(), BF16, BF16, 2, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, make_default_loop_scheduler(), BF16, BF16, 2, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, make_default_loop_scheduler(), BF16, BF16, 2, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, make_default_loop_scheduler(), BF16, BF16, 2, 2>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, make_default_loop_scheduler(), BF16, BF16, 4, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, make_default_loop_scheduler(), BF16, BF16, 4, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, make_default_loop_scheduler(), BF16, BF16, 4, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, make_default_loop_scheduler(), BF16, BF16, 4, 4>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, make_default_loop_scheduler(), BF16, BF16, 1, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, make_default_loop_scheduler(), BF16, BF16, 1, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, make_default_loop_scheduler(), BF16, BF16, 1, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, make_default_loop_scheduler(), BF16, BF16, 1, 2>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, make_default_loop_scheduler(), BF16, BF16, 2, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, make_default_loop_scheduler(), BF16, BF16, 2, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, make_default_loop_scheduler(), BF16, BF16, 2, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, make_default_loop_scheduler(), BF16, BF16, 2, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
// f32_f32_f32_f32
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardDataSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_data_transpose_xdl_f32_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| LoopSched| AComputeType| BComputeType| MaxTranspose| MaxTranspose|
|
||||
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| | | | TransferIn| TransferOut|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| | | | ScalarPer| ScalarPer|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Vector| Vector|
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, make_default_loop_scheduler(), F32, F32, 2, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4, make_default_loop_scheduler(), F32, F32, 2, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4, make_default_loop_scheduler(), F32, F32, 2, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4, make_default_loop_scheduler(), F32, F32, 2, 2>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, make_default_loop_scheduler(), F32, F32, 4, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4, make_default_loop_scheduler(), F32, F32, 4, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4, make_default_loop_scheduler(), F32, F32, 4, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4, make_default_loop_scheduler(), F32, F32, 4, 4>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, make_default_loop_scheduler(), F32, F32, 1, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4, make_default_loop_scheduler(), F32, F32, 1, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4, make_default_loop_scheduler(), F32, F32, 1, 2>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4, make_default_loop_scheduler(), F32, F32, 1, 2>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, make_default_loop_scheduler(), F32, F32, 2, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4, make_default_loop_scheduler(), F32, F32, 2, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4, make_default_loop_scheduler(), F32, F32, 2, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4, make_default_loop_scheduler(), F32, F32, 2, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -127,6 +127,35 @@ struct DeviceOperationInstanceFactory<
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
if constexpr(is_same_v<InLayout, NGCHW> && is_same_v<WeiLayout, GKYXC> &&
|
||||
is_same_v<OutLayout, NGKHW>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
|
||||
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_f16_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
|
||||
is_same_v<ComputeTypeB, F32>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_f32_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
@@ -201,6 +230,37 @@ struct DeviceOperationInstanceFactory<
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
if constexpr(is_same_v<InLayout, NGCDHW> && is_same_v<WeiLayout, GKZYXC> &&
|
||||
is_same_v<OutLayout, NGKDHW>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
|
||||
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_f16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
|
||||
is_same_v<ComputeTypeB, F32>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_f32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@@ -101,6 +101,52 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NGKHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGCHW,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NGKHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGCHW,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NGKHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGCHW,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
// conv3d backward data
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
@@ -209,6 +255,51 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_
|
||||
BF8,
|
||||
F8>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NGKDHW,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NGCDHW,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NGKDHW,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NGCDHW,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NGKDHW,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NGCDHW,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
|
||||
@@ -7,6 +7,9 @@ add_instance_library(
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_bf16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp
|
||||
|
||||
wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp
|
||||
wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[n, hi, wi, g, c] * wei[g, k, y, x, c] = in[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NGKHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGCHW,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_bf16_instances<2,
|
||||
NGKHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGCHW,
|
||||
ConvBwdDataDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_transpose_xdl_bf16_instances<2,
|
||||
NGKHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGCHW,
|
||||
ConvBwdDataDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,48 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[n, hi, wi, g, c] * wei[g, k, y, x, c] = in[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NGKHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGCHW,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_f16_instances<2,
|
||||
NGKHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGCHW,
|
||||
ConvBwdDataDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_transpose_xdl_f16_instances<2,
|
||||
NGKHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGCHW,
|
||||
ConvBwdDataDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,48 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[n, hi, wi, g, c] * wei[g, k, y, x, c] = in[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NGKHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGCHW,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_f32_instances<2,
|
||||
NGKHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGCHW,
|
||||
ConvBwdDataDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_transpose_xdl_f32_instances<2,
|
||||
NGKHW,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NGCHW,
|
||||
ConvBwdDataDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -6,6 +6,9 @@ set(GROUPED_CONV3D_BWD_DATA
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_f16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instance.cpp
|
||||
wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
|
||||
wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo,
|
||||
// g, k]
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NGKDHW,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NGCDHW,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_bf16_instances<3,
|
||||
NGKDHW,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NGCDHW,
|
||||
ConvBwdDataDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_transpose_xdl_bf16_instances<3,
|
||||
NGKDHW,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NGCDHW,
|
||||
ConvBwdDataDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,49 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo,
|
||||
// g, k]
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NGKDHW,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NGCDHW,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_f16_instances<3,
|
||||
NGKDHW,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NGCDHW,
|
||||
ConvBwdDataDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_transpose_xdl_f16_instances<3,
|
||||
NGKDHW,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NGCDHW,
|
||||
ConvBwdDataDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,49 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo,
|
||||
// g, k]
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NGKDHW,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NGCDHW,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_f32_instances<3,
|
||||
NGKDHW,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NGCDHW,
|
||||
ConvBwdDataDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_transpose_xdl_f32_instances<3,
|
||||
NGKDHW,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NGCDHW,
|
||||
ConvBwdDataDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
@@ -125,6 +125,11 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
|
||||
bool pass = true;
|
||||
|
||||
auto run_impl = [&](auto& op_ptr, auto& argument_ptr) {
|
||||
// workspace_sz will be equal to 0 for other layout than NGCHW
|
||||
const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
|
||||
DeviceMem workspace_dev(workspace_sz);
|
||||
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
// re-init output to zero before profiling next kernel
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
@@ -15,6 +15,7 @@ enum struct ConvLayout
|
||||
{
|
||||
GNHWC_GKYXC_GNHWK, // 0
|
||||
NHWGC_GKYXC_NHWGK, // 1
|
||||
NGCHW_GKYXC_NGKHW, // 2
|
||||
};
|
||||
|
||||
enum struct ConvDataType
|
||||
@@ -37,6 +38,7 @@ static void print_helper_msg()
|
||||
<< " 2: Output bf16, Weight bf16, Input bf16\n"
|
||||
<< "arg3: tensor layout (0: Output[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Input[G, N, Ho, Wo, K]\n"
|
||||
<< " 1: Output[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Input[N, Ho, Wo, G, K])\n"
|
||||
<< " 2: Output[N, G, C, Hi, Wi], Weight[G, K, Y, X, C], Input[N, G, K, Ho, Wo])\n"
|
||||
<< "arg4: verification (0: no, 1: yes)\n"
|
||||
<< "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n"
|
||||
<< "arg6: print tensor value (0: no; 1: yes)\n"
|
||||
@@ -143,6 +145,21 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
|
||||
return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
}
|
||||
else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return profile(I2, NGKHW{}, GKYXC{}, NGCHW{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return profile(I2, NGKHW{}, GKYXC{}, NGCHW{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return profile(I2, NGKHW{}, GKYXC{}, NGCHW{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 3)
|
||||
{
|
||||
@@ -176,6 +193,21 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
|
||||
return profile(I3, NDHWGK{}, GKZYXC{}, NDHWGC{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
}
|
||||
else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return profile(I3, NGKDHW{}, GKZYXC{}, NGCDHW{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return profile(I3, NGKDHW{}, GKZYXC{}, NGCDHW{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return profile(I3, NGKDHW{}, GKZYXC{}, NGCDHW{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "this data_type & layout is not implemented" << std::endl;
|
||||
|
||||
@@ -28,7 +28,8 @@ def parse_layouts(args):
|
||||
args.in_layout == "NCDHW":
|
||||
if args.ck_profier_op == "grouped_conv_bwd_weight":
|
||||
args.layout = 3
|
||||
elif args.ck_profier_op == "grouped_conv_fwd":
|
||||
elif args.ck_profier_op == "grouped_conv_bwd_data" or \
|
||||
args.ck_profier_op == "grouped_conv_fwd":
|
||||
args.layout = 2
|
||||
else:
|
||||
print('Not supported layout for this op')
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# Currently ck_tile is only built on gfx9
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
# Currently ck_tile is only built on gfx94/gfx95
|
||||
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_mem test_gemm_pipeline_mem.cpp)
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_compv3 test_gemm_pipeline_compv3.cpp)
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_compv4 test_gemm_pipeline_compv4.cpp)
|
||||
else()
|
||||
message("Skipping ck_tile_gemm tests for current target")
|
||||
endif()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
@@ -51,6 +51,9 @@ using namespace ck::tensor_layout::convolution;
|
||||
using KernelTypes2d = ::testing::Types<std::tuple<float, GNHWK, GKYXC, GNHWC>,
|
||||
std::tuple<ck::half_t, GNHWK, GKYXC, GNHWC>,
|
||||
std::tuple<ck::bhalf_t, GNHWK, GKYXC, GNHWC>,
|
||||
std::tuple<float, NGKHW, GKYXC, NGCHW>,
|
||||
std::tuple<ck::half_t, NGKHW, GKYXC, NGCHW>,
|
||||
std::tuple<ck::bhalf_t, NGKHW, GKYXC, NGCHW>,
|
||||
std::tuple<float, NHWGK, GKYXC, NHWGC>,
|
||||
std::tuple<ck::half_t, NHWGK, GKYXC, NHWGC>,
|
||||
std::tuple<ck::bhalf_t, NHWGK, GKYXC, NHWGC>>;
|
||||
@@ -58,6 +61,9 @@ using KernelTypes2d = ::testing::Types<std::tuple<float, GNHWK, GKYXC, GNHWC>,
|
||||
using KernelTypes3d = ::testing::Types<std::tuple<float, GNDHWK, GKZYXC, GNDHWC>,
|
||||
std::tuple<ck::half_t, GNDHWK, GKZYXC, GNDHWC>,
|
||||
std::tuple<ck::bhalf_t, GNDHWK, GKZYXC, GNDHWC>,
|
||||
std::tuple<float, NGKDHW, GKZYXC, NGCDHW>,
|
||||
std::tuple<ck::half_t, NGKDHW, GKZYXC, NGCDHW>,
|
||||
std::tuple<ck::bhalf_t, NGKDHW, GKZYXC, NGCDHW>,
|
||||
std::tuple<float, NDHWGK, GKZYXC, NDHWGC>,
|
||||
std::tuple<ck::half_t, NDHWGK, GKZYXC, NDHWGC>,
|
||||
std::tuple<ck::bhalf_t, NDHWGK, GKZYXC, NDHWGC>>;
|
||||
|
||||
Reference in New Issue
Block a user