Merge branch 'vpietila/retina-net-fwd-convs' into vpietila/retina-net-training-perf

This commit is contained in:
Ville Pietilä
2026-02-09 06:54:16 -05:00
44 changed files with 2777 additions and 458 deletions

3
.gitignore vendored
View File

@@ -112,3 +112,6 @@ experimental/grouped_convolution_tile_instances/instances/*
!experimental/grouped_convolution_tile_instances/instances/*.in
!experimental/grouped_convolution_tile_instances/instances/*.inc
experimental/grouped_convolution_tile_instances/*.inc
benchmarking/workloads/
benchmarking/.rocprofv3/

View File

@@ -0,0 +1,5 @@
1 1 1 0 1 0 1 2 32 32 4 4 3 3 200 200 1 1 1 1 1 1 1 1; DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 64, 32, Default, 32, 32, 2, 2, 4, 4, 1, 1, 1, 8>; 4.84353
1 1 1 0 1 0 1 2 32 32 8 8 3 3 200 200 2 2 1 1 1 1 1 1; DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 64, 32, Default, 32, 32, 4, 2, 8, 8, 1, 1, 1, 8>; 19.7032
1 1 1 0 1 0 1 2 32 32 8 8 3 3 100 100 1 2 1 1 1 1 1 1; DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 64, 32, Default, 32, 32, 4, 2, 8, 8, 1, 1, 1, 8>; 20.432
1 1 1 0 1 0 1 2 1 32 2376 256 3 3 100 100 1 1 1 1 1 1 1 1; DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 256, 128, 32, Default, 32, 32, 4, 2, 8, 8, 8, 1, 1, 1>; 866.878
1 1 1 0 1 0 1 2 1 32 256 256 3 3 100 100 1 1 1 1 1 1 1 1; DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 256, 128, 32, Default, 32, 32, 4, 2, 8, 8, 8, 1, 1, 1>; 978.148
1 1 1 1 0 1 0 1 2 32 32 4 4 3 3 200 200 1 1 1 1 1 1 1 1; DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256 128 64 32 Default 32 32 2 2 4 4 1 1 1 8>; 4.84353
2 1 1 1 0 1 0 1 2 32 32 8 8 3 3 200 200 2 2 1 1 1 1 1 1; DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256 128 64 32 Default 32 32 4 2 8 8 1 1 1 8>; 19.7032
3 1 1 1 0 1 0 1 2 32 32 8 8 3 3 100 100 1 2 1 1 1 1 1 1; DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256 128 64 32 Default 32 32 4 2 8 8 1 1 1 8>; 20.432
4 1 1 1 0 1 0 1 2 1 32 2376 256 3 3 100 100 1 1 1 1 1 1 1 1; DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256 256 128 32 Default 32 32 4 2 8 8 8 1 1 1>; 866.878
5 1 1 1 0 1 0 1 2 1 32 256 256 3 3 100 100 1 1 1 1 1 1 1 1; DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256 256 128 32 Default 32 32 4 2 8 8 8 1 1 1>; 978.148

View File

@@ -0,0 +1,5 @@
1 1 1 0 1 0 1 2 32 32 4 4 3 3 200 200 1 1 1 1 1 1 1 1; DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Default, 16, 16, 2, 2, 4, 4, 4, 1, 1, 1>; 2.85567
1 1 1 0 1 0 1 2 32 32 8 8 3 3 200 200 2 2 1 1 1 1 1 1; DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 64, 64, Default, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>; 10.3518
1 1 1 0 1 0 1 2 32 32 8 8 3 3 100 100 1 2 1 1 1 1 1 1; DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 256, 32, 64, Default, 32, 32, 2, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1>; 15.4196
1 1 1 0 1 0 1 2 1 32 2376 256 3 3 100 100 1 1 1 1 1 1 1 1; DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 256, 128, 32, Default, 32, 32, 4, 2, 8, 8, 8, 1, 1, 1>; 849.773
1 1 1 0 1 0 1 2 1 32 256 256 3 3 100 100 1 1 1 1 1 1 1 1; DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 256, 128, 32, OddC, 32, 32, 4, 2, 8, 8, 8, 1, 1, 1>; 990.991
1 1 1 1 0 1 0 1 2 32 32 4 4 3 3 200 200 1 1 1 1 1 1 1 1 DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Default, 16, 16, 2, 2, 4, 4, 4, 1, 1, 1> 2.85567
2 1 1 1 0 1 0 1 2 32 32 8 8 3 3 200 200 2 2 1 1 1 1 1 1 DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 64, 64, Default, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> 10.3518
3 1 1 1 0 1 0 1 2 32 32 8 8 3 3 100 100 1 2 1 1 1 1 1 1 DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 256, 32, 64, Default, 32, 32, 2, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1> 15.4196
4 1 1 1 0 1 0 1 2 1 32 2376 256 3 3 100 100 1 1 1 1 1 1 1 1 DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 256, 128, 32, Default, 32, 32, 4, 2, 8, 8, 8, 1, 1, 1> 849.773
5 1 1 1 0 1 0 1 2 1 32 256 256 3 3 100 100 1 1 1 1 1 1 1 1 DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 256, 128, 32, OddC, 32, 32, 4, 2, 8, 8, 8, 1, 1, 1> 990.991

View File

@@ -0,0 +1,5 @@
1 1 1 0 1 0 1 2 32 32 4 4 3 3 200 200 1 1 1 1 1 1 1 1; DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 32, 32, Default, 32, 32, 2, 1, 4, 4, 1, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1, 8>; 8.09047
1 1 1 0 1 0 1 2 32 32 8 8 3 3 200 200 2 2 1 1 1 1 1 1; DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 64, 32, Default, 32, 32, 4, 2, 8, 8, 1, 1, 1, 8>; 19.8867
1 1 1 0 1 0 1 2 32 32 8 8 3 3 100 100 1 2 1 1 1 1 1 1; DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 64, 32, Default, 32, 32, 4, 2, 8, 8, 1, 1, 1, 8>; 20.4604
1 1 1 0 1 0 1 2 1 32 2376 256 3 3 100 100 1 1 1 1 1 1 1 1; DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 64, Default, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3, 1>; 952.806
1 1 1 0 1 0 1 2 1 32 256 256 3 3 100 100 1 1 1 1 1 1 1 1; DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 64, Default, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3, 1>; 1091.89
1 1 1 1 0 1 0 1 2 32 32 4 4 3 3 200 200 1 1 1 1 1 1 1 1 DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 32, 32, Default, 32, 32, 2, 1, 4, 4, 1, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1, 8> 8.09047
2 1 1 1 0 1 0 1 2 32 32 8 8 3 3 200 200 2 2 1 1 1 1 1 1 DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 64, 32, Default, 32, 32, 4, 2, 8, 8, 1, 1, 1, 8> 19.8867
3 1 1 1 0 1 0 1 2 32 32 8 8 3 3 100 100 1 2 1 1 1 1 1 1 DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 64, 32, Default, 32, 32, 4, 2, 8, 8, 1, 1, 1, 8> 20.4604
4 1 1 1 0 1 0 1 2 1 32 2376 256 3 3 100 100 1 1 1 1 1 1 1 1 DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 64, Default, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3, 1> 952.806
5 1 1 1 0 1 0 1 2 1 32 256 256 3 3 100 100 1 1 1 1 1 1 1 1 DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 64, Default, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3, 1> 1091.89

View File

@@ -0,0 +1,118 @@
#!/usr/bin/env python3
import subprocess
import sys
from bs4 import BeautifulSoup
import re
import json
def get_mangled_kernel_names_from_html(html_file):
with open(html_file, 'r') as f:
html_content = f.read()
# Parse with BeautifulSoup to find the right script tag
soup = BeautifulSoup(html_content, 'html.parser')
script_tags = soup.find_all('script')
for i, script in enumerate(script_tags):
if script.string:
# Look for scripts containing our kernel signature strings
if 'tensor_operation' in script.string and 'device' in script.string:
print(f"Found kernel data in script tag {i}")
# Now extract the data array from this specific script
match = re.search(r'var\s+data\s*=\s*(\[.*?\]);', script.string, re.DOTALL)
if match:
try:
data_str = match.group(1)
data = json.loads(data_str)
# Extract kernel names
kernel_names = [item['name'] for item in data]
print(f"Found {len(kernel_names)} kernel names:")
for name in kernel_names:
print(f" {name}")
return kernel_names
except json.JSONDecodeError as e:
print(f"Error parsing JSON: {e}")
# Continue to next script tag
print("Could not find kernel data in any script tag")
return []
def de_mangle_name(mangled):
demangler = '/opt/rocm/llvm/bin/llvm-cxxfilt'
try:
result = subprocess.run(
[demangler],
input=mangled,
text=True,
capture_output=True,
check=True
)
demangled = result.stdout.strip()
if demangled and demangled != mangled:
return demangled
except (FileNotFoundError, subprocess.CalledProcessError) as e:
print(f"Error using {demangler}: {e}")
return None
def extract_instance_name(demangled, prefix, prefix_instance):
if not demangled.startswith(prefix + prefix_instance):
return None
# Start after the prefix
start = len(prefix + prefix_instance)
# Track angle bracket depth to find matching closing bracket
depth = 1 # We already counted the opening '<' from prefix
i = start
while i < len(demangled) and depth > 0:
if demangled[i] == '<':
depth += 1
elif demangled[i] == '>':
depth -= 1
i += 1
if depth == 0:
# Extract from start to just before the matching '>'
return prefix_instance + demangled[start:i]
return None
def extract_GridwiseGemmMultiD_xdl_cshuffle_v3(demangled):
prefix = "void ck::tensor_operation::device::(anonymous namespace)::kernel_grouped_conv_fwd_xdl_cshuffle_v3<"
prefix_instance = "ck::GridwiseGemmMultiD_xdl_cshuffle_v3<"
return extract_instance_name(demangled, prefix, prefix_instance)
def extract_GridwiseGemmMultiD_xdl_cshuffle(demangled):
prefix = "void ck::tensor_operation::device::(anonymous namespace)::kernel_grouped_conv_fwd_xdl_cshuffle<"
prefix_instance = "ck::GridwiseGemmMultiD_xdl_cshuffle<"
return extract_instance_name(demangled, prefix, prefix_instance)
if __name__ == "__main__":
# get_mangled_kernel_names_from_html(sys.argv[1])
# Mangled name is the first argument
if len(sys.argv) > 1:
mangled = sys.argv[1]
demangled = de_mangle_name(mangled)
print()
print("Demangled name:")
print(demangled)
v3_instance_name = extract_GridwiseGemmMultiD_xdl_cshuffle_v3(demangled)
if v3_instance_name:
print()
print("Extracted GridwiseGemmMultiD_xdl_cshuffle_v3 instance name:")
print(v3_instance_name)
v1_instance_name = extract_GridwiseGemmMultiD_xdl_cshuffle(demangled)
if v1_instance_name:
print()
print("Extracted GridwiseGemmMultiD_xdl_cshuffle instance name:")
print(v1_instance_name)
else:
print("Please provide a mangled name as an argument.")

View File

@@ -0,0 +1,40 @@
#!/usr/bin/env python3
import numpy as np
import matplotlib
matplotlib.use('Agg') # Set non-interactive backend before importing pyplot
import matplotlib.pyplot as plt
tflops_baseline = [2.85567, 10.3518, 15.4196, 849.773, 990.991]
tflops_improved = [8.09047, 19.8867, 20.4604, 952.806, 1091.89]
# Plot tflops for both baseline and improved
plt.figure(figsize=(10, 6))
indices = np.arange(len(tflops_baseline))
bar_width = 0.35
plt.bar(indices, tflops_baseline, bar_width, label='Baseline', color='b')
plt.bar(indices + bar_width, tflops_improved, bar_width, label='Improved', color='g')
plt.xlabel('Test Cases')
plt.ylabel('TFLOPS')
plt.title('TFLOPS Comparison: Baseline vs Improved')
plt.xticks(indices + bar_width / 2, ['Case 1', 'Case 2', 'Case 3', 'Case 4', 'Case 5'])
plt.yscale('log')
plt.legend()
plt.grid(axis='y')
plt.tight_layout()
plt.savefig('tflops_comparison.png')
plt.close()
# plot improvement factor
improvement_factor = [100 *(improved / baseline)for improved, baseline in zip(tflops_improved, tflops_baseline)]
improvement_factor = [factor - 100 for factor in improvement_factor] # Convert to percentage improvement
plt.figure(figsize=(10, 6))
plt.bar(indices, improvement_factor, color='orange')
plt.xlabel('Test Cases')
plt.ylabel('Improvement (%)')
plt.title('Improvement')
plt.xticks(indices, ['Case 1', 'Case 2', 'Case 3', 'Case 4', 'Case 5'])
plt.grid(axis='y')
plt.tight_layout()
plt.savefig('improvement.png')
plt.close()

View File

@@ -0,0 +1,70 @@
#!/usr/bin/env python3
import os
import subprocess
import sys
import argparse
# Run under rocprof-compute
# HIP_VISIBLE_DEVICES=7 rockprof-compute profile -n grouped_conv_fwd --roofline-data-type FP16 -- ./run-best-instances.py --profiler-path ../build-improved-convs/bin/ckProfiler
# Set no verify, and timing of the kernel. Multiple calls of the same kernel inside the script will confuse the profiler.
profiler_commands = [
"1 1 1 0 1 0 0 2 32 32 4 4 3 3 200 200 1 1 1 1 1 1 1 1",
"1 1 1 0 1 0 0 2 32 32 8 8 3 3 200 200 2 2 1 1 1 1 1 1",
#"1 1 1 0 1 0 0 2 32 32 8 8 3 3 100 100 1 2 1 1 1 1 1 1",
"1 1 1 0 1 0 0 2 1 32 2376 256 3 3 100 100 1 1 1 1 1 1 1 1",
#"1 1 1 0 1 0 0 2 1 32 256 256 3 3 100 100 1 1 1 1 1 1 1 1"
]
baseline_instances = [
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 64, 64, 32, Default, 16, 16, 2, 2, 4, 4, 4, 1, 1, 1>",
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<128, 16, 64, 64, Default, 16, 16, 1, 2, 8, 8, 4, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1, 1>",
#"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_DirectLoad<256, 256, 32, 64, Default, 32, 32, 2, 1, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v1, 1>",
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 256, 128, 32, Default, 32, 32, 4, 2, 8, 8, 8, 1, 1, 1>",
#"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 256, 128, 32, OddC, 32, 32, 4, 2, 8, 8, 8, 1, 1, 1>"
]
improved_instances = [
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 32, 32, Default, 32, 32, 2, 1, 4, 4, 1, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1, 8>",
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 64, 32, Default, 32, 32, 4, 2, 8, 8, 1, 1, 1, 8>",
#"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 64, 32, Default, 32, 32, 4, 2, 8, 8, 1, 1, 1, 8>",
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 64, Default, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3, 1>",
#"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 64, Default, 32, 32, 4, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3, 1>"
]
def main():
# Parse command-line arguments
parser = argparse.ArgumentParser(description='Run CK profiler with best instances for given conv shapes.')
parser.add_argument('--profiler-path', type=str, required=True, help='Path to the profiler binary')
parser.add_argument('--baseline', action='store_true',
help='Run baseline instances (default: run improved instances)')
parser.add_argument("--print-stdout", action='store_true', help='Print CK profiler output to stdout')
args = parser.parse_args()
instances_to_run = baseline_instances if args.baseline else improved_instances
instance_type = "baseline" if args.baseline else "improved"
print(f"Running {instance_type} instances...\n")
ck_profiler_path = args.profiler_path
if not os.path.isfile(ck_profiler_path):
print(f"Error: Profiler binary not found at {ck_profiler_path}")
sys.exit(1)
for i in range(len(profiler_commands)):
command = profiler_commands[i]
instance = instances_to_run[i]
profiler_args = [x for x in command.split()]
profiler_args.append(instance)
print(f"Running profiler for {instance_type} instance {i+1}/{len(profiler_commands)}:")
print(instance)
res = subprocess.run([ck_profiler_path] + ["grouped_conv_fwd"] + profiler_args, check=True, timeout=300,
capture_output=True, text=True)
if args.print_stdout:
print(res.stdout)
print()
if __name__ == "__main__":
main()

View File

@@ -8,6 +8,9 @@ add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_con
add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp)
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16)
add_example_executable(example_grouped_conv_fwd_xdl_bf16 grouped_conv_fwd_xdl_bf16.cpp)
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_bf16)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp)
add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32)

View File

@@ -0,0 +1,24 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "common.hpp"
// kernel data types
using InKernelDataType = BF16;
using WeiKernelDataType = BF16;
using AccDataType = FP32;
using CShuffleDataType = BF16;
using OutKernelDataType = BF16;
// tensor data types
using InUserDataType = InKernelDataType;
using WeiUserDataType = WeiKernelDataType;
using OutUserDataType = OutKernelDataType;
using InElementOp = PassThrough;
using WeiElementOp = PassThrough;
using OutElementOp = PassThrough;
#include "run_grouped_conv_fwd_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_example(argc, argv); }

View File

@@ -179,7 +179,7 @@ ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimS
/*
template <ck::index_t NDimSpatial>
using DeviceConvFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
NDimSpatial,
InputLayout<NDimSpatial>,
WeightLayout<NDimSpatial>,
@@ -196,7 +196,6 @@ using DeviceConvFwdInstance =
OutElementOp,
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
1, //
256, // BlockSize
128, // MPerBlock
256, // NPerBlock

View File

@@ -30,7 +30,9 @@ template <index_t BlockSize,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
bool TransposeC = false>
bool TransposeC = false,
bool ALdsScalarLoadToVgpr = false,
bool BLdsScalarLoadToVgpr = false>
struct BlockwiseGemmXdlops_pipeline_base
{
static constexpr auto I0 = Number<0>{};
@@ -385,7 +387,7 @@ struct BlockwiseGemmXdlops_pipeline_base
Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2, 3>,
3,
A_K1,
ALdsScalarLoadToVgpr ? 1 : A_K1,
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
@@ -395,7 +397,7 @@ struct BlockwiseGemmXdlops_pipeline_base
Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2, 3>,
3,
B_K1,
BLdsScalarLoadToVgpr ? 1 : B_K1,
B_K1>;
AThreadCopy a_thread_copy_;

View File

@@ -32,9 +32,16 @@ template <BlockGemmPipelineVersion BlkGemmPipelineVer,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
bool DirectLoad = false>
bool DirectLoad = false,
bool ALdsScalarLoadToVgpr = false,
bool BLdsScalarLoadToVgpr = false>
constexpr auto BlockGemmPipeline_Selector()
{
// Supported for Direct Load and V1
if constexpr(ALdsScalarLoadToVgpr || BLdsScalarLoadToVgpr)
{
static_assert(DirectLoad && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1);
}
if constexpr(DirectLoad)
{
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
@@ -58,7 +65,9 @@ constexpr auto BlockGemmPipeline_Selector()
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
KPack,
ALdsScalarLoadToVgpr,
BLdsScalarLoadToVgpr>{};
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{

View File

@@ -758,7 +758,9 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPacks>
index_t KPacks,
bool ALdsScalarLoadToVgpr = false,
bool BLdsScalarLoadToVgpr = false>
struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1
{
};
@@ -781,9 +783,10 @@ template <index_t BlockSize,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack
index_t KPack,
// ,bool TransposeC //disable transposec right now...
>
bool ALdsScalarLoadToVgpr,
bool BLdsScalarLoadToVgpr>
struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
BlockSize,
ADataType,
@@ -803,7 +806,9 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
NPerXDL,
MRepeat,
NRepeat,
KPack>
KPack,
ALdsScalarLoadToVgpr,
BLdsScalarLoadToVgpr>
: BlockwiseGemmXdlops_pipeline_base<BlockSize,
ADataType,
BDataType,
@@ -822,7 +827,10 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
NPerXDL,
MRepeat,
NRepeat,
KPack>
KPack,
false /*TransposeC*/,
ALdsScalarLoadToVgpr,
BLdsScalarLoadToVgpr>
{
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
@@ -843,7 +851,10 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
NPerXDL,
MRepeat,
NRepeat,
KPack>;
KPack,
false /*TransposeC*/,
ALdsScalarLoadToVgpr,
BLdsScalarLoadToVgpr>;
using Base::I0;
using Base::KRepeat;
using Base::xdlops_gemm;

View File

@@ -140,10 +140,6 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
"Direct load transfer does not support datatypes conversion. Source and "
"destination data types must be the same.");
static_assert(
DstVectorDim == nDim - 1,
"Direct load transfer requires the destination vector dimension to be the last one.");
static_assert(ScalarPerVector == 1 || SrcVectorDim == DstVectorDim,
"When loading more than one element per thread at once, the contiguous "
"dimension must be the same between source and destination.");

View File

@@ -11,8 +11,6 @@ namespace ck {
namespace tensor_operation {
namespace device {
#define DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS 1
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,

View File

@@ -11,8 +11,6 @@ namespace ck {
namespace tensor_operation {
namespace device {
#define DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS 1
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,

View File

@@ -162,7 +162,6 @@ struct DeviceGroupedConvBwdWeight_Explicit
}
else
{
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(split_k < 0)
{
const auto max_occupancy = DeviceGemmV3Op::GetMaxOccupancy();
@@ -171,9 +170,11 @@ struct DeviceGroupedConvBwdWeight_Explicit
DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize);
const index_t grid_size = gdx * gdy * gdz;
k_batch_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size);
// Cap k_batch_ to 128 to avoid accuracy issues
k_batch_ = std::min(k_batch_, 128);
}
else
#endif
{
k_batch_ = split_k;
}
@@ -338,16 +339,6 @@ struct DeviceGroupedConvBwdWeight_Explicit
static bool IsSupportedArgument(const Argument& arg)
{
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if constexpr(!IsTwoStageNeeded)
{
if(arg.k_batch_ < 0)
{
return false;
}
}
#endif
if constexpr(NDimSpatial == 2)
{
if constexpr(!is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>())

View File

@@ -22,6 +22,7 @@
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
@@ -524,6 +525,44 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
decltype(GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
CGridDesc_M_N{}, 1, 1));
struct ActiveWorkgroupsPerCU
{
ActiveWorkgroupsPerCU()
{
if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported())
{
return;
}
constexpr int dynamic_smem_size = 0;
constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
int max_occupancy = 0;
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
// TODO: implement
}
else
{
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&max_occupancy,
kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>,
BlockSize,
dynamic_smem_size));
}
max_occupancy_ = std::max(1, max_occupancy);
}
int max_occupancy_;
};
struct Argument : public BaseArgument, public ArgumentSplitK
{
Argument(
@@ -574,6 +613,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
static ActiveWorkgroupsPerCU active_workgroups_per_cu;
constexpr index_t spatial_offset = 3;
std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset,
end(b_g_n_c_wis_lengths),
@@ -585,7 +626,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
end(a_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(split_k < 0)
{
ck::index_t gemmM, gemmN, gemmK;
@@ -602,6 +642,9 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
const auto k_batch_max = math::integer_divide_ceil((gemmK - 1), KPerBlock);
k_batch_ = std::min(k_batch_, k_batch_max);
// Cap k_batch_ to 128 to avoid accuracy issues
k_batch_ = std::min(k_batch_, 128);
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max
@@ -611,7 +654,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
}
}
else
#endif
{
k_batch_ = split_k;
}
@@ -988,13 +1030,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
static bool IsSupportedArgument(const Argument& arg)
{
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(arg.k_batch_ < 0)
{
return false;
}
#endif
const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1);
const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) *

View File

@@ -677,7 +677,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
end(a_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(split_k < 0)
{
ck::index_t gemmM, gemmN;
@@ -688,9 +687,11 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
calculate_mn_grid_size<MPerBlock, NPerBlock>(gemmM, gemmN) * Conv_G_;
k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_,
grid_size);
// Cap k_batch_ to 128 to avoid accuracy issues
k_batch_ = std::min(k_batch_, 128);
}
else
#endif
{
k_batch_ = split_k;
}
@@ -947,12 +948,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(arg.k_batch_ < 0)
{
return false;
}
#endif
if(!ck::is_xdl_wmma_supported<ComputeTypeA, ComputeTypeB, MPerXDL, NPerXDL>())
{
return false;

View File

@@ -511,7 +511,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
std::copy(begin(a_g_n_k_wos_lengths) + spatial_offset,
end(a_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(split_k < 0)
{
ck::index_t gemmM, gemmN, gemmK;
@@ -528,6 +528,9 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
const auto k_batch_max = math::integer_divide_ceil((gemmK - 1), KPerBlock);
k_batch_ = std::min(k_batch_, k_batch_max);
// Cap k_batch_ to 128 to avoid accuracy issues
k_batch_ = std::min(k_batch_, 128);
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max
@@ -537,7 +540,6 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
}
}
else
#endif
{
k_batch_ = split_k;
}
@@ -1040,12 +1042,6 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
static bool IsSupportedArgument(const Argument& arg)
{
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(arg.k_batch_ < 0)
{
return false;
}
#endif
const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1);
const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) *

View File

@@ -651,7 +651,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths,
e_g_k_c_xs_strides);
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(split_k < 0)
{
ck::index_t gemmM, gemmN;
@@ -662,9 +661,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
calculate_mn_grid_size<MPerBlock, NPerBlock>(gemmM, gemmN) * Conv_G_;
k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_,
grid_size);
// Cap k_batch_ to 128 to avoid accuracy issues
k_batch_ = std::min(k_batch_, 128);
}
else
#endif
{
k_batch_ = split_k;
}
@@ -1083,12 +1084,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(arg.k_batch_ < 0)
{
return false;
}
#endif
if(!ck::is_xdl_wmma_supported<ComputeTypeA, ComputeTypeB, MPerXDL, NPerXDL>())
{
return false;

View File

@@ -236,7 +236,9 @@ template <ck::index_t NDimSpatial,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA = InDataType,
typename ComputeTypeB = ComputeTypeA>
typename ComputeTypeB = ComputeTypeA,
bool DirectLoad = false,
index_t NumGroupsToMerge = 1>
struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
: public DeviceGroupedConvBwdWeight<NDimSpatial,
InLayout,
@@ -287,7 +289,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
NPerBlock,
K1Number,
K0PerBlock / K1Number,
1 /*NumGroupsToMerge*/,
NumGroupsToMerge,
ConvBackwardWeightSpecialization>{};
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
@@ -371,10 +373,30 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
static constexpr index_t ABlockTransferSrcScalarPerVectorAligned =
ABlockTransferSrcScalarPerVector * sizeof(ADataType) == 8
? 4 / sizeof(ADataType)
: ABlockTransferSrcScalarPerVector;
static constexpr index_t BBlockTransferSrcScalarPerVectorAligned =
BBlockTransferSrcScalarPerVector * sizeof(BDataType) == 8
? 4 / sizeof(BDataType)
: BBlockTransferSrcScalarPerVector;
static constexpr bool ALdsScalarLoadToVgpr = (DirectLoad && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ? true : false);
static constexpr bool BLdsScalarLoadToVgpr = (DirectLoad && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ? true : false);
// Note: Direct load use layout to create proper block and mmtile descriptor
// TODO: Fix and verify RC layout for not direct load (currently it returns wrong results)
template <index_t NXdlPerWave_>
using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_conv_v3<
tensor_layout::gemm::RowMajor,
tensor_layout::gemm::ColumnMajor,
std::conditional_t<
DirectLoad,
tensor_layout::gemm::ColumnMajor,
tensor_layout::gemm::RowMajor>,
std::conditional_t<
DirectLoad,
tensor_layout::gemm::RowMajor,
tensor_layout::gemm::ColumnMajor>,
tensor_layout::gemm::RowMajor,
ADataType,
BDataType,
@@ -399,7 +421,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
DirectLoad ? ABlockTransferSrcScalarPerVectorAligned : ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false,
ABlockLdsAddExtraM,
@@ -407,7 +429,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
DirectLoad ? BBlockTransferSrcScalarPerVectorAligned : BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false,
BBlockLdsAddExtraN,
@@ -418,7 +440,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
ComputeTypeB,
DirectLoad,
ALdsScalarLoadToVgpr,
BLdsScalarLoadToVgpr>;
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
@@ -556,7 +581,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
end(a_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(split_k < 0)
{
ck::index_t gemmM, gemmN, gemmK;
@@ -573,6 +597,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
const auto k_batch_max = static_cast<index_t>((gemmK - 1) / K0PerBlock);
k_batch_ = std::max(std::min(k_batch_, k_batch_max), 1);
// Cap k_batch_ to 128 to avoid accuracy issues
k_batch_ = std::min(k_batch_, 128);
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max
@@ -582,7 +609,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
}
}
else
#endif
{
k_batch_ = split_k;
}
@@ -653,15 +679,16 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
if(split_k_offset_hack_)
split_k_stride_b_ /= k_batch_;
// A/B/C Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides[0];
// A/B/C Batch Stride (multiply by NumGroupsToMerge for group merging)
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0] * NumGroupsToMerge;
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides[0] * NumGroupsToMerge;
compute_ptr_offset_of_batch_.BatchStrideC_ =
Conv_K_ * Conv_C_ *
std::accumulate(begin(filter_spatial_lengths_),
end(filter_spatial_lengths_),
index_t{1},
std::multiplies<>{});
std::multiplies<>{}) *
NumGroupsToMerge;
const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1);
const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1);
@@ -743,7 +770,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(
gemm_arg.M, gemm_arg.N, gemm_arg.KBatch, arg.Conv_G_);
gemm_arg.M, gemm_arg.N, gemm_arg.KBatch, arg.Conv_G_ / NumGroupsToMerge);
float ave_time = 0;
@@ -1360,12 +1387,30 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
static bool IsSupportedArgument(const Argument& arg)
{
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(arg.k_batch_ < 0)
// check device
if constexpr(DirectLoad)
{
return false;
if(get_device_name() != "gfx950")
{
return false;
}
}
// Check that NumGroupsToMerge divides Conv_G evenly
if constexpr(NumGroupsToMerge > 1)
{
if(arg.Conv_G_ % NumGroupsToMerge != 0)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Unsupported! Conv_G_ % NumGroupsToMerge != 0: Conv_G_="
<< arg.Conv_G_ << ", NumGroupsToMerge=" << NumGroupsToMerge
<< std::endl;
}
return false;
}
}
#endif
const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1);
const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1);
@@ -1617,8 +1662,13 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3"
<< "<"
str << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3";
if constexpr(DirectLoad) {
str << "_DirectLoad";
}
str << "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
@@ -1633,7 +1683,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
<< BBlockTransferDstScalarPerVector_K1 << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< CBlockTransferScalarPerVector_NWaveNPerXdl
<< CBlockTransferScalarPerVector_NWaveNPerXdl << ", "
<< NumGroupsToMerge
<< ">";
// clang-format on

View File

@@ -787,7 +787,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
cde_element_op_{cde_element_op}
{
// A/B/E Batch/N Stride
compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides_[0] * NumGroupsToMerge;
compute_ptr_offset_of_groups_.BatchStrideA_ =
a_g_n_c_wis_strides_[0] * NumGroupsToMerge;
compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides_[0] * NumGroupsToMerge;
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides_[1] * conv_N_per_block_;
@@ -799,7 +800,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
// D batch stride
compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides_[i][0] * NumGroupsToMerge;
compute_ptr_offset_of_groups_.BatchStrideDs_(i) =
ds_g_n_k_wos_strides_[i][0] * NumGroupsToMerge;
compute_ptr_offset_of_n_.BatchStrideDs_(i) =
ds_g_n_k_wos_strides_[i][1] * conv_N_per_block_;
@@ -819,7 +821,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer_d);
});
compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides_[0] * NumGroupsToMerge;
compute_ptr_offset_of_groups_.BatchStrideE_ =
e_g_n_k_wos_strides_[0] * NumGroupsToMerge;
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides_[1] * conv_N_per_block_;
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
@@ -1611,10 +1614,21 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
}
}
}
else if constexpr (ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter3x3Stride1Pad1Dilation1)
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter3x3)
{
if(C != 1)
{
return false;
}
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t filter_spatial_dim = arg.b_g_k_c_xs_lengths_[i + I3];
if(filter_spatial_dim != I3)
{
return false;
}
}
}
// check vector access of A
@@ -2165,7 +2179,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<< "BlkGemmPipelineScheduler: "
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< NumGroupsToMerge
<< ">";
// clang-format on

View File

@@ -12,6 +12,7 @@
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp"
namespace ck {
@@ -61,108 +62,32 @@ template <typename ALayout,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v4,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA>
typename ComputeTypeB = ComputeTypeA,
bool DirectLoad = false,
bool ALdsScalarLoadToVgpr = false,
bool BLdsScalarLoadToVgpr = false>
struct GridwiseGemm_xdl_cshuffle_conv_v3
: public GridwiseGemm_xdl_cshuffle_base<
ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
Tuple<>,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1Value,
BK1Value,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsExtraMCustom,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsExtraNCustom,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<CShuffleBlockTransferScalarPerVector_NPerBlock>,
ComputeTypeA,
ComputeTypeB,
false> // ForceNaiveLayout
{
using Base = GridwiseGemm_xdl_cshuffle_base<
ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
Tuple<>,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1Value,
BK1Value,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsExtraMCustom,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsExtraNCustom,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<CShuffleBlockTransferScalarPerVector_NPerBlock>,
ComputeTypeA,
ComputeTypeB,
false>; // ForceNaiveLayout
static_assert((is_same_v<AElementwiseOperation, tensor_operation::element_wise::PassThrough> &&
is_same_v<BElementwiseOperation, tensor_operation::element_wise::PassThrough>) ||
!DirectLoad);
using Base::AK0Number;
using Base::AK1Number;
using Base::BK0Number;
using Base::BK1Number;
using Base::I0;
using Base::I1;
using Base::I2;
using ThisThreadBlock = typename Base::ThisThreadBlock;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
// K1 should be Number<...>
static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1Number = Number<AK1Value>{};
static constexpr auto BK1Number = Number<BK1Value>{};
static constexpr bool DirectLoadEnabled = DirectLoad;
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
static constexpr bool is_single_rate_mfma =
@@ -238,19 +163,90 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
return math::integer_divide_ceil(N, NPerBlock);
}
template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
template <typename GridDesc_K0_MN_K1_T, index_t K0Number, index_t K1Value>
__host__ __device__ static auto TransformGrid(const GridDesc_K0_MN_K1_T& desc)
{
if constexpr(!DirectLoad)
{
return desc;
}
else
{
const index_t K = desc.GetLength(I0) * desc.GetLength(I2);
const index_t MN = desc.GetLength(I1);
const auto desc_unmerged = transform_tensor_descriptor(
desc,
make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, K0Number)),
make_pass_through_transform(MN),
make_pass_through_transform(K1Value)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
const auto desc_permuted = transform_tensor_descriptor(
desc_unmerged,
make_tuple(make_pass_through_transform(K / KPerBlock),
make_xor_with_modulo_transform(make_tuple(MN, K0Number)),
make_pass_through_transform(K1Value)),
make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}));
return transform_tensor_descriptor(
desc_permuted,
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, K0Number)),
make_pass_through_transform(MN),
make_pass_through_transform(K1Value)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
}
}
template <index_t MNXdlPerWave,
index_t MNWaves,
index_t MNPerXdl,
bool IsKContinous,
typename TileDesc_K0_MN_K1>
__host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
{
constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
if constexpr(DirectLoad && IsKContinous)
{
constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
return transform_tensor_descriptor(
TileDesc_K0_MN_K1{},
make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
make_unmerge_transform(make_tuple(
Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
constexpr index_t MN = TileDesc_K0_MN_K1{}.GetLength(Number<1>{});
constexpr auto desc = transform_tensor_descriptor(
TileDesc_K0_MN_K1{},
make_tuple(make_xor_with_modulo_transform(make_tuple(Number<MN>{}, Number<K0>{})),
make_pass_through_transform(Number<K1>{})),
make_tuple(Sequence<1, 0>{}, Sequence<2>{}),
make_tuple(Sequence<1, 0>{}, Sequence<2>{}));
return transform_tensor_descriptor(
desc,
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
make_unmerge_transform(
make_tuple(Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}
else
{
constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
return transform_tensor_descriptor(
TileDesc_K0_MN_K1{},
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
make_unmerge_transform(
make_tuple(Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}
}
template <typename ABlockDesc_AK0_M_AK1>
@@ -259,7 +255,11 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
{
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
return MakeGemmMmaTileDescriptor<MXdlPerWave,
MWaves,
MPerXdl,
is_same<tensor_layout::gemm::RowMajor, ALayout>::value>(
ABlockDesc_AK0_M_AK1{});
}
template <typename BBlockDesc_BK0_N_BK1>
@@ -268,7 +268,11 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
{
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(BBlockDesc_BK0_N_BK1{});
return MakeGemmMmaTileDescriptor<NXdlPerWave,
NWaves,
NPerXdl,
is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value>(
BBlockDesc_BK0_N_BK1{});
}
struct Problem
@@ -353,26 +357,195 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
template <typename DeviceArch>
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(DeviceArch)
{
if constexpr(is_same_v<DeviceArch, gfx950_t>)
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr index_t WaveSize = BlockSize / (MWave * NWave);
#if defined(__gfx950__)
// Force use padded layout on gfx950 to reduce bank conflicts
constexpr index_t ABlockLdsExtraM = 1;
#else
constexpr index_t ABlockLdsExtraM = ABlockLdsExtraMCustom;
#endif
// A matrix in LDS memory, dst of blockwise copy
if constexpr(DirectLoad)
{
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
make_tuple(AK1Number, Number<KPerBlock>{}, I1));
}
else
{
return make_naive_tensor_descriptor(
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
make_tuple(Number<MPerBlock * AK1Number>{}, I1, Number<MPerBlock>{}));
}
}
// A matrix in LDS memory, dst of blockwise copy
else if constexpr(ABlockLdsExtraM)
{
// Force use padded layout on gfx950 to reduce bank conflicts
constexpr index_t ABlockLdsExtraM = 1;
return make_naive_tensor_descriptor(
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1Number, AK1Number, I1));
}
else
// xor tensor transformation request more unnecessary vgpr usage, would cause register spill
// in some cases.
else if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(DeviceArch{});
constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1
? 1
: 32 * 4 / KPerBlock / sizeof(ADataType);
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
make_tuple(
AK0Number * Number<MLdsLayer>{}, Number<MPerBlock / MLdsLayer>{}, AK1Number),
make_tuple(AK1Number, Number<KPerBlock * MLdsLayer>{}, I1));
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc,
make_tuple(make_xor_with_modulo_transform(make_tuple(
Number<MPerBlock / MLdsLayer>{}, Number<AK0Number * MLdsLayer>{})),
make_pass_through_transform(AK1Number)),
make_tuple(Sequence<1, 0>{}, Sequence<2>{}),
make_tuple(Sequence<1, 0>{}, Sequence<2>{}));
constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number<MLdsLayer>{})),
make_pass_through_transform(Number<MPerBlock / MLdsLayer>{}),
make_pass_through_transform(AK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}));
constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_lds_block_desc_ak0_mldslayer_m_ak1,
make_tuple(make_pass_through_transform(AK0Number),
make_merge_transform_v3_division_mod(
make_tuple(Number<MPerBlock / MLdsLayer>{}, Number<MLdsLayer>{})),
make_pass_through_transform(AK1Number)),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return a_lds_block_desc_ak0_m_ak1;
}
else // ColumnMajor A
{
// kfold and mpair dimension is not always required.
// more dimension in merge_transform increase the difficulty of generating immarg offset
// for compiler.
constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
constexpr auto M1 = MPerBlock / M0;
constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
constexpr auto KThreadRead = WaveSize / MPerXdl;
constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
? 1
: 128 / (AK1Number * M0 * sizeof(ADataType));
constexpr auto KThreadReadPerm =
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead;
// 1<=mpair<=n0
constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128)
? 1
: ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0
? M0
: 128 / (AK1Number * MPerXdl * sizeof(ADataType)));
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<KThreadWrite / kfold / KThreadReadPerm>{},
Number<K0PerThreadWrite>{},
Number<KThreadReadPerm * M1>{},
Number<kfold * M0 / mpair>{},
Number<mpair>{},
AK1Number));
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc,
make_tuple(
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(Number<K0PerThreadWrite>{}),
make_xor_with_modulo_transform(
make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
make_pass_through_transform(Number<mpair>{}),
make_pass_through_transform(AK1Number)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}));
constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(Number<K0PerThreadWrite>{}),
make_unmerge_transform(make_tuple(Number<KThreadReadPerm>{}, Number<M1>{})),
make_unmerge_transform(make_tuple(Number<kfold>{}, Number<M0 / mpair>{})),
make_pass_through_transform(Number<mpair>{}),
make_pass_through_transform(AK1Number)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<1>{},
Sequence<2>{},
Sequence<0, 3>{},
Sequence<4, 5>{},
Sequence<6>{},
Sequence<7>{}));
constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_lds_block_desc_unmerged,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(Number<KThreadReadPerm>{},
Number<KThreadWrite / kfold / KThreadReadPerm>{},
Number<kfold>{},
Number<K0PerThreadWrite>{})),
make_merge_transform_v3_division_mod(
make_tuple(Number<M0 / mpair>{}, Number<mpair>{}, Number<M1>{})),
make_pass_through_transform(AK1Number)),
make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return a_lds_block_desc_ak0_m_ak1;
}
}
template <typename DeviceArch>
__device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(DeviceArch)
__device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
if constexpr(is_same_v<DeviceArch, gfx950_t>)
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr index_t WaveSize = BlockSize / (MWave * NWave);
#if defined(__gfx950__)
// Force use padded layout on gfx950 to reduce bank conflicts
constexpr index_t BBlockLdsExtraN = 1;
#else
constexpr index_t BBlockLdsExtraN = BBlockLdsExtraNCustom;
#endif
// B matrix in LDS memory, dst of blockwise copy
if constexpr(DirectLoad)
{
if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
make_tuple(BK1Number, Number<KPerBlock>{}, I1));
}
else
{
return make_naive_tensor_descriptor(
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
make_tuple(Number<NPerBlock * BK1Number>{}, I1, Number<NPerBlock>{}));
}
}
else if constexpr(BBlockLdsExtraN)
{
constexpr index_t BBlockLdsExtraN = 1;
return make_naive_tensor_descriptor(
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1Number, BK1Number, I1));
@@ -385,31 +558,37 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType)
using BlockwiseGemmPipe = remove_cvref_t<
decltype(BlockGemmPipeline_Selector<
BlkGemmPipelineVer,
BlkGemmPipeSched,
BlockSize,
ADataType,
BDataType,
ComputeTypeA,
AccDataType,
decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch())),
decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch())),
decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()))),
decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()))),
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack>())>;
// Disable vector load from lds to vgpr for direct load (backward weight store with continous M
// or N dimension)
//static constexpr bool LdsScalarLoadToVgpr = DirectLoad;
using BlockwiseGemmPipe = remove_cvref_t<
decltype(BlockGemmPipeline_Selector<
BlkGemmPipelineVer,
BlkGemmPipeSched,
BlockSize,
ADataType,
BDataType,
ComputeTypeA,
AccDataType,
decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()),
decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()),
decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())),
decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())),
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
DirectLoad,
ALdsScalarLoadToVgpr,
BLdsScalarLoadToVgpr>())>;
template <typename DeviceArch>
__device__ static constexpr index_t GetSharedMemoryNumberOfByte(DeviceArch)
@@ -489,8 +668,9 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const index_t k_id = 0,
const index_t k_batch = 1)
const index_t k_id = 0,
const index_t k_batch = 1,
const index_t block_idx_x = static_cast<index_t>(blockIdx.x))
{
const long_index_t a_space_size_divisor = SplitKOffsetHack ? k_batch : 1;
const long_index_t b_space_size_divisor = SplitKOffsetHack ? k_batch : 1;
@@ -507,8 +687,8 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
// divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(
make_multi_index(static_cast<index_t>(blockIdx.x)));
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(block_idx_x));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
@@ -536,70 +716,113 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch());
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch());
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
ADataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1,
make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
auto get_a_blockwise_copy = [&]() {
if constexpr(DirectLoad)
{
return ThreadGroupTensorSliceTransfer_DirectLoad < ThisThreadBlock,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ADataType, ADataType,
decltype(a_grid_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim,
is_same<tensor_layout::gemm::RowMajor, ALayout>::value ? 2 : 1,
ABlockTransferSrcScalarPerVector >
(a_grid_desc_ak0_m_ak1,
make_multi_index(
SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0));
}
else
{
return ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
ADataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1,
make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
}
};
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
BDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
b_grid_desc_bk0_n_bk1,
make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
auto get_b_blockwise_copy = [&]() {
if constexpr(DirectLoad)
{
return ThreadGroupTensorSliceTransfer_DirectLoad < ThisThreadBlock,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BDataType, BDataType,
decltype(b_grid_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim,
is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value ? 2 : 1,
BBlockTransferSrcScalarPerVector >
(b_grid_desc_bk0_n_bk1,
make_multi_index(
SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0));
}
else
{
return ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
BDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
b_grid_desc_bk0_n_bk1,
make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
}
};
auto a_blockwise_copy = get_a_blockwise_copy();
auto b_blockwise_copy = get_b_blockwise_copy();
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
@@ -670,8 +893,9 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const index_t k_id = 0,
const index_t k_batch = 1)
const index_t k_id = 0,
const index_t k_batch = 1,
const index_t block_idx_x = static_cast<index_t>(blockIdx.x))
{
const long_index_t a_space_size_divisor = SplitKOffsetHack ? k_batch : 1;
const long_index_t b_space_size_divisor = SplitKOffsetHack ? k_batch : 1;
@@ -691,7 +915,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(
make_multi_index(static_cast<index_t>(blockIdx.x)));
make_multi_index(static_cast<index_t>(block_idx_x)));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
@@ -719,70 +943,113 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch());
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch());
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
ADataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1,
make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
auto get_a_blockwise_copy = [&]() {
if constexpr(DirectLoad)
{
return ThreadGroupTensorSliceTransfer_DirectLoad < ThisThreadBlock,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ADataType, ADataType,
decltype(a_grid_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim,
is_same<tensor_layout::gemm::RowMajor, ALayout>::value ? 2 : 1,
ABlockTransferSrcScalarPerVector >
(a_grid_desc_ak0_m_ak1,
make_multi_index(
SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0));
}
else
{
return ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
ADataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1,
make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
}
};
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
BDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
b_grid_desc_bk0_n_bk1,
make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
auto get_b_blockwise_copy = [&]() {
if constexpr(DirectLoad)
{
return ThreadGroupTensorSliceTransfer_DirectLoad < ThisThreadBlock,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BDataType, BDataType,
decltype(b_grid_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim,
is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value ? 2 : 1,
BBlockTransferSrcScalarPerVector >
(b_grid_desc_bk0_n_bk1,
make_multi_index(
SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0));
}
else
{
return ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
BDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
b_grid_desc_bk0_n_bk1,
make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
}
};
auto a_blockwise_copy = get_a_blockwise_copy();
auto b_blockwise_copy = get_b_blockwise_copy();
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(

View File

@@ -0,0 +1,115 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using BF8 = ck::bf8_t;
using F8 = ck::f8_t;
using Empty_Tuple = ck::Tuple<>;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using namespace ck::tensor_layout::convolution;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvBwdDataDefault = ConvolutionBackwardDataSpecialization::Default;
static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 =
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionBackwardDataSpecialization ConvSpec>
using device_grouped_conv_bwd_data_xdl_v3_f16_instances = std::tuple<
// clang-format off
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer|
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat | _MBlock_MPerBlock| ScalarPerVector|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| |
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 16, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 16, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 128, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<2, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 16, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 128, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<2, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 16, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<8, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 64, 64, 8, 8, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 64, 64, 8, 8, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 32, 1, 4>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 1, 0, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 32, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 64, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 16, 64, 8, 8, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 64, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 128, 64, 8, 8, 32, 32, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 16, 4>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 128, 64, 8, 8, 32, 32, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 16, 4>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 256, 64, 8, 8, 32, 32, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 16, 4>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 32, 64, 8, 8, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<8, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>
// clang-format on
>;
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionBackwardDataSpecialization ConvSpec>
using device_grouped_conv_bwd_data_xdl_v3_bf16_instances = std::tuple<
// clang-format off
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer|
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat | _MBlock_MPerBlock| ScalarPerVector|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| |
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 16, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 16, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 128, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<2, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 16, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 128, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<2, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 16, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<8, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 64, 64, 8, 8, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 64, 64, 8, 8, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 32, 1, 4>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 1, 0, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 32, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 64, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 16, 64, 8, 8, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 64, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 128, 64, 8, 8, 32, 32, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 16, 4>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 128, 64, 8, 8, 32, 32, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 16, 4>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 256, 64, 8, 8, 32, 32, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 16, 4>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 32, 64, 8, 8, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<8, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -93,14 +93,68 @@ using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_instances = std::tuple
// generic instance
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 32, 32, 1, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 32, 32, 1, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion, F16, F16, false, 2>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 32, 32, 1, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion, F16, F16, false, 4>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 32, 32, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 32, 32, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion, F16, F16, false, 2>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 32, 32, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion, F16, F16, false, 4>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 32, 32, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 80, 32, 8, 16, 16, 4, 5, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 5, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 112, 32, 8, 16, 16, 4, 7, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 7, 4, false, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>
// clang-format on
>;
template <ck::index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename ELayout,
ConvolutionBackwardWeightSpecialization ConvSpec,
BlockGemmPipelineScheduler Scheduler,
BlockGemmPipelineVersion PipelineVersion>
using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_direct_load_instances = std::tuple<
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| Compute| Compute| Direct|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| Data| Data| Load|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| Type| Type| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | | |
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 16, 32, 64, 8, 16, 16, 1, 1, S<8, 2, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, F16, F16, true>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 16, 32, 64, 8, 16, 16, 1, 1, S<8, 2, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, F16, F16, true, 2>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 16, 64, 64, 8, 16, 16, 1, 2, S<8, 2, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<2, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, F16, F16, true>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 16, 64, 64, 8, 16, 16, 1, 2, S<8, 2, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<2, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, F16, F16, true, 2>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 32, 64, 8, 32, 32, 2, 1, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<8, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Scheduler, PipelineVersion, F16, F16, true>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 64, 8, 32, 32, 1, 2, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 16, 4>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Scheduler, PipelineVersion, F16, F16, true>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 64, 8, 32, 32, 1, 2, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 16, 4>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Scheduler, PipelineVersion, F16, F16, true, 2>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Scheduler, PipelineVersion, F16, F16, true>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 2, Scheduler, PipelineVersion, F16, F16, true>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 2, 1, 0, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 32, 1, 4>, 4, Scheduler, PipelineVersion, F16, F16, true>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 64, 32, 8, 32, 32, 2, 1, S<2, 16, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<2, 8, 16>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Scheduler, PipelineVersion, F16, F16, true>
// clang-format on
>;
template <ck::index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename ELayout,
ConvolutionBackwardWeightSpecialization ConvSpec,
BlockGemmPipelineScheduler Scheduler,
BlockGemmPipelineVersion PipelineVersion>
using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_bf16_direct_load_instances = std::tuple<
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| Compute| Compute| Direct|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| Data| Data| Load|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| Type| Type| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | | |
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 16, 32, 64, 8, 16, 16, 1, 1, S<8, 2, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, BF16, BF16, true>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 16, 64, 64, 8, 16, 16, 1, 2, S<8, 2, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<2, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, BF16, BF16, true>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 32, 64, 8, 32, 32, 2, 1, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<8, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Scheduler, PipelineVersion, BF16, BF16, true>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 64, 8, 32, 32, 1, 2, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 16, 4>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Scheduler, PipelineVersion, BF16, BF16, true>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Scheduler, PipelineVersion, BF16, BF16, true>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 2, Scheduler, PipelineVersion, BF16, BF16, true>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 2, 1, 0, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 32, 1, 4>, 4, Scheduler, PipelineVersion, BF16, BF16, true>
// clang-format on
>;
template <ck::index_t NDimSpatial,
typename ALayout,
typename BLayout,

View File

@@ -97,7 +97,11 @@ using device_grouped_conv_fwd_xdl_bf16_comp_instances = std::tuple<
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 64, 64, 8, 8, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 64, 64, 64, 8, 8, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 64, 64, 64, 8, 8, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
// Instances optimized for G=1
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 256, 256, 64, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 512, 128, 32, 8, 8, 32, 32, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>
// clang-format on
>;
@@ -154,8 +158,11 @@ using device_grouped_conv_fwd_xdl_f16_comp_instances = std::tuple<
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>
// clang-format on
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
// Instances optimized for G=1
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 256, 256, 64, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 512, 128, 32, 8, 8, 32, 32, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>
>;
// instances not working on gfx950

View File

@@ -4,6 +4,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
@@ -54,7 +55,13 @@ using device_grouped_conv_fwd_xdl_merged_groups_bf16_instances = std::tuple<
// Instances with NumGroupsPerBatch > 1
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 16>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 32>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 32>,
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block GEMM| Block GEMM| In| Wei| Direct| Num|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| pipeline| pipeline| compute| compute| load| merged|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| scheduler| version| type| type| | groups|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, BF16, BF16, false, 8>
// clang-format on
>;
@@ -75,7 +82,13 @@ using device_grouped_conv_fwd_xdl_merged_groups_bf16_instances_2x = std::tuple<
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 32, 8, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 32, 8, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 16>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 32, 8, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 32>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 32, 8, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 32>,
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block GEMM| Block GEMM| In| Wei| Direct| Num|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| pipeline| pipeline| compute| compute| load| merged|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| scheduler| version| type| type| | groups|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, BF16, BF16, F32, BF16, DsDataTypes, BF16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, BF16, BF16, false, 8>
// clang-format on
>;
@@ -96,7 +109,13 @@ using device_grouped_conv_fwd_xdl_merged_groups_f16_instances = std::tuple<
// Instances with NumGroupsPerBatch > 1
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 16>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 32>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 16, 16, 4, 4, 16, 16, 4, 1, S< 4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 32>,
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block GEMM| Block GEMM| In| Wei| Direct| Num|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| pipeline| pipeline| compute| compute| load| merged|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| scheduler| version| type| type| | groups|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, F16, F16, false, 8>
// clang-format on
>;
@@ -122,7 +141,13 @@ using device_grouped_conv_fwd_xdl_merged_groups_f16_instances_2x = std::tuple<
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, F16, F16, LoopScheduler::Default, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, F16, F16, LoopScheduler::Default, 8>,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, F16, F16, LoopScheduler::Default, 8>
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, F16, F16, LoopScheduler::Default, 8>,
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block GEMM| Block GEMM| In| Wei| Direct| Num|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| pipeline| pipeline| compute| compute| load| merged|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| scheduler| version| type| type| | groups|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<NDimSpatial,ALayout,BLayout, DsLayout,ELayout, F16, F16, F32, F16, DsDataTypes, F16, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmMNKPadding, 256, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, F16, F16, false, 8>
// clang-format on
>;

View File

@@ -108,6 +108,8 @@ struct DeviceOperationInstanceFactory<
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
is_same_v<ComputeTypeB, F16>)
{
add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_f16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs);
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_16_16_instances(
op_ptrs);
@@ -148,6 +150,8 @@ struct DeviceOperationInstanceFactory<
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
{
add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_bf16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_16_16_instances(
@@ -355,7 +359,7 @@ struct DeviceOperationInstanceFactory<
is_same_v<ComputeTypeB, BF16>)
{
add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_bf16_instances(
op_ptrs);
op_ptrs);
}
#endif
}

View File

@@ -56,6 +56,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
@@ -232,6 +246,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_optimized_loa
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16

View File

@@ -393,6 +393,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_direct_load_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev2_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev5_instances(
@@ -453,6 +456,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_direct_load_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev2_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev5_instances(

View File

@@ -184,6 +184,18 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pip
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_direct_load_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev5_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
@@ -389,6 +401,18 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipe
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_direct_load_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev5_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,

View File

@@ -32,6 +32,8 @@ add_instance_library(
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_vec_transpose_instance.cpp
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_vec_transpose_instance.cpp
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_vec_transpose_instance.cpp
xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_f16_instance.cpp
xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp
wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp

View File

@@ -0,0 +1,49 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_v3_bf16_instances<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_v3_bf16_instances<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,49 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_v3_f16_instances<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
ConvBwdDataDefault>{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances(
instances,
device_grouped_conv_bwd_data_xdl_v3_f16_instances<2,
NHWGK,
GKYXC,
Empty_Tuple,
NHWGC,
ConvBwdDataFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -20,6 +20,8 @@ set(GROUPED_CONV2D_BWD_WEIGHT
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev5_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev2_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev5_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_direct_load.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_direct_load.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev2_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev5_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev2_instance.cpp

View File

@@ -0,0 +1,40 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_direct_load_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_bf16_direct_load_instances<
2,
NHWGC,
GKYXC,
NHWGK,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v1>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,40 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_direct_load_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_direct_load_instances<
2,
NHWGC,
GKYXC,
NHWGK,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v1>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -364,26 +364,39 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
using AccDataType =
std::conditional_t<std::is_same_v<ComputeType, int8_t>, int32_t, float>;
// Calculate number of accumulations accounting for split_k
const int num_accums =
static_cast<int>(output.GetElementSize() / conv_param.K_ / split_k_value);
// Additional tolerance for split_k accumulation if needed
int total_accums = num_accums;
if(split_k_value > 1)
{
total_accums = std::max(num_accums, static_cast<int>(split_k_value));
}
// Perform GPU verification (max value computed internally on GPU)
const index_t num_accums = output.GetElementSize() / conv_param.K_;
const index_t num_accums_split_k = split_k_value;
// Get maximum accumulated value from reference
const std::size_t tensor_size =
weight_device_result.mDesc.GetElementSpaceSize();
max_accumulated_value =
gpu_reduce_max<WeiDataType>(gpu_ref_wei_buf.GetDeviceBuffer(), tensor_size);
// Calculate thresholds
auto rtol =
ck::utils::get_relative_threshold<ComputeType, WeiDataType, AccDataType>(
num_accums / num_accums_split_k);
auto atol =
ck::utils::get_absolute_threshold<ComputeType, WeiDataType, AccDataType>(
max_accumulated_value / num_accums_split_k,
num_accums / num_accums_split_k);
// Calculate error due to split_k accumulation
auto rtol_split_k =
ck::utils::get_relative_threshold<WeiDataType, WeiDataType, WeiDataType>(
num_accums_split_k);
auto atol_split_k =
ck::utils::get_absolute_threshold<WeiDataType, WeiDataType, WeiDataType>(
max_accumulated_value, num_accums_split_k);
// Use higher threshold
rtol = std::max(rtol, rtol_split_k);
atol = std::max(atol, atol_split_k);
// Perform GPU verification
auto gpu_result =
ck::profiler::gpu_verify<WeiDataType, ComputeType, AccDataType>(
wei_device_buf.GetDeviceBuffer(),
gpu_ref_wei_buf.GetDeviceBuffer(),
total_accums,
tensor_size);
ck::profiler::gpu_verify<WeiDataType>(wei_device_buf.GetDeviceBuffer(),
gpu_ref_wei_buf.GetDeviceBuffer(),
rtol,
atol,
tensor_size);
if(!gpu_result)
{

View File

@@ -44,6 +44,7 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
bool do_log,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param,
std::optional<std::string> run_instance = std::nullopt,
const OutElementOp out_element_op = OutElementOp{},
index_t instance_index = -1)
{
@@ -232,6 +233,12 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
return;
}
if (run_instance.has_value() && !run_instance.value().empty() && op_ptr->GetTypeString().find(run_instance.value()) == std::string::npos)
{
// skip if run_instance is specified and does not match op name
return;
}
std::string op_name = op_ptr->GetTypeString();
valids++;

View File

@@ -13,36 +13,39 @@ endif()
message(STATUS "CK_PROFILER_OP_FILTER: ${CK_PROFILER_OP_FILTER}")
message(STATUS "CK_PROFILER_INSTANCE_FILTER: ${CK_PROFILER_INSTANCE_FILTER}")
# set(PROFILER_OPS
# profile_gemm.cpp
# profile_reduce.cpp
# profile_groupnorm_bwd_data.cpp
# profile_groupnorm_fwd.cpp
# profile_layernorm_bwd_data.cpp
# profile_layernorm_bwd_gamma_beta.cpp
# profile_groupnorm_bwd_gamma_beta.cpp
# profile_layernorm_fwd.cpp
# profile_max_pool2d_fwd.cpp
# profile_pool3d_fwd.cpp
# profile_avg_pool3d_bwd.cpp
# profile_max_pool3d_bwd.cpp
# profile_avg_pool2d_bwd.cpp
# profile_max_pool2d_bwd.cpp
# profile_softmax.cpp
# profile_batchnorm_fwd.cpp
# profile_batchnorm_bwd.cpp
# profile_batchnorm_infer.cpp
# profile_conv_tensor_rearrange.cpp
# profile_transpose.cpp
# profile_permute_scale.cpp
# profile_gemm_quantization.cpp
# )
set(PROFILER_OPS
# profile_gemm.cpp
# profile_reduce.cpp
# profile_groupnorm_bwd_data.cpp
# profile_groupnorm_fwd.cpp
# profile_layernorm_bwd_data.cpp
# profile_layernorm_bwd_gamma_beta.cpp
# profile_groupnorm_bwd_gamma_beta.cpp
# profile_layernorm_fwd.cpp
# profile_max_pool2d_fwd.cpp
# profile_pool3d_fwd.cpp
# profile_avg_pool3d_bwd.cpp
# profile_max_pool3d_bwd.cpp
# profile_avg_pool2d_bwd.cpp
# profile_max_pool2d_bwd.cpp
# profile_softmax.cpp
# profile_batchnorm_fwd.cpp
# profile_batchnorm_bwd.cpp
# profile_batchnorm_infer.cpp
# profile_conv_tensor_rearrange.cpp
# profile_transpose.cpp
# profile_permute_scale.cpp
# profile_gemm_quantization.cpp
)
# if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
# if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
# list(APPEND PROFILER_OPS profile_contraction_bilinear.cpp)
# list(APPEND PROFILER_OPS profile_contraction_scale.cpp)
# endif()
# if(CK_EXPERIMENTAL_BUILDER)
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd_tile.cpp)
# endif()
# endif()
# if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]")
@@ -96,13 +99,13 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]")
# list(APPEND PROFILER_OPS profile_gemm_b_scale.cpp)
# list(APPEND PROFILER_OPS profile_gemm_universal_reduce.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_fwd.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_clamp.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_bnorm_clamp.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_clamp.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_data.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bilinear.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_outelementop.cpp)
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_clamp.cpp)
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_bnorm_clamp.cpp)
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd_clamp.cpp)
# list(APPEND PROFILER_OPS profile_grouped_conv_bwd_data.cpp)
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bilinear.cpp)
# list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
# list(APPEND PROFILER_OPS profile_grouped_conv_fwd_outelementop.cpp)
# list(APPEND PROFILER_OPS profile_gemm_multi_abd.cpp)
# if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
# list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp)
@@ -116,10 +119,10 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]")
# list(APPEND PROFILER_OPS profile_batched_gemm_gemm.cpp)
endif()
if(DL_KERNELS)
# list(APPEND PROFILER_OPS profile_batched_gemm_multi_d.cpp)
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
endif()
# if(DL_KERNELS)
# list(APPEND PROFILER_OPS profile_batched_gemm_multi_d.cpp)
# list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
# endif()
# if(CK_ENABLE_INT8)
# list(APPEND PROFILER_OPS profile_gemm_quantization.cpp)
@@ -204,14 +207,14 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]")
# list(APPEND DEVICE_INSTANCES device_gemm_add_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_reduce_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_bias_add_reduce_instance)
# list(APPEND DEVICE_INSTANCES device_conv2d_fwd_instance)
list(APPEND DEVICE_INSTANCES device_conv2d_fwd_instance)
# list(APPEND DEVICE_INSTANCES device_conv2d_fwd_bias_relu_instance)
# list(APPEND DEVICE_INSTANCES device_conv2d_fwd_bias_relu_add_instance)
# list(APPEND DEVICE_INSTANCES device_conv1d_bwd_data_instance)
# list(APPEND DEVICE_INSTANCES device_conv3d_bwd_data_instance)
# list(APPEND DEVICE_INSTANCES device_conv2d_bwd_data_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convscale_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convinvscale_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convscale_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convinvscale_instance)
endif()
# if((SUPPORTED_GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)) OR
@@ -228,19 +231,19 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]")
# list(APPEND DEVICE_INSTANCES device_gemm_b_scale_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_universal_reduce_instance)
# list(APPEND DEVICE_INSTANCES device_batched_gemm_b_scale_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_data_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_data_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_data_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_data_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv1d_fwd_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_clamp_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_clamp_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_scale_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_clamp_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_clamp_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_bnorm_clamp_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_bnorm_clamp_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bilinear_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_clamp_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_clamp_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_scale_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_clamp_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_clamp_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_bnorm_clamp_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_bnorm_clamp_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bilinear_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_multi_abd_instance)
# if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
@@ -253,10 +256,10 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx1[12]")
# list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance)
# endif()
# list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance)
list(APPEND DEVICE_INSTANCES device_grouped_convnd_bwd_weight_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_convnd_bwd_weight_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
@@ -265,12 +268,12 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
endif()
endif()
if(DL_KERNELS)
# list(APPEND DEVICE_INSTANCES device_batched_gemm_multi_d_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
endif()
# if(DL_KERNELS)
# list(APPEND DEVICE_INSTANCES device_batched_gemm_multi_d_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance)
# list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
# endif()
# if(CK_ENABLE_INT8)
# list(APPEND DEVICE_INSTANCES device_quantization_instance)

View File

@@ -66,7 +66,9 @@ static void print_helper_msg()
<< "arg6: initialization (0: no init, 1: integer value, 2: decimal value)\n"
<< "arg7: print tensor value (0: no; 1: yes)\n"
<< "arg8: time kernel (0: no, 1: yes)\n"
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl;
<< ck::utils::conv::get_conv_param_parser_helper_msg()
<< "last arg: run only given instance (string), optional\n"
<< std::endl;
// clang-format on
}
@@ -90,14 +92,17 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
const bool time_kernel = std::stoi(argv[8]);
const int num_dim_spatial = std::stoi(argv[9]);
// 9 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial
if(argc != 9 + 1 + 4 + 6 * num_dim_spatial)
// 9 for control, 1 for num_dim_spatial, 4 for G/N/K/C, 6 * num_dim_spatial, and optionally 1 for instance name
const int base_number_of_args = 9 + 1 + 4 + 6 * num_dim_spatial;
if(argc != base_number_of_args && argc != base_number_of_args + 1)
{
print_helper_msg();
return 1;
}
const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 10, argv);
const std::string run_instance =
(argc == base_number_of_args + 1) ? std::string(argv[base_number_of_args]) : "";
using F32 = float;
using F16 = ck::half_t;
@@ -178,7 +183,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
AComputeType,
BComputeType,
ck::index_t>(
do_verification, init_method, do_log, time_kernel, params);
do_verification, init_method, do_log, time_kernel, params, run_instance);
return pass ? 0 : 1;
}
@@ -194,7 +199,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
AComputeType,
BComputeType,
ck::long_index_t>(
do_verification, init_method, do_log, time_kernel, params);
do_verification, init_method, do_log, time_kernel, params, run_instance);
return pass ? 0 : 1;
}

View File

@@ -184,5 +184,5 @@ TYPED_TEST(TestGroupedConvndBwdWeightDefault, SingleStageAutoDeduce)
this->conv_param = {2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}};
this->split_k_ = -1;
bool is_supported = this->template Run<2>();
EXPECT_FALSE(is_supported);
EXPECT_TRUE(is_supported);
}