[CK Tile] batched contraction kernel generalizing (#3126)

* Add help for example

* Refactore the compute reference batched contraction to manage stride-aware calculation and some code cleanings

* Add stride-aware reference for batched contraction with independent D tensor layouts

* Add -num_d argument for runtime D tensor count selection in batched contraction

* Add stride vector arguments in example code for testing non-contiguous batched contraction inputs

* Add descriptor-based architecture for batched contraction multi-dimensional stride support

* Add multi-dimensional non-contiguous stride support to batched contraction, num_d = 0

* Add complete multi-dimensional stride support via descriptors

* Enable vectorization in descriptor-based batched contraction. Add pad_tensor_view to local RunGemm

* Clean up batched contraction: remove old UniversalGemmKernel path

* Clean up batched contraction: remove legacy paths and finalize docs

* Optimize batched contraction example: pass dimension sizes not vectors

* correct the reference calculation, unsigned int to int

* Fix batched_contraction C++17 build errors for gfx90a CI
This commit is contained in:
msaffari-amd
2025-12-02 13:30:27 +01:00
committed by GitHub
parent d3f37ebf6c
commit 2d3020e5b0
6 changed files with 694 additions and 313 deletions

View File

@@ -219,9 +219,7 @@ float batched_contraction(const ck_tile::BatchedContractionHostArgs<DsDataType::
HANDLE_CASE(2, 1, 1, 1);
HANDLE_CASE(2, 2, 2, 1);
HANDLE_CASE(1, 2, 1, 1);
HANDLE_CASE(1, 1, 1, 2);
HANDLE_CASE(2, 2, 2, 2);
HANDLE_CASE(4, 4, 4, 4);
throw std::runtime_error(
"Unsupported dimension combination: G=" + std::to_string(num_g_dims) +

View File

@@ -42,17 +42,83 @@ using AccDataType = ContractionTypes::AccDataType;
using EDataType = ContractionTypes::EDataType;
using DDataType = ContractionTypes::DDataType;
void print_help(const char* program_name)
{
std::cout << "\n";
std::cout << "Batched Tensor Contraction with element-wise fusion\n";
std::cout << "E[G,M,N] = element_wise_op(contraction(A[G,M,K], B[G,N,K]), D0, D1, ...)\n";
std::cout << "(Supports multiple D tensors with configurable element-wise operations)\n\n";
std::cout << "Usage: " << program_name << " [OPTIONS]\n\n";
std::cout << "Dimension Arguments (comma-separated, no spaces):\n";
std::cout << " -g_dims=<dims> Batch dimensions (default: \"1,2\")\n";
std::cout << " -m_dims=<dims> M (row) dimensions (default: \"4,256\")\n";
std::cout << " -n_dims=<dims> N (column) dimensions (default: \"16,128\")\n";
std::cout << " -k_dims=<dims> K (contract) dims (default: \"64\")\n";
std::cout << " -num_d=<int> Number of D tensors (default: 2, range: 0-4)\n\n";
std::cout << "Custom Stride Arguments (for testing non-contiguous tensors):\n";
std::cout << " -strides_a=<s> A tensor strides (comma-separated, empty = auto)\n";
std::cout << " -strides_b=<s> B tensor strides (comma-separated, empty = auto)\n";
std::cout << " -strides_e=<s> E tensor strides (comma-separated, empty = auto)\n";
std::cout << " -strides_ds=<s> D tensors strides (semicolon-separated, empty = same as E)\n";
std::cout << " Example: -strides_a=\"32768,128,1\" -strides_ds=\"512,2,1;1024,4,1\"\n\n";
std::cout << "Layout Arguments:\n";
std::cout
<< " -a_layout=<R|C> A tensor layout (R=Row-major, C=Column-major, default: \"R\")\n";
std::cout << " -b_layout=<R|C> B tensor layout (default: \"C\")\n";
std::cout << " -e_layout=<R|C> E tensor layout (default: \"R\")\n\n";
std::cout << "Examples:\n";
std::cout << " Single batch (12 batches of 256×128):\n";
std::cout << " " << program_name
<< " -g_dims=\"12\" -m_dims=\"256\" -n_dims=\"128\" -k_dims=\"64\"\n\n";
std::cout << " 2D batch grid (2×3=6 batches):\n";
std::cout << " " << program_name
<< " -g_dims=\"2,3\" -m_dims=\"128\" -n_dims=\"128\" -k_dims=\"64\"\n\n";
std::cout << " Multi-dimensional (flattened to M=128, N=128, K=128):\n";
std::cout << " " << program_name
<< " -g_dims=\"4\" -m_dims=\"8,16\" -n_dims=\"32,4\" -k_dims=\"16,8\"\n\n";
std::cout << "Other Options:\n";
std::cout << " -v=<0|1> Validation (0=off, 1=on, default: 1)\n";
std::cout << " -split_k=<int> Split-K value (default: 1)\n";
std::cout << " -warmup=<int> Warmup iterations (default: 5)\n";
std::cout << " -repeat=<int> Benchmark iterations (default: 10)\n";
std::cout << " -log=<0|1> Logging level (default: 1)\n";
std::cout << " -help Show this help\n\n";
}
auto create_args(int argc, char* argv[])
{
// Check for --help flag
for(int i = 1; i < argc; ++i)
{
std::string arg = argv[i];
if(arg == "--help" || arg == "-h" || arg == "-help")
{
print_help(argv[0]);
std::exit(0);
}
}
ck_tile::ArgParser arg_parser;
arg_parser.insert("m_dims", "4,256", "M dimensions separated by comma (e.g., '16,32' for 2D M)")
.insert("n_dims", "16,128", "N dimensions separated by comma (e.g., '32,32' for 2D N)")
.insert("k_dims", "64", "K dimensions separated by comma (e.g., '64,32' for 2D K)")
.insert(
"g_dims", "1,2", "G dimensions separated by comma (e.g., '4,2' for 2D, '2,3,4' for 3D)")
.insert("stride_a", "0", "Custom A tensor leading dimension stride (0 = auto)")
.insert("stride_b", "0", "Custom B tensor leading dimension stride (0 = auto)")
.insert("stride_e", "0", "Custom E tensor leading dimension stride (0 = auto)")
.insert("num_d", "2", "Number of D (auxiliary input) tensors")
.insert("strides_a", "", "A tensor strides (comma-separated, empty = auto/contiguous)")
.insert("strides_b", "", "B tensor strides (comma-separated, empty = auto/contiguous)")
.insert("strides_e", "", "E tensor strides (comma-separated, empty = auto/contiguous)")
.insert("strides_ds",
"",
"D tensors strides (semicolon-separated for multiple, empty = same as E)")
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "C", "B tensor data layout - Col by default")
.insert("e_layout", "R", "E tensor data layout - Row by default")

View File

@@ -45,10 +45,10 @@ float invoke_batched_contraction_kernel(
const void* b_full_dims_dev_buf,
const std::array<const void*, DsDataType::size()>& ds_dev_buf,
void* e_full_dims_dev_buf,
const std::vector<ck_tile::index_t>& G_dims,
const std::vector<ck_tile::index_t>& M_dims,
const std::vector<ck_tile::index_t>& N_dims,
const std::vector<ck_tile::index_t>& K_dims,
ck_tile::index_t num_g_dims,
ck_tile::index_t num_m_dims,
ck_tile::index_t num_n_dims,
ck_tile::index_t num_k_dims,
const std::vector<ck_tile::index_t>& A_dims, // [G0,G1,..,M0,M1,..,K0,K1,..]
const std::vector<ck_tile::index_t>& B_dims, // [G0,G1,..,N0,N1,..,K0,K1,..]
const std::array<std::vector<ck_tile::index_t>, DsDataType::size()>&
@@ -79,9 +79,8 @@ float invoke_batched_contraction_kernel(
E_strides // E_strides
);
std::cout << "Calling batched_contraction with dimensions: G=" << G_dims.size()
<< ", M=" << M_dims.size() << ", N=" << N_dims.size() << ", K=" << K_dims.size()
<< std::endl;
std::cout << "Calling batched_contraction with dimensions: G=" << num_g_dims
<< ", M=" << num_m_dims << ", N=" << num_n_dims << ", K=" << num_k_dims << std::endl;
float ave_time = batched_contraction<ADataType,
BDataType,
@@ -95,16 +94,38 @@ float invoke_batched_contraction_kernel(
CDEElementWise>(
args,
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat},
G_dims.size(), // num_g_dims
M_dims.size(), // num_m_dims
N_dims.size(), // num_n_dims
K_dims.size() // num_k_dims
);
num_g_dims,
num_m_dims,
num_n_dims,
num_k_dims);
return ave_time;
}
template <typename ALayout, typename BLayout, typename DLayout, typename ELayout>
// C++17-compatible helper function to create array of HostTensors
namespace {
template <typename DDataType, std::size_t NumDTensor, std::size_t... Is>
std::array<ck_tile::HostTensor<DDataType>, NumDTensor>
make_ds_host_tensors_impl(const std::array<ck_tile::HostTensorDescriptor, NumDTensor>& descs,
std::index_sequence<Is...>)
{
return {ck_tile::HostTensor<DDataType>(descs[Is])...};
}
template <typename DDataType, std::size_t NumDTensor>
std::array<ck_tile::HostTensor<DDataType>, NumDTensor>
make_ds_host_tensors(const std::array<ck_tile::HostTensorDescriptor, NumDTensor>& descs)
{
return make_ds_host_tensors_impl<DDataType, NumDTensor>(descs,
std::make_index_sequence<NumDTensor>{});
}
} // anonymous namespace
template <typename ALayout,
typename BLayout,
typename DLayout,
typename ELayout,
ck_tile::index_t NumDTensor>
int run_batched_contraction_example_with_layouts(
int argc,
char* argv[],
@@ -122,8 +143,6 @@ int run_batched_contraction_example_with_layouts(
std::vector<ck_tile::index_t> N_dims = parse_dimensions(arg_parser.get_str("n_dims"));
std::vector<ck_tile::index_t> K_dims = parse_dimensions(arg_parser.get_str("k_dims"));
constexpr ck_tile::index_t NumDTensor = 2;
ck_tile::index_t G_total = calculate_total_elements(G_dims);
ck_tile::index_t M_total = calculate_total_elements(M_dims);
ck_tile::index_t N_total = calculate_total_elements(N_dims);
@@ -148,13 +167,105 @@ int run_batched_contraction_example_with_layouts(
return converted;
};
ck_tile::HostTensorDescriptor a_desc(A_dims);
ck_tile::HostTensorDescriptor b_desc(B_dims);
ck_tile::HostTensorDescriptor e_desc(E_dims);
std::array<ck_tile::HostTensorDescriptor, NumDTensor> ds_descs;
for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
// Get custom stride arguments
std::string strides_a_str = arg_parser.get_str("strides_a");
std::string strides_b_str = arg_parser.get_str("strides_b");
std::string strides_e_str = arg_parser.get_str("strides_e");
std::string strides_ds_str = arg_parser.get_str("strides_ds");
// Create A descriptor with custom or default strides
ck_tile::HostTensorDescriptor a_desc;
if(!strides_a_str.empty())
{
ds_descs[d] = ck_tile::HostTensorDescriptor(Ds_dims[d], e_desc.get_strides());
std::vector<ck_tile::index_t> custom_a_strides = parse_dimensions(strides_a_str);
if(custom_a_strides.size() != A_dims.size())
{
throw std::runtime_error("strides_a size must match A_dims size");
}
std::vector<std::size_t> a_strides_size_t(custom_a_strides.begin(), custom_a_strides.end());
a_desc = ck_tile::HostTensorDescriptor(A_dims, a_strides_size_t);
std::cout << "Using custom strides for A (non-contiguous)" << std::endl;
}
else
{
a_desc = ck_tile::HostTensorDescriptor(A_dims);
}
// Create B descriptor with custom or default strides
ck_tile::HostTensorDescriptor b_desc;
if(!strides_b_str.empty())
{
std::vector<ck_tile::index_t> custom_b_strides = parse_dimensions(strides_b_str);
if(custom_b_strides.size() != B_dims.size())
{
throw std::runtime_error("strides_b size must match B_dims size");
}
std::vector<std::size_t> b_strides_size_t(custom_b_strides.begin(), custom_b_strides.end());
b_desc = ck_tile::HostTensorDescriptor(B_dims, b_strides_size_t);
std::cout << "Using custom strides for B (non-contiguous)" << std::endl;
}
else
{
b_desc = ck_tile::HostTensorDescriptor(B_dims);
}
// Create E descriptor with custom or default strides
ck_tile::HostTensorDescriptor e_desc;
if(!strides_e_str.empty())
{
std::vector<ck_tile::index_t> custom_e_strides = parse_dimensions(strides_e_str);
if(custom_e_strides.size() != E_dims.size())
{
throw std::runtime_error("strides_e size must match E_dims size");
}
std::vector<std::size_t> e_strides_size_t(custom_e_strides.begin(), custom_e_strides.end());
e_desc = ck_tile::HostTensorDescriptor(E_dims, e_strides_size_t);
std::cout << "Using custom strides for E (non-contiguous)" << std::endl;
}
else
{
e_desc = ck_tile::HostTensorDescriptor(E_dims);
}
// Create D descriptors with custom or default strides (default = same as E)
std::array<ck_tile::HostTensorDescriptor, NumDTensor> ds_descs;
if(!strides_ds_str.empty())
{
// Parse semicolon-separated stride vectors for multiple D tensors
std::vector<std::vector<ck_tile::index_t>> all_ds_strides;
std::stringstream ss(strides_ds_str);
std::string d_stride_str;
while(std::getline(ss, d_stride_str, ';'))
{
all_ds_strides.push_back(parse_dimensions(d_stride_str));
}
if(all_ds_strides.size() != NumDTensor)
{
throw std::runtime_error("Number of D stride vectors must match num_d=" +
std::to_string(NumDTensor));
}
std::cout << "Using custom strides for D tensors (non-contiguous)" << std::endl;
for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
{
if(all_ds_strides[d].size() != E_dims.size())
{
throw std::runtime_error("D tensor " + std::to_string(d) +
" stride size must match E_dims size");
}
std::vector<std::size_t> d_strides_size_t(all_ds_strides[d].begin(),
all_ds_strides[d].end());
ds_descs[d] = ck_tile::HostTensorDescriptor(Ds_dims[d], d_strides_size_t);
}
}
else
{
// Default: use same strides as E
for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
{
ds_descs[d] = ck_tile::HostTensorDescriptor(Ds_dims[d], e_desc.get_strides());
}
}
std::vector<ck_tile::index_t> A_strides = convert_strides(a_desc.get_strides());
@@ -201,11 +312,8 @@ int run_batched_contraction_example_with_layouts(
ck_tile::HostTensor<::BDataType> b_full_dims_host(b_desc);
ck_tile::HostTensor<::EDataType> e_full_dims_host(e_desc);
std::vector<ck_tile::HostTensor<::DDataType>> ds_full_dims_host;
for(int d = 0; d < NumDTensor; ++d)
{
ds_full_dims_host.emplace_back(ck_tile::HostTensor<::DDataType>(ds_descs[d]));
}
// Construct array of HostTensors - C++17 compatible
auto ds_full_dims_host = make_ds_host_tensors<::DDataType, NumDTensor>(ds_descs);
ck_tile::FillUniformDistribution<::ADataType>{-5.f, 5.f, std::nullopt}(a_full_dims_host);
ck_tile::FillUniformDistribution<::BDataType>{-5.f, 5.f, std::nullopt}(b_full_dims_host);
@@ -260,10 +368,10 @@ int run_batched_contraction_example_with_layouts(
b_full_dims_dev_buf.GetDeviceBuffer(),
ds_ptr_buf,
e_full_dims_dev_buf.GetDeviceBuffer(),
G_dims,
M_dims,
N_dims,
K_dims,
G_dims.size(),
M_dims.size(),
N_dims.size(),
K_dims.size(),
A_dims,
B_dims,
Ds_dims,
@@ -316,12 +424,13 @@ int run_batched_contraction_example_with_layouts(
auto start_time = std::chrono::high_resolution_clock::now();
ck_tile::calculate_reference_flat_indexing<ADataType,
BDataType,
DDataType,
EDataType,
AccDataType,
CDEElementWise>(a_full_dims_host,
ck_tile::compute_reference_batched_contraction<ADataType,
BDataType,
DDataType,
EDataType,
AccDataType,
CDEElementWise,
NumDTensor>(a_full_dims_host,
b_full_dims_host,
ds_full_dims_host,
e_full_dims_host_ref,
@@ -329,7 +438,11 @@ int run_batched_contraction_example_with_layouts(
M_total,
N_total,
K_total,
CDEElementWise{});
CDEElementWise{},
G_dims,
M_dims,
N_dims,
K_dims);
auto end_time = std::chrono::high_resolution_clock::now();
auto duration =
@@ -387,15 +500,45 @@ int run_batched_contraction_example(int argc, char* argv[])
if(!result)
return -1;
// Get NumDTensor to dispatch at runtime
const int num_d = arg_parser.get_int("num_d");
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
// Runtime dispatch based on num_d value
if(a_layout == "R" && b_layout == "C")
{
return run_batched_contraction_example_with_layouts(argc, argv, Row{}, Col{}, Row{}, Row{});
// Dispatch to appropriate template instantiation based on runtime num_d
switch(num_d)
{
case 0:
std::cout << "Running with 0 D tensors" << std::endl;
return run_batched_contraction_example_with_layouts<Row, Col, Row, Row, 0>(
argc, argv, Row{}, Col{}, Row{}, Row{});
case 1:
std::cout << "Running with 1 D tensor" << std::endl;
return run_batched_contraction_example_with_layouts<Row, Col, Row, Row, 1>(
argc, argv, Row{}, Col{}, Row{}, Row{});
case 2:
std::cout << "Running with 2 D tensors" << std::endl;
return run_batched_contraction_example_with_layouts<Row, Col, Row, Row, 2>(
argc, argv, Row{}, Col{}, Row{}, Row{});
case 3:
std::cout << "Running with 3 D tensors" << std::endl;
return run_batched_contraction_example_with_layouts<Row, Col, Row, Row, 3>(
argc, argv, Row{}, Col{}, Row{}, Row{});
case 4:
std::cout << "Running with 4 D tensors" << std::endl;
return run_batched_contraction_example_with_layouts<Row, Col, Row, Row, 4>(
argc, argv, Row{}, Col{}, Row{}, Row{});
default:
throw std::runtime_error("num_d must be between 0 and 4, got: " +
std::to_string(num_d));
}
}
else
{