mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 09:45:56 +00:00
Refactore the compute reference batched contraction to manage stride-aware calculation and some code cleanings
This commit is contained in:
@@ -48,30 +48,34 @@ void print_help(const char* program_name)
|
||||
std::cout << "Batched Tensor Contraction with element-wise fusion\n";
|
||||
std::cout << "E[G,M,N] = element_wise_op(contraction(A[G,M,K], B[G,N,K]), D0, D1, ...)\n";
|
||||
std::cout << "(Supports multiple D tensors with configurable element-wise operations)\n\n";
|
||||
|
||||
|
||||
std::cout << "Usage: " << program_name << " [OPTIONS]\n\n";
|
||||
|
||||
|
||||
std::cout << "Dimension Arguments (comma-separated, no spaces):\n";
|
||||
std::cout << " -g_dims=<dims> Batch dimensions (default: \"1,2\")\n";
|
||||
std::cout << " -m_dims=<dims> M (row) dimensions (default: \"4,256\")\n";
|
||||
std::cout << " -n_dims=<dims> N (column) dimensions (default: \"16,128\")\n";
|
||||
std::cout << " -k_dims=<dims> K (contract) dims (default: \"64\")\n\n";
|
||||
|
||||
|
||||
std::cout << "Layout Arguments:\n";
|
||||
std::cout << " -a_layout=<R|C> A tensor layout (R=Row-major, C=Column-major, default: \"R\")\n";
|
||||
std::cout
|
||||
<< " -a_layout=<R|C> A tensor layout (R=Row-major, C=Column-major, default: \"R\")\n";
|
||||
std::cout << " -b_layout=<R|C> B tensor layout (default: \"C\")\n";
|
||||
std::cout << " -e_layout=<R|C> E tensor layout (default: \"R\")\n\n";
|
||||
|
||||
|
||||
std::cout << "Examples:\n";
|
||||
std::cout << " Single batch (12 batches of 256×128):\n";
|
||||
std::cout << " " << program_name << " -g_dims=\"12\" -m_dims=\"256\" -n_dims=\"128\" -k_dims=\"64\"\n\n";
|
||||
|
||||
std::cout << " " << program_name
|
||||
<< " -g_dims=\"12\" -m_dims=\"256\" -n_dims=\"128\" -k_dims=\"64\"\n\n";
|
||||
|
||||
std::cout << " 2D batch grid (2×3=6 batches):\n";
|
||||
std::cout << " " << program_name << " -g_dims=\"2,3\" -m_dims=\"128\" -n_dims=\"128\" -k_dims=\"64\"\n\n";
|
||||
|
||||
std::cout << " " << program_name
|
||||
<< " -g_dims=\"2,3\" -m_dims=\"128\" -n_dims=\"128\" -k_dims=\"64\"\n\n";
|
||||
|
||||
std::cout << " Multi-dimensional (flattened to M=128, N=128, K=128):\n";
|
||||
std::cout << " " << program_name << " -g_dims=\"4\" -m_dims=\"8,16\" -n_dims=\"32,4\" -k_dims=\"16,8\"\n\n";
|
||||
|
||||
std::cout << " " << program_name
|
||||
<< " -g_dims=\"4\" -m_dims=\"8,16\" -n_dims=\"32,4\" -k_dims=\"16,8\"\n\n";
|
||||
|
||||
std::cout << "Other Options:\n";
|
||||
std::cout << " -v=<0|1> Validation (0=off, 1=on, default: 1)\n";
|
||||
std::cout << " -split_k=<int> Split-K value (default: 1)\n";
|
||||
@@ -93,7 +97,7 @@ auto create_args(int argc, char* argv[])
|
||||
std::exit(0);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m_dims", "4,256", "M dimensions separated by comma (e.g., '16,32' for 2D M)")
|
||||
.insert("n_dims", "16,128", "N dimensions separated by comma (e.g., '32,32' for 2D N)")
|
||||
|
||||
@@ -201,11 +201,14 @@ int run_batched_contraction_example_with_layouts(
|
||||
ck_tile::HostTensor<::BDataType> b_full_dims_host(b_desc);
|
||||
ck_tile::HostTensor<::EDataType> e_full_dims_host(e_desc);
|
||||
|
||||
std::vector<ck_tile::HostTensor<::DDataType>> ds_full_dims_host;
|
||||
for(int d = 0; d < NumDTensor; ++d)
|
||||
{
|
||||
ds_full_dims_host.emplace_back(ck_tile::HostTensor<::DDataType>(ds_descs[d]));
|
||||
}
|
||||
// Helper to construct array of HostTensors using index_sequence
|
||||
auto make_ds_host_tensors = []<std::size_t... Is>(const auto& descs,
|
||||
std::index_sequence<Is...>) {
|
||||
return std::array<ck_tile::HostTensor<::DDataType>, sizeof...(Is)>{
|
||||
ck_tile::HostTensor<::DDataType>(descs[Is])...};
|
||||
};
|
||||
|
||||
auto ds_full_dims_host = make_ds_host_tensors(ds_descs, std::make_index_sequence<NumDTensor>{});
|
||||
|
||||
ck_tile::FillUniformDistribution<::ADataType>{-5.f, 5.f, std::nullopt}(a_full_dims_host);
|
||||
ck_tile::FillUniformDistribution<::BDataType>{-5.f, 5.f, std::nullopt}(b_full_dims_host);
|
||||
@@ -316,12 +319,13 @@ int run_batched_contraction_example_with_layouts(
|
||||
|
||||
auto start_time = std::chrono::high_resolution_clock::now();
|
||||
|
||||
calculate_reference_flat_indexing<ADataType,
|
||||
BDataType,
|
||||
DDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
CDEElementWise>(a_full_dims_host,
|
||||
compute_reference_batched_contraction<ADataType,
|
||||
BDataType,
|
||||
DDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
CDEElementWise,
|
||||
NumDTensor>(a_full_dims_host,
|
||||
b_full_dims_host,
|
||||
ds_full_dims_host,
|
||||
e_full_dims_host_ref,
|
||||
@@ -329,7 +333,11 @@ int run_batched_contraction_example_with_layouts(
|
||||
M_total,
|
||||
N_total,
|
||||
K_total,
|
||||
CDEElementWise{});
|
||||
CDEElementWise{},
|
||||
G_dims,
|
||||
M_dims,
|
||||
N_dims,
|
||||
K_dims);
|
||||
|
||||
auto end_time = std::chrono::high_resolution_clock::now();
|
||||
auto duration =
|
||||
|
||||
Reference in New Issue
Block a user