diff --git a/example/ck_tile/41_batched_contraction/batched_contraction.cpp b/example/ck_tile/41_batched_contraction/batched_contraction.cpp index d83cc70c62..6536894394 100644 --- a/example/ck_tile/41_batched_contraction/batched_contraction.cpp +++ b/example/ck_tile/41_batched_contraction/batched_contraction.cpp @@ -219,9 +219,7 @@ float batched_contraction(const ck_tile::BatchedContractionHostArgs Batch dimensions (default: \"1,2\")\n"; + std::cout << " -m_dims= M (row) dimensions (default: \"4,256\")\n"; + std::cout << " -n_dims= N (column) dimensions (default: \"16,128\")\n"; + std::cout << " -k_dims= K (contract) dims (default: \"64\")\n"; + std::cout << " -num_d= Number of D tensors (default: 2, range: 0-4)\n\n"; + + std::cout << "Custom Stride Arguments (for testing non-contiguous tensors):\n"; + std::cout << " -strides_a= A tensor strides (comma-separated, empty = auto)\n"; + std::cout << " -strides_b= B tensor strides (comma-separated, empty = auto)\n"; + std::cout << " -strides_e= E tensor strides (comma-separated, empty = auto)\n"; + std::cout << " -strides_ds= D tensors strides (semicolon-separated, empty = same as E)\n"; + std::cout << " Example: -strides_a=\"32768,128,1\" -strides_ds=\"512,2,1;1024,4,1\"\n\n"; + + std::cout << "Layout Arguments:\n"; + std::cout + << " -a_layout= A tensor layout (R=Row-major, C=Column-major, default: \"R\")\n"; + std::cout << " -b_layout= B tensor layout (default: \"C\")\n"; + std::cout << " -e_layout= E tensor layout (default: \"R\")\n\n"; + + std::cout << "Examples:\n"; + std::cout << " Single batch (12 batches of 256×128):\n"; + std::cout << " " << program_name + << " -g_dims=\"12\" -m_dims=\"256\" -n_dims=\"128\" -k_dims=\"64\"\n\n"; + + std::cout << " 2D batch grid (2×3=6 batches):\n"; + std::cout << " " << program_name + << " -g_dims=\"2,3\" -m_dims=\"128\" -n_dims=\"128\" -k_dims=\"64\"\n\n"; + + std::cout << " Multi-dimensional (flattened to M=128, N=128, K=128):\n"; + std::cout << " " << program_name + << " -g_dims=\"4\" -m_dims=\"8,16\" -n_dims=\"32,4\" -k_dims=\"16,8\"\n\n"; + + std::cout << "Other Options:\n"; + std::cout << " -v=<0|1> Validation (0=off, 1=on, default: 1)\n"; + std::cout << " -split_k= Split-K value (default: 1)\n"; + std::cout << " -warmup= Warmup iterations (default: 5)\n"; + std::cout << " -repeat= Benchmark iterations (default: 10)\n"; + std::cout << " -log=<0|1> Logging level (default: 1)\n"; + std::cout << " -help Show this help\n\n"; +} + auto create_args(int argc, char* argv[]) { + // Check for --help flag + for(int i = 1; i < argc; ++i) + { + std::string arg = argv[i]; + if(arg == "--help" || arg == "-h" || arg == "-help") + { + print_help(argv[0]); + std::exit(0); + } + } + ck_tile::ArgParser arg_parser; arg_parser.insert("m_dims", "4,256", "M dimensions separated by comma (e.g., '16,32' for 2D M)") .insert("n_dims", "16,128", "N dimensions separated by comma (e.g., '32,32' for 2D N)") .insert("k_dims", "64", "K dimensions separated by comma (e.g., '64,32' for 2D K)") .insert( "g_dims", "1,2", "G dimensions separated by comma (e.g., '4,2' for 2D, '2,3,4' for 3D)") - .insert("stride_a", "0", "Custom A tensor leading dimension stride (0 = auto)") - .insert("stride_b", "0", "Custom B tensor leading dimension stride (0 = auto)") - .insert("stride_e", "0", "Custom E tensor leading dimension stride (0 = auto)") + .insert("num_d", "2", "Number of D (auxiliary input) tensors") + .insert("strides_a", "", "A tensor strides (comma-separated, empty = auto/contiguous)") + .insert("strides_b", "", "B tensor strides (comma-separated, empty = auto/contiguous)") + .insert("strides_e", "", "E tensor strides (comma-separated, empty = auto/contiguous)") + .insert("strides_ds", + "", + "D tensors strides (semicolon-separated for multiple, empty = same as E)") .insert("a_layout", "R", "A tensor data layout - Row by default") .insert("b_layout", "C", "B tensor data layout - Col by default") .insert("e_layout", "R", "E tensor data layout - Row by default") diff --git a/example/ck_tile/41_batched_contraction/run_batched_contraction_example.inc b/example/ck_tile/41_batched_contraction/run_batched_contraction_example.inc index f1a5f8e9ae..214b14633d 100644 --- a/example/ck_tile/41_batched_contraction/run_batched_contraction_example.inc +++ b/example/ck_tile/41_batched_contraction/run_batched_contraction_example.inc @@ -45,10 +45,10 @@ float invoke_batched_contraction_kernel( const void* b_full_dims_dev_buf, const std::array& ds_dev_buf, void* e_full_dims_dev_buf, - const std::vector& G_dims, - const std::vector& M_dims, - const std::vector& N_dims, - const std::vector& K_dims, + ck_tile::index_t num_g_dims, + ck_tile::index_t num_m_dims, + ck_tile::index_t num_n_dims, + ck_tile::index_t num_k_dims, const std::vector& A_dims, // [G0,G1,..,M0,M1,..,K0,K1,..] const std::vector& B_dims, // [G0,G1,..,N0,N1,..,K0,K1,..] const std::array, DsDataType::size()>& @@ -79,9 +79,8 @@ float invoke_batched_contraction_kernel( E_strides // E_strides ); - std::cout << "Calling batched_contraction with dimensions: G=" << G_dims.size() - << ", M=" << M_dims.size() << ", N=" << N_dims.size() << ", K=" << K_dims.size() - << std::endl; + std::cout << "Calling batched_contraction with dimensions: G=" << num_g_dims + << ", M=" << num_m_dims << ", N=" << num_n_dims << ", K=" << num_k_dims << std::endl; float ave_time = batched_contraction( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}, - G_dims.size(), // num_g_dims - M_dims.size(), // num_m_dims - N_dims.size(), // num_n_dims - K_dims.size() // num_k_dims - ); + num_g_dims, + num_m_dims, + num_n_dims, + num_k_dims); return ave_time; } -template +// C++17-compatible helper function to create array of HostTensors +namespace { +template +std::array, NumDTensor> +make_ds_host_tensors_impl(const std::array& descs, + std::index_sequence) +{ + return {ck_tile::HostTensor(descs[Is])...}; +} + +template +std::array, NumDTensor> +make_ds_host_tensors(const std::array& descs) +{ + return make_ds_host_tensors_impl(descs, + std::make_index_sequence{}); +} +} // anonymous namespace + +template int run_batched_contraction_example_with_layouts( int argc, char* argv[], @@ -122,8 +143,6 @@ int run_batched_contraction_example_with_layouts( std::vector N_dims = parse_dimensions(arg_parser.get_str("n_dims")); std::vector K_dims = parse_dimensions(arg_parser.get_str("k_dims")); - constexpr ck_tile::index_t NumDTensor = 2; - ck_tile::index_t G_total = calculate_total_elements(G_dims); ck_tile::index_t M_total = calculate_total_elements(M_dims); ck_tile::index_t N_total = calculate_total_elements(N_dims); @@ -148,13 +167,105 @@ int run_batched_contraction_example_with_layouts( return converted; }; - ck_tile::HostTensorDescriptor a_desc(A_dims); - ck_tile::HostTensorDescriptor b_desc(B_dims); - ck_tile::HostTensorDescriptor e_desc(E_dims); - std::array ds_descs; - for(ck_tile::index_t d = 0; d < NumDTensor; ++d) + // Get custom stride arguments + std::string strides_a_str = arg_parser.get_str("strides_a"); + std::string strides_b_str = arg_parser.get_str("strides_b"); + std::string strides_e_str = arg_parser.get_str("strides_e"); + std::string strides_ds_str = arg_parser.get_str("strides_ds"); + + // Create A descriptor with custom or default strides + ck_tile::HostTensorDescriptor a_desc; + if(!strides_a_str.empty()) { - ds_descs[d] = ck_tile::HostTensorDescriptor(Ds_dims[d], e_desc.get_strides()); + std::vector custom_a_strides = parse_dimensions(strides_a_str); + if(custom_a_strides.size() != A_dims.size()) + { + throw std::runtime_error("strides_a size must match A_dims size"); + } + std::vector a_strides_size_t(custom_a_strides.begin(), custom_a_strides.end()); + a_desc = ck_tile::HostTensorDescriptor(A_dims, a_strides_size_t); + std::cout << "Using custom strides for A (non-contiguous)" << std::endl; + } + else + { + a_desc = ck_tile::HostTensorDescriptor(A_dims); + } + + // Create B descriptor with custom or default strides + ck_tile::HostTensorDescriptor b_desc; + if(!strides_b_str.empty()) + { + std::vector custom_b_strides = parse_dimensions(strides_b_str); + if(custom_b_strides.size() != B_dims.size()) + { + throw std::runtime_error("strides_b size must match B_dims size"); + } + std::vector b_strides_size_t(custom_b_strides.begin(), custom_b_strides.end()); + b_desc = ck_tile::HostTensorDescriptor(B_dims, b_strides_size_t); + std::cout << "Using custom strides for B (non-contiguous)" << std::endl; + } + else + { + b_desc = ck_tile::HostTensorDescriptor(B_dims); + } + + // Create E descriptor with custom or default strides + ck_tile::HostTensorDescriptor e_desc; + if(!strides_e_str.empty()) + { + std::vector custom_e_strides = parse_dimensions(strides_e_str); + if(custom_e_strides.size() != E_dims.size()) + { + throw std::runtime_error("strides_e size must match E_dims size"); + } + std::vector e_strides_size_t(custom_e_strides.begin(), custom_e_strides.end()); + e_desc = ck_tile::HostTensorDescriptor(E_dims, e_strides_size_t); + std::cout << "Using custom strides for E (non-contiguous)" << std::endl; + } + else + { + e_desc = ck_tile::HostTensorDescriptor(E_dims); + } + // Create D descriptors with custom or default strides (default = same as E) + std::array ds_descs; + if(!strides_ds_str.empty()) + { + // Parse semicolon-separated stride vectors for multiple D tensors + std::vector> all_ds_strides; + std::stringstream ss(strides_ds_str); + std::string d_stride_str; + + while(std::getline(ss, d_stride_str, ';')) + { + all_ds_strides.push_back(parse_dimensions(d_stride_str)); + } + + if(all_ds_strides.size() != NumDTensor) + { + throw std::runtime_error("Number of D stride vectors must match num_d=" + + std::to_string(NumDTensor)); + } + + std::cout << "Using custom strides for D tensors (non-contiguous)" << std::endl; + for(ck_tile::index_t d = 0; d < NumDTensor; ++d) + { + if(all_ds_strides[d].size() != E_dims.size()) + { + throw std::runtime_error("D tensor " + std::to_string(d) + + " stride size must match E_dims size"); + } + std::vector d_strides_size_t(all_ds_strides[d].begin(), + all_ds_strides[d].end()); + ds_descs[d] = ck_tile::HostTensorDescriptor(Ds_dims[d], d_strides_size_t); + } + } + else + { + // Default: use same strides as E + for(ck_tile::index_t d = 0; d < NumDTensor; ++d) + { + ds_descs[d] = ck_tile::HostTensorDescriptor(Ds_dims[d], e_desc.get_strides()); + } } std::vector A_strides = convert_strides(a_desc.get_strides()); @@ -201,11 +312,8 @@ int run_batched_contraction_example_with_layouts( ck_tile::HostTensor<::BDataType> b_full_dims_host(b_desc); ck_tile::HostTensor<::EDataType> e_full_dims_host(e_desc); - std::vector> ds_full_dims_host; - for(int d = 0; d < NumDTensor; ++d) - { - ds_full_dims_host.emplace_back(ck_tile::HostTensor<::DDataType>(ds_descs[d])); - } + // Construct array of HostTensors - C++17 compatible + auto ds_full_dims_host = make_ds_host_tensors<::DDataType, NumDTensor>(ds_descs); ck_tile::FillUniformDistribution<::ADataType>{-5.f, 5.f, std::nullopt}(a_full_dims_host); ck_tile::FillUniformDistribution<::BDataType>{-5.f, 5.f, std::nullopt}(b_full_dims_host); @@ -260,10 +368,10 @@ int run_batched_contraction_example_with_layouts( b_full_dims_dev_buf.GetDeviceBuffer(), ds_ptr_buf, e_full_dims_dev_buf.GetDeviceBuffer(), - G_dims, - M_dims, - N_dims, - K_dims, + G_dims.size(), + M_dims.size(), + N_dims.size(), + K_dims.size(), A_dims, B_dims, Ds_dims, @@ -316,12 +424,13 @@ int run_batched_contraction_example_with_layouts( auto start_time = std::chrono::high_resolution_clock::now(); - ck_tile::calculate_reference_flat_indexing(a_full_dims_host, + ck_tile::compute_reference_batched_contraction(a_full_dims_host, b_full_dims_host, ds_full_dims_host, e_full_dims_host_ref, @@ -329,7 +438,11 @@ int run_batched_contraction_example_with_layouts( M_total, N_total, K_total, - CDEElementWise{}); + CDEElementWise{}, + G_dims, + M_dims, + N_dims, + K_dims); auto end_time = std::chrono::high_resolution_clock::now(); auto duration = @@ -387,15 +500,45 @@ int run_batched_contraction_example(int argc, char* argv[]) if(!result) return -1; + // Get NumDTensor to dispatch at runtime + const int num_d = arg_parser.get_int("num_d"); + using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; std::string a_layout = arg_parser.get_str("a_layout"); std::string b_layout = arg_parser.get_str("b_layout"); + // Runtime dispatch based on num_d value if(a_layout == "R" && b_layout == "C") { - return run_batched_contraction_example_with_layouts(argc, argv, Row{}, Col{}, Row{}, Row{}); + // Dispatch to appropriate template instantiation based on runtime num_d + switch(num_d) + { + case 0: + std::cout << "Running with 0 D tensors" << std::endl; + return run_batched_contraction_example_with_layouts( + argc, argv, Row{}, Col{}, Row{}, Row{}); + case 1: + std::cout << "Running with 1 D tensor" << std::endl; + return run_batched_contraction_example_with_layouts( + argc, argv, Row{}, Col{}, Row{}, Row{}); + case 2: + std::cout << "Running with 2 D tensors" << std::endl; + return run_batched_contraction_example_with_layouts( + argc, argv, Row{}, Col{}, Row{}, Row{}); + case 3: + std::cout << "Running with 3 D tensors" << std::endl; + return run_batched_contraction_example_with_layouts( + argc, argv, Row{}, Col{}, Row{}, Row{}); + case 4: + std::cout << "Running with 4 D tensors" << std::endl; + return run_batched_contraction_example_with_layouts( + argc, argv, Row{}, Col{}, Row{}, Row{}); + default: + throw std::runtime_error("num_d must be between 0 and 4, got: " + + std::to_string(num_d)); + } } else { diff --git a/include/ck_tile/host/reference/reference_batched_contraction.hpp b/include/ck_tile/host/reference/reference_batched_contraction.hpp index d0ff358e89..cc42d77d43 100644 --- a/include/ck_tile/host/reference/reference_batched_contraction.hpp +++ b/include/ck_tile/host/reference/reference_batched_contraction.hpp @@ -4,8 +4,6 @@ #pragma once #include -#include -#include #include #include "ck_tile/core.hpp" @@ -13,110 +11,259 @@ namespace ck_tile { +// Helper to apply elementwise operation with variable number of D tensors +template +struct ApplyCDEElementWise +{ + template + CK_TILE_HOST_DEVICE static void apply(EDataType& result, + AccDataType sum, + const CDEElementWise& cde_elementwise, + DValues... d_vals) + { + if constexpr(sizeof...(DValues) == 0) + { + result = static_cast(sum); + } + else + { + cde_elementwise( + result, ck_tile::type_convert(sum), ck_tile::type_convert(d_vals)...); + } + } +}; + +// Helper to extract D values at a given offset using index sequence +template > +struct ExtractDValues; + +template +struct ExtractDValues> +{ + template + CK_TILE_HOST static void + apply_at_offsets(EDataType& result, + AccDataType sum, + const CDEElementWise& cde_elementwise, + const std::array, NumDTensor>& ds_tensors, + const std::array& d_offsets) + { + ApplyCDEElementWise::apply( + result, sum, cde_elementwise, ds_tensors[Is].mData[d_offsets[Is]]...); + } +}; + template + typename CDEElementWise, + ck_tile::index_t NumDTensor> -void calculate_reference_flat_indexing( +void compute_reference_batched_contraction( const ck_tile::HostTensor& a_full_dims, const ck_tile::HostTensor& b_full_dims, - const std::vector>& ds_full_dims_host, + const std::array, NumDTensor>& ds_full_dims_host, ck_tile::HostTensor& e_full_dims_host_ref, ck_tile::index_t G_total, ck_tile::index_t M_total, ck_tile::index_t N_total, ck_tile::index_t K_total, - const CDEElementWise& cde_elementwise) + const CDEElementWise& cde_elementwise, + const std::vector& G_dims, + const std::vector& M_dims, + const std::vector& N_dims, + const std::vector& K_dims) { - std::cout << "Calculating reference using optimized flat indexing with parallel processing..." + std::cout << "Calculating reference using stride-aware indexing with parallel processing..." << std::endl; - // Parallel computation over G and M dimensions using pattern from reference_batched_gemm.hpp + // Extract stride information from tensor descriptors + const auto a_strides = a_full_dims.get_strides(); + const auto b_strides = b_full_dims.get_strides(); + const auto e_strides = e_full_dims_host_ref.get_strides(); + + // Extract D tensor strides + std::array, NumDTensor> ds_strides; + for(ck_tile::index_t d = 0; d < NumDTensor; ++d) + { + ds_strides[d] = ds_full_dims_host[d].get_strides(); + } + + const ck_tile::index_t num_g_dims = G_dims.size(); + const ck_tile::index_t num_m_dims = M_dims.size(); + const ck_tile::index_t num_n_dims = N_dims.size(); + const ck_tile::index_t num_k_dims = K_dims.size(); + + // Helper lambda to compute linear index from flat indices using strides + auto compute_a_offset = [&](ck_tile::index_t g_flat, + ck_tile::index_t m_flat, + ck_tile::index_t k_flat) -> std::size_t { + std::size_t offset = 0; + + // Decode G dimensions + ck_tile::index_t temp = g_flat; + for(int i = num_g_dims - 1; i >= 0; --i) + { + offset += (temp % G_dims[i]) * a_strides[i]; + temp /= G_dims[i]; + } + + // Decode M dimensions + temp = m_flat; + for(int i = num_m_dims - 1; i >= 0; --i) + { + offset += (temp % M_dims[i]) * a_strides[num_g_dims + i]; + temp /= M_dims[i]; + } + + // Decode K dimensions + temp = k_flat; + for(int i = num_k_dims - 1; i >= 0; --i) + { + offset += (temp % K_dims[i]) * a_strides[num_g_dims + num_m_dims + i]; + temp /= K_dims[i]; + } + + return offset; + }; + + auto compute_b_offset = [&](ck_tile::index_t g_flat, + ck_tile::index_t n_flat, + ck_tile::index_t k_flat) -> std::size_t { + std::size_t offset = 0; + + // Decode G dimensions + ck_tile::index_t temp = g_flat; + for(int i = num_g_dims - 1; i >= 0; --i) + { + offset += (temp % G_dims[i]) * b_strides[i]; + temp /= G_dims[i]; + } + + // Decode N dimensions + temp = n_flat; + for(int i = num_n_dims - 1; i >= 0; --i) + { + offset += (temp % N_dims[i]) * b_strides[num_g_dims + i]; + temp /= N_dims[i]; + } + + // Decode K dimensions + temp = k_flat; + for(int i = num_k_dims - 1; i >= 0; --i) + { + offset += (temp % K_dims[i]) * b_strides[num_g_dims + num_n_dims + i]; + temp /= K_dims[i]; + } + + return offset; + }; + + auto compute_e_offset = [&](ck_tile::index_t g_flat, + ck_tile::index_t m_flat, + ck_tile::index_t n_flat) -> std::size_t { + std::size_t offset = 0; + + // Decode G dimensions + ck_tile::index_t temp = g_flat; + for(int i = num_g_dims - 1; i >= 0; --i) + { + offset += (temp % G_dims[i]) * e_strides[i]; + temp /= G_dims[i]; + } + + // Decode M dimensions + temp = m_flat; + for(int i = num_m_dims - 1; i >= 0; --i) + { + offset += (temp % M_dims[i]) * e_strides[num_g_dims + i]; + temp /= M_dims[i]; + } + + // Decode N dimensions + temp = n_flat; + for(int i = num_n_dims - 1; i >= 0; --i) + { + offset += (temp % N_dims[i]) * e_strides[num_g_dims + num_m_dims + i]; + temp /= N_dims[i]; + } + + return offset; + }; + + // Helper to compute D tensor offset (D tensors have same shape as E: [G, M, N]) + auto compute_d_offset = [&](ck_tile::index_t g_flat, + ck_tile::index_t m_flat, + ck_tile::index_t n_flat, + ck_tile::index_t d_idx) -> std::size_t { + std::size_t offset = 0; + const auto& d_strides = ds_strides[d_idx]; + + // Decode G dimensions + ck_tile::index_t temp = g_flat; + for(int i = num_g_dims - 1; i >= 0; --i) + { + offset += (temp % G_dims[i]) * d_strides[i]; + temp /= G_dims[i]; + } + + // Decode M dimensions + temp = m_flat; + for(int i = num_m_dims - 1; i >= 0; --i) + { + offset += (temp % M_dims[i]) * d_strides[num_g_dims + i]; + temp /= M_dims[i]; + } + + // Decode N dimensions + temp = n_flat; + for(int i = num_n_dims - 1; i >= 0; --i) + { + offset += (temp % N_dims[i]) * d_strides[num_g_dims + num_m_dims + i]; + temp /= N_dims[i]; + } + + return offset; + }; + + // Parallel computation over G and M dimensions auto f_gm = [&](auto g_flat, auto m_flat) { for(ck_tile::index_t n_flat = 0; n_flat < N_total; ++n_flat) { AccDataType sum = 0; - // Compute dot product over K dimension + // Compute dot product over K dimension using stride-aware indexing for(ck_tile::index_t k_flat = 0; k_flat < K_total; ++k_flat) { - auto a_val = - a_full_dims.mData[g_flat * M_total * K_total + m_flat * K_total + k_flat]; - auto b_val = - b_full_dims.mData[g_flat * N_total * K_total + n_flat * K_total + k_flat]; + const std::size_t a_offset = compute_a_offset(g_flat, m_flat, k_flat); + const std::size_t b_offset = compute_b_offset(g_flat, n_flat, k_flat); + + auto a_val = a_full_dims.mData[a_offset]; + auto b_val = b_full_dims.mData[b_offset]; sum += static_cast(a_val) * static_cast(b_val); } - // Apply elementwise operation with D tensors - EDataType result = static_cast(sum); - if(ds_full_dims_host.size() == 0) + // Compute output offset using strides + const std::size_t e_offset = compute_e_offset(g_flat, m_flat, n_flat); + + // Compute individual D tensor offsets using their respective strides + std::array d_offsets; + for(ck_tile::index_t d = 0; d < NumDTensor; ++d) { - ; - } - else if(ds_full_dims_host.size() == 1) - { - cde_elementwise(result, - ck_tile::type_convert(sum), - ck_tile::type_convert( - ds_full_dims_host[0].mData[g_flat * M_total * N_total + - m_flat * N_total + n_flat])); - } - else if(ds_full_dims_host.size() == 2) - { - cde_elementwise( - result, - ck_tile::type_convert(sum), - ck_tile::type_convert( - ds_full_dims_host[0] - .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]), - ck_tile::type_convert( - ds_full_dims_host[1] - .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat])); - } - else if(ds_full_dims_host.size() == 3) - { - cde_elementwise( - result, - ck_tile::type_convert(sum), - ck_tile::type_convert( - ds_full_dims_host[0] - .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]), - ck_tile::type_convert( - ds_full_dims_host[1] - .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]), - ck_tile::type_convert( - ds_full_dims_host[2] - .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat])); - } - else if(ds_full_dims_host.size() == 4) - { - cde_elementwise( - result, - ck_tile::type_convert(sum), - ck_tile::type_convert( - ds_full_dims_host[0] - .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]), - ck_tile::type_convert( - ds_full_dims_host[1] - .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]), - ck_tile::type_convert( - ds_full_dims_host[2] - .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]), - ck_tile::type_convert( - ds_full_dims_host[3] - .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat])); - } - else - { - throw std::runtime_error("Unsupported NumDTensor for reference calculation"); + d_offsets[d] = compute_d_offset(g_flat, m_flat, n_flat, d); } - // Store result - e_full_dims_host_ref.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat] = - static_cast(result); + // Apply elementwise operation with D tensors using compile-time dispatch + EDataType result = static_cast(sum); + ExtractDValues::apply_at_offsets( + result, sum, cde_elementwise, ds_full_dims_host, d_offsets); + + // Store result using stride-aware indexing + e_full_dims_host_ref.mData[e_offset] = static_cast(result); } }; @@ -125,147 +272,4 @@ void calculate_reference_flat_indexing( make_ParallelTensorFunctor(f_gm, G_total, M_total)(std::thread::hardware_concurrency()); } -template -void calculate_reference_multi_dimensional( - const HostTensor& a_full_dims, - const HostTensor& b_full_dims, - const std::vector>& ds_full_dims_host, - HostTensor& e_full_dims_host_ref, - const std::vector& G_dims, - const std::vector& M_dims, - const std::vector& N_dims, - const std::vector& K_dims, - const std::vector& A_dims, - const std::vector& B_dims, - const std::vector& E_dims, - const CDEElementWise& cde_elementwise) -{ - std::cout << "Calculating reference using multi-dimensional indexing..." << std::endl; - - std::vector g_idx(G_dims.size()); - std::vector m_idx(M_dims.size()); - std::vector n_idx(N_dims.size()); - std::vector k_idx(K_dims.size()); - std::vector a_idx, b_idx, e_idx; - - a_idx.reserve(A_dims.size()); - b_idx.reserve(B_dims.size()); - e_idx.reserve(E_dims.size()); - - auto calculate_total_elements = [](const std::vector& dims) { - return std::accumulate(dims.begin(), dims.end(), 1, std::multiplies()); - }; - - for(ck_tile::index_t g_flat = 0; g_flat < calculate_total_elements(G_dims); ++g_flat) - { - ck_tile::index_t temp = g_flat; - for(int i = G_dims.size() - 1; i >= 0; --i) - { - g_idx[i] = temp % G_dims[i]; - temp /= G_dims[i]; - } - - for(ck_tile::index_t m_flat = 0; m_flat < calculate_total_elements(M_dims); ++m_flat) - { - temp = m_flat; - for(int i = M_dims.size() - 1; i >= 0; --i) - { - m_idx[i] = temp % M_dims[i]; - temp /= M_dims[i]; - } - - for(ck_tile::index_t n_flat = 0; n_flat < calculate_total_elements(N_dims); ++n_flat) - { - temp = n_flat; - for(int i = N_dims.size() - 1; i >= 0; --i) - { - n_idx[i] = temp % N_dims[i]; - temp /= N_dims[i]; - } - - AccDataType sum = 0; - - for(ck_tile::index_t k_flat = 0; k_flat < calculate_total_elements(K_dims); - ++k_flat) - { - temp = k_flat; - for(int i = K_dims.size() - 1; i >= 0; --i) - { - k_idx[i] = temp % K_dims[i]; - temp /= K_dims[i]; - } - - a_idx.clear(); - b_idx.clear(); - - a_idx.insert(a_idx.end(), g_idx.begin(), g_idx.end()); - a_idx.insert(a_idx.end(), m_idx.begin(), m_idx.end()); - a_idx.insert(a_idx.end(), k_idx.begin(), k_idx.end()); - - b_idx.insert(b_idx.end(), g_idx.begin(), g_idx.end()); - b_idx.insert(b_idx.end(), n_idx.begin(), n_idx.end()); - b_idx.insert(b_idx.end(), k_idx.begin(), k_idx.end()); - - auto a_val = a_full_dims(a_idx); - auto b_val = b_full_dims(b_idx); - - sum += static_cast(a_val) * static_cast(b_val); - } - - e_idx.clear(); - e_idx.insert(e_idx.end(), g_idx.begin(), g_idx.end()); - e_idx.insert(e_idx.end(), m_idx.begin(), m_idx.end()); - e_idx.insert(e_idx.end(), n_idx.begin(), n_idx.end()); - - EDataType result = static_cast(sum); - if(ds_full_dims_host.size() == 0) - { - ; - } - else if(ds_full_dims_host.size() == 1) - { - cde_elementwise(result, - ck_tile::type_convert(sum), - ck_tile::type_convert(ds_full_dims_host[0](e_idx))); - } - else if(ds_full_dims_host.size() == 2) - { - cde_elementwise(result, - ck_tile::type_convert(sum), - ck_tile::type_convert(ds_full_dims_host[0](e_idx)), - ck_tile::type_convert(ds_full_dims_host[1](e_idx))); - } - else if(ds_full_dims_host.size() == 3) - { - cde_elementwise(result, - ck_tile::type_convert(sum), - ck_tile::type_convert(ds_full_dims_host[0](e_idx)), - ck_tile::type_convert(ds_full_dims_host[1](e_idx)), - ck_tile::type_convert(ds_full_dims_host[2](e_idx))); - } - else if(ds_full_dims_host.size() == 4) - { - cde_elementwise(result, - ck_tile::type_convert(sum), - ck_tile::type_convert(ds_full_dims_host[0](e_idx)), - ck_tile::type_convert(ds_full_dims_host[1](e_idx)), - ck_tile::type_convert(ds_full_dims_host[2](e_idx)), - ck_tile::type_convert(ds_full_dims_host[3](e_idx))); - } - else - { - throw std::runtime_error("Unsupported NumDTensor for reference calculation"); - } - - e_full_dims_host_ref(e_idx) = static_cast(result); - } - } - } -} - } // namespace ck_tile diff --git a/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp b/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp index 50d9af113f..968d5d6ac2 100644 --- a/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp +++ b/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp" +#include "ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp" #include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp" /** @@ -60,23 +61,19 @@ * Rather than implementing tensor contraction from scratch, this kernel leverages the highly * optimized `UniversalGemmKernel` as its computational backend. * - * @subsection current_limitations Current Kernel Limitations + * @subsection implementation_features Implementation Features * - * **Layout Restrictions:** - * - **Row-Major Only**: All tensors must use row-major memory layout - * - **Packed Tensors**: Only contiguous/packed tensor layouts supported - * - **Hardcoded Strides**: stride_A = K_total, stride_B = K_total, stride_E = N_total - * - **D Tensor Layout**: All D tensors must match E tensor layout (stride_Ds = N_total) + * **Stride Support:** + * - Supports arbitrary multi-dimensional stride patterns + * - Handles non-contiguous and padded tensor layouts + * - Independent strides for each auxiliary D tensor + * - Descriptor-based architecture with vectorization * - * **Implementation Constraints:** - * - **Fixed Stride Calculation**: Strides are automatically calculated and cannot be customized - * - **No Column-Major**: Column-major or custom stride patterns not supported - * - **No Strided Access**: Non-contiguous tensor slicing not supported - * - * **Future Enhancements:** - * - Support for arbitrary stride patterns - * - Column-major and mixed layout support - * - Non-contiguous tensor operation support + * **Architecture:** + * - Uses TensorDescriptorUtils for stride-aware descriptor creation + * - Custom RunGemm implementation with descriptor-based tensor views + * - Reuses GemmPipeline and EpiloguePipeline for computation + * - Split-K support via UniversalGemmKernel utilities */ namespace ck_tile { @@ -184,7 +181,10 @@ template + ck_tile::index_t NumDTensor = 0, + ck_tile::index_t VectorSizeA = 1, + ck_tile::index_t VectorSizeB = 1, + ck_tile::index_t VectorSizeE = 1> struct BatchedContractionKernelArgs { const void* a_ptr; ///< Pointer to input tensor A @@ -210,11 +210,46 @@ struct BatchedContractionKernelArgs ck_tile::index_t N_total; ///< Total N dimension: N0 * N1 * ... * N_{NumDimN-1} ck_tile::index_t K_total; ///< Total K dimension: K0 * K1 * ... * K_{NumDimK-1} - ck_tile::index_t stride_A; ///< Leading dimension stride for tensor A (row-major: K_total) - ck_tile::index_t stride_B; ///< Leading dimension stride for tensor B (row-major: K_total) + ck_tile::index_t + stride_A; ///< Leading dimension stride for tensor A (for backward compatibility) + ck_tile::index_t + stride_B; ///< Leading dimension stride for tensor B (for backward compatibility) std::array - stride_Ds; ///< Leading dimension strides for D tensors (row-major: N_total) - ck_tile::index_t stride_E; ///< Leading dimension stride for tensor E (row-major: N_total) + stride_Ds; ///< Leading dimension strides for D tensors (for backward compatibility) + ck_tile::index_t + stride_E; ///< Leading dimension stride for tensor E (for backward compatibility) + + // Tensor descriptors (encode full multi-dimensional stride information with vectorization) + using AGridDesc_M_K_ = + decltype(TensorDescriptorUtils::Make_A_GridDescriptor_M_K({}, {})); + using BGridDesc_N_K_ = + decltype(TensorDescriptorUtils::Make_B_GridDescriptor_N_K({}, {})); + using EGridDesc_M_N_ = + decltype(TensorDescriptorUtils::Make_E_GridDescriptor_M_N({}, {})); + + AGridDesc_M_K_ a_grid_desc_m_k; ///< Tensor descriptor for A[M, K] with actual strides + BGridDesc_N_K_ b_grid_desc_n_k; ///< Tensor descriptor for B[N, K] with actual strides + EGridDesc_M_N_ e_grid_desc_m_n; ///< Tensor descriptor for E[M, N] with actual strides + std::array + ds_grid_desc_m_n; ///< Descriptors for D tensors (same shape as E, independent strides) }; /// @brief GPU kernel for batched tensor contraction operations. @@ -274,10 +309,24 @@ struct BatchedContractionKernel static constexpr ck_tile::index_t kBlockSize = UniversalGemmKernel::kBlockSize; ///< GPU block size inherited from GEMM kernel - using KernelArgs = - BatchedContractionKernelArgs; ///< Kernel - ///< argument - ///< structure + // Tensor descriptor utilities with vectorization support + using DescriptorUtils = TensorDescriptorUtils; + + // Kernel arguments with vectorization support + using KernelArgs = BatchedContractionKernelArgs; /// @brief Returns the kernel name for debugging and profiling purposes. /// @return Constant string identifier for this kernel @@ -326,6 +375,104 @@ struct BatchedContractionKernel TilePartitioner::GridSize(kargs.M_total, kargs.N_total), kargs.G_total, kargs.k_batch); } + /// @brief Executes GEMM computation with descriptor-based tensor views for arbitrary stride + /// support + /// + /// @details This function performs the core GEMM computation using tensor descriptors to handle + /// arbitrary multi-dimensional stride patterns. It creates tensor views from + /// pre-computed descriptors (stored in kargs), applies padding, creates tile windows, + /// and executes the GemmPipeline and EpiloguePipeline. + /// + /// @param a_ptr Pointer to input tensor A data (after batch and split-K offsets applied) + /// @param b_ptr Pointer to input tensor B data (after batch and split-K offsets applied) + /// @param ds_ptr Array of pointers to auxiliary D tensor data + /// @param e_ptr Pointer to output tensor E data (after batch offset applied) + /// @param smem_ptr Pointer to shared memory for tile operations + /// @param kargs Kernel arguments containing tensor descriptors and dimension information + /// @param k_size Size of K dimension for this split (for split-K support) + /// @param i_m Starting M index for this block's tile + /// @param i_n Starting N index for this block's tile + CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr, + const BDataType* b_ptr, + const std::array& ds_ptr, + EDataType* e_ptr, + void* smem_ptr, + const KernelArgs& kargs, + const index_t k_size, + const index_t i_m, + const index_t i_n) + { + // Create tensor views from descriptors (supports arbitrary stride patterns) + auto a_tensor_view = + make_tensor_view(a_ptr, kargs.a_grid_desc_m_k); + auto b_tensor_view = + make_tensor_view(b_ptr, kargs.b_grid_desc_n_k); + auto e_tensor_view = + make_tensor_view(e_ptr, kargs.e_grid_desc_m_n); + + // Pad views for boundary handling and optimization (like UniversalGemmKernel) + auto a_pad_view = pad_tensor_view( + a_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + + auto b_pad_view = pad_tensor_view( + b_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + + auto e_pad_view = pad_tensor_view( + e_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + + // Create tile windows from PADDED views + auto a_block_window = make_tile_window( + a_pad_view, + make_tuple(number{}, number{}), + {i_m, 0}); + + auto b_block_window = make_tile_window( + b_pad_view, + make_tuple(number{}, number{}), + {i_n, 0}); + + auto e_block_window = make_tile_window( + e_pad_view, + make_tuple(number{}, number{}), + {i_m, i_n}); + + // Calculate number of K loops + const index_t num_loop = + __builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(k_size)); + + // Run GEMM Pipeline (same as UniversalGemmKernel, but with descriptor-based windows) + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; + + const auto& c_block_tile = GemmPipeline{}( + a_block_window, AElementWise{}, b_block_window, BElementWise{}, num_loop, smem_ptr); + + // Create D windows from descriptors (for each D tensor) + auto ds_block_windows = generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + const DDataType* d_ptr = static_cast(ds_ptr[i]); + + auto d_tensor_view = + make_tensor_view(d_ptr, kargs.ds_grid_desc_m_n[i]); + + return make_tile_window(d_tensor_view, + make_tuple(number{}, + number{}), + {i_m, i_n}); + }, + number{}); + + // Run Epilogue Pipeline with descriptor-based D windows + EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_windows, smem_ptr); + } + CK_TILE_HOST static constexpr KernelArgs MakeKernelArgs(const BatchedContractionHostArgs& host_args) { @@ -435,6 +582,22 @@ struct BatchedContractionKernel kargs.K_total *= kargs.K_dims[i]; } + // Create tensor descriptors on host using actual dims and strides + kargs.a_grid_desc_m_k = + DescriptorUtils::Make_A_GridDescriptor_M_K(host_args.A_dims, host_args.A_strides); + kargs.b_grid_desc_n_k = + DescriptorUtils::Make_B_GridDescriptor_N_K(host_args.B_dims, host_args.B_strides); + kargs.e_grid_desc_m_n = + DescriptorUtils::Make_E_GridDescriptor_M_N(host_args.E_dims, host_args.E_strides); + + // Create D descriptors with their own strides (same shape as E, independent strides) + for(ck_tile::index_t d = 0; d < NumDTensor; ++d) + { + kargs.ds_grid_desc_m_n[d] = DescriptorUtils::Make_E_GridDescriptor_M_N( + host_args.Ds_dims[d], host_args.Ds_strides[d]); + } + + // Keep simple strides for backward compatibility kargs.stride_A = kargs.K_total; kargs.stride_B = kargs.K_total; kargs.stride_E = kargs.N_total; @@ -468,8 +631,8 @@ struct BatchedContractionKernel const ck_tile::index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); - const auto i_batch_flat = __builtin_amdgcn_readfirstlane(blockIdx.y); - const auto i_splitk = __builtin_amdgcn_readfirstlane(blockIdx.z); + const auto i_batch_flat = __builtin_amdgcn_readfirstlane(blockIdx.y); + [[maybe_unused]] const auto i_splitk = __builtin_amdgcn_readfirstlane(blockIdx.z); // Calculate batch offsets for each tensor const auto batch_offset_A = i_batch_flat * kargs.batch_stride_A; @@ -487,6 +650,10 @@ struct BatchedContractionKernel ds_batch_ptr[i] = static_cast(kargs.ds_ptr[i]) + batch_offset_D; }); + // Allocate shared memory + __shared__ char smem_ptr[GetSmemSize()]; + + // Use UniversalGemmKernel's SplitKBatchOffset for split-K calculation typename UniversalGemmKernel::KernelArgs gemm_kargs{{a_ptr}, {b_ptr}, ds_batch_ptr, @@ -503,19 +670,19 @@ struct BatchedContractionKernel const typename UniversalGemmKernel::SplitKBatchOffset splitk_batch_offset(gemm_kargs, i_splitk); - const ADataType* a_ptr_final = a_ptr + splitk_batch_offset.as_k_split_offset[0]; - const BDataType* b_ptr_final = b_ptr + splitk_batch_offset.bs_k_split_offset[0]; - __shared__ char smem_ptr[GetSmemSize()]; + // Apply K-split offsets and run descriptor-based RunGemm + const ADataType* a_ptr_split = a_ptr + splitk_batch_offset.as_k_split_offset[0]; + const BDataType* b_ptr_split = b_ptr + splitk_batch_offset.bs_k_split_offset[0]; - UniversalGemmKernel::RunGemm({a_ptr_final}, - {b_ptr_final}, - ds_batch_ptr, - e_ptr, - smem_ptr, - gemm_kargs, - splitk_batch_offset, - i_m, - i_n); + RunGemm(a_ptr_split, + b_ptr_split, + ds_batch_ptr, + e_ptr, + smem_ptr, + kargs, + splitk_batch_offset.splitted_k, + i_m, + i_n); } }; diff --git a/include/ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp b/include/ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp index e712bf03e4..4767a430ac 100644 --- a/include/ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp +++ b/include/ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp @@ -13,8 +13,8 @@ * dimensions for GEMM operations. These functions transform multi-dimensional tensors into * 2D matrix descriptors by removing batch dimensions and flattening the remaining dimensions. * - * These utilities are currently not used in the main batched contraction kernel but are preserved - * for future implementations that may require explicit tensor descriptor creation. + * These utilities are used by BatchedContractionKernel to create stride-aware descriptors + * that support arbitrary multi-dimensional non-contiguous tensor layouts. */ namespace ck_tile { @@ -30,7 +30,10 @@ namespace ck_tile { template + ck_tile::index_t NumDimK, + ck_tile::index_t VectorSizeA, + ck_tile::index_t VectorSizeB, + ck_tile::index_t VectorSizeE> struct TensorDescriptorUtils { /// @brief Creates a tensor descriptor for input tensor A with batch dimensions removed. @@ -62,9 +65,9 @@ struct TensorDescriptorUtils const auto dims_M = get_container_subset(A_dims_M_K, A_dims_M_ids); const auto dims_K = get_container_subset(A_dims_M_K, A_dims_K_ids); - // naive tensor A[M0, M1, M2, ..., K0, K1, K2...] Discriptor - const auto A_grid_desc_Ms_Ks = - ck_tile::make_naive_tensor_descriptor(A_dims_M_K, A_strides_M_K); + // naive tensor A[M0, M1, M2, ..., K0, K1, K2...] Descriptor with vector size + const auto A_grid_desc_Ms_Ks = ck_tile::make_naive_tensor_descriptor( + A_dims_M_K, A_strides_M_K, number{}, number<1>{}); // transformed tensor to flatten M and K dimensions [M_total = M0 * M1 * M2 * ... , K_total // = K0 * K1 * K2 * ...] @@ -106,9 +109,9 @@ struct TensorDescriptorUtils const auto dims_N = get_container_subset(B_dims_N_K, B_dims_N_ids); const auto dims_K = get_container_subset(B_dims_N_K, B_dims_K_ids); - // naive tensor B[N0, N1, N2, ..., K0, K1, K2...] Discriptor - const auto B_grid_desc_Ns_Ks = - ck_tile::make_naive_tensor_descriptor(B_dims_N_K, B_strides_N_K); + // naive tensor B[N0, N1, N2, ..., K0, K1, K2...] Descriptor with vector size + const auto B_grid_desc_Ns_Ks = ck_tile::make_naive_tensor_descriptor( + B_dims_N_K, B_strides_N_K, number{}, number<1>{}); // transformed tensor to flatten N and K dimensions [N_total = N0 * N1 * N2 * ... , K_total // = K0 * K1 * K2 * ...] @@ -150,9 +153,9 @@ struct TensorDescriptorUtils const auto dims_M = get_container_subset(E_dims_M_N, E_dims_M_ids); const auto dims_N = get_container_subset(E_dims_M_N, E_dims_N_ids); - // naive tensor E[M0, M1, M2, ..., N0, N1, N2...] Discriptor - const auto E_grid_desc_Ms_Ns = - ck_tile::make_naive_tensor_descriptor(E_dims_M_N, E_strides_M_N); + // naive tensor E[M0, M1, M2, ..., N0, N1, N2...] Descriptor with vector size + const auto E_grid_desc_Ms_Ns = ck_tile::make_naive_tensor_descriptor( + E_dims_M_N, E_strides_M_N, number{}, number<1>{}); // transformed tensor to flatten M and N dimensions [M_total = M0 * M1 * M2 * ... , // N_total = N0 * N1 * N2 * ...]