[CK Tile] batched contraction kernel generalizing (#3126)

* Add help for example

* Refactore the compute reference batched contraction to manage stride-aware calculation and some code cleanings

* Add stride-aware reference for batched contraction with independent D tensor layouts

* Add -num_d argument for runtime D tensor count selection in batched contraction

* Add stride vector arguments in example code for testing non-contiguous batched contraction inputs

* Add descriptor-based architecture for batched contraction multi-dimensional stride support

* Add multi-dimensional non-contiguous stride support to batched contraction, num_d = 0

* Add complete multi-dimensional stride support via descriptors

* Enable vectorization in descriptor-based batched contraction. Add pad_tensor_view to local RunGemm

* Clean up batched contraction: remove old UniversalGemmKernel path

* Clean up batched contraction: remove legacy paths and finalize docs

* Optimize batched contraction example: pass dimension sizes not vectors

* correct the reference calculation, unsigned int to int

* Fix batched_contraction C++17 build errors for gfx90a CI

[ROCm/composable_kernel commit: 2d3020e5b0]
This commit is contained in:
msaffari-amd
2025-12-02 13:30:27 +01:00
committed by GitHub
parent 3ba598e05d
commit b06b6e684c
6 changed files with 694 additions and 313 deletions

View File

@@ -219,9 +219,7 @@ float batched_contraction(const ck_tile::BatchedContractionHostArgs<DsDataType::
HANDLE_CASE(2, 1, 1, 1);
HANDLE_CASE(2, 2, 2, 1);
HANDLE_CASE(1, 2, 1, 1);
HANDLE_CASE(1, 1, 1, 2);
HANDLE_CASE(2, 2, 2, 2);
HANDLE_CASE(4, 4, 4, 4);
throw std::runtime_error(
"Unsupported dimension combination: G=" + std::to_string(num_g_dims) +

View File

@@ -42,17 +42,83 @@ using AccDataType = ContractionTypes::AccDataType;
using EDataType = ContractionTypes::EDataType;
using DDataType = ContractionTypes::DDataType;
void print_help(const char* program_name)
{
std::cout << "\n";
std::cout << "Batched Tensor Contraction with element-wise fusion\n";
std::cout << "E[G,M,N] = element_wise_op(contraction(A[G,M,K], B[G,N,K]), D0, D1, ...)\n";
std::cout << "(Supports multiple D tensors with configurable element-wise operations)\n\n";
std::cout << "Usage: " << program_name << " [OPTIONS]\n\n";
std::cout << "Dimension Arguments (comma-separated, no spaces):\n";
std::cout << " -g_dims=<dims> Batch dimensions (default: \"1,2\")\n";
std::cout << " -m_dims=<dims> M (row) dimensions (default: \"4,256\")\n";
std::cout << " -n_dims=<dims> N (column) dimensions (default: \"16,128\")\n";
std::cout << " -k_dims=<dims> K (contract) dims (default: \"64\")\n";
std::cout << " -num_d=<int> Number of D tensors (default: 2, range: 0-4)\n\n";
std::cout << "Custom Stride Arguments (for testing non-contiguous tensors):\n";
std::cout << " -strides_a=<s> A tensor strides (comma-separated, empty = auto)\n";
std::cout << " -strides_b=<s> B tensor strides (comma-separated, empty = auto)\n";
std::cout << " -strides_e=<s> E tensor strides (comma-separated, empty = auto)\n";
std::cout << " -strides_ds=<s> D tensors strides (semicolon-separated, empty = same as E)\n";
std::cout << " Example: -strides_a=\"32768,128,1\" -strides_ds=\"512,2,1;1024,4,1\"\n\n";
std::cout << "Layout Arguments:\n";
std::cout
<< " -a_layout=<R|C> A tensor layout (R=Row-major, C=Column-major, default: \"R\")\n";
std::cout << " -b_layout=<R|C> B tensor layout (default: \"C\")\n";
std::cout << " -e_layout=<R|C> E tensor layout (default: \"R\")\n\n";
std::cout << "Examples:\n";
std::cout << " Single batch (12 batches of 256×128):\n";
std::cout << " " << program_name
<< " -g_dims=\"12\" -m_dims=\"256\" -n_dims=\"128\" -k_dims=\"64\"\n\n";
std::cout << " 2D batch grid (2×3=6 batches):\n";
std::cout << " " << program_name
<< " -g_dims=\"2,3\" -m_dims=\"128\" -n_dims=\"128\" -k_dims=\"64\"\n\n";
std::cout << " Multi-dimensional (flattened to M=128, N=128, K=128):\n";
std::cout << " " << program_name
<< " -g_dims=\"4\" -m_dims=\"8,16\" -n_dims=\"32,4\" -k_dims=\"16,8\"\n\n";
std::cout << "Other Options:\n";
std::cout << " -v=<0|1> Validation (0=off, 1=on, default: 1)\n";
std::cout << " -split_k=<int> Split-K value (default: 1)\n";
std::cout << " -warmup=<int> Warmup iterations (default: 5)\n";
std::cout << " -repeat=<int> Benchmark iterations (default: 10)\n";
std::cout << " -log=<0|1> Logging level (default: 1)\n";
std::cout << " -help Show this help\n\n";
}
auto create_args(int argc, char* argv[])
{
// Check for --help flag
for(int i = 1; i < argc; ++i)
{
std::string arg = argv[i];
if(arg == "--help" || arg == "-h" || arg == "-help")
{
print_help(argv[0]);
std::exit(0);
}
}
ck_tile::ArgParser arg_parser;
arg_parser.insert("m_dims", "4,256", "M dimensions separated by comma (e.g., '16,32' for 2D M)")
.insert("n_dims", "16,128", "N dimensions separated by comma (e.g., '32,32' for 2D N)")
.insert("k_dims", "64", "K dimensions separated by comma (e.g., '64,32' for 2D K)")
.insert(
"g_dims", "1,2", "G dimensions separated by comma (e.g., '4,2' for 2D, '2,3,4' for 3D)")
.insert("stride_a", "0", "Custom A tensor leading dimension stride (0 = auto)")
.insert("stride_b", "0", "Custom B tensor leading dimension stride (0 = auto)")
.insert("stride_e", "0", "Custom E tensor leading dimension stride (0 = auto)")
.insert("num_d", "2", "Number of D (auxiliary input) tensors")
.insert("strides_a", "", "A tensor strides (comma-separated, empty = auto/contiguous)")
.insert("strides_b", "", "B tensor strides (comma-separated, empty = auto/contiguous)")
.insert("strides_e", "", "E tensor strides (comma-separated, empty = auto/contiguous)")
.insert("strides_ds",
"",
"D tensors strides (semicolon-separated for multiple, empty = same as E)")
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "C", "B tensor data layout - Col by default")
.insert("e_layout", "R", "E tensor data layout - Row by default")

View File

@@ -45,10 +45,10 @@ float invoke_batched_contraction_kernel(
const void* b_full_dims_dev_buf,
const std::array<const void*, DsDataType::size()>& ds_dev_buf,
void* e_full_dims_dev_buf,
const std::vector<ck_tile::index_t>& G_dims,
const std::vector<ck_tile::index_t>& M_dims,
const std::vector<ck_tile::index_t>& N_dims,
const std::vector<ck_tile::index_t>& K_dims,
ck_tile::index_t num_g_dims,
ck_tile::index_t num_m_dims,
ck_tile::index_t num_n_dims,
ck_tile::index_t num_k_dims,
const std::vector<ck_tile::index_t>& A_dims, // [G0,G1,..,M0,M1,..,K0,K1,..]
const std::vector<ck_tile::index_t>& B_dims, // [G0,G1,..,N0,N1,..,K0,K1,..]
const std::array<std::vector<ck_tile::index_t>, DsDataType::size()>&
@@ -79,9 +79,8 @@ float invoke_batched_contraction_kernel(
E_strides // E_strides
);
std::cout << "Calling batched_contraction with dimensions: G=" << G_dims.size()
<< ", M=" << M_dims.size() << ", N=" << N_dims.size() << ", K=" << K_dims.size()
<< std::endl;
std::cout << "Calling batched_contraction with dimensions: G=" << num_g_dims
<< ", M=" << num_m_dims << ", N=" << num_n_dims << ", K=" << num_k_dims << std::endl;
float ave_time = batched_contraction<ADataType,
BDataType,
@@ -95,16 +94,38 @@ float invoke_batched_contraction_kernel(
CDEElementWise>(
args,
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat},
G_dims.size(), // num_g_dims
M_dims.size(), // num_m_dims
N_dims.size(), // num_n_dims
K_dims.size() // num_k_dims
);
num_g_dims,
num_m_dims,
num_n_dims,
num_k_dims);
return ave_time;
}
template <typename ALayout, typename BLayout, typename DLayout, typename ELayout>
// C++17-compatible helper function to create array of HostTensors
namespace {
template <typename DDataType, std::size_t NumDTensor, std::size_t... Is>
std::array<ck_tile::HostTensor<DDataType>, NumDTensor>
make_ds_host_tensors_impl(const std::array<ck_tile::HostTensorDescriptor, NumDTensor>& descs,
std::index_sequence<Is...>)
{
return {ck_tile::HostTensor<DDataType>(descs[Is])...};
}
template <typename DDataType, std::size_t NumDTensor>
std::array<ck_tile::HostTensor<DDataType>, NumDTensor>
make_ds_host_tensors(const std::array<ck_tile::HostTensorDescriptor, NumDTensor>& descs)
{
return make_ds_host_tensors_impl<DDataType, NumDTensor>(descs,
std::make_index_sequence<NumDTensor>{});
}
} // anonymous namespace
template <typename ALayout,
typename BLayout,
typename DLayout,
typename ELayout,
ck_tile::index_t NumDTensor>
int run_batched_contraction_example_with_layouts(
int argc,
char* argv[],
@@ -122,8 +143,6 @@ int run_batched_contraction_example_with_layouts(
std::vector<ck_tile::index_t> N_dims = parse_dimensions(arg_parser.get_str("n_dims"));
std::vector<ck_tile::index_t> K_dims = parse_dimensions(arg_parser.get_str("k_dims"));
constexpr ck_tile::index_t NumDTensor = 2;
ck_tile::index_t G_total = calculate_total_elements(G_dims);
ck_tile::index_t M_total = calculate_total_elements(M_dims);
ck_tile::index_t N_total = calculate_total_elements(N_dims);
@@ -148,13 +167,105 @@ int run_batched_contraction_example_with_layouts(
return converted;
};
ck_tile::HostTensorDescriptor a_desc(A_dims);
ck_tile::HostTensorDescriptor b_desc(B_dims);
ck_tile::HostTensorDescriptor e_desc(E_dims);
std::array<ck_tile::HostTensorDescriptor, NumDTensor> ds_descs;
for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
// Get custom stride arguments
std::string strides_a_str = arg_parser.get_str("strides_a");
std::string strides_b_str = arg_parser.get_str("strides_b");
std::string strides_e_str = arg_parser.get_str("strides_e");
std::string strides_ds_str = arg_parser.get_str("strides_ds");
// Create A descriptor with custom or default strides
ck_tile::HostTensorDescriptor a_desc;
if(!strides_a_str.empty())
{
ds_descs[d] = ck_tile::HostTensorDescriptor(Ds_dims[d], e_desc.get_strides());
std::vector<ck_tile::index_t> custom_a_strides = parse_dimensions(strides_a_str);
if(custom_a_strides.size() != A_dims.size())
{
throw std::runtime_error("strides_a size must match A_dims size");
}
std::vector<std::size_t> a_strides_size_t(custom_a_strides.begin(), custom_a_strides.end());
a_desc = ck_tile::HostTensorDescriptor(A_dims, a_strides_size_t);
std::cout << "Using custom strides for A (non-contiguous)" << std::endl;
}
else
{
a_desc = ck_tile::HostTensorDescriptor(A_dims);
}
// Create B descriptor with custom or default strides
ck_tile::HostTensorDescriptor b_desc;
if(!strides_b_str.empty())
{
std::vector<ck_tile::index_t> custom_b_strides = parse_dimensions(strides_b_str);
if(custom_b_strides.size() != B_dims.size())
{
throw std::runtime_error("strides_b size must match B_dims size");
}
std::vector<std::size_t> b_strides_size_t(custom_b_strides.begin(), custom_b_strides.end());
b_desc = ck_tile::HostTensorDescriptor(B_dims, b_strides_size_t);
std::cout << "Using custom strides for B (non-contiguous)" << std::endl;
}
else
{
b_desc = ck_tile::HostTensorDescriptor(B_dims);
}
// Create E descriptor with custom or default strides
ck_tile::HostTensorDescriptor e_desc;
if(!strides_e_str.empty())
{
std::vector<ck_tile::index_t> custom_e_strides = parse_dimensions(strides_e_str);
if(custom_e_strides.size() != E_dims.size())
{
throw std::runtime_error("strides_e size must match E_dims size");
}
std::vector<std::size_t> e_strides_size_t(custom_e_strides.begin(), custom_e_strides.end());
e_desc = ck_tile::HostTensorDescriptor(E_dims, e_strides_size_t);
std::cout << "Using custom strides for E (non-contiguous)" << std::endl;
}
else
{
e_desc = ck_tile::HostTensorDescriptor(E_dims);
}
// Create D descriptors with custom or default strides (default = same as E)
std::array<ck_tile::HostTensorDescriptor, NumDTensor> ds_descs;
if(!strides_ds_str.empty())
{
// Parse semicolon-separated stride vectors for multiple D tensors
std::vector<std::vector<ck_tile::index_t>> all_ds_strides;
std::stringstream ss(strides_ds_str);
std::string d_stride_str;
while(std::getline(ss, d_stride_str, ';'))
{
all_ds_strides.push_back(parse_dimensions(d_stride_str));
}
if(all_ds_strides.size() != NumDTensor)
{
throw std::runtime_error("Number of D stride vectors must match num_d=" +
std::to_string(NumDTensor));
}
std::cout << "Using custom strides for D tensors (non-contiguous)" << std::endl;
for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
{
if(all_ds_strides[d].size() != E_dims.size())
{
throw std::runtime_error("D tensor " + std::to_string(d) +
" stride size must match E_dims size");
}
std::vector<std::size_t> d_strides_size_t(all_ds_strides[d].begin(),
all_ds_strides[d].end());
ds_descs[d] = ck_tile::HostTensorDescriptor(Ds_dims[d], d_strides_size_t);
}
}
else
{
// Default: use same strides as E
for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
{
ds_descs[d] = ck_tile::HostTensorDescriptor(Ds_dims[d], e_desc.get_strides());
}
}
std::vector<ck_tile::index_t> A_strides = convert_strides(a_desc.get_strides());
@@ -201,11 +312,8 @@ int run_batched_contraction_example_with_layouts(
ck_tile::HostTensor<::BDataType> b_full_dims_host(b_desc);
ck_tile::HostTensor<::EDataType> e_full_dims_host(e_desc);
std::vector<ck_tile::HostTensor<::DDataType>> ds_full_dims_host;
for(int d = 0; d < NumDTensor; ++d)
{
ds_full_dims_host.emplace_back(ck_tile::HostTensor<::DDataType>(ds_descs[d]));
}
// Construct array of HostTensors - C++17 compatible
auto ds_full_dims_host = make_ds_host_tensors<::DDataType, NumDTensor>(ds_descs);
ck_tile::FillUniformDistribution<::ADataType>{-5.f, 5.f, std::nullopt}(a_full_dims_host);
ck_tile::FillUniformDistribution<::BDataType>{-5.f, 5.f, std::nullopt}(b_full_dims_host);
@@ -260,10 +368,10 @@ int run_batched_contraction_example_with_layouts(
b_full_dims_dev_buf.GetDeviceBuffer(),
ds_ptr_buf,
e_full_dims_dev_buf.GetDeviceBuffer(),
G_dims,
M_dims,
N_dims,
K_dims,
G_dims.size(),
M_dims.size(),
N_dims.size(),
K_dims.size(),
A_dims,
B_dims,
Ds_dims,
@@ -316,12 +424,13 @@ int run_batched_contraction_example_with_layouts(
auto start_time = std::chrono::high_resolution_clock::now();
ck_tile::calculate_reference_flat_indexing<ADataType,
BDataType,
DDataType,
EDataType,
AccDataType,
CDEElementWise>(a_full_dims_host,
ck_tile::compute_reference_batched_contraction<ADataType,
BDataType,
DDataType,
EDataType,
AccDataType,
CDEElementWise,
NumDTensor>(a_full_dims_host,
b_full_dims_host,
ds_full_dims_host,
e_full_dims_host_ref,
@@ -329,7 +438,11 @@ int run_batched_contraction_example_with_layouts(
M_total,
N_total,
K_total,
CDEElementWise{});
CDEElementWise{},
G_dims,
M_dims,
N_dims,
K_dims);
auto end_time = std::chrono::high_resolution_clock::now();
auto duration =
@@ -387,15 +500,45 @@ int run_batched_contraction_example(int argc, char* argv[])
if(!result)
return -1;
// Get NumDTensor to dispatch at runtime
const int num_d = arg_parser.get_int("num_d");
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
// Runtime dispatch based on num_d value
if(a_layout == "R" && b_layout == "C")
{
return run_batched_contraction_example_with_layouts(argc, argv, Row{}, Col{}, Row{}, Row{});
// Dispatch to appropriate template instantiation based on runtime num_d
switch(num_d)
{
case 0:
std::cout << "Running with 0 D tensors" << std::endl;
return run_batched_contraction_example_with_layouts<Row, Col, Row, Row, 0>(
argc, argv, Row{}, Col{}, Row{}, Row{});
case 1:
std::cout << "Running with 1 D tensor" << std::endl;
return run_batched_contraction_example_with_layouts<Row, Col, Row, Row, 1>(
argc, argv, Row{}, Col{}, Row{}, Row{});
case 2:
std::cout << "Running with 2 D tensors" << std::endl;
return run_batched_contraction_example_with_layouts<Row, Col, Row, Row, 2>(
argc, argv, Row{}, Col{}, Row{}, Row{});
case 3:
std::cout << "Running with 3 D tensors" << std::endl;
return run_batched_contraction_example_with_layouts<Row, Col, Row, Row, 3>(
argc, argv, Row{}, Col{}, Row{}, Row{});
case 4:
std::cout << "Running with 4 D tensors" << std::endl;
return run_batched_contraction_example_with_layouts<Row, Col, Row, Row, 4>(
argc, argv, Row{}, Col{}, Row{}, Row{});
default:
throw std::runtime_error("num_d must be between 0 and 4, got: " +
std::to_string(num_d));
}
}
else
{

View File

@@ -4,8 +4,6 @@
#pragma once
#include <cstdlib>
#include <functional>
#include <numeric>
#include <thread>
#include "ck_tile/core.hpp"
@@ -13,110 +11,259 @@
namespace ck_tile {
// Helper to apply elementwise operation with variable number of D tensors
template <typename EDataType, typename AccDataType, typename CDEElementWise>
struct ApplyCDEElementWise
{
template <typename... DValues>
CK_TILE_HOST_DEVICE static void apply(EDataType& result,
AccDataType sum,
const CDEElementWise& cde_elementwise,
DValues... d_vals)
{
if constexpr(sizeof...(DValues) == 0)
{
result = static_cast<EDataType>(sum);
}
else
{
cde_elementwise(
result, ck_tile::type_convert<float>(sum), ck_tile::type_convert<float>(d_vals)...);
}
}
};
// Helper to extract D values at a given offset using index sequence
template <typename DDataType,
ck_tile::index_t NumDTensor,
typename Indices = std::make_index_sequence<NumDTensor>>
struct ExtractDValues;
template <typename DDataType, ck_tile::index_t NumDTensor, std::size_t... Is>
struct ExtractDValues<DDataType, NumDTensor, std::index_sequence<Is...>>
{
template <typename EDataType, typename AccDataType, typename CDEElementWise>
CK_TILE_HOST static void
apply_at_offsets(EDataType& result,
AccDataType sum,
const CDEElementWise& cde_elementwise,
const std::array<ck_tile::HostTensor<DDataType>, NumDTensor>& ds_tensors,
const std::array<std::size_t, NumDTensor>& d_offsets)
{
ApplyCDEElementWise<EDataType, AccDataType, CDEElementWise>::apply(
result, sum, cde_elementwise, ds_tensors[Is].mData[d_offsets[Is]]...);
}
};
template <typename ADataType,
typename BDataType,
typename DDataType,
typename EDataType,
typename AccDataType,
typename CDEElementWise>
typename CDEElementWise,
ck_tile::index_t NumDTensor>
void calculate_reference_flat_indexing(
void compute_reference_batched_contraction(
const ck_tile::HostTensor<ADataType>& a_full_dims,
const ck_tile::HostTensor<BDataType>& b_full_dims,
const std::vector<ck_tile::HostTensor<DDataType>>& ds_full_dims_host,
const std::array<ck_tile::HostTensor<DDataType>, NumDTensor>& ds_full_dims_host,
ck_tile::HostTensor<EDataType>& e_full_dims_host_ref,
ck_tile::index_t G_total,
ck_tile::index_t M_total,
ck_tile::index_t N_total,
ck_tile::index_t K_total,
const CDEElementWise& cde_elementwise)
const CDEElementWise& cde_elementwise,
const std::vector<ck_tile::index_t>& G_dims,
const std::vector<ck_tile::index_t>& M_dims,
const std::vector<ck_tile::index_t>& N_dims,
const std::vector<ck_tile::index_t>& K_dims)
{
std::cout << "Calculating reference using optimized flat indexing with parallel processing..."
std::cout << "Calculating reference using stride-aware indexing with parallel processing..."
<< std::endl;
// Parallel computation over G and M dimensions using pattern from reference_batched_gemm.hpp
// Extract stride information from tensor descriptors
const auto a_strides = a_full_dims.get_strides();
const auto b_strides = b_full_dims.get_strides();
const auto e_strides = e_full_dims_host_ref.get_strides();
// Extract D tensor strides
std::array<std::vector<std::size_t>, NumDTensor> ds_strides;
for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
{
ds_strides[d] = ds_full_dims_host[d].get_strides();
}
const ck_tile::index_t num_g_dims = G_dims.size();
const ck_tile::index_t num_m_dims = M_dims.size();
const ck_tile::index_t num_n_dims = N_dims.size();
const ck_tile::index_t num_k_dims = K_dims.size();
// Helper lambda to compute linear index from flat indices using strides
auto compute_a_offset = [&](ck_tile::index_t g_flat,
ck_tile::index_t m_flat,
ck_tile::index_t k_flat) -> std::size_t {
std::size_t offset = 0;
// Decode G dimensions
ck_tile::index_t temp = g_flat;
for(int i = num_g_dims - 1; i >= 0; --i)
{
offset += (temp % G_dims[i]) * a_strides[i];
temp /= G_dims[i];
}
// Decode M dimensions
temp = m_flat;
for(int i = num_m_dims - 1; i >= 0; --i)
{
offset += (temp % M_dims[i]) * a_strides[num_g_dims + i];
temp /= M_dims[i];
}
// Decode K dimensions
temp = k_flat;
for(int i = num_k_dims - 1; i >= 0; --i)
{
offset += (temp % K_dims[i]) * a_strides[num_g_dims + num_m_dims + i];
temp /= K_dims[i];
}
return offset;
};
auto compute_b_offset = [&](ck_tile::index_t g_flat,
ck_tile::index_t n_flat,
ck_tile::index_t k_flat) -> std::size_t {
std::size_t offset = 0;
// Decode G dimensions
ck_tile::index_t temp = g_flat;
for(int i = num_g_dims - 1; i >= 0; --i)
{
offset += (temp % G_dims[i]) * b_strides[i];
temp /= G_dims[i];
}
// Decode N dimensions
temp = n_flat;
for(int i = num_n_dims - 1; i >= 0; --i)
{
offset += (temp % N_dims[i]) * b_strides[num_g_dims + i];
temp /= N_dims[i];
}
// Decode K dimensions
temp = k_flat;
for(int i = num_k_dims - 1; i >= 0; --i)
{
offset += (temp % K_dims[i]) * b_strides[num_g_dims + num_n_dims + i];
temp /= K_dims[i];
}
return offset;
};
auto compute_e_offset = [&](ck_tile::index_t g_flat,
ck_tile::index_t m_flat,
ck_tile::index_t n_flat) -> std::size_t {
std::size_t offset = 0;
// Decode G dimensions
ck_tile::index_t temp = g_flat;
for(int i = num_g_dims - 1; i >= 0; --i)
{
offset += (temp % G_dims[i]) * e_strides[i];
temp /= G_dims[i];
}
// Decode M dimensions
temp = m_flat;
for(int i = num_m_dims - 1; i >= 0; --i)
{
offset += (temp % M_dims[i]) * e_strides[num_g_dims + i];
temp /= M_dims[i];
}
// Decode N dimensions
temp = n_flat;
for(int i = num_n_dims - 1; i >= 0; --i)
{
offset += (temp % N_dims[i]) * e_strides[num_g_dims + num_m_dims + i];
temp /= N_dims[i];
}
return offset;
};
// Helper to compute D tensor offset (D tensors have same shape as E: [G, M, N])
auto compute_d_offset = [&](ck_tile::index_t g_flat,
ck_tile::index_t m_flat,
ck_tile::index_t n_flat,
ck_tile::index_t d_idx) -> std::size_t {
std::size_t offset = 0;
const auto& d_strides = ds_strides[d_idx];
// Decode G dimensions
ck_tile::index_t temp = g_flat;
for(int i = num_g_dims - 1; i >= 0; --i)
{
offset += (temp % G_dims[i]) * d_strides[i];
temp /= G_dims[i];
}
// Decode M dimensions
temp = m_flat;
for(int i = num_m_dims - 1; i >= 0; --i)
{
offset += (temp % M_dims[i]) * d_strides[num_g_dims + i];
temp /= M_dims[i];
}
// Decode N dimensions
temp = n_flat;
for(int i = num_n_dims - 1; i >= 0; --i)
{
offset += (temp % N_dims[i]) * d_strides[num_g_dims + num_m_dims + i];
temp /= N_dims[i];
}
return offset;
};
// Parallel computation over G and M dimensions
auto f_gm = [&](auto g_flat, auto m_flat) {
for(ck_tile::index_t n_flat = 0; n_flat < N_total; ++n_flat)
{
AccDataType sum = 0;
// Compute dot product over K dimension
// Compute dot product over K dimension using stride-aware indexing
for(ck_tile::index_t k_flat = 0; k_flat < K_total; ++k_flat)
{
auto a_val =
a_full_dims.mData[g_flat * M_total * K_total + m_flat * K_total + k_flat];
auto b_val =
b_full_dims.mData[g_flat * N_total * K_total + n_flat * K_total + k_flat];
const std::size_t a_offset = compute_a_offset(g_flat, m_flat, k_flat);
const std::size_t b_offset = compute_b_offset(g_flat, n_flat, k_flat);
auto a_val = a_full_dims.mData[a_offset];
auto b_val = b_full_dims.mData[b_offset];
sum += static_cast<AccDataType>(a_val) * static_cast<AccDataType>(b_val);
}
// Apply elementwise operation with D tensors
EDataType result = static_cast<EDataType>(sum);
if(ds_full_dims_host.size() == 0)
// Compute output offset using strides
const std::size_t e_offset = compute_e_offset(g_flat, m_flat, n_flat);
// Compute individual D tensor offsets using their respective strides
std::array<std::size_t, NumDTensor> d_offsets;
for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
{
;
}
else if(ds_full_dims_host.size() == 1)
{
cde_elementwise(result,
ck_tile::type_convert<float>(sum),
ck_tile::type_convert<float>(
ds_full_dims_host[0].mData[g_flat * M_total * N_total +
m_flat * N_total + n_flat]));
}
else if(ds_full_dims_host.size() == 2)
{
cde_elementwise(
result,
ck_tile::type_convert<float>(sum),
ck_tile::type_convert<float>(
ds_full_dims_host[0]
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
ck_tile::type_convert<float>(
ds_full_dims_host[1]
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
}
else if(ds_full_dims_host.size() == 3)
{
cde_elementwise(
result,
ck_tile::type_convert<float>(sum),
ck_tile::type_convert<float>(
ds_full_dims_host[0]
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
ck_tile::type_convert<float>(
ds_full_dims_host[1]
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
ck_tile::type_convert<float>(
ds_full_dims_host[2]
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
}
else if(ds_full_dims_host.size() == 4)
{
cde_elementwise(
result,
ck_tile::type_convert<float>(sum),
ck_tile::type_convert<float>(
ds_full_dims_host[0]
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
ck_tile::type_convert<float>(
ds_full_dims_host[1]
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
ck_tile::type_convert<float>(
ds_full_dims_host[2]
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
ck_tile::type_convert<float>(
ds_full_dims_host[3]
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
}
else
{
throw std::runtime_error("Unsupported NumDTensor for reference calculation");
d_offsets[d] = compute_d_offset(g_flat, m_flat, n_flat, d);
}
// Store result
e_full_dims_host_ref.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat] =
static_cast<EDataType>(result);
// Apply elementwise operation with D tensors using compile-time dispatch
EDataType result = static_cast<EDataType>(sum);
ExtractDValues<DDataType, NumDTensor>::apply_at_offsets(
result, sum, cde_elementwise, ds_full_dims_host, d_offsets);
// Store result using stride-aware indexing
e_full_dims_host_ref.mData[e_offset] = static_cast<EDataType>(result);
}
};
@@ -125,147 +272,4 @@ void calculate_reference_flat_indexing(
make_ParallelTensorFunctor(f_gm, G_total, M_total)(std::thread::hardware_concurrency());
}
template <typename ADataType,
typename BDataType,
typename DDataType,
typename EDataType,
typename AccDataType,
typename CDEElementWise>
void calculate_reference_multi_dimensional(
const HostTensor<ADataType>& a_full_dims,
const HostTensor<BDataType>& b_full_dims,
const std::vector<HostTensor<DDataType>>& ds_full_dims_host,
HostTensor<EDataType>& e_full_dims_host_ref,
const std::vector<index_t>& G_dims,
const std::vector<index_t>& M_dims,
const std::vector<index_t>& N_dims,
const std::vector<index_t>& K_dims,
const std::vector<index_t>& A_dims,
const std::vector<index_t>& B_dims,
const std::vector<index_t>& E_dims,
const CDEElementWise& cde_elementwise)
{
std::cout << "Calculating reference using multi-dimensional indexing..." << std::endl;
std::vector<std::size_t> g_idx(G_dims.size());
std::vector<std::size_t> m_idx(M_dims.size());
std::vector<std::size_t> n_idx(N_dims.size());
std::vector<std::size_t> k_idx(K_dims.size());
std::vector<std::size_t> a_idx, b_idx, e_idx;
a_idx.reserve(A_dims.size());
b_idx.reserve(B_dims.size());
e_idx.reserve(E_dims.size());
auto calculate_total_elements = [](const std::vector<ck_tile::index_t>& dims) {
return std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<ck_tile::index_t>());
};
for(ck_tile::index_t g_flat = 0; g_flat < calculate_total_elements(G_dims); ++g_flat)
{
ck_tile::index_t temp = g_flat;
for(int i = G_dims.size() - 1; i >= 0; --i)
{
g_idx[i] = temp % G_dims[i];
temp /= G_dims[i];
}
for(ck_tile::index_t m_flat = 0; m_flat < calculate_total_elements(M_dims); ++m_flat)
{
temp = m_flat;
for(int i = M_dims.size() - 1; i >= 0; --i)
{
m_idx[i] = temp % M_dims[i];
temp /= M_dims[i];
}
for(ck_tile::index_t n_flat = 0; n_flat < calculate_total_elements(N_dims); ++n_flat)
{
temp = n_flat;
for(int i = N_dims.size() - 1; i >= 0; --i)
{
n_idx[i] = temp % N_dims[i];
temp /= N_dims[i];
}
AccDataType sum = 0;
for(ck_tile::index_t k_flat = 0; k_flat < calculate_total_elements(K_dims);
++k_flat)
{
temp = k_flat;
for(int i = K_dims.size() - 1; i >= 0; --i)
{
k_idx[i] = temp % K_dims[i];
temp /= K_dims[i];
}
a_idx.clear();
b_idx.clear();
a_idx.insert(a_idx.end(), g_idx.begin(), g_idx.end());
a_idx.insert(a_idx.end(), m_idx.begin(), m_idx.end());
a_idx.insert(a_idx.end(), k_idx.begin(), k_idx.end());
b_idx.insert(b_idx.end(), g_idx.begin(), g_idx.end());
b_idx.insert(b_idx.end(), n_idx.begin(), n_idx.end());
b_idx.insert(b_idx.end(), k_idx.begin(), k_idx.end());
auto a_val = a_full_dims(a_idx);
auto b_val = b_full_dims(b_idx);
sum += static_cast<AccDataType>(a_val) * static_cast<AccDataType>(b_val);
}
e_idx.clear();
e_idx.insert(e_idx.end(), g_idx.begin(), g_idx.end());
e_idx.insert(e_idx.end(), m_idx.begin(), m_idx.end());
e_idx.insert(e_idx.end(), n_idx.begin(), n_idx.end());
EDataType result = static_cast<EDataType>(sum);
if(ds_full_dims_host.size() == 0)
{
;
}
else if(ds_full_dims_host.size() == 1)
{
cde_elementwise(result,
ck_tile::type_convert<float>(sum),
ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)));
}
else if(ds_full_dims_host.size() == 2)
{
cde_elementwise(result,
ck_tile::type_convert<float>(sum),
ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)),
ck_tile::type_convert<float>(ds_full_dims_host[1](e_idx)));
}
else if(ds_full_dims_host.size() == 3)
{
cde_elementwise(result,
ck_tile::type_convert<float>(sum),
ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)),
ck_tile::type_convert<float>(ds_full_dims_host[1](e_idx)),
ck_tile::type_convert<float>(ds_full_dims_host[2](e_idx)));
}
else if(ds_full_dims_host.size() == 4)
{
cde_elementwise(result,
ck_tile::type_convert<float>(sum),
ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)),
ck_tile::type_convert<float>(ds_full_dims_host[1](e_idx)),
ck_tile::type_convert<float>(ds_full_dims_host[2](e_idx)),
ck_tile::type_convert<float>(ds_full_dims_host[3](e_idx)));
}
else
{
throw std::runtime_error("Unsupported NumDTensor for reference calculation");
}
e_full_dims_host_ref(e_idx) = static_cast<EDataType>(result);
}
}
}
}
} // namespace ck_tile

View File

@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp"
#include "ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp"
#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp"
/**
@@ -60,23 +61,19 @@
* Rather than implementing tensor contraction from scratch, this kernel leverages the highly
* optimized `UniversalGemmKernel` as its computational backend.
*
* @subsection current_limitations Current Kernel Limitations
* @subsection implementation_features Implementation Features
*
* **Layout Restrictions:**
* - **Row-Major Only**: All tensors must use row-major memory layout
* - **Packed Tensors**: Only contiguous/packed tensor layouts supported
* - **Hardcoded Strides**: stride_A = K_total, stride_B = K_total, stride_E = N_total
* - **D Tensor Layout**: All D tensors must match E tensor layout (stride_Ds = N_total)
* **Stride Support:**
* - Supports arbitrary multi-dimensional stride patterns
* - Handles non-contiguous and padded tensor layouts
* - Independent strides for each auxiliary D tensor
* - Descriptor-based architecture with vectorization
*
* **Implementation Constraints:**
* - **Fixed Stride Calculation**: Strides are automatically calculated and cannot be customized
* - **No Column-Major**: Column-major or custom stride patterns not supported
* - **No Strided Access**: Non-contiguous tensor slicing not supported
*
* **Future Enhancements:**
* - Support for arbitrary stride patterns
* - Column-major and mixed layout support
* - Non-contiguous tensor operation support
* **Architecture:**
* - Uses TensorDescriptorUtils for stride-aware descriptor creation
* - Custom RunGemm implementation with descriptor-based tensor views
* - Reuses GemmPipeline and EpiloguePipeline for computation
* - Split-K support via UniversalGemmKernel utilities
*/
namespace ck_tile {
@@ -184,7 +181,10 @@ template <ck_tile::index_t NumDimG,
ck_tile::index_t NumDimM,
ck_tile::index_t NumDimN,
ck_tile::index_t NumDimK,
ck_tile::index_t NumDTensor = 0>
ck_tile::index_t NumDTensor = 0,
ck_tile::index_t VectorSizeA = 1,
ck_tile::index_t VectorSizeB = 1,
ck_tile::index_t VectorSizeE = 1>
struct BatchedContractionKernelArgs
{
const void* a_ptr; ///< Pointer to input tensor A
@@ -210,11 +210,46 @@ struct BatchedContractionKernelArgs
ck_tile::index_t N_total; ///< Total N dimension: N0 * N1 * ... * N_{NumDimN-1}
ck_tile::index_t K_total; ///< Total K dimension: K0 * K1 * ... * K_{NumDimK-1}
ck_tile::index_t stride_A; ///< Leading dimension stride for tensor A (row-major: K_total)
ck_tile::index_t stride_B; ///< Leading dimension stride for tensor B (row-major: K_total)
ck_tile::index_t
stride_A; ///< Leading dimension stride for tensor A (for backward compatibility)
ck_tile::index_t
stride_B; ///< Leading dimension stride for tensor B (for backward compatibility)
std::array<ck_tile::index_t, NumDTensor>
stride_Ds; ///< Leading dimension strides for D tensors (row-major: N_total)
ck_tile::index_t stride_E; ///< Leading dimension stride for tensor E (row-major: N_total)
stride_Ds; ///< Leading dimension strides for D tensors (for backward compatibility)
ck_tile::index_t
stride_E; ///< Leading dimension stride for tensor E (for backward compatibility)
// Tensor descriptors (encode full multi-dimensional stride information with vectorization)
using AGridDesc_M_K_ =
decltype(TensorDescriptorUtils<NumDimG,
NumDimM,
NumDimN,
NumDimK,
VectorSizeA,
VectorSizeB,
VectorSizeE>::Make_A_GridDescriptor_M_K({}, {}));
using BGridDesc_N_K_ =
decltype(TensorDescriptorUtils<NumDimG,
NumDimM,
NumDimN,
NumDimK,
VectorSizeA,
VectorSizeB,
VectorSizeE>::Make_B_GridDescriptor_N_K({}, {}));
using EGridDesc_M_N_ =
decltype(TensorDescriptorUtils<NumDimG,
NumDimM,
NumDimN,
NumDimK,
VectorSizeA,
VectorSizeB,
VectorSizeE>::Make_E_GridDescriptor_M_N({}, {}));
AGridDesc_M_K_ a_grid_desc_m_k; ///< Tensor descriptor for A[M, K] with actual strides
BGridDesc_N_K_ b_grid_desc_n_k; ///< Tensor descriptor for B[N, K] with actual strides
EGridDesc_M_N_ e_grid_desc_m_n; ///< Tensor descriptor for E[M, N] with actual strides
std::array<EGridDesc_M_N_, NumDTensor>
ds_grid_desc_m_n; ///< Descriptors for D tensors (same shape as E, independent strides)
};
/// @brief GPU kernel for batched tensor contraction operations.
@@ -274,10 +309,24 @@ struct BatchedContractionKernel
static constexpr ck_tile::index_t kBlockSize =
UniversalGemmKernel::kBlockSize; ///< GPU block size inherited from GEMM kernel
using KernelArgs =
BatchedContractionKernelArgs<NumDimG, NumDimM, NumDimN, NumDimK, NumDTensor>; ///< Kernel
///< argument
///< structure
// Tensor descriptor utilities with vectorization support
using DescriptorUtils = TensorDescriptorUtils<NumDimG,
NumDimM,
NumDimN,
NumDimK,
GemmPipeline::GetVectorSizeA(),
GemmPipeline::GetVectorSizeB(),
EpiloguePipeline::GetVectorSizeC()>;
// Kernel arguments with vectorization support
using KernelArgs = BatchedContractionKernelArgs<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDTensor,
GemmPipeline::GetVectorSizeA(),
GemmPipeline::GetVectorSizeB(),
EpiloguePipeline::GetVectorSizeC()>;
/// @brief Returns the kernel name for debugging and profiling purposes.
/// @return Constant string identifier for this kernel
@@ -326,6 +375,104 @@ struct BatchedContractionKernel
TilePartitioner::GridSize(kargs.M_total, kargs.N_total), kargs.G_total, kargs.k_batch);
}
/// @brief Executes GEMM computation with descriptor-based tensor views for arbitrary stride
/// support
///
/// @details This function performs the core GEMM computation using tensor descriptors to handle
/// arbitrary multi-dimensional stride patterns. It creates tensor views from
/// pre-computed descriptors (stored in kargs), applies padding, creates tile windows,
/// and executes the GemmPipeline and EpiloguePipeline.
///
/// @param a_ptr Pointer to input tensor A data (after batch and split-K offsets applied)
/// @param b_ptr Pointer to input tensor B data (after batch and split-K offsets applied)
/// @param ds_ptr Array of pointers to auxiliary D tensor data
/// @param e_ptr Pointer to output tensor E data (after batch offset applied)
/// @param smem_ptr Pointer to shared memory for tile operations
/// @param kargs Kernel arguments containing tensor descriptors and dimension information
/// @param k_size Size of K dimension for this split (for split-K support)
/// @param i_m Starting M index for this block's tile
/// @param i_n Starting N index for this block's tile
CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
const BDataType* b_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
EDataType* e_ptr,
void* smem_ptr,
const KernelArgs& kargs,
const index_t k_size,
const index_t i_m,
const index_t i_n)
{
// Create tensor views from descriptors (supports arbitrary stride patterns)
auto a_tensor_view =
make_tensor_view<address_space_enum::global>(a_ptr, kargs.a_grid_desc_m_k);
auto b_tensor_view =
make_tensor_view<address_space_enum::global>(b_ptr, kargs.b_grid_desc_n_k);
auto e_tensor_view =
make_tensor_view<address_space_enum::global>(e_ptr, kargs.e_grid_desc_m_n);
// Pad views for boundary handling and optimization (like UniversalGemmKernel)
auto a_pad_view = pad_tensor_view(
a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{});
auto b_pad_view = pad_tensor_view(
b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{});
auto e_pad_view = pad_tensor_view(
e_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
sequence<false, GemmPipeline::kPadN>{});
// Create tile windows from PADDED views
auto a_block_window = make_tile_window(
a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
{i_m, 0});
auto b_block_window = make_tile_window(
b_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
{i_n, 0});
auto e_block_window = make_tile_window(
e_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
// Calculate number of K loops
const index_t num_loop =
__builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(k_size));
// Run GEMM Pipeline (same as UniversalGemmKernel, but with descriptor-based windows)
using AElementWise = remove_cvref_t<typename GemmPipeline::AElementWise>;
using BElementWise = remove_cvref_t<typename GemmPipeline::BElementWise>;
const auto& c_block_tile = GemmPipeline{}(
a_block_window, AElementWise{}, b_block_window, BElementWise{}, num_loop, smem_ptr);
// Create D windows from descriptors (for each D tensor)
auto ds_block_windows = generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
const DDataType* d_ptr = static_cast<const DDataType*>(ds_ptr[i]);
auto d_tensor_view =
make_tensor_view<address_space_enum::global>(d_ptr, kargs.ds_grid_desc_m_n[i]);
return make_tile_window(d_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
},
number<NumDTensor>{});
// Run Epilogue Pipeline with descriptor-based D windows
EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_windows, smem_ptr);
}
CK_TILE_HOST static constexpr KernelArgs
MakeKernelArgs(const BatchedContractionHostArgs<NumDTensor>& host_args)
{
@@ -435,6 +582,22 @@ struct BatchedContractionKernel
kargs.K_total *= kargs.K_dims[i];
}
// Create tensor descriptors on host using actual dims and strides
kargs.a_grid_desc_m_k =
DescriptorUtils::Make_A_GridDescriptor_M_K(host_args.A_dims, host_args.A_strides);
kargs.b_grid_desc_n_k =
DescriptorUtils::Make_B_GridDescriptor_N_K(host_args.B_dims, host_args.B_strides);
kargs.e_grid_desc_m_n =
DescriptorUtils::Make_E_GridDescriptor_M_N(host_args.E_dims, host_args.E_strides);
// Create D descriptors with their own strides (same shape as E, independent strides)
for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
{
kargs.ds_grid_desc_m_n[d] = DescriptorUtils::Make_E_GridDescriptor_M_N(
host_args.Ds_dims[d], host_args.Ds_strides[d]);
}
// Keep simple strides for backward compatibility
kargs.stride_A = kargs.K_total;
kargs.stride_B = kargs.K_total;
kargs.stride_E = kargs.N_total;
@@ -468,8 +631,8 @@ struct BatchedContractionKernel
const ck_tile::index_t i_n =
__builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
const auto i_batch_flat = __builtin_amdgcn_readfirstlane(blockIdx.y);
const auto i_splitk = __builtin_amdgcn_readfirstlane(blockIdx.z);
const auto i_batch_flat = __builtin_amdgcn_readfirstlane(blockIdx.y);
[[maybe_unused]] const auto i_splitk = __builtin_amdgcn_readfirstlane(blockIdx.z);
// Calculate batch offsets for each tensor
const auto batch_offset_A = i_batch_flat * kargs.batch_stride_A;
@@ -487,6 +650,10 @@ struct BatchedContractionKernel
ds_batch_ptr[i] = static_cast<const DDataType*>(kargs.ds_ptr[i]) + batch_offset_D;
});
// Allocate shared memory
__shared__ char smem_ptr[GetSmemSize()];
// Use UniversalGemmKernel's SplitKBatchOffset for split-K calculation
typename UniversalGemmKernel::KernelArgs gemm_kargs{{a_ptr},
{b_ptr},
ds_batch_ptr,
@@ -503,19 +670,19 @@ struct BatchedContractionKernel
const typename UniversalGemmKernel::SplitKBatchOffset splitk_batch_offset(gemm_kargs,
i_splitk);
const ADataType* a_ptr_final = a_ptr + splitk_batch_offset.as_k_split_offset[0];
const BDataType* b_ptr_final = b_ptr + splitk_batch_offset.bs_k_split_offset[0];
__shared__ char smem_ptr[GetSmemSize()];
// Apply K-split offsets and run descriptor-based RunGemm
const ADataType* a_ptr_split = a_ptr + splitk_batch_offset.as_k_split_offset[0];
const BDataType* b_ptr_split = b_ptr + splitk_batch_offset.bs_k_split_offset[0];
UniversalGemmKernel::RunGemm({a_ptr_final},
{b_ptr_final},
ds_batch_ptr,
e_ptr,
smem_ptr,
gemm_kargs,
splitk_batch_offset,
i_m,
i_n);
RunGemm(a_ptr_split,
b_ptr_split,
ds_batch_ptr,
e_ptr,
smem_ptr,
kargs,
splitk_batch_offset.splitted_k,
i_m,
i_n);
}
};

View File

@@ -13,8 +13,8 @@
* dimensions for GEMM operations. These functions transform multi-dimensional tensors into
* 2D matrix descriptors by removing batch dimensions and flattening the remaining dimensions.
*
* These utilities are currently not used in the main batched contraction kernel but are preserved
* for future implementations that may require explicit tensor descriptor creation.
* These utilities are used by BatchedContractionKernel to create stride-aware descriptors
* that support arbitrary multi-dimensional non-contiguous tensor layouts.
*/
namespace ck_tile {
@@ -30,7 +30,10 @@ namespace ck_tile {
template <ck_tile::index_t NumDimG,
ck_tile::index_t NumDimM,
ck_tile::index_t NumDimN,
ck_tile::index_t NumDimK>
ck_tile::index_t NumDimK,
ck_tile::index_t VectorSizeA,
ck_tile::index_t VectorSizeB,
ck_tile::index_t VectorSizeE>
struct TensorDescriptorUtils
{
/// @brief Creates a tensor descriptor for input tensor A with batch dimensions removed.
@@ -62,9 +65,9 @@ struct TensorDescriptorUtils
const auto dims_M = get_container_subset(A_dims_M_K, A_dims_M_ids);
const auto dims_K = get_container_subset(A_dims_M_K, A_dims_K_ids);
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...] Discriptor
const auto A_grid_desc_Ms_Ks =
ck_tile::make_naive_tensor_descriptor(A_dims_M_K, A_strides_M_K);
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...] Descriptor with vector size
const auto A_grid_desc_Ms_Ks = ck_tile::make_naive_tensor_descriptor(
A_dims_M_K, A_strides_M_K, number<VectorSizeA>{}, number<1>{});
// transformed tensor to flatten M and K dimensions [M_total = M0 * M1 * M2 * ... , K_total
// = K0 * K1 * K2 * ...]
@@ -106,9 +109,9 @@ struct TensorDescriptorUtils
const auto dims_N = get_container_subset(B_dims_N_K, B_dims_N_ids);
const auto dims_K = get_container_subset(B_dims_N_K, B_dims_K_ids);
// naive tensor B[N0, N1, N2, ..., K0, K1, K2...] Discriptor
const auto B_grid_desc_Ns_Ks =
ck_tile::make_naive_tensor_descriptor(B_dims_N_K, B_strides_N_K);
// naive tensor B[N0, N1, N2, ..., K0, K1, K2...] Descriptor with vector size
const auto B_grid_desc_Ns_Ks = ck_tile::make_naive_tensor_descriptor(
B_dims_N_K, B_strides_N_K, number<VectorSizeB>{}, number<1>{});
// transformed tensor to flatten N and K dimensions [N_total = N0 * N1 * N2 * ... , K_total
// = K0 * K1 * K2 * ...]
@@ -150,9 +153,9 @@ struct TensorDescriptorUtils
const auto dims_M = get_container_subset(E_dims_M_N, E_dims_M_ids);
const auto dims_N = get_container_subset(E_dims_M_N, E_dims_N_ids);
// naive tensor E[M0, M1, M2, ..., N0, N1, N2...] Discriptor
const auto E_grid_desc_Ms_Ns =
ck_tile::make_naive_tensor_descriptor(E_dims_M_N, E_strides_M_N);
// naive tensor E[M0, M1, M2, ..., N0, N1, N2...] Descriptor with vector size
const auto E_grid_desc_Ms_Ns = ck_tile::make_naive_tensor_descriptor(
E_dims_M_N, E_strides_M_N, number<VectorSizeE>{}, number<1>{});
// transformed tensor to flatten M and N dimensions [M_total = M0 * M1 * M2 * ... ,
// N_total = N0 * N1 * N2 * ...]