// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #pragma once #include "ck_tile/core.hpp" #include #include namespace ck_tile { // Helper function to convert std::vector to std::array for kernel parameters template inline std::array to_array(const std::vector& vec) { std::array arr; for(ck_tile::index_t i = 0; i < NDimSpatial; ++i) { arr[i] = vec[i]; } return arr; } // Helper to fill missing dimensions with default value template inline std::array to_array_with_default(const std::vector& vec, ck_tile::long_index_t default_val = 1) { std::array arr; for(ck_tile::index_t i = 0; i < NDimSpatial; ++i) { arr[i] = (static_cast(i) < vec.size()) ? vec[i] : default_val; } return arr; } // Index calculation helpers for GPU reference kernels namespace detail { // Calculate linear input index for grouped convolution // Layout: [N, spatial..., G, C] template inline __device__ long_index_t calculate_input_index(index_t n, index_t g, index_t c, const std::array& spatial_idx, const std::array& strides) { long_index_t idx = n * strides[0]; for(index_t i = 0; i < NDimSpatial; ++i) idx += spatial_idx[i] * strides[i + 1]; idx += g * strides[NDimSpatial + 1] + c; return idx; } // Calculate linear weight index for grouped convolution // Layout: [G, K, spatial..., C] template inline __device__ long_index_t calculate_weight_index(index_t g, index_t k, index_t c, const std::array& spatial_idx, const std::array& strides) { long_index_t idx = g * strides[0] + k * strides[1]; for(index_t i = 0; i < NDimSpatial; ++i) idx += spatial_idx[i] * strides[i + 2]; idx += c * strides[NDimSpatial + 2]; return idx; } // Calculate linear output index for grouped convolution // Layout: [N, spatial..., G, K] template inline __device__ long_index_t calculate_output_index(index_t n, index_t g, index_t k, const std::array& spatial_idx, const std::array& strides) { long_index_t idx = n * strides[0]; for(index_t i = 0; i < NDimSpatial; ++i) idx += spatial_idx[i] * strides[i + 1]; idx += g * strides[NDimSpatial + 1] + k; return idx; } } // namespace detail } // namespace ck_tile