mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK Tile] batched contraction kernel generalizing (#3126)
* Add help for example * Refactore the compute reference batched contraction to manage stride-aware calculation and some code cleanings * Add stride-aware reference for batched contraction with independent D tensor layouts * Add -num_d argument for runtime D tensor count selection in batched contraction * Add stride vector arguments in example code for testing non-contiguous batched contraction inputs * Add descriptor-based architecture for batched contraction multi-dimensional stride support * Add multi-dimensional non-contiguous stride support to batched contraction, num_d = 0 * Add complete multi-dimensional stride support via descriptors * Enable vectorization in descriptor-based batched contraction. Add pad_tensor_view to local RunGemm * Clean up batched contraction: remove old UniversalGemmKernel path * Clean up batched contraction: remove legacy paths and finalize docs * Optimize batched contraction example: pass dimension sizes not vectors * correct the reference calculation, unsigned int to int * Fix batched_contraction C++17 build errors for gfx90a CI
This commit is contained in:
@@ -219,9 +219,7 @@ float batched_contraction(const ck_tile::BatchedContractionHostArgs<DsDataType::
|
||||
HANDLE_CASE(2, 1, 1, 1);
|
||||
HANDLE_CASE(2, 2, 2, 1);
|
||||
HANDLE_CASE(1, 2, 1, 1);
|
||||
HANDLE_CASE(1, 1, 1, 2);
|
||||
HANDLE_CASE(2, 2, 2, 2);
|
||||
HANDLE_CASE(4, 4, 4, 4);
|
||||
|
||||
throw std::runtime_error(
|
||||
"Unsupported dimension combination: G=" + std::to_string(num_g_dims) +
|
||||
|
||||
@@ -42,17 +42,83 @@ using AccDataType = ContractionTypes::AccDataType;
|
||||
using EDataType = ContractionTypes::EDataType;
|
||||
using DDataType = ContractionTypes::DDataType;
|
||||
|
||||
void print_help(const char* program_name)
|
||||
{
|
||||
std::cout << "\n";
|
||||
std::cout << "Batched Tensor Contraction with element-wise fusion\n";
|
||||
std::cout << "E[G,M,N] = element_wise_op(contraction(A[G,M,K], B[G,N,K]), D0, D1, ...)\n";
|
||||
std::cout << "(Supports multiple D tensors with configurable element-wise operations)\n\n";
|
||||
|
||||
std::cout << "Usage: " << program_name << " [OPTIONS]\n\n";
|
||||
|
||||
std::cout << "Dimension Arguments (comma-separated, no spaces):\n";
|
||||
std::cout << " -g_dims=<dims> Batch dimensions (default: \"1,2\")\n";
|
||||
std::cout << " -m_dims=<dims> M (row) dimensions (default: \"4,256\")\n";
|
||||
std::cout << " -n_dims=<dims> N (column) dimensions (default: \"16,128\")\n";
|
||||
std::cout << " -k_dims=<dims> K (contract) dims (default: \"64\")\n";
|
||||
std::cout << " -num_d=<int> Number of D tensors (default: 2, range: 0-4)\n\n";
|
||||
|
||||
std::cout << "Custom Stride Arguments (for testing non-contiguous tensors):\n";
|
||||
std::cout << " -strides_a=<s> A tensor strides (comma-separated, empty = auto)\n";
|
||||
std::cout << " -strides_b=<s> B tensor strides (comma-separated, empty = auto)\n";
|
||||
std::cout << " -strides_e=<s> E tensor strides (comma-separated, empty = auto)\n";
|
||||
std::cout << " -strides_ds=<s> D tensors strides (semicolon-separated, empty = same as E)\n";
|
||||
std::cout << " Example: -strides_a=\"32768,128,1\" -strides_ds=\"512,2,1;1024,4,1\"\n\n";
|
||||
|
||||
std::cout << "Layout Arguments:\n";
|
||||
std::cout
|
||||
<< " -a_layout=<R|C> A tensor layout (R=Row-major, C=Column-major, default: \"R\")\n";
|
||||
std::cout << " -b_layout=<R|C> B tensor layout (default: \"C\")\n";
|
||||
std::cout << " -e_layout=<R|C> E tensor layout (default: \"R\")\n\n";
|
||||
|
||||
std::cout << "Examples:\n";
|
||||
std::cout << " Single batch (12 batches of 256×128):\n";
|
||||
std::cout << " " << program_name
|
||||
<< " -g_dims=\"12\" -m_dims=\"256\" -n_dims=\"128\" -k_dims=\"64\"\n\n";
|
||||
|
||||
std::cout << " 2D batch grid (2×3=6 batches):\n";
|
||||
std::cout << " " << program_name
|
||||
<< " -g_dims=\"2,3\" -m_dims=\"128\" -n_dims=\"128\" -k_dims=\"64\"\n\n";
|
||||
|
||||
std::cout << " Multi-dimensional (flattened to M=128, N=128, K=128):\n";
|
||||
std::cout << " " << program_name
|
||||
<< " -g_dims=\"4\" -m_dims=\"8,16\" -n_dims=\"32,4\" -k_dims=\"16,8\"\n\n";
|
||||
|
||||
std::cout << "Other Options:\n";
|
||||
std::cout << " -v=<0|1> Validation (0=off, 1=on, default: 1)\n";
|
||||
std::cout << " -split_k=<int> Split-K value (default: 1)\n";
|
||||
std::cout << " -warmup=<int> Warmup iterations (default: 5)\n";
|
||||
std::cout << " -repeat=<int> Benchmark iterations (default: 10)\n";
|
||||
std::cout << " -log=<0|1> Logging level (default: 1)\n";
|
||||
std::cout << " -help Show this help\n\n";
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
// Check for --help flag
|
||||
for(int i = 1; i < argc; ++i)
|
||||
{
|
||||
std::string arg = argv[i];
|
||||
if(arg == "--help" || arg == "-h" || arg == "-help")
|
||||
{
|
||||
print_help(argv[0]);
|
||||
std::exit(0);
|
||||
}
|
||||
}
|
||||
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m_dims", "4,256", "M dimensions separated by comma (e.g., '16,32' for 2D M)")
|
||||
.insert("n_dims", "16,128", "N dimensions separated by comma (e.g., '32,32' for 2D N)")
|
||||
.insert("k_dims", "64", "K dimensions separated by comma (e.g., '64,32' for 2D K)")
|
||||
.insert(
|
||||
"g_dims", "1,2", "G dimensions separated by comma (e.g., '4,2' for 2D, '2,3,4' for 3D)")
|
||||
.insert("stride_a", "0", "Custom A tensor leading dimension stride (0 = auto)")
|
||||
.insert("stride_b", "0", "Custom B tensor leading dimension stride (0 = auto)")
|
||||
.insert("stride_e", "0", "Custom E tensor leading dimension stride (0 = auto)")
|
||||
.insert("num_d", "2", "Number of D (auxiliary input) tensors")
|
||||
.insert("strides_a", "", "A tensor strides (comma-separated, empty = auto/contiguous)")
|
||||
.insert("strides_b", "", "B tensor strides (comma-separated, empty = auto/contiguous)")
|
||||
.insert("strides_e", "", "E tensor strides (comma-separated, empty = auto/contiguous)")
|
||||
.insert("strides_ds",
|
||||
"",
|
||||
"D tensors strides (semicolon-separated for multiple, empty = same as E)")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
.insert("b_layout", "C", "B tensor data layout - Col by default")
|
||||
.insert("e_layout", "R", "E tensor data layout - Row by default")
|
||||
|
||||
@@ -45,10 +45,10 @@ float invoke_batched_contraction_kernel(
|
||||
const void* b_full_dims_dev_buf,
|
||||
const std::array<const void*, DsDataType::size()>& ds_dev_buf,
|
||||
void* e_full_dims_dev_buf,
|
||||
const std::vector<ck_tile::index_t>& G_dims,
|
||||
const std::vector<ck_tile::index_t>& M_dims,
|
||||
const std::vector<ck_tile::index_t>& N_dims,
|
||||
const std::vector<ck_tile::index_t>& K_dims,
|
||||
ck_tile::index_t num_g_dims,
|
||||
ck_tile::index_t num_m_dims,
|
||||
ck_tile::index_t num_n_dims,
|
||||
ck_tile::index_t num_k_dims,
|
||||
const std::vector<ck_tile::index_t>& A_dims, // [G0,G1,..,M0,M1,..,K0,K1,..]
|
||||
const std::vector<ck_tile::index_t>& B_dims, // [G0,G1,..,N0,N1,..,K0,K1,..]
|
||||
const std::array<std::vector<ck_tile::index_t>, DsDataType::size()>&
|
||||
@@ -79,9 +79,8 @@ float invoke_batched_contraction_kernel(
|
||||
E_strides // E_strides
|
||||
);
|
||||
|
||||
std::cout << "Calling batched_contraction with dimensions: G=" << G_dims.size()
|
||||
<< ", M=" << M_dims.size() << ", N=" << N_dims.size() << ", K=" << K_dims.size()
|
||||
<< std::endl;
|
||||
std::cout << "Calling batched_contraction with dimensions: G=" << num_g_dims
|
||||
<< ", M=" << num_m_dims << ", N=" << num_n_dims << ", K=" << num_k_dims << std::endl;
|
||||
|
||||
float ave_time = batched_contraction<ADataType,
|
||||
BDataType,
|
||||
@@ -95,16 +94,38 @@ float invoke_batched_contraction_kernel(
|
||||
CDEElementWise>(
|
||||
args,
|
||||
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat},
|
||||
G_dims.size(), // num_g_dims
|
||||
M_dims.size(), // num_m_dims
|
||||
N_dims.size(), // num_n_dims
|
||||
K_dims.size() // num_k_dims
|
||||
);
|
||||
num_g_dims,
|
||||
num_m_dims,
|
||||
num_n_dims,
|
||||
num_k_dims);
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename DLayout, typename ELayout>
|
||||
// C++17-compatible helper function to create array of HostTensors
|
||||
namespace {
|
||||
template <typename DDataType, std::size_t NumDTensor, std::size_t... Is>
|
||||
std::array<ck_tile::HostTensor<DDataType>, NumDTensor>
|
||||
make_ds_host_tensors_impl(const std::array<ck_tile::HostTensorDescriptor, NumDTensor>& descs,
|
||||
std::index_sequence<Is...>)
|
||||
{
|
||||
return {ck_tile::HostTensor<DDataType>(descs[Is])...};
|
||||
}
|
||||
|
||||
template <typename DDataType, std::size_t NumDTensor>
|
||||
std::array<ck_tile::HostTensor<DDataType>, NumDTensor>
|
||||
make_ds_host_tensors(const std::array<ck_tile::HostTensorDescriptor, NumDTensor>& descs)
|
||||
{
|
||||
return make_ds_host_tensors_impl<DDataType, NumDTensor>(descs,
|
||||
std::make_index_sequence<NumDTensor>{});
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DLayout,
|
||||
typename ELayout,
|
||||
ck_tile::index_t NumDTensor>
|
||||
int run_batched_contraction_example_with_layouts(
|
||||
int argc,
|
||||
char* argv[],
|
||||
@@ -122,8 +143,6 @@ int run_batched_contraction_example_with_layouts(
|
||||
std::vector<ck_tile::index_t> N_dims = parse_dimensions(arg_parser.get_str("n_dims"));
|
||||
std::vector<ck_tile::index_t> K_dims = parse_dimensions(arg_parser.get_str("k_dims"));
|
||||
|
||||
constexpr ck_tile::index_t NumDTensor = 2;
|
||||
|
||||
ck_tile::index_t G_total = calculate_total_elements(G_dims);
|
||||
ck_tile::index_t M_total = calculate_total_elements(M_dims);
|
||||
ck_tile::index_t N_total = calculate_total_elements(N_dims);
|
||||
@@ -148,13 +167,105 @@ int run_batched_contraction_example_with_layouts(
|
||||
return converted;
|
||||
};
|
||||
|
||||
ck_tile::HostTensorDescriptor a_desc(A_dims);
|
||||
ck_tile::HostTensorDescriptor b_desc(B_dims);
|
||||
ck_tile::HostTensorDescriptor e_desc(E_dims);
|
||||
std::array<ck_tile::HostTensorDescriptor, NumDTensor> ds_descs;
|
||||
for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
|
||||
// Get custom stride arguments
|
||||
std::string strides_a_str = arg_parser.get_str("strides_a");
|
||||
std::string strides_b_str = arg_parser.get_str("strides_b");
|
||||
std::string strides_e_str = arg_parser.get_str("strides_e");
|
||||
std::string strides_ds_str = arg_parser.get_str("strides_ds");
|
||||
|
||||
// Create A descriptor with custom or default strides
|
||||
ck_tile::HostTensorDescriptor a_desc;
|
||||
if(!strides_a_str.empty())
|
||||
{
|
||||
ds_descs[d] = ck_tile::HostTensorDescriptor(Ds_dims[d], e_desc.get_strides());
|
||||
std::vector<ck_tile::index_t> custom_a_strides = parse_dimensions(strides_a_str);
|
||||
if(custom_a_strides.size() != A_dims.size())
|
||||
{
|
||||
throw std::runtime_error("strides_a size must match A_dims size");
|
||||
}
|
||||
std::vector<std::size_t> a_strides_size_t(custom_a_strides.begin(), custom_a_strides.end());
|
||||
a_desc = ck_tile::HostTensorDescriptor(A_dims, a_strides_size_t);
|
||||
std::cout << "Using custom strides for A (non-contiguous)" << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
a_desc = ck_tile::HostTensorDescriptor(A_dims);
|
||||
}
|
||||
|
||||
// Create B descriptor with custom or default strides
|
||||
ck_tile::HostTensorDescriptor b_desc;
|
||||
if(!strides_b_str.empty())
|
||||
{
|
||||
std::vector<ck_tile::index_t> custom_b_strides = parse_dimensions(strides_b_str);
|
||||
if(custom_b_strides.size() != B_dims.size())
|
||||
{
|
||||
throw std::runtime_error("strides_b size must match B_dims size");
|
||||
}
|
||||
std::vector<std::size_t> b_strides_size_t(custom_b_strides.begin(), custom_b_strides.end());
|
||||
b_desc = ck_tile::HostTensorDescriptor(B_dims, b_strides_size_t);
|
||||
std::cout << "Using custom strides for B (non-contiguous)" << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
b_desc = ck_tile::HostTensorDescriptor(B_dims);
|
||||
}
|
||||
|
||||
// Create E descriptor with custom or default strides
|
||||
ck_tile::HostTensorDescriptor e_desc;
|
||||
if(!strides_e_str.empty())
|
||||
{
|
||||
std::vector<ck_tile::index_t> custom_e_strides = parse_dimensions(strides_e_str);
|
||||
if(custom_e_strides.size() != E_dims.size())
|
||||
{
|
||||
throw std::runtime_error("strides_e size must match E_dims size");
|
||||
}
|
||||
std::vector<std::size_t> e_strides_size_t(custom_e_strides.begin(), custom_e_strides.end());
|
||||
e_desc = ck_tile::HostTensorDescriptor(E_dims, e_strides_size_t);
|
||||
std::cout << "Using custom strides for E (non-contiguous)" << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
e_desc = ck_tile::HostTensorDescriptor(E_dims);
|
||||
}
|
||||
// Create D descriptors with custom or default strides (default = same as E)
|
||||
std::array<ck_tile::HostTensorDescriptor, NumDTensor> ds_descs;
|
||||
if(!strides_ds_str.empty())
|
||||
{
|
||||
// Parse semicolon-separated stride vectors for multiple D tensors
|
||||
std::vector<std::vector<ck_tile::index_t>> all_ds_strides;
|
||||
std::stringstream ss(strides_ds_str);
|
||||
std::string d_stride_str;
|
||||
|
||||
while(std::getline(ss, d_stride_str, ';'))
|
||||
{
|
||||
all_ds_strides.push_back(parse_dimensions(d_stride_str));
|
||||
}
|
||||
|
||||
if(all_ds_strides.size() != NumDTensor)
|
||||
{
|
||||
throw std::runtime_error("Number of D stride vectors must match num_d=" +
|
||||
std::to_string(NumDTensor));
|
||||
}
|
||||
|
||||
std::cout << "Using custom strides for D tensors (non-contiguous)" << std::endl;
|
||||
for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
|
||||
{
|
||||
if(all_ds_strides[d].size() != E_dims.size())
|
||||
{
|
||||
throw std::runtime_error("D tensor " + std::to_string(d) +
|
||||
" stride size must match E_dims size");
|
||||
}
|
||||
std::vector<std::size_t> d_strides_size_t(all_ds_strides[d].begin(),
|
||||
all_ds_strides[d].end());
|
||||
ds_descs[d] = ck_tile::HostTensorDescriptor(Ds_dims[d], d_strides_size_t);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Default: use same strides as E
|
||||
for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
|
||||
{
|
||||
ds_descs[d] = ck_tile::HostTensorDescriptor(Ds_dims[d], e_desc.get_strides());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ck_tile::index_t> A_strides = convert_strides(a_desc.get_strides());
|
||||
@@ -201,11 +312,8 @@ int run_batched_contraction_example_with_layouts(
|
||||
ck_tile::HostTensor<::BDataType> b_full_dims_host(b_desc);
|
||||
ck_tile::HostTensor<::EDataType> e_full_dims_host(e_desc);
|
||||
|
||||
std::vector<ck_tile::HostTensor<::DDataType>> ds_full_dims_host;
|
||||
for(int d = 0; d < NumDTensor; ++d)
|
||||
{
|
||||
ds_full_dims_host.emplace_back(ck_tile::HostTensor<::DDataType>(ds_descs[d]));
|
||||
}
|
||||
// Construct array of HostTensors - C++17 compatible
|
||||
auto ds_full_dims_host = make_ds_host_tensors<::DDataType, NumDTensor>(ds_descs);
|
||||
|
||||
ck_tile::FillUniformDistribution<::ADataType>{-5.f, 5.f, std::nullopt}(a_full_dims_host);
|
||||
ck_tile::FillUniformDistribution<::BDataType>{-5.f, 5.f, std::nullopt}(b_full_dims_host);
|
||||
@@ -260,10 +368,10 @@ int run_batched_contraction_example_with_layouts(
|
||||
b_full_dims_dev_buf.GetDeviceBuffer(),
|
||||
ds_ptr_buf,
|
||||
e_full_dims_dev_buf.GetDeviceBuffer(),
|
||||
G_dims,
|
||||
M_dims,
|
||||
N_dims,
|
||||
K_dims,
|
||||
G_dims.size(),
|
||||
M_dims.size(),
|
||||
N_dims.size(),
|
||||
K_dims.size(),
|
||||
A_dims,
|
||||
B_dims,
|
||||
Ds_dims,
|
||||
@@ -316,12 +424,13 @@ int run_batched_contraction_example_with_layouts(
|
||||
|
||||
auto start_time = std::chrono::high_resolution_clock::now();
|
||||
|
||||
ck_tile::calculate_reference_flat_indexing<ADataType,
|
||||
BDataType,
|
||||
DDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
CDEElementWise>(a_full_dims_host,
|
||||
ck_tile::compute_reference_batched_contraction<ADataType,
|
||||
BDataType,
|
||||
DDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
CDEElementWise,
|
||||
NumDTensor>(a_full_dims_host,
|
||||
b_full_dims_host,
|
||||
ds_full_dims_host,
|
||||
e_full_dims_host_ref,
|
||||
@@ -329,7 +438,11 @@ int run_batched_contraction_example_with_layouts(
|
||||
M_total,
|
||||
N_total,
|
||||
K_total,
|
||||
CDEElementWise{});
|
||||
CDEElementWise{},
|
||||
G_dims,
|
||||
M_dims,
|
||||
N_dims,
|
||||
K_dims);
|
||||
|
||||
auto end_time = std::chrono::high_resolution_clock::now();
|
||||
auto duration =
|
||||
@@ -387,15 +500,45 @@ int run_batched_contraction_example(int argc, char* argv[])
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
// Get NumDTensor to dispatch at runtime
|
||||
const int num_d = arg_parser.get_int("num_d");
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
// Runtime dispatch based on num_d value
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_batched_contraction_example_with_layouts(argc, argv, Row{}, Col{}, Row{}, Row{});
|
||||
// Dispatch to appropriate template instantiation based on runtime num_d
|
||||
switch(num_d)
|
||||
{
|
||||
case 0:
|
||||
std::cout << "Running with 0 D tensors" << std::endl;
|
||||
return run_batched_contraction_example_with_layouts<Row, Col, Row, Row, 0>(
|
||||
argc, argv, Row{}, Col{}, Row{}, Row{});
|
||||
case 1:
|
||||
std::cout << "Running with 1 D tensor" << std::endl;
|
||||
return run_batched_contraction_example_with_layouts<Row, Col, Row, Row, 1>(
|
||||
argc, argv, Row{}, Col{}, Row{}, Row{});
|
||||
case 2:
|
||||
std::cout << "Running with 2 D tensors" << std::endl;
|
||||
return run_batched_contraction_example_with_layouts<Row, Col, Row, Row, 2>(
|
||||
argc, argv, Row{}, Col{}, Row{}, Row{});
|
||||
case 3:
|
||||
std::cout << "Running with 3 D tensors" << std::endl;
|
||||
return run_batched_contraction_example_with_layouts<Row, Col, Row, Row, 3>(
|
||||
argc, argv, Row{}, Col{}, Row{}, Row{});
|
||||
case 4:
|
||||
std::cout << "Running with 4 D tensors" << std::endl;
|
||||
return run_batched_contraction_example_with_layouts<Row, Col, Row, Row, 4>(
|
||||
argc, argv, Row{}, Col{}, Row{}, Row{});
|
||||
default:
|
||||
throw std::runtime_error("num_d must be between 0 and 4, got: " +
|
||||
std::to_string(num_d));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user