[CK Tile] batched contraction kernel generalizing (#3126)

* Add help for example

* Refactore the compute reference batched contraction to manage stride-aware calculation and some code cleanings

* Add stride-aware reference for batched contraction with independent D tensor layouts

* Add -num_d argument for runtime D tensor count selection in batched contraction

* Add stride vector arguments in example code for testing non-contiguous batched contraction inputs

* Add descriptor-based architecture for batched contraction multi-dimensional stride support

* Add multi-dimensional non-contiguous stride support to batched contraction, num_d = 0

* Add complete multi-dimensional stride support via descriptors

* Enable vectorization in descriptor-based batched contraction. Add pad_tensor_view to local RunGemm

* Clean up batched contraction: remove old UniversalGemmKernel path

* Clean up batched contraction: remove legacy paths and finalize docs

* Optimize batched contraction example: pass dimension sizes not vectors

* correct the reference calculation, unsigned int to int

* Fix batched_contraction C++17 build errors for gfx90a CI
This commit is contained in:
msaffari-amd
2025-12-02 13:30:27 +01:00
committed by GitHub
parent d3f37ebf6c
commit 2d3020e5b0
6 changed files with 694 additions and 313 deletions

View File

@@ -4,8 +4,6 @@
#pragma once
#include <cstdlib>
#include <functional>
#include <numeric>
#include <thread>
#include "ck_tile/core.hpp"
@@ -13,110 +11,259 @@
namespace ck_tile {
// Helper to apply elementwise operation with variable number of D tensors
template <typename EDataType, typename AccDataType, typename CDEElementWise>
struct ApplyCDEElementWise
{
template <typename... DValues>
CK_TILE_HOST_DEVICE static void apply(EDataType& result,
AccDataType sum,
const CDEElementWise& cde_elementwise,
DValues... d_vals)
{
if constexpr(sizeof...(DValues) == 0)
{
result = static_cast<EDataType>(sum);
}
else
{
cde_elementwise(
result, ck_tile::type_convert<float>(sum), ck_tile::type_convert<float>(d_vals)...);
}
}
};
// Helper to extract D values at a given offset using index sequence
template <typename DDataType,
ck_tile::index_t NumDTensor,
typename Indices = std::make_index_sequence<NumDTensor>>
struct ExtractDValues;
template <typename DDataType, ck_tile::index_t NumDTensor, std::size_t... Is>
struct ExtractDValues<DDataType, NumDTensor, std::index_sequence<Is...>>
{
template <typename EDataType, typename AccDataType, typename CDEElementWise>
CK_TILE_HOST static void
apply_at_offsets(EDataType& result,
AccDataType sum,
const CDEElementWise& cde_elementwise,
const std::array<ck_tile::HostTensor<DDataType>, NumDTensor>& ds_tensors,
const std::array<std::size_t, NumDTensor>& d_offsets)
{
ApplyCDEElementWise<EDataType, AccDataType, CDEElementWise>::apply(
result, sum, cde_elementwise, ds_tensors[Is].mData[d_offsets[Is]]...);
}
};
template <typename ADataType,
typename BDataType,
typename DDataType,
typename EDataType,
typename AccDataType,
typename CDEElementWise>
typename CDEElementWise,
ck_tile::index_t NumDTensor>
void calculate_reference_flat_indexing(
void compute_reference_batched_contraction(
const ck_tile::HostTensor<ADataType>& a_full_dims,
const ck_tile::HostTensor<BDataType>& b_full_dims,
const std::vector<ck_tile::HostTensor<DDataType>>& ds_full_dims_host,
const std::array<ck_tile::HostTensor<DDataType>, NumDTensor>& ds_full_dims_host,
ck_tile::HostTensor<EDataType>& e_full_dims_host_ref,
ck_tile::index_t G_total,
ck_tile::index_t M_total,
ck_tile::index_t N_total,
ck_tile::index_t K_total,
const CDEElementWise& cde_elementwise)
const CDEElementWise& cde_elementwise,
const std::vector<ck_tile::index_t>& G_dims,
const std::vector<ck_tile::index_t>& M_dims,
const std::vector<ck_tile::index_t>& N_dims,
const std::vector<ck_tile::index_t>& K_dims)
{
std::cout << "Calculating reference using optimized flat indexing with parallel processing..."
std::cout << "Calculating reference using stride-aware indexing with parallel processing..."
<< std::endl;
// Parallel computation over G and M dimensions using pattern from reference_batched_gemm.hpp
// Extract stride information from tensor descriptors
const auto a_strides = a_full_dims.get_strides();
const auto b_strides = b_full_dims.get_strides();
const auto e_strides = e_full_dims_host_ref.get_strides();
// Extract D tensor strides
std::array<std::vector<std::size_t>, NumDTensor> ds_strides;
for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
{
ds_strides[d] = ds_full_dims_host[d].get_strides();
}
const ck_tile::index_t num_g_dims = G_dims.size();
const ck_tile::index_t num_m_dims = M_dims.size();
const ck_tile::index_t num_n_dims = N_dims.size();
const ck_tile::index_t num_k_dims = K_dims.size();
// Helper lambda to compute linear index from flat indices using strides
auto compute_a_offset = [&](ck_tile::index_t g_flat,
ck_tile::index_t m_flat,
ck_tile::index_t k_flat) -> std::size_t {
std::size_t offset = 0;
// Decode G dimensions
ck_tile::index_t temp = g_flat;
for(int i = num_g_dims - 1; i >= 0; --i)
{
offset += (temp % G_dims[i]) * a_strides[i];
temp /= G_dims[i];
}
// Decode M dimensions
temp = m_flat;
for(int i = num_m_dims - 1; i >= 0; --i)
{
offset += (temp % M_dims[i]) * a_strides[num_g_dims + i];
temp /= M_dims[i];
}
// Decode K dimensions
temp = k_flat;
for(int i = num_k_dims - 1; i >= 0; --i)
{
offset += (temp % K_dims[i]) * a_strides[num_g_dims + num_m_dims + i];
temp /= K_dims[i];
}
return offset;
};
auto compute_b_offset = [&](ck_tile::index_t g_flat,
ck_tile::index_t n_flat,
ck_tile::index_t k_flat) -> std::size_t {
std::size_t offset = 0;
// Decode G dimensions
ck_tile::index_t temp = g_flat;
for(int i = num_g_dims - 1; i >= 0; --i)
{
offset += (temp % G_dims[i]) * b_strides[i];
temp /= G_dims[i];
}
// Decode N dimensions
temp = n_flat;
for(int i = num_n_dims - 1; i >= 0; --i)
{
offset += (temp % N_dims[i]) * b_strides[num_g_dims + i];
temp /= N_dims[i];
}
// Decode K dimensions
temp = k_flat;
for(int i = num_k_dims - 1; i >= 0; --i)
{
offset += (temp % K_dims[i]) * b_strides[num_g_dims + num_n_dims + i];
temp /= K_dims[i];
}
return offset;
};
auto compute_e_offset = [&](ck_tile::index_t g_flat,
ck_tile::index_t m_flat,
ck_tile::index_t n_flat) -> std::size_t {
std::size_t offset = 0;
// Decode G dimensions
ck_tile::index_t temp = g_flat;
for(int i = num_g_dims - 1; i >= 0; --i)
{
offset += (temp % G_dims[i]) * e_strides[i];
temp /= G_dims[i];
}
// Decode M dimensions
temp = m_flat;
for(int i = num_m_dims - 1; i >= 0; --i)
{
offset += (temp % M_dims[i]) * e_strides[num_g_dims + i];
temp /= M_dims[i];
}
// Decode N dimensions
temp = n_flat;
for(int i = num_n_dims - 1; i >= 0; --i)
{
offset += (temp % N_dims[i]) * e_strides[num_g_dims + num_m_dims + i];
temp /= N_dims[i];
}
return offset;
};
// Helper to compute D tensor offset (D tensors have same shape as E: [G, M, N])
auto compute_d_offset = [&](ck_tile::index_t g_flat,
ck_tile::index_t m_flat,
ck_tile::index_t n_flat,
ck_tile::index_t d_idx) -> std::size_t {
std::size_t offset = 0;
const auto& d_strides = ds_strides[d_idx];
// Decode G dimensions
ck_tile::index_t temp = g_flat;
for(int i = num_g_dims - 1; i >= 0; --i)
{
offset += (temp % G_dims[i]) * d_strides[i];
temp /= G_dims[i];
}
// Decode M dimensions
temp = m_flat;
for(int i = num_m_dims - 1; i >= 0; --i)
{
offset += (temp % M_dims[i]) * d_strides[num_g_dims + i];
temp /= M_dims[i];
}
// Decode N dimensions
temp = n_flat;
for(int i = num_n_dims - 1; i >= 0; --i)
{
offset += (temp % N_dims[i]) * d_strides[num_g_dims + num_m_dims + i];
temp /= N_dims[i];
}
return offset;
};
// Parallel computation over G and M dimensions
auto f_gm = [&](auto g_flat, auto m_flat) {
for(ck_tile::index_t n_flat = 0; n_flat < N_total; ++n_flat)
{
AccDataType sum = 0;
// Compute dot product over K dimension
// Compute dot product over K dimension using stride-aware indexing
for(ck_tile::index_t k_flat = 0; k_flat < K_total; ++k_flat)
{
auto a_val =
a_full_dims.mData[g_flat * M_total * K_total + m_flat * K_total + k_flat];
auto b_val =
b_full_dims.mData[g_flat * N_total * K_total + n_flat * K_total + k_flat];
const std::size_t a_offset = compute_a_offset(g_flat, m_flat, k_flat);
const std::size_t b_offset = compute_b_offset(g_flat, n_flat, k_flat);
auto a_val = a_full_dims.mData[a_offset];
auto b_val = b_full_dims.mData[b_offset];
sum += static_cast<AccDataType>(a_val) * static_cast<AccDataType>(b_val);
}
// Apply elementwise operation with D tensors
EDataType result = static_cast<EDataType>(sum);
if(ds_full_dims_host.size() == 0)
// Compute output offset using strides
const std::size_t e_offset = compute_e_offset(g_flat, m_flat, n_flat);
// Compute individual D tensor offsets using their respective strides
std::array<std::size_t, NumDTensor> d_offsets;
for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
{
;
}
else if(ds_full_dims_host.size() == 1)
{
cde_elementwise(result,
ck_tile::type_convert<float>(sum),
ck_tile::type_convert<float>(
ds_full_dims_host[0].mData[g_flat * M_total * N_total +
m_flat * N_total + n_flat]));
}
else if(ds_full_dims_host.size() == 2)
{
cde_elementwise(
result,
ck_tile::type_convert<float>(sum),
ck_tile::type_convert<float>(
ds_full_dims_host[0]
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
ck_tile::type_convert<float>(
ds_full_dims_host[1]
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
}
else if(ds_full_dims_host.size() == 3)
{
cde_elementwise(
result,
ck_tile::type_convert<float>(sum),
ck_tile::type_convert<float>(
ds_full_dims_host[0]
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
ck_tile::type_convert<float>(
ds_full_dims_host[1]
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
ck_tile::type_convert<float>(
ds_full_dims_host[2]
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
}
else if(ds_full_dims_host.size() == 4)
{
cde_elementwise(
result,
ck_tile::type_convert<float>(sum),
ck_tile::type_convert<float>(
ds_full_dims_host[0]
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
ck_tile::type_convert<float>(
ds_full_dims_host[1]
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
ck_tile::type_convert<float>(
ds_full_dims_host[2]
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
ck_tile::type_convert<float>(
ds_full_dims_host[3]
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
}
else
{
throw std::runtime_error("Unsupported NumDTensor for reference calculation");
d_offsets[d] = compute_d_offset(g_flat, m_flat, n_flat, d);
}
// Store result
e_full_dims_host_ref.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat] =
static_cast<EDataType>(result);
// Apply elementwise operation with D tensors using compile-time dispatch
EDataType result = static_cast<EDataType>(sum);
ExtractDValues<DDataType, NumDTensor>::apply_at_offsets(
result, sum, cde_elementwise, ds_full_dims_host, d_offsets);
// Store result using stride-aware indexing
e_full_dims_host_ref.mData[e_offset] = static_cast<EDataType>(result);
}
};
@@ -125,147 +272,4 @@ void calculate_reference_flat_indexing(
make_ParallelTensorFunctor(f_gm, G_total, M_total)(std::thread::hardware_concurrency());
}
template <typename ADataType,
typename BDataType,
typename DDataType,
typename EDataType,
typename AccDataType,
typename CDEElementWise>
void calculate_reference_multi_dimensional(
const HostTensor<ADataType>& a_full_dims,
const HostTensor<BDataType>& b_full_dims,
const std::vector<HostTensor<DDataType>>& ds_full_dims_host,
HostTensor<EDataType>& e_full_dims_host_ref,
const std::vector<index_t>& G_dims,
const std::vector<index_t>& M_dims,
const std::vector<index_t>& N_dims,
const std::vector<index_t>& K_dims,
const std::vector<index_t>& A_dims,
const std::vector<index_t>& B_dims,
const std::vector<index_t>& E_dims,
const CDEElementWise& cde_elementwise)
{
std::cout << "Calculating reference using multi-dimensional indexing..." << std::endl;
std::vector<std::size_t> g_idx(G_dims.size());
std::vector<std::size_t> m_idx(M_dims.size());
std::vector<std::size_t> n_idx(N_dims.size());
std::vector<std::size_t> k_idx(K_dims.size());
std::vector<std::size_t> a_idx, b_idx, e_idx;
a_idx.reserve(A_dims.size());
b_idx.reserve(B_dims.size());
e_idx.reserve(E_dims.size());
auto calculate_total_elements = [](const std::vector<ck_tile::index_t>& dims) {
return std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<ck_tile::index_t>());
};
for(ck_tile::index_t g_flat = 0; g_flat < calculate_total_elements(G_dims); ++g_flat)
{
ck_tile::index_t temp = g_flat;
for(int i = G_dims.size() - 1; i >= 0; --i)
{
g_idx[i] = temp % G_dims[i];
temp /= G_dims[i];
}
for(ck_tile::index_t m_flat = 0; m_flat < calculate_total_elements(M_dims); ++m_flat)
{
temp = m_flat;
for(int i = M_dims.size() - 1; i >= 0; --i)
{
m_idx[i] = temp % M_dims[i];
temp /= M_dims[i];
}
for(ck_tile::index_t n_flat = 0; n_flat < calculate_total_elements(N_dims); ++n_flat)
{
temp = n_flat;
for(int i = N_dims.size() - 1; i >= 0; --i)
{
n_idx[i] = temp % N_dims[i];
temp /= N_dims[i];
}
AccDataType sum = 0;
for(ck_tile::index_t k_flat = 0; k_flat < calculate_total_elements(K_dims);
++k_flat)
{
temp = k_flat;
for(int i = K_dims.size() - 1; i >= 0; --i)
{
k_idx[i] = temp % K_dims[i];
temp /= K_dims[i];
}
a_idx.clear();
b_idx.clear();
a_idx.insert(a_idx.end(), g_idx.begin(), g_idx.end());
a_idx.insert(a_idx.end(), m_idx.begin(), m_idx.end());
a_idx.insert(a_idx.end(), k_idx.begin(), k_idx.end());
b_idx.insert(b_idx.end(), g_idx.begin(), g_idx.end());
b_idx.insert(b_idx.end(), n_idx.begin(), n_idx.end());
b_idx.insert(b_idx.end(), k_idx.begin(), k_idx.end());
auto a_val = a_full_dims(a_idx);
auto b_val = b_full_dims(b_idx);
sum += static_cast<AccDataType>(a_val) * static_cast<AccDataType>(b_val);
}
e_idx.clear();
e_idx.insert(e_idx.end(), g_idx.begin(), g_idx.end());
e_idx.insert(e_idx.end(), m_idx.begin(), m_idx.end());
e_idx.insert(e_idx.end(), n_idx.begin(), n_idx.end());
EDataType result = static_cast<EDataType>(sum);
if(ds_full_dims_host.size() == 0)
{
;
}
else if(ds_full_dims_host.size() == 1)
{
cde_elementwise(result,
ck_tile::type_convert<float>(sum),
ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)));
}
else if(ds_full_dims_host.size() == 2)
{
cde_elementwise(result,
ck_tile::type_convert<float>(sum),
ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)),
ck_tile::type_convert<float>(ds_full_dims_host[1](e_idx)));
}
else if(ds_full_dims_host.size() == 3)
{
cde_elementwise(result,
ck_tile::type_convert<float>(sum),
ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)),
ck_tile::type_convert<float>(ds_full_dims_host[1](e_idx)),
ck_tile::type_convert<float>(ds_full_dims_host[2](e_idx)));
}
else if(ds_full_dims_host.size() == 4)
{
cde_elementwise(result,
ck_tile::type_convert<float>(sum),
ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)),
ck_tile::type_convert<float>(ds_full_dims_host[1](e_idx)),
ck_tile::type_convert<float>(ds_full_dims_host[2](e_idx)),
ck_tile::type_convert<float>(ds_full_dims_host[3](e_idx)));
}
else
{
throw std::runtime_error("Unsupported NumDTensor for reference calculation");
}
e_full_dims_host_ref(e_idx) = static_cast<EDataType>(result);
}
}
}
}
} // namespace ck_tile