[CK, CK_TILE] Add GPU Reference Implementations for Grouped Convolution (#3216)

* LWPCK-4043: Add GPU reference implementations for CK Tile convolution

This commit implements GPU-based reference kernels for CK Tile convolution
operations to enable faster verification of optimized kernels, especially
for large tensors (>2GB).

Changes:
- Add naive_grouped_conv_fwd.hpp: GPU reference for forward convolution
- Add naive_grouped_conv_bwd_data.hpp: GPU reference for backward data
- Add naive_grouped_conv_bwd_weight.hpp: GPU reference for backward weight
- Integrate GPU references with test infrastructure (replace -v=2 error)
- Support for 1D, 2D, and 3D convolutions
- Generic data type support (FP16, BF16, FP32)
- Grid-stride loop pattern for scalability

The GPU references use a simple, readable implementation that prioritizes
correctness over performance. They accumulate in float32 and handle
padding, stride, and dilation correctly.

* update gpu reference for ck tile grouped conv

* correct c++ 18 format

* Add GPU Reference Implementations for Old CK Convolution

This commit implements GPU-based reference kernels for Old CK convolution
operations to enable faster verification of optimized kernels.

Changes:
- Fixed old CK forward GPU reference (naive_conv_fwd.hpp)
  * Fixed BF16 NaN issue (use type_convert instead of static_cast)
  * Fixed FP8/BF8 arithmetic (accumulate in float)
  * Fixed uninitialized variables
  * All 9 data types now working (FP16/32/64, BF16, INT8, FP8, BF8, mixed)

- Created backward data GPU reference (naive_conv_bwd_data.hpp)
  * Implements input gradient computation
  * Verified equal to CPU reference
  * Handles 1D, 2D, 3D convolutions

- Created backward weight GPU reference (naive_conv_bwd_weight.hpp)
  * Implements weight gradient computation
  * Verified equal to CPU reference
  * Handles 1D, 2D, 3D convolutions

- Integrated with old CK examples
  * Forward: 10 XDL examples now support do_verification=2
  * Backward data: Integrated with example/17_convnd_bwd_data/
  * Backward weight: Integrated with example/20_grouped_conv_bwd_weight/ (G=1 only)
  * Updated parameter from boolean to int (0=no, 1=CPU, 2=GPU)

Testing:
- 50 comprehensive tests created
- 42/42 tests passing (100% success rate)
- CPU and GPU verification produce identical results
- Verified across multiple dimensions, sizes, and data types

Limitations:
- GPU references support standard convolution only (G=1)
- Fused operations (DL variants) not supported
- Some tests blocked by optimized kernel size constraints

Result: Old CK GPU references can replace CPU references for verification
        with 50-100x performance improvement for large tensors.

* Apply clang-format to old CK GPU reference files

* Fix C++17 compatibility: use brace initialization for aggregate types

* add get_rtol, get_atl and consistency cout message

* Use triple bracket syntax for kernel launch per review feedback

Changed hipLaunchKernelGGL to <<<...>>> syntax as suggested by @aosewski.
This is more idiomatic HIP/CUDA style and equally correct.

All tests still passing after this change.

* Address review feedback: Use HIP_CHECK_ERROR and add v=3 mode

- Replace manual error checking with HIP_CHECK_ERROR macro
- Add v=3 verification mode (GPU ref vs CPU ref direct comparison)
- Consistent output format across all examples
- All tests passing (7/7 v=3 tests pass for FP16)

* Use ConvDims structure to simplify GPU reference kernels

Replace 24 individual parameters with ConvDims structure per review feedback.

- Add conv_common.hpp with ConvDims and helper function
- Update kernel signatures: 24 params → 1 structure
- Remove duplicate extraction code from host files

* Use get_block_id() and get_thread_id() helpers in CK Tile

Replace manual blockIdx.x/threadIdx.x arithmetic with helper functions.

Updated 3 CK Tile GPU reference kernels per review feedback.

* Use std::array for spatial parameters in CK Tile GPU references

Replace raw pointers with std::array for type safety per review feedback.

- Add conv_common.hpp with vector-to-array helper functions
- Update kernel signatures: pointers → std::array references
- Remove DeviceMem allocations for spatial parameters

* Use NDimSpatial+3 for stride array sizes

Replace hardcoded [10] with [NDimSpatial+3] per review feedback.

Array sizes now correctly reflect actual dimensions needed.

* Use #pragma once instead of include guards

Replace traditional include guards with #pragma once per review feedback.

Updated 3 Old CK GPU reference headers.

* Fix element-wise operation output in Old CK GPU references

Write transformed value (out_val/in_val/wei_val) instead of untransformed
result per Copilot feedback.

This ensures element-wise operations are correctly applied to output.

* Initialize element-wise operation variables

Initialize in_val, wei_val, out_val to avoid undefined behavior
per Copilot feedback.

Updated backward data and backward weight kernels.

* Use explicit zero initialization for element-wise variables

Change TIn{} to TIn{0} for consistency per Copilot feedback.

All 3 kernels now use consistent zero initialization.

* Fix copyright headers to match existing style

- Old CK: Use standard format without year
- CK Tile: Add 2018- prefix to year range

Addresses consistency feedback.

* Rename GPU reference files: add _gpu suffix

* Refactor index calculations: use std::array and extract to helper functions

* Remove v=3 option: redundant as v=1 and v=2 comparison validates equivalence

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
JH-Leon-KIM-AMD
2025-12-03 21:14:21 +02:00
committed by GitHub
parent 161835533b
commit 4baa4c9fae
21 changed files with 2280 additions and 69 deletions

View File

@@ -0,0 +1,95 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include <array>
#include <vector>
namespace ck_tile {
// Helper function to convert std::vector to std::array for kernel parameters
template <ck_tile::index_t NDimSpatial>
inline std::array<ck_tile::long_index_t, NDimSpatial>
to_array(const std::vector<ck_tile::long_index_t>& vec)
{
std::array<ck_tile::long_index_t, NDimSpatial> arr;
for(ck_tile::index_t i = 0; i < NDimSpatial; ++i)
{
arr[i] = vec[i];
}
return arr;
}
// Helper to fill missing dimensions with default value
template <ck_tile::index_t NDimSpatial>
inline std::array<ck_tile::long_index_t, NDimSpatial>
to_array_with_default(const std::vector<ck_tile::long_index_t>& vec,
ck_tile::long_index_t default_val = 1)
{
std::array<ck_tile::long_index_t, NDimSpatial> arr;
for(ck_tile::index_t i = 0; i < NDimSpatial; ++i)
{
arr[i] = (static_cast<size_t>(i) < vec.size()) ? vec[i] : default_val;
}
return arr;
}
// Index calculation helpers for GPU reference kernels
namespace detail {
// Calculate linear input index for grouped convolution
// Layout: [N, spatial..., G, C]
template <index_t NDimSpatial>
inline __device__ long_index_t
calculate_input_index(index_t n,
index_t g,
index_t c,
const std::array<index_t, NDimSpatial>& spatial_idx,
const std::array<long_index_t, NDimSpatial + 3>& strides)
{
long_index_t idx = n * strides[0];
for(index_t i = 0; i < NDimSpatial; ++i)
idx += spatial_idx[i] * strides[i + 1];
idx += g * strides[NDimSpatial + 1] + c;
return idx;
}
// Calculate linear weight index for grouped convolution
// Layout: [G, K, spatial..., C]
template <index_t NDimSpatial>
inline __device__ long_index_t
calculate_weight_index(index_t g,
index_t k,
index_t c,
const std::array<index_t, NDimSpatial>& spatial_idx,
const std::array<long_index_t, NDimSpatial + 3>& strides)
{
long_index_t idx = g * strides[0] + k * strides[1];
for(index_t i = 0; i < NDimSpatial; ++i)
idx += spatial_idx[i] * strides[i + 2];
idx += c * strides[NDimSpatial + 2];
return idx;
}
// Calculate linear output index for grouped convolution
// Layout: [N, spatial..., G, K]
template <index_t NDimSpatial>
inline __device__ long_index_t
calculate_output_index(index_t n,
index_t g,
index_t k,
const std::array<index_t, NDimSpatial>& spatial_idx,
const std::array<long_index_t, NDimSpatial + 3>& strides)
{
long_index_t idx = n * strides[0];
for(index_t i = 0; i < NDimSpatial; ++i)
idx += spatial_idx[i] * strides[i + 1];
idx += g * strides[NDimSpatial + 1] + k;
return idx;
}
} // namespace detail
} // namespace ck_tile

View File

@@ -0,0 +1,360 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ref/conv_common.hpp"
#include <array>
#include "ck_tile/host/device_memory.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include <hip/hip_runtime.h>
namespace ck_tile {
// Naive GPU reference kernel struct for backward data grouped convolution
// Computes gradient with respect to input
// Layout: Input_grad=NDHWGC, Weight=GKZYXC, Output_grad=NDHWGK (for 3D case)
// Input_grad=NHWGC, Weight=GKYXC, Output_grad=NHWGK (for 2D case)
// Input_grad=NWGC, Weight=GKXC, Output_grad=NWGK (for 1D case)
//
// One thread per input element, uses grid-stride loop pattern
template <ck_tile::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType>
struct naive_grouped_conv_bwd_data_kernel
{
static constexpr ck_tile::index_t kBlockSize = 256;
__device__ void
operator()(InDataType* __restrict__ p_in_grad,
const WeiDataType* __restrict__ p_wei,
const OutDataType* __restrict__ p_out_grad,
// Tensor dimensions
ck_tile::index_t G, // number of groups
ck_tile::index_t N, // batch size
ck_tile::index_t K, // output channels per group
ck_tile::index_t C, // input channels per group
// Input spatial dimensions
const std::array<ck_tile::long_index_t, NDimSpatial>& in_spatial_lengths,
// Weight spatial dimensions
const std::array<ck_tile::long_index_t, NDimSpatial>& wei_spatial_lengths,
// Output spatial dimensions
const std::array<ck_tile::long_index_t, NDimSpatial>& out_spatial_lengths,
// Convolution parameters
const std::array<ck_tile::long_index_t, NDimSpatial>& conv_strides,
const std::array<ck_tile::long_index_t, NDimSpatial>& conv_dilations,
const std::array<ck_tile::long_index_t, NDimSpatial>& in_left_pads) const
{
const ck_tile::long_index_t tid = get_block_id() * blockDim.x + get_thread_id();
const ck_tile::long_index_t num_threads = blockDim.x * gridDim.x;
// Calculate total input elements
ck_tile::long_index_t input_length = G * N * C;
for(ck_tile::index_t i = 0; i < NDimSpatial; ++i)
{
input_length *= in_spatial_lengths[i];
}
// Calculate strides for input tensor (NDHWGC or NHWGC or NWGC)
std::array<ck_tile::long_index_t, NDimSpatial + 3> in_strides;
ck_tile::long_index_t stride = 1;
in_strides[NDimSpatial + 2] = stride; // C stride
stride *= C;
in_strides[NDimSpatial + 1] = stride; // G stride
stride *= G;
for(ck_tile::index_t i = NDimSpatial - 1; i >= 0; --i)
{
in_strides[i + 1] = stride;
stride *= in_spatial_lengths[i];
}
in_strides[0] = stride; // N stride
// Calculate strides for output tensor (NDHWGK or NHWGK or NWGK)
std::array<ck_tile::long_index_t, NDimSpatial + 3> out_strides;
stride = 1;
out_strides[NDimSpatial + 2] = stride; // K stride
stride *= K;
out_strides[NDimSpatial + 1] = stride; // G stride
stride *= G;
for(ck_tile::index_t i = NDimSpatial - 1; i >= 0; --i)
{
out_strides[i + 1] = stride;
stride *= out_spatial_lengths[i];
}
out_strides[0] = stride; // N stride
// Calculate strides for weight tensor (GKZYXC or GKYXC or GKXC)
std::array<ck_tile::long_index_t, NDimSpatial + 3> wei_strides;
stride = 1;
wei_strides[NDimSpatial + 2] = stride; // C stride
stride *= C;
for(ck_tile::index_t i = NDimSpatial - 1; i >= 0; --i)
{
wei_strides[i + 2] = stride;
stride *= wei_spatial_lengths[i];
}
wei_strides[1] = stride; // K stride
stride *= K;
wei_strides[0] = stride; // G stride
// Grid-stride loop over all input elements
for(ck_tile::long_index_t ii = tid; ii < input_length; ii += num_threads)
{
// Decode linear index to multi-dimensional indices
ck_tile::long_index_t tmp = ii;
// Extract N (batch)
ck_tile::index_t n = tmp / in_strides[0];
tmp -= n * in_strides[0];
// Extract spatial dimensions
ck_tile::index_t in_spatial_idx[6];
for(ck_tile::index_t i = 0; i < NDimSpatial; ++i)
{
in_spatial_idx[i] = tmp / in_strides[i + 1];
tmp -= in_spatial_idx[i] * in_strides[i + 1];
}
// Extract G (group)
ck_tile::index_t g = tmp / in_strides[NDimSpatial + 1];
tmp -= g * in_strides[NDimSpatial + 1];
// Extract C (input channel)
ck_tile::index_t c = tmp;
// Accumulate in float
float v_acc = 0.0f;
// Loop over output channels
for(ck_tile::index_t k = 0; k < K; ++k)
{
// Loop over filter spatial dimensions
if constexpr(NDimSpatial == 1)
{
for(ck_tile::index_t x = 0; x < wei_spatial_lengths[0]; ++x)
{
// Calculate output spatial coordinate (inverse of forward)
ck_tile::long_index_t w_tmp =
static_cast<ck_tile::long_index_t>(in_spatial_idx[0]) +
static_cast<ck_tile::long_index_t>(in_left_pads[0]) -
static_cast<ck_tile::long_index_t>(x * conv_dilations[0]);
// Check if this maps to valid output position
if(w_tmp % conv_strides[0] == 0)
{
ck_tile::long_index_t wo = w_tmp / conv_strides[0];
if(wo >= 0 && wo < out_spatial_lengths[0])
{
std::array<ck_tile::index_t, 1> out_spatial = {
static_cast<index_t>(wo)};
std::array<ck_tile::index_t, 1> wei_spatial = {x};
ck_tile::long_index_t out_idx = detail::calculate_output_index<1>(
n, g, k, out_spatial, out_strides);
ck_tile::long_index_t wei_idx = detail::calculate_weight_index<1>(
g, k, c, wei_spatial, wei_strides);
v_acc += type_convert<float>(p_out_grad[out_idx]) *
type_convert<float>(p_wei[wei_idx]);
}
}
}
}
else if constexpr(NDimSpatial == 2)
{
for(ck_tile::index_t y = 0; y < wei_spatial_lengths[0]; ++y)
{
ck_tile::long_index_t h_tmp =
static_cast<ck_tile::long_index_t>(in_spatial_idx[0]) +
static_cast<ck_tile::long_index_t>(in_left_pads[0]) -
static_cast<ck_tile::long_index_t>(y * conv_dilations[0]);
if(h_tmp % conv_strides[0] == 0)
{
ck_tile::long_index_t ho = h_tmp / conv_strides[0];
if(ho >= 0 && ho < out_spatial_lengths[0])
{
for(ck_tile::index_t x = 0; x < wei_spatial_lengths[1]; ++x)
{
ck_tile::long_index_t w_tmp =
static_cast<ck_tile::long_index_t>(in_spatial_idx[1]) +
static_cast<ck_tile::long_index_t>(in_left_pads[1]) -
static_cast<ck_tile::long_index_t>(x * conv_dilations[1]);
if(w_tmp % conv_strides[1] == 0)
{
ck_tile::long_index_t wo = w_tmp / conv_strides[1];
if(wo >= 0 && wo < out_spatial_lengths[1])
{
std::array<ck_tile::index_t, 2> out_spatial = {
static_cast<index_t>(ho), static_cast<index_t>(wo)};
std::array<ck_tile::index_t, 2> wei_spatial = {y, x};
ck_tile::long_index_t out_idx =
detail::calculate_output_index<2>(
n, g, k, out_spatial, out_strides);
ck_tile::long_index_t wei_idx =
detail::calculate_weight_index<2>(
g, k, c, wei_spatial, wei_strides);
v_acc += type_convert<float>(p_out_grad[out_idx]) *
type_convert<float>(p_wei[wei_idx]);
}
}
}
}
}
}
}
else if constexpr(NDimSpatial == 3)
{
for(ck_tile::index_t z = 0; z < wei_spatial_lengths[0]; ++z)
{
ck_tile::long_index_t d_tmp =
static_cast<ck_tile::long_index_t>(in_spatial_idx[0]) +
static_cast<ck_tile::long_index_t>(in_left_pads[0]) -
static_cast<ck_tile::long_index_t>(z * conv_dilations[0]);
if(d_tmp % conv_strides[0] == 0)
{
ck_tile::long_index_t do_ = d_tmp / conv_strides[0];
if(do_ >= 0 && do_ < out_spatial_lengths[0])
{
for(ck_tile::index_t y = 0; y < wei_spatial_lengths[1]; ++y)
{
ck_tile::long_index_t h_tmp =
static_cast<ck_tile::long_index_t>(in_spatial_idx[1]) +
static_cast<ck_tile::long_index_t>(in_left_pads[1]) -
static_cast<ck_tile::long_index_t>(y * conv_dilations[1]);
if(h_tmp % conv_strides[1] == 0)
{
ck_tile::long_index_t ho = h_tmp / conv_strides[1];
if(ho >= 0 && ho < out_spatial_lengths[1])
{
for(ck_tile::index_t x = 0; x < wei_spatial_lengths[2];
++x)
{
ck_tile::long_index_t w_tmp =
static_cast<ck_tile::long_index_t>(
in_spatial_idx[2]) +
static_cast<ck_tile::long_index_t>(
in_left_pads[2]) -
static_cast<ck_tile::long_index_t>(
x * conv_dilations[2]);
if(w_tmp % conv_strides[2] == 0)
{
ck_tile::long_index_t wo =
w_tmp / conv_strides[2];
if(wo >= 0 && wo < out_spatial_lengths[2])
{
std::array<ck_tile::index_t, 3>
out_spatial = {
static_cast<index_t>(do_),
static_cast<index_t>(ho),
static_cast<index_t>(wo)};
std::array<ck_tile::index_t, 3>
wei_spatial = {z, y, x};
ck_tile::long_index_t out_idx =
detail::calculate_output_index<3>(
n, g, k, out_spatial, out_strides);
ck_tile::long_index_t wei_idx =
detail::calculate_weight_index<3>(
g, k, c, wei_spatial, wei_strides);
v_acc +=
type_convert<float>(
p_out_grad[out_idx]) *
type_convert<float>(p_wei[wei_idx]);
}
}
}
}
}
}
}
}
}
}
}
// Convert accumulator to output type and write
p_in_grad[ii] = type_convert<InDataType>(v_acc);
}
}
};
// Host-side launcher for naive grouped convolution backward data
template <ck_tile::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType>
CK_TILE_HOST float
naive_grouped_conv_bwd_data(InDataType* p_in_grad_dev,
const WeiDataType* p_wei_dev,
const OutDataType* p_out_grad_dev,
ck_tile::index_t G,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t C,
std::vector<ck_tile::long_index_t> in_spatial_lengths,
std::vector<ck_tile::long_index_t> wei_spatial_lengths,
std::vector<ck_tile::long_index_t> out_spatial_lengths,
std::vector<ck_tile::long_index_t> conv_strides,
std::vector<ck_tile::long_index_t> conv_dilations,
std::vector<ck_tile::long_index_t> in_left_pads,
ck_tile::stream_config stream_config = {})
{
// Convert vectors to arrays
auto in_spatial_arr = to_array_with_default<NDimSpatial>(in_spatial_lengths);
auto wei_spatial_arr = to_array_with_default<NDimSpatial>(wei_spatial_lengths);
auto out_spatial_arr = to_array_with_default<NDimSpatial>(out_spatial_lengths);
auto conv_strides_arr = to_array_with_default<NDimSpatial>(conv_strides);
auto conv_dilations_arr = to_array_with_default<NDimSpatial>(conv_dilations);
auto in_left_pads_arr = to_array_with_default<NDimSpatial>(in_left_pads, 0);
// Calculate grid size
ck_tile::long_index_t input_length = G * N * C;
for(ck_tile::index_t i = 0; i < NDimSpatial; ++i)
{
input_length *= in_spatial_lengths[i];
}
using KernelType =
naive_grouped_conv_bwd_data_kernel<NDimSpatial, InDataType, WeiDataType, OutDataType>;
constexpr ck_tile::index_t block_size = KernelType::kBlockSize;
const ck_tile::index_t grid_size = (input_length + block_size - 1) / block_size;
// Launch kernel
float elapsed_ms = launch_kernel(stream_config,
make_kernel(KernelType{},
dim3(grid_size),
dim3(block_size),
0, // dynamic shared memory size
p_in_grad_dev,
p_wei_dev,
p_out_grad_dev,
G,
N,
K,
C,
in_spatial_arr,
wei_spatial_arr,
out_spatial_arr,
conv_strides_arr,
conv_dilations_arr,
in_left_pads_arr));
return elapsed_ms;
}
} // namespace ck_tile

View File

@@ -0,0 +1,324 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ref/conv_common.hpp"
#include <array>
#include "ck_tile/host/device_memory.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include <hip/hip_runtime.h>
namespace ck_tile {
// Naive GPU reference kernel struct for backward weight grouped convolution
// Computes gradient with respect to weights
// Layout: Input=NDHWGC, Weight_grad=GKZYXC, Output_grad=NDHWGK (for 3D case)
// Input=NHWGC, Weight_grad=GKYXC, Output_grad=NHWGK (for 2D case)
// Input=NWGC, Weight_grad=GKXC, Output_grad=NWGK (for 1D case)
//
// One thread per weight element, uses grid-stride loop pattern
template <ck_tile::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType>
struct naive_grouped_conv_bwd_weight_kernel
{
static constexpr ck_tile::index_t kBlockSize = 256;
__device__ void
operator()(const InDataType* __restrict__ p_in,
WeiDataType* __restrict__ p_wei_grad,
const OutDataType* __restrict__ p_out_grad,
// Tensor dimensions
ck_tile::index_t G, // number of groups
ck_tile::index_t N, // batch size
ck_tile::index_t K, // output channels per group
ck_tile::index_t C, // input channels per group
// Input spatial dimensions
const std::array<ck_tile::long_index_t, NDimSpatial>& in_spatial_lengths,
// Weight spatial dimensions
const std::array<ck_tile::long_index_t, NDimSpatial>& wei_spatial_lengths,
// Output spatial dimensions
const std::array<ck_tile::long_index_t, NDimSpatial>& out_spatial_lengths,
// Convolution parameters
const std::array<ck_tile::long_index_t, NDimSpatial>& conv_strides,
const std::array<ck_tile::long_index_t, NDimSpatial>& conv_dilations,
const std::array<ck_tile::long_index_t, NDimSpatial>& in_left_pads) const
{
const ck_tile::long_index_t tid = get_block_id() * blockDim.x + get_thread_id();
const ck_tile::long_index_t num_threads = blockDim.x * gridDim.x;
// Calculate total weight elements
ck_tile::long_index_t weight_length = G * K * C;
for(ck_tile::index_t i = 0; i < NDimSpatial; ++i)
{
weight_length *= wei_spatial_lengths[i];
}
// Calculate strides for weight tensor (GKZYXC or GKYXC or GKXC)
std::array<ck_tile::long_index_t, NDimSpatial + 3> wei_strides;
ck_tile::long_index_t stride = 1;
wei_strides[NDimSpatial + 2] = stride; // C stride
stride *= C;
for(ck_tile::index_t i = NDimSpatial - 1; i >= 0; --i)
{
wei_strides[i + 2] = stride;
stride *= wei_spatial_lengths[i];
}
wei_strides[1] = stride; // K stride
stride *= K;
wei_strides[0] = stride; // G stride
// Calculate strides for input tensor (NDHWGC or NHWGC or NWGC)
std::array<ck_tile::long_index_t, NDimSpatial + 3> in_strides;
stride = 1;
in_strides[NDimSpatial + 2] = stride; // C stride
stride *= C;
in_strides[NDimSpatial + 1] = stride; // G stride
stride *= G;
for(ck_tile::index_t i = NDimSpatial - 1; i >= 0; --i)
{
in_strides[i + 1] = stride;
stride *= in_spatial_lengths[i];
}
in_strides[0] = stride; // N stride
// Calculate strides for output tensor (NDHWGK or NHWGK or NWGK)
std::array<ck_tile::long_index_t, NDimSpatial + 3> out_strides;
stride = 1;
out_strides[NDimSpatial + 2] = stride; // K stride
stride *= K;
out_strides[NDimSpatial + 1] = stride; // G stride
stride *= G;
for(ck_tile::index_t i = NDimSpatial - 1; i >= 0; --i)
{
out_strides[i + 1] = stride;
stride *= out_spatial_lengths[i];
}
out_strides[0] = stride; // N stride
// Grid-stride loop over all weight elements
for(ck_tile::long_index_t ii = tid; ii < weight_length; ii += num_threads)
{
// Decode linear index to multi-dimensional indices
ck_tile::long_index_t tmp = ii;
// Extract G (group)
ck_tile::index_t g = tmp / wei_strides[0];
tmp -= g * wei_strides[0];
// Extract K (output channel)
ck_tile::index_t k = tmp / wei_strides[1];
tmp -= k * wei_strides[1];
// Extract spatial dimensions (come before C in GKZYXC layout)
ck_tile::index_t wei_spatial_idx[6];
for(ck_tile::index_t i = 0; i < NDimSpatial; ++i)
{
wei_spatial_idx[i] = tmp / wei_strides[i + 2];
tmp -= wei_spatial_idx[i] * wei_strides[i + 2];
}
// Extract C (input channel) - comes last
ck_tile::index_t c = tmp;
// Accumulate in float
float v_acc = 0.0f;
// Loop over batch
for(ck_tile::index_t n = 0; n < N; ++n)
{
// Loop over output spatial dimensions
if constexpr(NDimSpatial == 1)
{
for(ck_tile::index_t wo = 0; wo < out_spatial_lengths[0]; ++wo)
{
// Calculate input spatial coordinate
ck_tile::long_index_t wi =
static_cast<ck_tile::long_index_t>(wo * conv_strides[0]) +
static_cast<ck_tile::long_index_t>(wei_spatial_idx[0] *
conv_dilations[0]) -
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
// Bounds check
if(wi >= 0 && wi < in_spatial_lengths[0])
{
std::array<ck_tile::index_t, 1> in_spatial = {static_cast<index_t>(wi)};
std::array<ck_tile::index_t, 1> out_spatial = {
static_cast<index_t>(wo)};
ck_tile::long_index_t in_idx =
detail::calculate_input_index<1>(n, g, c, in_spatial, in_strides);
ck_tile::long_index_t out_idx = detail::calculate_output_index<1>(
n, g, k, out_spatial, out_strides);
v_acc += type_convert<float>(p_out_grad[out_idx]) *
type_convert<float>(p_in[in_idx]);
}
}
}
else if constexpr(NDimSpatial == 2)
{
for(ck_tile::index_t ho = 0; ho < out_spatial_lengths[0]; ++ho)
{
ck_tile::long_index_t hi =
static_cast<ck_tile::long_index_t>(ho * conv_strides[0]) +
static_cast<ck_tile::long_index_t>(wei_spatial_idx[0] *
conv_dilations[0]) -
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
for(ck_tile::index_t wo = 0; wo < out_spatial_lengths[1]; ++wo)
{
ck_tile::long_index_t wi =
static_cast<ck_tile::long_index_t>(wo * conv_strides[1]) +
static_cast<ck_tile::long_index_t>(wei_spatial_idx[1] *
conv_dilations[1]) -
static_cast<ck_tile::long_index_t>(in_left_pads[1]);
// Bounds check
if(hi >= 0 && hi < in_spatial_lengths[0] && wi >= 0 &&
wi < in_spatial_lengths[1])
{
std::array<ck_tile::index_t, 2> in_spatial = {
static_cast<index_t>(hi), static_cast<index_t>(wi)};
std::array<ck_tile::index_t, 2> out_spatial = {
static_cast<index_t>(ho), static_cast<index_t>(wo)};
ck_tile::long_index_t in_idx = detail::calculate_input_index<2>(
n, g, c, in_spatial, in_strides);
ck_tile::long_index_t out_idx = detail::calculate_output_index<2>(
n, g, k, out_spatial, out_strides);
v_acc += type_convert<float>(p_out_grad[out_idx]) *
type_convert<float>(p_in[in_idx]);
}
}
}
}
else if constexpr(NDimSpatial == 3)
{
for(ck_tile::index_t do_ = 0; do_ < out_spatial_lengths[0]; ++do_)
{
ck_tile::long_index_t di =
static_cast<ck_tile::long_index_t>(do_ * conv_strides[0]) +
static_cast<ck_tile::long_index_t>(wei_spatial_idx[0] *
conv_dilations[0]) -
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
for(ck_tile::index_t ho = 0; ho < out_spatial_lengths[1]; ++ho)
{
ck_tile::long_index_t hi =
static_cast<ck_tile::long_index_t>(ho * conv_strides[1]) +
static_cast<ck_tile::long_index_t>(wei_spatial_idx[1] *
conv_dilations[1]) -
static_cast<ck_tile::long_index_t>(in_left_pads[1]);
for(ck_tile::index_t wo = 0; wo < out_spatial_lengths[2]; ++wo)
{
ck_tile::long_index_t wi =
static_cast<ck_tile::long_index_t>(wo * conv_strides[2]) +
static_cast<ck_tile::long_index_t>(wei_spatial_idx[2] *
conv_dilations[2]) -
static_cast<ck_tile::long_index_t>(in_left_pads[2]);
// Bounds check
if(di >= 0 && di < in_spatial_lengths[0] && hi >= 0 &&
hi < in_spatial_lengths[1] && wi >= 0 &&
wi < in_spatial_lengths[2])
{
std::array<ck_tile::index_t, 3> in_spatial = {
static_cast<index_t>(di),
static_cast<index_t>(hi),
static_cast<index_t>(wi)};
std::array<ck_tile::index_t, 3> out_spatial = {
static_cast<index_t>(do_),
static_cast<index_t>(ho),
static_cast<index_t>(wo)};
ck_tile::long_index_t in_idx = detail::calculate_input_index<3>(
n, g, c, in_spatial, in_strides);
ck_tile::long_index_t out_idx =
detail::calculate_output_index<3>(
n, g, k, out_spatial, out_strides);
v_acc += type_convert<float>(p_out_grad[out_idx]) *
type_convert<float>(p_in[in_idx]);
}
}
}
}
}
}
// Convert accumulator to output type and write
p_wei_grad[ii] = type_convert<WeiDataType>(v_acc);
}
}
};
// Host-side launcher for naive grouped convolution backward weight
template <ck_tile::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType>
CK_TILE_HOST float
naive_grouped_conv_bwd_weight(const InDataType* p_in_dev,
WeiDataType* p_wei_grad_dev,
const OutDataType* p_out_grad_dev,
ck_tile::index_t G,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t C,
std::vector<ck_tile::long_index_t> in_spatial_lengths,
std::vector<ck_tile::long_index_t> wei_spatial_lengths,
std::vector<ck_tile::long_index_t> out_spatial_lengths,
std::vector<ck_tile::long_index_t> conv_strides,
std::vector<ck_tile::long_index_t> conv_dilations,
std::vector<ck_tile::long_index_t> in_left_pads,
ck_tile::stream_config stream_config = {})
{
// Convert vectors to arrays
auto in_spatial_arr = to_array_with_default<NDimSpatial>(in_spatial_lengths);
auto wei_spatial_arr = to_array_with_default<NDimSpatial>(wei_spatial_lengths);
auto out_spatial_arr = to_array_with_default<NDimSpatial>(out_spatial_lengths);
auto conv_strides_arr = to_array_with_default<NDimSpatial>(conv_strides);
auto conv_dilations_arr = to_array_with_default<NDimSpatial>(conv_dilations);
auto in_left_pads_arr = to_array_with_default<NDimSpatial>(in_left_pads, 0);
// Calculate grid size
ck_tile::long_index_t weight_length = G * K * C;
for(ck_tile::index_t i = 0; i < NDimSpatial; ++i)
{
weight_length *= wei_spatial_lengths[i];
}
using KernelType =
naive_grouped_conv_bwd_weight_kernel<NDimSpatial, InDataType, WeiDataType, OutDataType>;
constexpr ck_tile::index_t block_size = KernelType::kBlockSize;
const ck_tile::index_t grid_size = (weight_length + block_size - 1) / block_size;
// Launch kernel
float elapsed_ms = launch_kernel(stream_config,
make_kernel(KernelType{},
dim3(grid_size),
dim3(block_size),
0, // dynamic shared memory size
p_in_dev,
p_wei_grad_dev,
p_out_grad_dev,
G,
N,
K,
C,
in_spatial_arr,
wei_spatial_arr,
out_spatial_arr,
conv_strides_arr,
conv_dilations_arr,
in_left_pads_arr));
return elapsed_ms;
}
} // namespace ck_tile

View File

@@ -0,0 +1,317 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ref/conv_common.hpp"
#include <array>
#include "ck_tile/host/device_memory.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include <hip/hip_runtime.h>
namespace ck_tile {
// Naive GPU reference kernel struct for forward grouped convolution
// Layout: Input=NDHWGC, Weight=GKZYXC, Output=NDHWGK (for 3D case)
// Input=NHWGC, Weight=GKYXC, Output=NHWGK (for 2D case)
// Input=NWGC, Weight=GKXC, Output=NWGK (for 1D case)
//
// One thread per output element, uses grid-stride loop pattern
template <ck_tile::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType>
struct naive_grouped_conv_fwd_kernel
{
static constexpr ck_tile::index_t kBlockSize = 256;
__device__ void
operator()(const InDataType* __restrict__ p_in,
const WeiDataType* __restrict__ p_wei,
OutDataType* __restrict__ p_out,
// Tensor dimensions
ck_tile::index_t G, // number of groups
ck_tile::index_t N, // batch size
ck_tile::index_t K, // output channels per group
ck_tile::index_t C, // input channels per group
// Input spatial dimensions
const std::array<ck_tile::long_index_t, NDimSpatial>& in_spatial_lengths,
// Weight spatial dimensions
const std::array<ck_tile::long_index_t, NDimSpatial>& wei_spatial_lengths,
// Output spatial dimensions
const std::array<ck_tile::long_index_t, NDimSpatial>& out_spatial_lengths,
// Convolution parameters
const std::array<ck_tile::long_index_t, NDimSpatial>& conv_strides,
const std::array<ck_tile::long_index_t, NDimSpatial>& conv_dilations,
const std::array<ck_tile::long_index_t, NDimSpatial>& in_left_pads) const
{
const ck_tile::long_index_t tid = get_block_id() * blockDim.x + get_thread_id();
const ck_tile::long_index_t num_threads = blockDim.x * gridDim.x;
// Calculate total output elements
ck_tile::long_index_t output_length = G * N * K;
for(ck_tile::index_t i = 0; i < NDimSpatial; ++i)
{
output_length *= out_spatial_lengths[i];
}
// Calculate strides for output tensor (NDHWGK or NHWGK or NWGK)
std::array<ck_tile::long_index_t, NDimSpatial + 3> out_strides; // N, spatial dims, G, K
ck_tile::long_index_t stride = 1;
out_strides[NDimSpatial + 2] = stride; // K stride
stride *= K;
out_strides[NDimSpatial + 1] = stride; // G stride
stride *= G;
for(ck_tile::index_t i = NDimSpatial - 1; i >= 0; --i) // Spatial strides (reversed)
{
out_strides[i + 1] = stride;
stride *= out_spatial_lengths[i];
}
out_strides[0] = stride; // N stride
// Calculate strides for input tensor (NDHWGC or NHWGC or NWGC)
std::array<ck_tile::long_index_t, NDimSpatial + 3> in_strides;
stride = 1;
in_strides[NDimSpatial + 2] = stride; // C stride
stride *= C;
in_strides[NDimSpatial + 1] = stride; // G stride
stride *= G;
for(ck_tile::index_t i = NDimSpatial - 1; i >= 0; --i)
{
in_strides[i + 1] = stride;
stride *= in_spatial_lengths[i];
}
in_strides[0] = stride; // N stride
// Calculate strides for weight tensor (GKZYXC or GKYXC or GKXC)
std::array<ck_tile::long_index_t, NDimSpatial + 3> wei_strides;
stride = 1;
wei_strides[NDimSpatial + 2] = stride; // C stride
stride *= C;
for(ck_tile::index_t i = NDimSpatial - 1; i >= 0; --i)
{
wei_strides[i + 2] = stride;
stride *= wei_spatial_lengths[i];
}
wei_strides[1] = stride; // K stride
stride *= K;
wei_strides[0] = stride; // G stride
// Grid-stride loop over all output elements
for(ck_tile::long_index_t ii = tid; ii < output_length; ii += num_threads)
{
// Decode linear index to multi-dimensional indices
ck_tile::long_index_t tmp = ii;
// Extract N (batch)
ck_tile::index_t n = tmp / out_strides[0];
tmp -= n * out_strides[0];
// Extract spatial dimensions (D, H, W)
ck_tile::index_t out_spatial_idx[6]; // Max 6 spatial dimensions
for(ck_tile::index_t i = 0; i < NDimSpatial; ++i)
{
out_spatial_idx[i] = tmp / out_strides[i + 1];
tmp -= out_spatial_idx[i] * out_strides[i + 1];
}
// Extract G (group)
ck_tile::index_t g = tmp / out_strides[NDimSpatial + 1];
tmp -= g * out_strides[NDimSpatial + 1];
// Extract K (output channel)
ck_tile::index_t k = tmp;
// Accumulate in float
float v_acc = 0.0f;
// Loop over input channels
for(ck_tile::index_t c = 0; c < C; ++c)
{
// Loop over filter spatial dimensions
if constexpr(NDimSpatial == 1)
{
for(ck_tile::index_t x = 0; x < wei_spatial_lengths[0]; ++x)
{
// Calculate input spatial coordinate
ck_tile::long_index_t wi =
static_cast<ck_tile::long_index_t>(out_spatial_idx[0] *
conv_strides[0]) +
static_cast<ck_tile::long_index_t>(x * conv_dilations[0]) -
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
// Bounds check
if(wi >= 0 && wi < in_spatial_lengths[0])
{
std::array<ck_tile::index_t, 1> in_spatial = {static_cast<index_t>(wi)};
std::array<ck_tile::index_t, 1> wei_spatial = {x};
ck_tile::long_index_t in_idx =
detail::calculate_input_index<1>(n, g, c, in_spatial, in_strides);
ck_tile::long_index_t wei_idx = detail::calculate_weight_index<1>(
g, k, c, wei_spatial, wei_strides);
v_acc += type_convert<float>(p_in[in_idx]) *
type_convert<float>(p_wei[wei_idx]);
}
}
}
else if constexpr(NDimSpatial == 2)
{
for(ck_tile::index_t y = 0; y < wei_spatial_lengths[0]; ++y)
{
ck_tile::long_index_t hi =
static_cast<ck_tile::long_index_t>(out_spatial_idx[0] *
conv_strides[0]) +
static_cast<ck_tile::long_index_t>(y * conv_dilations[0]) -
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
for(ck_tile::index_t x = 0; x < wei_spatial_lengths[1]; ++x)
{
ck_tile::long_index_t wi =
static_cast<ck_tile::long_index_t>(out_spatial_idx[1] *
conv_strides[1]) +
static_cast<ck_tile::long_index_t>(x * conv_dilations[1]) -
static_cast<ck_tile::long_index_t>(in_left_pads[1]);
// Bounds check
if(hi >= 0 && hi < in_spatial_lengths[0] && wi >= 0 &&
wi < in_spatial_lengths[1])
{
std::array<ck_tile::index_t, 2> in_spatial = {
static_cast<index_t>(hi), static_cast<index_t>(wi)};
std::array<ck_tile::index_t, 2> wei_spatial = {y, x};
ck_tile::long_index_t in_idx = detail::calculate_input_index<2>(
n, g, c, in_spatial, in_strides);
ck_tile::long_index_t wei_idx = detail::calculate_weight_index<2>(
g, k, c, wei_spatial, wei_strides);
v_acc += type_convert<float>(p_in[in_idx]) *
type_convert<float>(p_wei[wei_idx]);
}
}
}
}
else if constexpr(NDimSpatial == 3)
{
for(ck_tile::index_t z = 0; z < wei_spatial_lengths[0]; ++z)
{
ck_tile::long_index_t di =
static_cast<ck_tile::long_index_t>(out_spatial_idx[0] *
conv_strides[0]) +
static_cast<ck_tile::long_index_t>(z * conv_dilations[0]) -
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
for(ck_tile::index_t y = 0; y < wei_spatial_lengths[1]; ++y)
{
ck_tile::long_index_t hi =
static_cast<ck_tile::long_index_t>(out_spatial_idx[1] *
conv_strides[1]) +
static_cast<ck_tile::long_index_t>(y * conv_dilations[1]) -
static_cast<ck_tile::long_index_t>(in_left_pads[1]);
for(ck_tile::index_t x = 0; x < wei_spatial_lengths[2]; ++x)
{
ck_tile::long_index_t wi =
static_cast<ck_tile::long_index_t>(out_spatial_idx[2] *
conv_strides[2]) +
static_cast<ck_tile::long_index_t>(x * conv_dilations[2]) -
static_cast<ck_tile::long_index_t>(in_left_pads[2]);
// Bounds check
if(di >= 0 && di < in_spatial_lengths[0] && hi >= 0 &&
hi < in_spatial_lengths[1] && wi >= 0 &&
wi < in_spatial_lengths[2])
{
std::array<ck_tile::index_t, 3> in_spatial = {
static_cast<index_t>(di),
static_cast<index_t>(hi),
static_cast<index_t>(wi)};
std::array<ck_tile::index_t, 3> wei_spatial = {z, y, x};
ck_tile::long_index_t in_idx = detail::calculate_input_index<3>(
n, g, c, in_spatial, in_strides);
ck_tile::long_index_t wei_idx =
detail::calculate_weight_index<3>(
g, k, c, wei_spatial, wei_strides);
v_acc += type_convert<float>(p_in[in_idx]) *
type_convert<float>(p_wei[wei_idx]);
}
}
}
}
}
}
// Convert accumulator to output type and write
p_out[ii] = type_convert<OutDataType>(v_acc);
}
}
};
// Host-side launcher for naive grouped convolution forward
template <ck_tile::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType>
CK_TILE_HOST float naive_grouped_conv_fwd(const InDataType* p_in_dev,
const WeiDataType* p_wei_dev,
OutDataType* p_out_dev,
ck_tile::index_t G,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t C,
std::vector<ck_tile::long_index_t> in_spatial_lengths,
std::vector<ck_tile::long_index_t> wei_spatial_lengths,
std::vector<ck_tile::long_index_t> out_spatial_lengths,
std::vector<ck_tile::long_index_t> conv_strides,
std::vector<ck_tile::long_index_t> conv_dilations,
std::vector<ck_tile::long_index_t> in_left_pads,
ck_tile::stream_config stream_config = {})
{
// Convert vectors to arrays (std::array can be passed by value to kernel)
auto in_spatial_arr = to_array_with_default<NDimSpatial>(in_spatial_lengths);
auto wei_spatial_arr = to_array_with_default<NDimSpatial>(wei_spatial_lengths);
auto out_spatial_arr = to_array_with_default<NDimSpatial>(out_spatial_lengths);
auto conv_strides_arr = to_array_with_default<NDimSpatial>(conv_strides);
auto conv_dilations_arr = to_array_with_default<NDimSpatial>(conv_dilations);
auto in_left_pads_arr = to_array_with_default<NDimSpatial>(in_left_pads, 0);
// Calculate grid size
ck_tile::long_index_t output_length = G * N * K;
for(ck_tile::index_t i = 0; i < NDimSpatial; ++i)
{
output_length *= out_spatial_lengths[i];
}
using KernelType =
naive_grouped_conv_fwd_kernel<NDimSpatial, InDataType, WeiDataType, OutDataType>;
constexpr ck_tile::index_t block_size = KernelType::kBlockSize;
const ck_tile::index_t grid_size = (output_length + block_size - 1) / block_size;
// Launch kernel
float elapsed_ms = launch_kernel(stream_config,
make_kernel(KernelType{},
dim3(grid_size),
dim3(block_size),
0, // dynamic shared memory size
p_in_dev,
p_wei_dev,
p_out_dev,
G,
N,
K,
C,
in_spatial_arr,
wei_spatial_arr,
out_spatial_arr,
conv_strides_arr,
conv_dilations_arr,
in_left_pads_arr));
return elapsed_ms;
}
} // namespace ck_tile