mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK Tile] batched contraction kernel generalizing (#3126)
* Add help for example * Refactore the compute reference batched contraction to manage stride-aware calculation and some code cleanings * Add stride-aware reference for batched contraction with independent D tensor layouts * Add -num_d argument for runtime D tensor count selection in batched contraction * Add stride vector arguments in example code for testing non-contiguous batched contraction inputs * Add descriptor-based architecture for batched contraction multi-dimensional stride support * Add multi-dimensional non-contiguous stride support to batched contraction, num_d = 0 * Add complete multi-dimensional stride support via descriptors * Enable vectorization in descriptor-based batched contraction. Add pad_tensor_view to local RunGemm * Clean up batched contraction: remove old UniversalGemmKernel path * Clean up batched contraction: remove legacy paths and finalize docs * Optimize batched contraction example: pass dimension sizes not vectors * correct the reference calculation, unsigned int to int * Fix batched_contraction C++17 build errors for gfx90a CI
This commit is contained in:
@@ -4,8 +4,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
#include <thread>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
@@ -13,110 +11,259 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Helper to apply elementwise operation with variable number of D tensors
|
||||
template <typename EDataType, typename AccDataType, typename CDEElementWise>
|
||||
struct ApplyCDEElementWise
|
||||
{
|
||||
template <typename... DValues>
|
||||
CK_TILE_HOST_DEVICE static void apply(EDataType& result,
|
||||
AccDataType sum,
|
||||
const CDEElementWise& cde_elementwise,
|
||||
DValues... d_vals)
|
||||
{
|
||||
if constexpr(sizeof...(DValues) == 0)
|
||||
{
|
||||
result = static_cast<EDataType>(sum);
|
||||
}
|
||||
else
|
||||
{
|
||||
cde_elementwise(
|
||||
result, ck_tile::type_convert<float>(sum), ck_tile::type_convert<float>(d_vals)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Helper to extract D values at a given offset using index sequence
|
||||
template <typename DDataType,
|
||||
ck_tile::index_t NumDTensor,
|
||||
typename Indices = std::make_index_sequence<NumDTensor>>
|
||||
struct ExtractDValues;
|
||||
|
||||
template <typename DDataType, ck_tile::index_t NumDTensor, std::size_t... Is>
|
||||
struct ExtractDValues<DDataType, NumDTensor, std::index_sequence<Is...>>
|
||||
{
|
||||
template <typename EDataType, typename AccDataType, typename CDEElementWise>
|
||||
CK_TILE_HOST static void
|
||||
apply_at_offsets(EDataType& result,
|
||||
AccDataType sum,
|
||||
const CDEElementWise& cde_elementwise,
|
||||
const std::array<ck_tile::HostTensor<DDataType>, NumDTensor>& ds_tensors,
|
||||
const std::array<std::size_t, NumDTensor>& d_offsets)
|
||||
{
|
||||
ApplyCDEElementWise<EDataType, AccDataType, CDEElementWise>::apply(
|
||||
result, sum, cde_elementwise, ds_tensors[Is].mData[d_offsets[Is]]...);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename DDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename CDEElementWise>
|
||||
typename CDEElementWise,
|
||||
ck_tile::index_t NumDTensor>
|
||||
|
||||
void calculate_reference_flat_indexing(
|
||||
void compute_reference_batched_contraction(
|
||||
const ck_tile::HostTensor<ADataType>& a_full_dims,
|
||||
const ck_tile::HostTensor<BDataType>& b_full_dims,
|
||||
const std::vector<ck_tile::HostTensor<DDataType>>& ds_full_dims_host,
|
||||
const std::array<ck_tile::HostTensor<DDataType>, NumDTensor>& ds_full_dims_host,
|
||||
ck_tile::HostTensor<EDataType>& e_full_dims_host_ref,
|
||||
ck_tile::index_t G_total,
|
||||
ck_tile::index_t M_total,
|
||||
ck_tile::index_t N_total,
|
||||
ck_tile::index_t K_total,
|
||||
const CDEElementWise& cde_elementwise)
|
||||
const CDEElementWise& cde_elementwise,
|
||||
const std::vector<ck_tile::index_t>& G_dims,
|
||||
const std::vector<ck_tile::index_t>& M_dims,
|
||||
const std::vector<ck_tile::index_t>& N_dims,
|
||||
const std::vector<ck_tile::index_t>& K_dims)
|
||||
{
|
||||
std::cout << "Calculating reference using optimized flat indexing with parallel processing..."
|
||||
std::cout << "Calculating reference using stride-aware indexing with parallel processing..."
|
||||
<< std::endl;
|
||||
|
||||
// Parallel computation over G and M dimensions using pattern from reference_batched_gemm.hpp
|
||||
// Extract stride information from tensor descriptors
|
||||
const auto a_strides = a_full_dims.get_strides();
|
||||
const auto b_strides = b_full_dims.get_strides();
|
||||
const auto e_strides = e_full_dims_host_ref.get_strides();
|
||||
|
||||
// Extract D tensor strides
|
||||
std::array<std::vector<std::size_t>, NumDTensor> ds_strides;
|
||||
for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
|
||||
{
|
||||
ds_strides[d] = ds_full_dims_host[d].get_strides();
|
||||
}
|
||||
|
||||
const ck_tile::index_t num_g_dims = G_dims.size();
|
||||
const ck_tile::index_t num_m_dims = M_dims.size();
|
||||
const ck_tile::index_t num_n_dims = N_dims.size();
|
||||
const ck_tile::index_t num_k_dims = K_dims.size();
|
||||
|
||||
// Helper lambda to compute linear index from flat indices using strides
|
||||
auto compute_a_offset = [&](ck_tile::index_t g_flat,
|
||||
ck_tile::index_t m_flat,
|
||||
ck_tile::index_t k_flat) -> std::size_t {
|
||||
std::size_t offset = 0;
|
||||
|
||||
// Decode G dimensions
|
||||
ck_tile::index_t temp = g_flat;
|
||||
for(int i = num_g_dims - 1; i >= 0; --i)
|
||||
{
|
||||
offset += (temp % G_dims[i]) * a_strides[i];
|
||||
temp /= G_dims[i];
|
||||
}
|
||||
|
||||
// Decode M dimensions
|
||||
temp = m_flat;
|
||||
for(int i = num_m_dims - 1; i >= 0; --i)
|
||||
{
|
||||
offset += (temp % M_dims[i]) * a_strides[num_g_dims + i];
|
||||
temp /= M_dims[i];
|
||||
}
|
||||
|
||||
// Decode K dimensions
|
||||
temp = k_flat;
|
||||
for(int i = num_k_dims - 1; i >= 0; --i)
|
||||
{
|
||||
offset += (temp % K_dims[i]) * a_strides[num_g_dims + num_m_dims + i];
|
||||
temp /= K_dims[i];
|
||||
}
|
||||
|
||||
return offset;
|
||||
};
|
||||
|
||||
auto compute_b_offset = [&](ck_tile::index_t g_flat,
|
||||
ck_tile::index_t n_flat,
|
||||
ck_tile::index_t k_flat) -> std::size_t {
|
||||
std::size_t offset = 0;
|
||||
|
||||
// Decode G dimensions
|
||||
ck_tile::index_t temp = g_flat;
|
||||
for(int i = num_g_dims - 1; i >= 0; --i)
|
||||
{
|
||||
offset += (temp % G_dims[i]) * b_strides[i];
|
||||
temp /= G_dims[i];
|
||||
}
|
||||
|
||||
// Decode N dimensions
|
||||
temp = n_flat;
|
||||
for(int i = num_n_dims - 1; i >= 0; --i)
|
||||
{
|
||||
offset += (temp % N_dims[i]) * b_strides[num_g_dims + i];
|
||||
temp /= N_dims[i];
|
||||
}
|
||||
|
||||
// Decode K dimensions
|
||||
temp = k_flat;
|
||||
for(int i = num_k_dims - 1; i >= 0; --i)
|
||||
{
|
||||
offset += (temp % K_dims[i]) * b_strides[num_g_dims + num_n_dims + i];
|
||||
temp /= K_dims[i];
|
||||
}
|
||||
|
||||
return offset;
|
||||
};
|
||||
|
||||
auto compute_e_offset = [&](ck_tile::index_t g_flat,
|
||||
ck_tile::index_t m_flat,
|
||||
ck_tile::index_t n_flat) -> std::size_t {
|
||||
std::size_t offset = 0;
|
||||
|
||||
// Decode G dimensions
|
||||
ck_tile::index_t temp = g_flat;
|
||||
for(int i = num_g_dims - 1; i >= 0; --i)
|
||||
{
|
||||
offset += (temp % G_dims[i]) * e_strides[i];
|
||||
temp /= G_dims[i];
|
||||
}
|
||||
|
||||
// Decode M dimensions
|
||||
temp = m_flat;
|
||||
for(int i = num_m_dims - 1; i >= 0; --i)
|
||||
{
|
||||
offset += (temp % M_dims[i]) * e_strides[num_g_dims + i];
|
||||
temp /= M_dims[i];
|
||||
}
|
||||
|
||||
// Decode N dimensions
|
||||
temp = n_flat;
|
||||
for(int i = num_n_dims - 1; i >= 0; --i)
|
||||
{
|
||||
offset += (temp % N_dims[i]) * e_strides[num_g_dims + num_m_dims + i];
|
||||
temp /= N_dims[i];
|
||||
}
|
||||
|
||||
return offset;
|
||||
};
|
||||
|
||||
// Helper to compute D tensor offset (D tensors have same shape as E: [G, M, N])
|
||||
auto compute_d_offset = [&](ck_tile::index_t g_flat,
|
||||
ck_tile::index_t m_flat,
|
||||
ck_tile::index_t n_flat,
|
||||
ck_tile::index_t d_idx) -> std::size_t {
|
||||
std::size_t offset = 0;
|
||||
const auto& d_strides = ds_strides[d_idx];
|
||||
|
||||
// Decode G dimensions
|
||||
ck_tile::index_t temp = g_flat;
|
||||
for(int i = num_g_dims - 1; i >= 0; --i)
|
||||
{
|
||||
offset += (temp % G_dims[i]) * d_strides[i];
|
||||
temp /= G_dims[i];
|
||||
}
|
||||
|
||||
// Decode M dimensions
|
||||
temp = m_flat;
|
||||
for(int i = num_m_dims - 1; i >= 0; --i)
|
||||
{
|
||||
offset += (temp % M_dims[i]) * d_strides[num_g_dims + i];
|
||||
temp /= M_dims[i];
|
||||
}
|
||||
|
||||
// Decode N dimensions
|
||||
temp = n_flat;
|
||||
for(int i = num_n_dims - 1; i >= 0; --i)
|
||||
{
|
||||
offset += (temp % N_dims[i]) * d_strides[num_g_dims + num_m_dims + i];
|
||||
temp /= N_dims[i];
|
||||
}
|
||||
|
||||
return offset;
|
||||
};
|
||||
|
||||
// Parallel computation over G and M dimensions
|
||||
auto f_gm = [&](auto g_flat, auto m_flat) {
|
||||
for(ck_tile::index_t n_flat = 0; n_flat < N_total; ++n_flat)
|
||||
{
|
||||
AccDataType sum = 0;
|
||||
|
||||
// Compute dot product over K dimension
|
||||
// Compute dot product over K dimension using stride-aware indexing
|
||||
for(ck_tile::index_t k_flat = 0; k_flat < K_total; ++k_flat)
|
||||
{
|
||||
auto a_val =
|
||||
a_full_dims.mData[g_flat * M_total * K_total + m_flat * K_total + k_flat];
|
||||
auto b_val =
|
||||
b_full_dims.mData[g_flat * N_total * K_total + n_flat * K_total + k_flat];
|
||||
const std::size_t a_offset = compute_a_offset(g_flat, m_flat, k_flat);
|
||||
const std::size_t b_offset = compute_b_offset(g_flat, n_flat, k_flat);
|
||||
|
||||
auto a_val = a_full_dims.mData[a_offset];
|
||||
auto b_val = b_full_dims.mData[b_offset];
|
||||
sum += static_cast<AccDataType>(a_val) * static_cast<AccDataType>(b_val);
|
||||
}
|
||||
|
||||
// Apply elementwise operation with D tensors
|
||||
EDataType result = static_cast<EDataType>(sum);
|
||||
if(ds_full_dims_host.size() == 0)
|
||||
// Compute output offset using strides
|
||||
const std::size_t e_offset = compute_e_offset(g_flat, m_flat, n_flat);
|
||||
|
||||
// Compute individual D tensor offsets using their respective strides
|
||||
std::array<std::size_t, NumDTensor> d_offsets;
|
||||
for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
|
||||
{
|
||||
;
|
||||
}
|
||||
else if(ds_full_dims_host.size() == 1)
|
||||
{
|
||||
cde_elementwise(result,
|
||||
ck_tile::type_convert<float>(sum),
|
||||
ck_tile::type_convert<float>(
|
||||
ds_full_dims_host[0].mData[g_flat * M_total * N_total +
|
||||
m_flat * N_total + n_flat]));
|
||||
}
|
||||
else if(ds_full_dims_host.size() == 2)
|
||||
{
|
||||
cde_elementwise(
|
||||
result,
|
||||
ck_tile::type_convert<float>(sum),
|
||||
ck_tile::type_convert<float>(
|
||||
ds_full_dims_host[0]
|
||||
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
|
||||
ck_tile::type_convert<float>(
|
||||
ds_full_dims_host[1]
|
||||
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
|
||||
}
|
||||
else if(ds_full_dims_host.size() == 3)
|
||||
{
|
||||
cde_elementwise(
|
||||
result,
|
||||
ck_tile::type_convert<float>(sum),
|
||||
ck_tile::type_convert<float>(
|
||||
ds_full_dims_host[0]
|
||||
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
|
||||
ck_tile::type_convert<float>(
|
||||
ds_full_dims_host[1]
|
||||
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
|
||||
ck_tile::type_convert<float>(
|
||||
ds_full_dims_host[2]
|
||||
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
|
||||
}
|
||||
else if(ds_full_dims_host.size() == 4)
|
||||
{
|
||||
cde_elementwise(
|
||||
result,
|
||||
ck_tile::type_convert<float>(sum),
|
||||
ck_tile::type_convert<float>(
|
||||
ds_full_dims_host[0]
|
||||
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
|
||||
ck_tile::type_convert<float>(
|
||||
ds_full_dims_host[1]
|
||||
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
|
||||
ck_tile::type_convert<float>(
|
||||
ds_full_dims_host[2]
|
||||
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
|
||||
ck_tile::type_convert<float>(
|
||||
ds_full_dims_host[3]
|
||||
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported NumDTensor for reference calculation");
|
||||
d_offsets[d] = compute_d_offset(g_flat, m_flat, n_flat, d);
|
||||
}
|
||||
|
||||
// Store result
|
||||
e_full_dims_host_ref.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat] =
|
||||
static_cast<EDataType>(result);
|
||||
// Apply elementwise operation with D tensors using compile-time dispatch
|
||||
EDataType result = static_cast<EDataType>(sum);
|
||||
ExtractDValues<DDataType, NumDTensor>::apply_at_offsets(
|
||||
result, sum, cde_elementwise, ds_full_dims_host, d_offsets);
|
||||
|
||||
// Store result using stride-aware indexing
|
||||
e_full_dims_host_ref.mData[e_offset] = static_cast<EDataType>(result);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -125,147 +272,4 @@ void calculate_reference_flat_indexing(
|
||||
make_ParallelTensorFunctor(f_gm, G_total, M_total)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename DDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename CDEElementWise>
|
||||
void calculate_reference_multi_dimensional(
|
||||
const HostTensor<ADataType>& a_full_dims,
|
||||
const HostTensor<BDataType>& b_full_dims,
|
||||
const std::vector<HostTensor<DDataType>>& ds_full_dims_host,
|
||||
HostTensor<EDataType>& e_full_dims_host_ref,
|
||||
const std::vector<index_t>& G_dims,
|
||||
const std::vector<index_t>& M_dims,
|
||||
const std::vector<index_t>& N_dims,
|
||||
const std::vector<index_t>& K_dims,
|
||||
const std::vector<index_t>& A_dims,
|
||||
const std::vector<index_t>& B_dims,
|
||||
const std::vector<index_t>& E_dims,
|
||||
const CDEElementWise& cde_elementwise)
|
||||
{
|
||||
std::cout << "Calculating reference using multi-dimensional indexing..." << std::endl;
|
||||
|
||||
std::vector<std::size_t> g_idx(G_dims.size());
|
||||
std::vector<std::size_t> m_idx(M_dims.size());
|
||||
std::vector<std::size_t> n_idx(N_dims.size());
|
||||
std::vector<std::size_t> k_idx(K_dims.size());
|
||||
std::vector<std::size_t> a_idx, b_idx, e_idx;
|
||||
|
||||
a_idx.reserve(A_dims.size());
|
||||
b_idx.reserve(B_dims.size());
|
||||
e_idx.reserve(E_dims.size());
|
||||
|
||||
auto calculate_total_elements = [](const std::vector<ck_tile::index_t>& dims) {
|
||||
return std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<ck_tile::index_t>());
|
||||
};
|
||||
|
||||
for(ck_tile::index_t g_flat = 0; g_flat < calculate_total_elements(G_dims); ++g_flat)
|
||||
{
|
||||
ck_tile::index_t temp = g_flat;
|
||||
for(int i = G_dims.size() - 1; i >= 0; --i)
|
||||
{
|
||||
g_idx[i] = temp % G_dims[i];
|
||||
temp /= G_dims[i];
|
||||
}
|
||||
|
||||
for(ck_tile::index_t m_flat = 0; m_flat < calculate_total_elements(M_dims); ++m_flat)
|
||||
{
|
||||
temp = m_flat;
|
||||
for(int i = M_dims.size() - 1; i >= 0; --i)
|
||||
{
|
||||
m_idx[i] = temp % M_dims[i];
|
||||
temp /= M_dims[i];
|
||||
}
|
||||
|
||||
for(ck_tile::index_t n_flat = 0; n_flat < calculate_total_elements(N_dims); ++n_flat)
|
||||
{
|
||||
temp = n_flat;
|
||||
for(int i = N_dims.size() - 1; i >= 0; --i)
|
||||
{
|
||||
n_idx[i] = temp % N_dims[i];
|
||||
temp /= N_dims[i];
|
||||
}
|
||||
|
||||
AccDataType sum = 0;
|
||||
|
||||
for(ck_tile::index_t k_flat = 0; k_flat < calculate_total_elements(K_dims);
|
||||
++k_flat)
|
||||
{
|
||||
temp = k_flat;
|
||||
for(int i = K_dims.size() - 1; i >= 0; --i)
|
||||
{
|
||||
k_idx[i] = temp % K_dims[i];
|
||||
temp /= K_dims[i];
|
||||
}
|
||||
|
||||
a_idx.clear();
|
||||
b_idx.clear();
|
||||
|
||||
a_idx.insert(a_idx.end(), g_idx.begin(), g_idx.end());
|
||||
a_idx.insert(a_idx.end(), m_idx.begin(), m_idx.end());
|
||||
a_idx.insert(a_idx.end(), k_idx.begin(), k_idx.end());
|
||||
|
||||
b_idx.insert(b_idx.end(), g_idx.begin(), g_idx.end());
|
||||
b_idx.insert(b_idx.end(), n_idx.begin(), n_idx.end());
|
||||
b_idx.insert(b_idx.end(), k_idx.begin(), k_idx.end());
|
||||
|
||||
auto a_val = a_full_dims(a_idx);
|
||||
auto b_val = b_full_dims(b_idx);
|
||||
|
||||
sum += static_cast<AccDataType>(a_val) * static_cast<AccDataType>(b_val);
|
||||
}
|
||||
|
||||
e_idx.clear();
|
||||
e_idx.insert(e_idx.end(), g_idx.begin(), g_idx.end());
|
||||
e_idx.insert(e_idx.end(), m_idx.begin(), m_idx.end());
|
||||
e_idx.insert(e_idx.end(), n_idx.begin(), n_idx.end());
|
||||
|
||||
EDataType result = static_cast<EDataType>(sum);
|
||||
if(ds_full_dims_host.size() == 0)
|
||||
{
|
||||
;
|
||||
}
|
||||
else if(ds_full_dims_host.size() == 1)
|
||||
{
|
||||
cde_elementwise(result,
|
||||
ck_tile::type_convert<float>(sum),
|
||||
ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)));
|
||||
}
|
||||
else if(ds_full_dims_host.size() == 2)
|
||||
{
|
||||
cde_elementwise(result,
|
||||
ck_tile::type_convert<float>(sum),
|
||||
ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)),
|
||||
ck_tile::type_convert<float>(ds_full_dims_host[1](e_idx)));
|
||||
}
|
||||
else if(ds_full_dims_host.size() == 3)
|
||||
{
|
||||
cde_elementwise(result,
|
||||
ck_tile::type_convert<float>(sum),
|
||||
ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)),
|
||||
ck_tile::type_convert<float>(ds_full_dims_host[1](e_idx)),
|
||||
ck_tile::type_convert<float>(ds_full_dims_host[2](e_idx)));
|
||||
}
|
||||
else if(ds_full_dims_host.size() == 4)
|
||||
{
|
||||
cde_elementwise(result,
|
||||
ck_tile::type_convert<float>(sum),
|
||||
ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)),
|
||||
ck_tile::type_convert<float>(ds_full_dims_host[1](e_idx)),
|
||||
ck_tile::type_convert<float>(ds_full_dims_host[2](e_idx)),
|
||||
ck_tile::type_convert<float>(ds_full_dims_host[3](e_idx)));
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported NumDTensor for reference calculation");
|
||||
}
|
||||
|
||||
e_full_dims_host_ref(e_idx) = static_cast<EDataType>(result);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp"
|
||||
#include "ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp"
|
||||
|
||||
/**
|
||||
@@ -60,23 +61,19 @@
|
||||
* Rather than implementing tensor contraction from scratch, this kernel leverages the highly
|
||||
* optimized `UniversalGemmKernel` as its computational backend.
|
||||
*
|
||||
* @subsection current_limitations Current Kernel Limitations
|
||||
* @subsection implementation_features Implementation Features
|
||||
*
|
||||
* **Layout Restrictions:**
|
||||
* - **Row-Major Only**: All tensors must use row-major memory layout
|
||||
* - **Packed Tensors**: Only contiguous/packed tensor layouts supported
|
||||
* - **Hardcoded Strides**: stride_A = K_total, stride_B = K_total, stride_E = N_total
|
||||
* - **D Tensor Layout**: All D tensors must match E tensor layout (stride_Ds = N_total)
|
||||
* **Stride Support:**
|
||||
* - Supports arbitrary multi-dimensional stride patterns
|
||||
* - Handles non-contiguous and padded tensor layouts
|
||||
* - Independent strides for each auxiliary D tensor
|
||||
* - Descriptor-based architecture with vectorization
|
||||
*
|
||||
* **Implementation Constraints:**
|
||||
* - **Fixed Stride Calculation**: Strides are automatically calculated and cannot be customized
|
||||
* - **No Column-Major**: Column-major or custom stride patterns not supported
|
||||
* - **No Strided Access**: Non-contiguous tensor slicing not supported
|
||||
*
|
||||
* **Future Enhancements:**
|
||||
* - Support for arbitrary stride patterns
|
||||
* - Column-major and mixed layout support
|
||||
* - Non-contiguous tensor operation support
|
||||
* **Architecture:**
|
||||
* - Uses TensorDescriptorUtils for stride-aware descriptor creation
|
||||
* - Custom RunGemm implementation with descriptor-based tensor views
|
||||
* - Reuses GemmPipeline and EpiloguePipeline for computation
|
||||
* - Split-K support via UniversalGemmKernel utilities
|
||||
*/
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -184,7 +181,10 @@ template <ck_tile::index_t NumDimG,
|
||||
ck_tile::index_t NumDimM,
|
||||
ck_tile::index_t NumDimN,
|
||||
ck_tile::index_t NumDimK,
|
||||
ck_tile::index_t NumDTensor = 0>
|
||||
ck_tile::index_t NumDTensor = 0,
|
||||
ck_tile::index_t VectorSizeA = 1,
|
||||
ck_tile::index_t VectorSizeB = 1,
|
||||
ck_tile::index_t VectorSizeE = 1>
|
||||
struct BatchedContractionKernelArgs
|
||||
{
|
||||
const void* a_ptr; ///< Pointer to input tensor A
|
||||
@@ -210,11 +210,46 @@ struct BatchedContractionKernelArgs
|
||||
ck_tile::index_t N_total; ///< Total N dimension: N0 * N1 * ... * N_{NumDimN-1}
|
||||
ck_tile::index_t K_total; ///< Total K dimension: K0 * K1 * ... * K_{NumDimK-1}
|
||||
|
||||
ck_tile::index_t stride_A; ///< Leading dimension stride for tensor A (row-major: K_total)
|
||||
ck_tile::index_t stride_B; ///< Leading dimension stride for tensor B (row-major: K_total)
|
||||
ck_tile::index_t
|
||||
stride_A; ///< Leading dimension stride for tensor A (for backward compatibility)
|
||||
ck_tile::index_t
|
||||
stride_B; ///< Leading dimension stride for tensor B (for backward compatibility)
|
||||
std::array<ck_tile::index_t, NumDTensor>
|
||||
stride_Ds; ///< Leading dimension strides for D tensors (row-major: N_total)
|
||||
ck_tile::index_t stride_E; ///< Leading dimension stride for tensor E (row-major: N_total)
|
||||
stride_Ds; ///< Leading dimension strides for D tensors (for backward compatibility)
|
||||
ck_tile::index_t
|
||||
stride_E; ///< Leading dimension stride for tensor E (for backward compatibility)
|
||||
|
||||
// Tensor descriptors (encode full multi-dimensional stride information with vectorization)
|
||||
using AGridDesc_M_K_ =
|
||||
decltype(TensorDescriptorUtils<NumDimG,
|
||||
NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeE>::Make_A_GridDescriptor_M_K({}, {}));
|
||||
using BGridDesc_N_K_ =
|
||||
decltype(TensorDescriptorUtils<NumDimG,
|
||||
NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeE>::Make_B_GridDescriptor_N_K({}, {}));
|
||||
using EGridDesc_M_N_ =
|
||||
decltype(TensorDescriptorUtils<NumDimG,
|
||||
NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeE>::Make_E_GridDescriptor_M_N({}, {}));
|
||||
|
||||
AGridDesc_M_K_ a_grid_desc_m_k; ///< Tensor descriptor for A[M, K] with actual strides
|
||||
BGridDesc_N_K_ b_grid_desc_n_k; ///< Tensor descriptor for B[N, K] with actual strides
|
||||
EGridDesc_M_N_ e_grid_desc_m_n; ///< Tensor descriptor for E[M, N] with actual strides
|
||||
std::array<EGridDesc_M_N_, NumDTensor>
|
||||
ds_grid_desc_m_n; ///< Descriptors for D tensors (same shape as E, independent strides)
|
||||
};
|
||||
|
||||
/// @brief GPU kernel for batched tensor contraction operations.
|
||||
@@ -274,10 +309,24 @@ struct BatchedContractionKernel
|
||||
static constexpr ck_tile::index_t kBlockSize =
|
||||
UniversalGemmKernel::kBlockSize; ///< GPU block size inherited from GEMM kernel
|
||||
|
||||
using KernelArgs =
|
||||
BatchedContractionKernelArgs<NumDimG, NumDimM, NumDimN, NumDimK, NumDTensor>; ///< Kernel
|
||||
///< argument
|
||||
///< structure
|
||||
// Tensor descriptor utilities with vectorization support
|
||||
using DescriptorUtils = TensorDescriptorUtils<NumDimG,
|
||||
NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
GemmPipeline::GetVectorSizeA(),
|
||||
GemmPipeline::GetVectorSizeB(),
|
||||
EpiloguePipeline::GetVectorSizeC()>;
|
||||
|
||||
// Kernel arguments with vectorization support
|
||||
using KernelArgs = BatchedContractionKernelArgs<NumDimG,
|
||||
NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
NumDTensor,
|
||||
GemmPipeline::GetVectorSizeA(),
|
||||
GemmPipeline::GetVectorSizeB(),
|
||||
EpiloguePipeline::GetVectorSizeC()>;
|
||||
|
||||
/// @brief Returns the kernel name for debugging and profiling purposes.
|
||||
/// @return Constant string identifier for this kernel
|
||||
@@ -326,6 +375,104 @@ struct BatchedContractionKernel
|
||||
TilePartitioner::GridSize(kargs.M_total, kargs.N_total), kargs.G_total, kargs.k_batch);
|
||||
}
|
||||
|
||||
/// @brief Executes GEMM computation with descriptor-based tensor views for arbitrary stride
|
||||
/// support
|
||||
///
|
||||
/// @details This function performs the core GEMM computation using tensor descriptors to handle
|
||||
/// arbitrary multi-dimensional stride patterns. It creates tensor views from
|
||||
/// pre-computed descriptors (stored in kargs), applies padding, creates tile windows,
|
||||
/// and executes the GemmPipeline and EpiloguePipeline.
|
||||
///
|
||||
/// @param a_ptr Pointer to input tensor A data (after batch and split-K offsets applied)
|
||||
/// @param b_ptr Pointer to input tensor B data (after batch and split-K offsets applied)
|
||||
/// @param ds_ptr Array of pointers to auxiliary D tensor data
|
||||
/// @param e_ptr Pointer to output tensor E data (after batch offset applied)
|
||||
/// @param smem_ptr Pointer to shared memory for tile operations
|
||||
/// @param kargs Kernel arguments containing tensor descriptors and dimension information
|
||||
/// @param k_size Size of K dimension for this split (for split-K support)
|
||||
/// @param i_m Starting M index for this block's tile
|
||||
/// @param i_n Starting N index for this block's tile
|
||||
CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
void* smem_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const index_t k_size,
|
||||
const index_t i_m,
|
||||
const index_t i_n)
|
||||
{
|
||||
// Create tensor views from descriptors (supports arbitrary stride patterns)
|
||||
auto a_tensor_view =
|
||||
make_tensor_view<address_space_enum::global>(a_ptr, kargs.a_grid_desc_m_k);
|
||||
auto b_tensor_view =
|
||||
make_tensor_view<address_space_enum::global>(b_ptr, kargs.b_grid_desc_n_k);
|
||||
auto e_tensor_view =
|
||||
make_tensor_view<address_space_enum::global>(e_ptr, kargs.e_grid_desc_m_n);
|
||||
|
||||
// Pad views for boundary handling and optimization (like UniversalGemmKernel)
|
||||
auto a_pad_view = pad_tensor_view(
|
||||
a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
|
||||
auto b_pad_view = pad_tensor_view(
|
||||
b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
|
||||
auto e_pad_view = pad_tensor_view(
|
||||
e_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadN>{});
|
||||
|
||||
// Create tile windows from PADDED views
|
||||
auto a_block_window = make_tile_window(
|
||||
a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, 0});
|
||||
|
||||
auto b_block_window = make_tile_window(
|
||||
b_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
{i_n, 0});
|
||||
|
||||
auto e_block_window = make_tile_window(
|
||||
e_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
|
||||
// Calculate number of K loops
|
||||
const index_t num_loop =
|
||||
__builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(k_size));
|
||||
|
||||
// Run GEMM Pipeline (same as UniversalGemmKernel, but with descriptor-based windows)
|
||||
using AElementWise = remove_cvref_t<typename GemmPipeline::AElementWise>;
|
||||
using BElementWise = remove_cvref_t<typename GemmPipeline::BElementWise>;
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}(
|
||||
a_block_window, AElementWise{}, b_block_window, BElementWise{}, num_loop, smem_ptr);
|
||||
|
||||
// Create D windows from descriptors (for each D tensor)
|
||||
auto ds_block_windows = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DDataType = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
|
||||
const DDataType* d_ptr = static_cast<const DDataType*>(ds_ptr[i]);
|
||||
|
||||
auto d_tensor_view =
|
||||
make_tensor_view<address_space_enum::global>(d_ptr, kargs.ds_grid_desc_m_n[i]);
|
||||
|
||||
return make_tile_window(d_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
// Run Epilogue Pipeline with descriptor-based D windows
|
||||
EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_windows, smem_ptr);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr KernelArgs
|
||||
MakeKernelArgs(const BatchedContractionHostArgs<NumDTensor>& host_args)
|
||||
{
|
||||
@@ -435,6 +582,22 @@ struct BatchedContractionKernel
|
||||
kargs.K_total *= kargs.K_dims[i];
|
||||
}
|
||||
|
||||
// Create tensor descriptors on host using actual dims and strides
|
||||
kargs.a_grid_desc_m_k =
|
||||
DescriptorUtils::Make_A_GridDescriptor_M_K(host_args.A_dims, host_args.A_strides);
|
||||
kargs.b_grid_desc_n_k =
|
||||
DescriptorUtils::Make_B_GridDescriptor_N_K(host_args.B_dims, host_args.B_strides);
|
||||
kargs.e_grid_desc_m_n =
|
||||
DescriptorUtils::Make_E_GridDescriptor_M_N(host_args.E_dims, host_args.E_strides);
|
||||
|
||||
// Create D descriptors with their own strides (same shape as E, independent strides)
|
||||
for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
|
||||
{
|
||||
kargs.ds_grid_desc_m_n[d] = DescriptorUtils::Make_E_GridDescriptor_M_N(
|
||||
host_args.Ds_dims[d], host_args.Ds_strides[d]);
|
||||
}
|
||||
|
||||
// Keep simple strides for backward compatibility
|
||||
kargs.stride_A = kargs.K_total;
|
||||
kargs.stride_B = kargs.K_total;
|
||||
kargs.stride_E = kargs.N_total;
|
||||
@@ -468,8 +631,8 @@ struct BatchedContractionKernel
|
||||
const ck_tile::index_t i_n =
|
||||
__builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
const auto i_batch_flat = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
const auto i_splitk = __builtin_amdgcn_readfirstlane(blockIdx.z);
|
||||
const auto i_batch_flat = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
[[maybe_unused]] const auto i_splitk = __builtin_amdgcn_readfirstlane(blockIdx.z);
|
||||
|
||||
// Calculate batch offsets for each tensor
|
||||
const auto batch_offset_A = i_batch_flat * kargs.batch_stride_A;
|
||||
@@ -487,6 +650,10 @@ struct BatchedContractionKernel
|
||||
ds_batch_ptr[i] = static_cast<const DDataType*>(kargs.ds_ptr[i]) + batch_offset_D;
|
||||
});
|
||||
|
||||
// Allocate shared memory
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
// Use UniversalGemmKernel's SplitKBatchOffset for split-K calculation
|
||||
typename UniversalGemmKernel::KernelArgs gemm_kargs{{a_ptr},
|
||||
{b_ptr},
|
||||
ds_batch_ptr,
|
||||
@@ -503,19 +670,19 @@ struct BatchedContractionKernel
|
||||
const typename UniversalGemmKernel::SplitKBatchOffset splitk_batch_offset(gemm_kargs,
|
||||
i_splitk);
|
||||
|
||||
const ADataType* a_ptr_final = a_ptr + splitk_batch_offset.as_k_split_offset[0];
|
||||
const BDataType* b_ptr_final = b_ptr + splitk_batch_offset.bs_k_split_offset[0];
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
// Apply K-split offsets and run descriptor-based RunGemm
|
||||
const ADataType* a_ptr_split = a_ptr + splitk_batch_offset.as_k_split_offset[0];
|
||||
const BDataType* b_ptr_split = b_ptr + splitk_batch_offset.bs_k_split_offset[0];
|
||||
|
||||
UniversalGemmKernel::RunGemm({a_ptr_final},
|
||||
{b_ptr_final},
|
||||
ds_batch_ptr,
|
||||
e_ptr,
|
||||
smem_ptr,
|
||||
gemm_kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
RunGemm(a_ptr_split,
|
||||
b_ptr_split,
|
||||
ds_batch_ptr,
|
||||
e_ptr,
|
||||
smem_ptr,
|
||||
kargs,
|
||||
splitk_batch_offset.splitted_k,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -13,8 +13,8 @@
|
||||
* dimensions for GEMM operations. These functions transform multi-dimensional tensors into
|
||||
* 2D matrix descriptors by removing batch dimensions and flattening the remaining dimensions.
|
||||
*
|
||||
* These utilities are currently not used in the main batched contraction kernel but are preserved
|
||||
* for future implementations that may require explicit tensor descriptor creation.
|
||||
* These utilities are used by BatchedContractionKernel to create stride-aware descriptors
|
||||
* that support arbitrary multi-dimensional non-contiguous tensor layouts.
|
||||
*/
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -30,7 +30,10 @@ namespace ck_tile {
|
||||
template <ck_tile::index_t NumDimG,
|
||||
ck_tile::index_t NumDimM,
|
||||
ck_tile::index_t NumDimN,
|
||||
ck_tile::index_t NumDimK>
|
||||
ck_tile::index_t NumDimK,
|
||||
ck_tile::index_t VectorSizeA,
|
||||
ck_tile::index_t VectorSizeB,
|
||||
ck_tile::index_t VectorSizeE>
|
||||
struct TensorDescriptorUtils
|
||||
{
|
||||
/// @brief Creates a tensor descriptor for input tensor A with batch dimensions removed.
|
||||
@@ -62,9 +65,9 @@ struct TensorDescriptorUtils
|
||||
const auto dims_M = get_container_subset(A_dims_M_K, A_dims_M_ids);
|
||||
const auto dims_K = get_container_subset(A_dims_M_K, A_dims_K_ids);
|
||||
|
||||
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...] Discriptor
|
||||
const auto A_grid_desc_Ms_Ks =
|
||||
ck_tile::make_naive_tensor_descriptor(A_dims_M_K, A_strides_M_K);
|
||||
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...] Descriptor with vector size
|
||||
const auto A_grid_desc_Ms_Ks = ck_tile::make_naive_tensor_descriptor(
|
||||
A_dims_M_K, A_strides_M_K, number<VectorSizeA>{}, number<1>{});
|
||||
|
||||
// transformed tensor to flatten M and K dimensions [M_total = M0 * M1 * M2 * ... , K_total
|
||||
// = K0 * K1 * K2 * ...]
|
||||
@@ -106,9 +109,9 @@ struct TensorDescriptorUtils
|
||||
const auto dims_N = get_container_subset(B_dims_N_K, B_dims_N_ids);
|
||||
const auto dims_K = get_container_subset(B_dims_N_K, B_dims_K_ids);
|
||||
|
||||
// naive tensor B[N0, N1, N2, ..., K0, K1, K2...] Discriptor
|
||||
const auto B_grid_desc_Ns_Ks =
|
||||
ck_tile::make_naive_tensor_descriptor(B_dims_N_K, B_strides_N_K);
|
||||
// naive tensor B[N0, N1, N2, ..., K0, K1, K2...] Descriptor with vector size
|
||||
const auto B_grid_desc_Ns_Ks = ck_tile::make_naive_tensor_descriptor(
|
||||
B_dims_N_K, B_strides_N_K, number<VectorSizeB>{}, number<1>{});
|
||||
|
||||
// transformed tensor to flatten N and K dimensions [N_total = N0 * N1 * N2 * ... , K_total
|
||||
// = K0 * K1 * K2 * ...]
|
||||
@@ -150,9 +153,9 @@ struct TensorDescriptorUtils
|
||||
const auto dims_M = get_container_subset(E_dims_M_N, E_dims_M_ids);
|
||||
const auto dims_N = get_container_subset(E_dims_M_N, E_dims_N_ids);
|
||||
|
||||
// naive tensor E[M0, M1, M2, ..., N0, N1, N2...] Discriptor
|
||||
const auto E_grid_desc_Ms_Ns =
|
||||
ck_tile::make_naive_tensor_descriptor(E_dims_M_N, E_strides_M_N);
|
||||
// naive tensor E[M0, M1, M2, ..., N0, N1, N2...] Descriptor with vector size
|
||||
const auto E_grid_desc_Ms_Ns = ck_tile::make_naive_tensor_descriptor(
|
||||
E_dims_M_N, E_strides_M_N, number<VectorSizeE>{}, number<1>{});
|
||||
|
||||
// transformed tensor to flatten M and N dimensions [M_total = M0 * M1 * M2 * ... ,
|
||||
// N_total = N0 * N1 * N2 * ...]
|
||||
|
||||
Reference in New Issue
Block a user