mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK Tile] contraction multi d - kernel & example (#2901)
* Initial commit. create batched_contraction_kernel file
* initial problem definition
* implement initial example to launch kernel
* add universal gemm to contraction. initial phase
* complete implementation for special case all Dims are 1 and no Ds
* clean code
* initial changes to support multi dimensional G
* more progress in implementing multiple G
* tmp commit
* manage dynamic NumDimG in kernel
* improving example for multi M,N,K,G handling. start generalizing kernel. it is a temporary commit
* implement the example for general Multi dimension G M N K and test different reference calculation algorithms
* 2 functions for reference using multi dimensional and flat indexing
* clean the code for muti dimentional G, M, N, K contraction and add some logs
* Add Make descriptor function in kernel for merging Ms, Ns, Ks for A, B, E
* some cleaning on kernel
* clean the code for calculating the offsets from flatten batch number
* Start adding MultiD support to kernel and example
* more changes to manage multi D in kernel and example
* manage passing multi d to kernel and testing.
* complete multi D support in kernel. modify example code to support it
* Correct algorithm to calc the correct offset values for D tensor batches and some code cleaning
* Minor fix
* Generalize example code for variable NumD tensors and apply cleanup based on review feedback
* Refactored code and addressed review feedback
* refactoring, cleaning, add documents, in kernel side and example codes
* Optimize batch offset calculation in kernel
* Inline CalculateBatchOffset in batched contraction kernel, update CHANGELOG.md
---------
Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
[ROCm/composable_kernel commit: e9f0cc83a8]
This commit is contained in:
@@ -36,6 +36,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
* Added the row-wise column-wise quantization for CK_TILE GEMM & CK_TILE Grouped GEMM.
|
||||
* Added support for f32 to FMHA (fwd/bwd).
|
||||
* Added tensor-wise quantization for CK_TILE GEMM.
|
||||
* Added support for batched contraction kernel.
|
||||
* Added pooling kernel in CK_TILE
|
||||
|
||||
### Optimized
|
||||
|
||||
7
example/ck_tile/41_batched_contraction/CMakeLists.txt
Normal file
7
example/ck_tile/41_batched_contraction/CMakeLists.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
add_executable(tile_example_batched_contraction EXCLUDE_FROM_ALL batched_contraction.cpp)
|
||||
set(EXAMPLE_CONTRACTION_COMPILE_OPTIONS)
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_CONTRACTION_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
target_compile_options(tile_example_batched_contraction PRIVATE ${EXAMPLE_CONTRACTION_COMPILE_OPTIONS})
|
||||
245
example/ck_tile/41_batched_contraction/batched_contraction.cpp
Normal file
245
example/ck_tile/41_batched_contraction/batched_contraction.cpp
Normal file
@@ -0,0 +1,245 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
|
||||
#include "ck_tile/ops/batched_contraction.hpp"
|
||||
#include "contraction_utils.hpp"
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename EDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ck_tile::index_t NumDimG,
|
||||
ck_tile::index_t NumDimM,
|
||||
ck_tile::index_t NumDimN,
|
||||
ck_tile::index_t NumDimK,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
|
||||
float batched_contraction_impl(const ck_tile::BatchedContractionHostArgs<DsDataType::size()>& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr ck_tile::index_t M_Tile = 256;
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
constexpr bool kPadK = false;
|
||||
|
||||
constexpr bool TransposeC = false;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
using GemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, ELayout>;
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
TransposeC>;
|
||||
|
||||
using Problem = ck_tile::BatchedContractionProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
NumDimG, // NumDimG
|
||||
NumDimM, // NumDimM
|
||||
NumDimN, // NumDimN
|
||||
NumDimK, // NumDimK
|
||||
DsDataType::size() // NumDTensor
|
||||
>;
|
||||
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE<GemmPipelineProblem>;
|
||||
|
||||
ck_tile::index_t K_total = 1;
|
||||
for(ck_tile::index_t i = NumDimG + NumDimM; i < NumDimG + NumDimM + NumDimK; ++i)
|
||||
{
|
||||
K_total *= args.A_dims[i];
|
||||
}
|
||||
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_total);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
|
||||
constexpr auto memory_operation =
|
||||
ck_tile::memory_operation_enum::set; // Always set (no atomic_add)
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
|
||||
using Kernel =
|
||||
ck_tile::BatchedContractionKernel<Problem, TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
const dim3 blocks = Kernel::GetBlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArguments(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping contraction!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetKernelName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << GemmPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
auto kernel = ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs);
|
||||
|
||||
ave_time = ck_tile::launch_kernel(s, kernel);
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
#define HANDLE_CASE(G, M, N, K) \
|
||||
if(num_g_dims == G && num_m_dims == M && num_n_dims == N && num_k_dims == K) \
|
||||
{ \
|
||||
return batched_contraction_impl<ADataType, \
|
||||
BDataType, \
|
||||
DsDataType, \
|
||||
AccDataType, \
|
||||
EDataType, \
|
||||
ALayout, \
|
||||
BLayout, \
|
||||
DsLayout, \
|
||||
ELayout, \
|
||||
G, \
|
||||
M, \
|
||||
N, \
|
||||
K, \
|
||||
CDEElementWise>(args, s); \
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename EDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float batched_contraction(const ck_tile::BatchedContractionHostArgs<DsDataType::size()>& args,
|
||||
const ck_tile::stream_config& s,
|
||||
ck_tile::index_t num_g_dims,
|
||||
ck_tile::index_t num_m_dims,
|
||||
ck_tile::index_t num_n_dims,
|
||||
ck_tile::index_t num_k_dims)
|
||||
{
|
||||
std::cout << "Dimensions: G=" << num_g_dims << ", M=" << num_m_dims << ", N=" << num_n_dims
|
||||
<< ", K=" << num_k_dims << std::endl;
|
||||
|
||||
HANDLE_CASE(1, 1, 1, 1);
|
||||
HANDLE_CASE(2, 1, 1, 1);
|
||||
HANDLE_CASE(2, 2, 2, 1);
|
||||
HANDLE_CASE(1, 2, 1, 1);
|
||||
HANDLE_CASE(1, 1, 1, 2);
|
||||
HANDLE_CASE(2, 2, 2, 2);
|
||||
HANDLE_CASE(4, 4, 4, 4);
|
||||
|
||||
throw std::runtime_error(
|
||||
"Unsupported dimension combination: G=" + std::to_string(num_g_dims) +
|
||||
", M=" + std::to_string(num_m_dims) + ", N=" + std::to_string(num_n_dims) +
|
||||
", K=" + std::to_string(num_k_dims) + ". Please add this combination to the kernel.");
|
||||
}
|
||||
|
||||
#include "run_batched_contraction_example.inc"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
try
|
||||
{
|
||||
return !run_batched_contraction_example(argc, argv);
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Runtime error: " << e.what() << '\n';
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
146
example/ck_tile/41_batched_contraction/contraction_utils.hpp
Normal file
146
example/ck_tile/41_batched_contraction/contraction_utils.hpp
Normal file
@@ -0,0 +1,146 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
|
||||
struct AddDs
|
||||
{
|
||||
template <typename E, typename C, typename... Ds>
|
||||
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void
|
||||
{
|
||||
const float x0_f =
|
||||
ck_tile::type_convert<float>(c) + (ck_tile::type_convert<float>(ds) + ...);
|
||||
|
||||
e = ck_tile::type_convert<E>(x0_f);
|
||||
}
|
||||
};
|
||||
|
||||
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3
|
||||
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3
|
||||
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
|
||||
|
||||
template <typename DataType>
|
||||
struct BatchedContractionTypeConfig
|
||||
{
|
||||
using ADataType = DataType;
|
||||
using BDataType = DataType;
|
||||
using AccDataType = float;
|
||||
using EDataType = DataType;
|
||||
using DDataType = DataType;
|
||||
};
|
||||
|
||||
using ContractionTypes = BatchedContractionTypeConfig<ck_tile::half_t>;
|
||||
|
||||
using ADataType = ContractionTypes::ADataType;
|
||||
using BDataType = ContractionTypes::BDataType;
|
||||
using AccDataType = ContractionTypes::AccDataType;
|
||||
using EDataType = ContractionTypes::EDataType;
|
||||
using DDataType = ContractionTypes::DDataType;
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m_dims", "4,256", "M dimensions separated by comma (e.g., '16,32' for 2D M)")
|
||||
.insert("n_dims", "16,128", "N dimensions separated by comma (e.g., '32,32' for 2D N)")
|
||||
.insert("k_dims", "64", "K dimensions separated by comma (e.g., '64,32' for 2D K)")
|
||||
.insert(
|
||||
"g_dims", "1,2", "G dimensions separated by comma (e.g., '4,2' for 2D, '2,3,4' for 3D)")
|
||||
.insert("stride_a", "0", "Custom A tensor leading dimension stride (0 = auto)")
|
||||
.insert("stride_b", "0", "Custom B tensor leading dimension stride (0 = auto)")
|
||||
.insert("stride_e", "0", "Custom E tensor leading dimension stride (0 = auto)")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
.insert("b_layout", "C", "B tensor data layout - Col by default")
|
||||
.insert("e_layout", "R", "E tensor data layout - Row by default")
|
||||
.insert("v", "1", "0. No validation, 1. Validation on CPU")
|
||||
.insert("prec", "fp16", "data type. fp32/fp16/bf16")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "10", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("log", "1", "log level for debugging");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
// Helper function to parse G, M, N, K dimensions from string
|
||||
std::vector<ck_tile::index_t> parse_dimensions(const std::string& dims_str)
|
||||
{
|
||||
std::vector<ck_tile::index_t> dims;
|
||||
std::stringstream ss(dims_str);
|
||||
std::string token;
|
||||
|
||||
while(std::getline(ss, token, ','))
|
||||
{
|
||||
dims.push_back(std::stoi(token));
|
||||
}
|
||||
|
||||
if(dims.empty())
|
||||
{
|
||||
throw std::invalid_argument("Dimensions cannot be empty");
|
||||
}
|
||||
|
||||
return dims;
|
||||
}
|
||||
|
||||
// Helper function to Calculate total elements from multi-dimensional vector
|
||||
ck_tile::index_t calculate_total_elements(const std::vector<ck_tile::index_t>& dims)
|
||||
{
|
||||
ck_tile::index_t total = 1;
|
||||
for(auto dim : dims)
|
||||
{
|
||||
total *= dim;
|
||||
}
|
||||
return total;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Flattens a list of tensor dimension components into a single dimension vector.
|
||||
*
|
||||
* This function takes a list of dimension vectors (e.g., representing different components
|
||||
* such as G, M, N, or K dimensions) and concatenates them into a single vector.
|
||||
*
|
||||
* Example:
|
||||
* Input: {{G0, G1}, {M0, M1}, {K0}}
|
||||
* Output: {G0, G1, M0, M1, K0}
|
||||
*
|
||||
* @param dim_components A vector of vectors, where each inner vector represents a set of tensor
|
||||
* dimensions.
|
||||
* @return A single vector containing all dimensions concatenated in order.
|
||||
*/
|
||||
std::vector<ck_tile::index_t>
|
||||
concatenate_dim_components(const std::vector<std::vector<ck_tile::index_t>>& dim_components)
|
||||
{
|
||||
std::vector<ck_tile::index_t> result;
|
||||
|
||||
// Concatenate all dimension components into a single vector
|
||||
for(const auto& component : dim_components)
|
||||
{
|
||||
result.insert(result.end(), component.begin(), component.end());
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Helper function for printing dimensions
|
||||
void print_dims(const std::string& name,
|
||||
const std::vector<ck_tile::index_t>& dims,
|
||||
ck_tile::index_t total)
|
||||
{
|
||||
std::cout << name << ": [";
|
||||
for(size_t i = 0; i < dims.size(); ++i)
|
||||
{
|
||||
std::cout << dims[i];
|
||||
if(i < dims.size() - 1)
|
||||
std::cout << ",";
|
||||
}
|
||||
std::cout << "] ";
|
||||
if(total != 0)
|
||||
std::cout << "(total=" << total << ")";
|
||||
std::cout << std::endl;
|
||||
}
|
||||
@@ -0,0 +1,405 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include "contraction_utils.hpp"
|
||||
#include "ck_tile/host/reference/reference_batched_contraction.hpp"
|
||||
|
||||
template <typename ADataType, typename BDataType, typename EDataType, typename AccDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, EDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, EDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<EDataType, EDataType, EDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<EDataType, EDataType, EDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename EDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_batched_contraction_kernel(
|
||||
const void* a_full_dims_dev_buf,
|
||||
const void* b_full_dims_dev_buf,
|
||||
const std::array<const void*, DsDataType::size()>& ds_dev_buf,
|
||||
void* e_full_dims_dev_buf,
|
||||
const std::vector<ck_tile::index_t>& G_dims,
|
||||
const std::vector<ck_tile::index_t>& M_dims,
|
||||
const std::vector<ck_tile::index_t>& N_dims,
|
||||
const std::vector<ck_tile::index_t>& K_dims,
|
||||
const std::vector<ck_tile::index_t>& A_dims, // [G0,G1,..,M0,M1,..,K0,K1,..]
|
||||
const std::vector<ck_tile::index_t>& B_dims, // [G0,G1,..,N0,N1,..,K0,K1,..]
|
||||
const std::array<std::vector<ck_tile::index_t>, DsDataType::size()>&
|
||||
Ds_dims, // [G0, G1, ..., M0, M1, ... , N0, N1, ...][NumDTensor]
|
||||
const std::vector<ck_tile::index_t>& E_dims, // [G0,G1,..,M0,M1,..,N0,N1,..]
|
||||
const std::vector<ck_tile::index_t>& A_strides, // [G0,G1,..,M0,M1,..,K0,K1,..]
|
||||
const std::vector<ck_tile::index_t>& B_strides, // [G0,G1,..,N0,N1,..,K0,K1,..]
|
||||
const std::array<std::vector<ck_tile::index_t>, DsDataType::size()>& Ds_strides,
|
||||
const std::vector<ck_tile::index_t>& E_strides, // [G0,G1,..,M0,M1,..,N0,N1,..]
|
||||
ck_tile::index_t kbatch,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
std::cout << "Creating BatchedContractionHostArgs..." << std::endl;
|
||||
|
||||
ck_tile::BatchedContractionHostArgs<DsDataType::size()> args(a_full_dims_dev_buf, // a_ptr
|
||||
b_full_dims_dev_buf, // b_ptr
|
||||
ds_dev_buf, // ds_ptr
|
||||
e_full_dims_dev_buf, // e_ptr
|
||||
kbatch, // k_batch
|
||||
A_dims, // A_dims
|
||||
B_dims, // B_dims
|
||||
Ds_dims, // Ds_dims
|
||||
E_dims, // E_dims
|
||||
A_strides, // A_strides
|
||||
B_strides, // B_strides
|
||||
Ds_strides, // Ds_strides
|
||||
E_strides // E_strides
|
||||
);
|
||||
|
||||
std::cout << "Calling batched_contraction with dimensions: G=" << G_dims.size()
|
||||
<< ", M=" << M_dims.size() << ", N=" << N_dims.size() << ", K=" << K_dims.size()
|
||||
<< std::endl;
|
||||
|
||||
float ave_time = batched_contraction<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise>(
|
||||
args,
|
||||
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat},
|
||||
G_dims.size(), // num_g_dims
|
||||
M_dims.size(), // num_m_dims
|
||||
N_dims.size(), // num_n_dims
|
||||
K_dims.size() // num_k_dims
|
||||
);
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename DLayout, typename ELayout>
|
||||
int run_batched_contraction_example_with_layouts(
|
||||
int argc,
|
||||
char* argv[],
|
||||
[[maybe_unused]] const ALayout a_layout = ALayout{},
|
||||
[[maybe_unused]] const BLayout b_layout = BLayout{},
|
||||
[[maybe_unused]] const DLayout d_layout = DLayout{},
|
||||
[[maybe_unused]] const ELayout e_layout = ELayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
std::vector<ck_tile::index_t> G_dims = parse_dimensions(arg_parser.get_str("g_dims"));
|
||||
std::vector<ck_tile::index_t> M_dims = parse_dimensions(arg_parser.get_str("m_dims"));
|
||||
std::vector<ck_tile::index_t> N_dims = parse_dimensions(arg_parser.get_str("n_dims"));
|
||||
std::vector<ck_tile::index_t> K_dims = parse_dimensions(arg_parser.get_str("k_dims"));
|
||||
|
||||
constexpr ck_tile::index_t NumDTensor = 2;
|
||||
|
||||
ck_tile::index_t G_total = calculate_total_elements(G_dims);
|
||||
ck_tile::index_t M_total = calculate_total_elements(M_dims);
|
||||
ck_tile::index_t N_total = calculate_total_elements(N_dims);
|
||||
ck_tile::index_t K_total = calculate_total_elements(K_dims);
|
||||
|
||||
std::vector<ck_tile::index_t> A_dims =
|
||||
concatenate_dim_components({G_dims, M_dims, K_dims}); // [G0,G1,..,M0,M1,..,K0,K1,..]
|
||||
std::vector<ck_tile::index_t> B_dims =
|
||||
concatenate_dim_components({G_dims, N_dims, K_dims}); // [G0,G1,..,N0,N1,..,K0,K1,..]
|
||||
std::vector<ck_tile::index_t> E_dims =
|
||||
concatenate_dim_components({G_dims, M_dims, N_dims}); // [G0,G1,..,M0,M1,..,N0,N1,..]
|
||||
|
||||
std::array<std::vector<ck_tile::index_t>, NumDTensor> Ds_dims;
|
||||
for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
|
||||
{
|
||||
Ds_dims[d] = E_dims;
|
||||
}
|
||||
|
||||
auto convert_strides = [](const std::vector<std::size_t>& strides) {
|
||||
std::vector<ck_tile::index_t> converted(strides.size());
|
||||
std::copy(strides.begin(), strides.end(), converted.begin());
|
||||
return converted;
|
||||
};
|
||||
|
||||
ck_tile::HostTensorDescriptor a_desc(A_dims);
|
||||
ck_tile::HostTensorDescriptor b_desc(B_dims);
|
||||
ck_tile::HostTensorDescriptor e_desc(E_dims);
|
||||
std::array<ck_tile::HostTensorDescriptor, NumDTensor> ds_descs;
|
||||
for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
|
||||
{
|
||||
ds_descs[d] = ck_tile::HostTensorDescriptor(Ds_dims[d], e_desc.get_strides());
|
||||
}
|
||||
|
||||
std::vector<ck_tile::index_t> A_strides = convert_strides(a_desc.get_strides());
|
||||
std::vector<ck_tile::index_t> B_strides = convert_strides(b_desc.get_strides());
|
||||
std::vector<ck_tile::index_t> E_strides = convert_strides(e_desc.get_strides());
|
||||
|
||||
std::array<std::vector<ck_tile::index_t>, NumDTensor> Ds_strides;
|
||||
for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
|
||||
{
|
||||
Ds_strides[d] = convert_strides(ds_descs[d].get_strides());
|
||||
}
|
||||
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
|
||||
print_dims("G_dims", G_dims, G_total);
|
||||
print_dims("M_dims", M_dims, M_total);
|
||||
print_dims("N_dims", N_dims, N_total);
|
||||
print_dims("K_dims", K_dims, K_total);
|
||||
|
||||
std::cout << "NumDTensor: " << NumDTensor << std::endl;
|
||||
std::cout << "\n=== Tensor Shapes for Kernel ===" << std::endl;
|
||||
print_dims("A_dims", A_dims, 0);
|
||||
print_dims("B_dims", B_dims, 0);
|
||||
print_dims("E_dims", E_dims, 0);
|
||||
for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
|
||||
{
|
||||
print_dims("Ds[" + std::to_string(d) + "]_dims", Ds_dims[d], 0);
|
||||
}
|
||||
|
||||
std::cout << "\n=== Tensor Strides ===" << std::endl;
|
||||
print_dims("A_strides", A_strides, 0);
|
||||
print_dims("B_strides", B_strides, 0);
|
||||
print_dims("E_strides", E_strides, 0);
|
||||
for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
|
||||
{
|
||||
print_dims("Ds[" + std::to_string(d) + "]_strides", Ds_strides[d], 0);
|
||||
}
|
||||
|
||||
std::cout << "===============================================\n" << std::endl;
|
||||
|
||||
ck_tile::HostTensor<::ADataType> a_full_dims_host(a_desc);
|
||||
ck_tile::HostTensor<::BDataType> b_full_dims_host(b_desc);
|
||||
ck_tile::HostTensor<::EDataType> e_full_dims_host(e_desc);
|
||||
|
||||
std::vector<ck_tile::HostTensor<::DDataType>> ds_full_dims_host;
|
||||
for(int d = 0; d < NumDTensor; ++d)
|
||||
{
|
||||
ds_full_dims_host.emplace_back(ck_tile::HostTensor<::DDataType>(ds_descs[d]));
|
||||
}
|
||||
|
||||
ck_tile::FillUniformDistribution<::ADataType>{-5.f, 5.f, std::nullopt}(a_full_dims_host);
|
||||
ck_tile::FillUniformDistribution<::BDataType>{-5.f, 5.f, std::nullopt}(b_full_dims_host);
|
||||
|
||||
ck_tile::DeviceMem a_full_dims_dev_buf(a_full_dims_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_full_dims_dev_buf(b_full_dims_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem e_full_dims_dev_buf(e_full_dims_host.get_element_space_size_in_bytes());
|
||||
|
||||
a_full_dims_dev_buf.ToDevice(a_full_dims_host.data());
|
||||
b_full_dims_dev_buf.ToDevice(b_full_dims_host.data());
|
||||
|
||||
for(int d = 0; d < NumDTensor; ++d)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<::DDataType>{-2.f, 2.f, std::nullopt}(
|
||||
ds_full_dims_host[d]);
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> ds_full_dims_dev_buf;
|
||||
for(int d = 0; d < NumDTensor; ++d)
|
||||
{
|
||||
ds_full_dims_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
ds_full_dims_host[d].get_element_space_size_in_bytes()));
|
||||
ds_full_dims_dev_buf[d]->ToDevice(ds_full_dims_host[d].data());
|
||||
}
|
||||
std::array<const void*, NumDTensor> ds_ptr_buf;
|
||||
for(int d = 0; d < NumDTensor; ++d)
|
||||
{
|
||||
ds_ptr_buf[d] = ds_full_dims_dev_buf[d]->GetDeviceBuffer();
|
||||
}
|
||||
|
||||
e_full_dims_dev_buf.SetZero();
|
||||
e_full_dims_host.SetZero();
|
||||
|
||||
std::cout << "\n=== Running GPU Kernel ===" << std::endl;
|
||||
|
||||
using DsDataType = ck_tile::tuple_array<::DDataType, NumDTensor>;
|
||||
using DsLayout = ck_tile::tuple_array<DLayout, NumDTensor>;
|
||||
using CDEElementWise =
|
||||
std::conditional_t<NumDTensor == 0, ck_tile::element_wise::PassThrough, AddDs>;
|
||||
|
||||
float ave_time =
|
||||
invoke_batched_contraction_kernel<::ADataType,
|
||||
::BDataType,
|
||||
DsDataType,
|
||||
::AccDataType,
|
||||
::EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise>(a_full_dims_dev_buf.GetDeviceBuffer(),
|
||||
b_full_dims_dev_buf.GetDeviceBuffer(),
|
||||
ds_ptr_buf,
|
||||
e_full_dims_dev_buf.GetDeviceBuffer(),
|
||||
G_dims,
|
||||
M_dims,
|
||||
N_dims,
|
||||
K_dims,
|
||||
A_dims,
|
||||
B_dims,
|
||||
Ds_dims,
|
||||
E_dims,
|
||||
A_strides,
|
||||
B_strides,
|
||||
Ds_strides,
|
||||
E_strides,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
|
||||
std::string op_name{
|
||||
"Multi-Dimensional Batched Contraction : G: " + std::to_string(G_dims.size()) +
|
||||
"D, M: " + std::to_string(M_dims.size()) + "D, N: " + std::to_string(N_dims.size()) +
|
||||
"D, K: " + std::to_string(K_dims.size()) + "D"};
|
||||
|
||||
std::size_t flop = std::size_t(2) * G_total * M_total * N_total * K_total +
|
||||
NumDTensor * K_total * M_total * N_total; // Number of operations
|
||||
std::size_t num_byte =
|
||||
sizeof(::ADataType) * G_total * M_total * K_total + // A tensor size
|
||||
sizeof(::BDataType) * G_total * N_total * K_total + // B tensor size
|
||||
sizeof(::DDataType) * NumDTensor * G_total * M_total * N_total + // D tensors
|
||||
sizeof(::EDataType) * G_total * M_total * N_total; // E tensor size
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; // TFlops calculation
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time; // GB/s calculation
|
||||
print_dims("G_dims", G_dims, G_total);
|
||||
print_dims("M_dims", M_dims, M_total);
|
||||
print_dims("N_dims", N_dims, N_total);
|
||||
print_dims("K_dims", K_dims, K_total);
|
||||
|
||||
std::cout << " Performance: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s" << std::endl;
|
||||
|
||||
std::cout << "===============================================" << std::endl;
|
||||
|
||||
e_full_dims_dev_buf.FromDevice(e_full_dims_host.data());
|
||||
std::cout << "GPU results retrieved from device." << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
{
|
||||
|
||||
std::cout << "Computing CPU reference..." << std::endl;
|
||||
|
||||
ck_tile::HostTensor<::EDataType> e_full_dims_host_ref(
|
||||
ck_tile::HostTensorDescriptor(E_dims, E_strides));
|
||||
e_full_dims_host_ref.SetZero();
|
||||
|
||||
auto start_time = std::chrono::high_resolution_clock::now();
|
||||
|
||||
calculate_reference_flat_indexing<ADataType,
|
||||
BDataType,
|
||||
DDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
CDEElementWise>(a_full_dims_host,
|
||||
b_full_dims_host,
|
||||
ds_full_dims_host,
|
||||
e_full_dims_host_ref,
|
||||
G_total,
|
||||
M_total,
|
||||
N_total,
|
||||
K_total,
|
||||
CDEElementWise{});
|
||||
|
||||
auto end_time = std::chrono::high_resolution_clock::now();
|
||||
auto duration =
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time);
|
||||
|
||||
std::cout << "CPU reference completed in " << duration.count() << "ms" << std::endl;
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(e_full_dims_host_ref.mData.begin(), e_full_dims_host_ref.mData.end());
|
||||
|
||||
const auto rtol_atol =
|
||||
calculate_rtol_atol<::ADataType, ::BDataType, ::EDataType, ::AccDataType>(
|
||||
K_total, kbatch, max_accumulated_value);
|
||||
|
||||
pass = ck_tile::check_err(e_full_dims_host,
|
||||
e_full_dims_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
std::cout << "The CPU verification result is: " << (pass ? "correct" : "fail") << std::endl;
|
||||
|
||||
std::cout << "===============================================" << std::endl;
|
||||
|
||||
std::cout << "\n=== Random Samples of Reference and Result ===" << std::endl;
|
||||
|
||||
// Generate 10 random indices
|
||||
std::vector<std::size_t> random_indices;
|
||||
std::size_t total_elements = e_full_dims_host_ref.mData.size();
|
||||
std::mt19937 rng(std::random_device{}());
|
||||
std::uniform_int_distribution<std::size_t> dist(0, total_elements - 1);
|
||||
|
||||
for(int i = 0; i < 10; ++i)
|
||||
{
|
||||
random_indices.push_back(dist(rng));
|
||||
}
|
||||
|
||||
// Print the values at the random indices
|
||||
for(std::size_t idx : random_indices)
|
||||
{
|
||||
std::cout << "Index " << idx << ": "
|
||||
<< "ref=" << static_cast<float>(e_full_dims_host_ref.mData[idx]) << ", "
|
||||
<< "GPU=" << static_cast<float>(e_full_dims_host.mData[idx]) << std::endl;
|
||||
}
|
||||
|
||||
std::cout << "===============================================" << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int run_batched_contraction_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_batched_contraction_example_with_layouts(argc, argv, Row{}, Col{}, Row{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and E tensors! "
|
||||
"Only R-C-R supported for now.");
|
||||
}
|
||||
}
|
||||
@@ -27,3 +27,4 @@ add_subdirectory(36_pooling)
|
||||
add_subdirectory(38_block_scale_gemm)
|
||||
add_subdirectory(39_copy)
|
||||
add_subdirectory(40_streamk_gemm)
|
||||
add_subdirectory(41_batched_contraction)
|
||||
|
||||
265
include/ck_tile/host/reference/reference_batched_contraction.hpp
Normal file
265
include/ck_tile/host/reference/reference_batched_contraction.hpp
Normal file
@@ -0,0 +1,265 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <thread>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename DDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename CDEElementWise>
|
||||
|
||||
void calculate_reference_flat_indexing(
|
||||
const ck_tile::HostTensor<ADataType>& a_full_dims,
|
||||
const ck_tile::HostTensor<BDataType>& b_full_dims,
|
||||
const std::vector<ck_tile::HostTensor<DDataType>>& ds_full_dims_host,
|
||||
ck_tile::HostTensor<EDataType>& e_full_dims_host_ref,
|
||||
ck_tile::index_t G_total,
|
||||
ck_tile::index_t M_total,
|
||||
ck_tile::index_t N_total,
|
||||
ck_tile::index_t K_total,
|
||||
const CDEElementWise& cde_elementwise)
|
||||
{
|
||||
std::cout << "Calculating reference using optimized flat indexing with parallel processing..."
|
||||
<< std::endl;
|
||||
|
||||
// Parallel computation over G and M dimensions using pattern from reference_batched_gemm.hpp
|
||||
auto f_gm = [&](auto g_flat, auto m_flat) {
|
||||
for(ck_tile::index_t n_flat = 0; n_flat < N_total; ++n_flat)
|
||||
{
|
||||
AccDataType sum = 0;
|
||||
|
||||
// Compute dot product over K dimension
|
||||
for(ck_tile::index_t k_flat = 0; k_flat < K_total; ++k_flat)
|
||||
{
|
||||
auto a_val =
|
||||
a_full_dims.mData[g_flat * M_total * K_total + m_flat * K_total + k_flat];
|
||||
auto b_val =
|
||||
b_full_dims.mData[g_flat * N_total * K_total + n_flat * K_total + k_flat];
|
||||
sum += static_cast<AccDataType>(a_val) * static_cast<AccDataType>(b_val);
|
||||
}
|
||||
|
||||
// Apply elementwise operation with D tensors
|
||||
EDataType result = static_cast<EDataType>(sum);
|
||||
if(ds_full_dims_host.size() == 0)
|
||||
{
|
||||
;
|
||||
}
|
||||
else if(ds_full_dims_host.size() == 1)
|
||||
{
|
||||
cde_elementwise(result,
|
||||
ck_tile::type_convert<float>(sum),
|
||||
ck_tile::type_convert<float>(
|
||||
ds_full_dims_host[0].mData[g_flat * M_total * N_total +
|
||||
m_flat * N_total + n_flat]));
|
||||
}
|
||||
else if(ds_full_dims_host.size() == 2)
|
||||
{
|
||||
cde_elementwise(
|
||||
result,
|
||||
ck_tile::type_convert<float>(sum),
|
||||
ck_tile::type_convert<float>(
|
||||
ds_full_dims_host[0]
|
||||
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
|
||||
ck_tile::type_convert<float>(
|
||||
ds_full_dims_host[1]
|
||||
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
|
||||
}
|
||||
else if(ds_full_dims_host.size() == 3)
|
||||
{
|
||||
cde_elementwise(
|
||||
result,
|
||||
ck_tile::type_convert<float>(sum),
|
||||
ck_tile::type_convert<float>(
|
||||
ds_full_dims_host[0]
|
||||
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
|
||||
ck_tile::type_convert<float>(
|
||||
ds_full_dims_host[1]
|
||||
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
|
||||
ck_tile::type_convert<float>(
|
||||
ds_full_dims_host[2]
|
||||
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
|
||||
}
|
||||
else if(ds_full_dims_host.size() == 4)
|
||||
{
|
||||
cde_elementwise(
|
||||
result,
|
||||
ck_tile::type_convert<float>(sum),
|
||||
ck_tile::type_convert<float>(
|
||||
ds_full_dims_host[0]
|
||||
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
|
||||
ck_tile::type_convert<float>(
|
||||
ds_full_dims_host[1]
|
||||
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
|
||||
ck_tile::type_convert<float>(
|
||||
ds_full_dims_host[2]
|
||||
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
|
||||
ck_tile::type_convert<float>(
|
||||
ds_full_dims_host[3]
|
||||
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported NumDTensor for reference calculation");
|
||||
}
|
||||
|
||||
// Store result
|
||||
e_full_dims_host_ref.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat] =
|
||||
static_cast<EDataType>(result);
|
||||
}
|
||||
};
|
||||
|
||||
// Execute parallel computation using hardware concurrency
|
||||
// Parallelize over G_total and M_total dimensions for optimal CPU utilization
|
||||
make_ParallelTensorFunctor(f_gm, G_total, M_total)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename DDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename CDEElementWise>
|
||||
void calculate_reference_multi_dimensional(
|
||||
const HostTensor<ADataType>& a_full_dims,
|
||||
const HostTensor<BDataType>& b_full_dims,
|
||||
const std::vector<HostTensor<DDataType>>& ds_full_dims_host,
|
||||
HostTensor<EDataType>& e_full_dims_host_ref,
|
||||
const std::vector<index_t>& G_dims,
|
||||
const std::vector<index_t>& M_dims,
|
||||
const std::vector<index_t>& N_dims,
|
||||
const std::vector<index_t>& K_dims,
|
||||
const std::vector<index_t>& A_dims,
|
||||
const std::vector<index_t>& B_dims,
|
||||
const std::vector<index_t>& E_dims,
|
||||
const CDEElementWise& cde_elementwise)
|
||||
{
|
||||
std::cout << "Calculating reference using multi-dimensional indexing..." << std::endl;
|
||||
|
||||
std::vector<std::size_t> g_idx(G_dims.size());
|
||||
std::vector<std::size_t> m_idx(M_dims.size());
|
||||
std::vector<std::size_t> n_idx(N_dims.size());
|
||||
std::vector<std::size_t> k_idx(K_dims.size());
|
||||
std::vector<std::size_t> a_idx, b_idx, e_idx;
|
||||
|
||||
a_idx.reserve(A_dims.size());
|
||||
b_idx.reserve(B_dims.size());
|
||||
e_idx.reserve(E_dims.size());
|
||||
|
||||
for(ck_tile::index_t g_flat = 0; g_flat < calculate_total_elements(G_dims); ++g_flat)
|
||||
{
|
||||
ck_tile::index_t temp = g_flat;
|
||||
for(int i = G_dims.size() - 1; i >= 0; --i)
|
||||
{
|
||||
g_idx[i] = temp % G_dims[i];
|
||||
temp /= G_dims[i];
|
||||
}
|
||||
|
||||
for(ck_tile::index_t m_flat = 0; m_flat < calculate_total_elements(M_dims); ++m_flat)
|
||||
{
|
||||
temp = m_flat;
|
||||
for(int i = M_dims.size() - 1; i >= 0; --i)
|
||||
{
|
||||
m_idx[i] = temp % M_dims[i];
|
||||
temp /= M_dims[i];
|
||||
}
|
||||
|
||||
for(ck_tile::index_t n_flat = 0; n_flat < calculate_total_elements(N_dims); ++n_flat)
|
||||
{
|
||||
temp = n_flat;
|
||||
for(int i = N_dims.size() - 1; i >= 0; --i)
|
||||
{
|
||||
n_idx[i] = temp % N_dims[i];
|
||||
temp /= N_dims[i];
|
||||
}
|
||||
|
||||
AccDataType sum = 0;
|
||||
|
||||
for(ck_tile::index_t k_flat = 0; k_flat < calculate_total_elements(K_dims);
|
||||
++k_flat)
|
||||
{
|
||||
temp = k_flat;
|
||||
for(int i = K_dims.size() - 1; i >= 0; --i)
|
||||
{
|
||||
k_idx[i] = temp % K_dims[i];
|
||||
temp /= K_dims[i];
|
||||
}
|
||||
|
||||
a_idx.clear();
|
||||
b_idx.clear();
|
||||
|
||||
a_idx.insert(a_idx.end(), g_idx.begin(), g_idx.end());
|
||||
a_idx.insert(a_idx.end(), m_idx.begin(), m_idx.end());
|
||||
a_idx.insert(a_idx.end(), k_idx.begin(), k_idx.end());
|
||||
|
||||
b_idx.insert(b_idx.end(), g_idx.begin(), g_idx.end());
|
||||
b_idx.insert(b_idx.end(), n_idx.begin(), n_idx.end());
|
||||
b_idx.insert(b_idx.end(), k_idx.begin(), k_idx.end());
|
||||
|
||||
auto a_val = a_full_dims(a_idx);
|
||||
auto b_val = b_full_dims(b_idx);
|
||||
|
||||
sum += static_cast<AccDataType>(a_val) * static_cast<AccDataType>(b_val);
|
||||
}
|
||||
|
||||
e_idx.clear();
|
||||
e_idx.insert(e_idx.end(), g_idx.begin(), g_idx.end());
|
||||
e_idx.insert(e_idx.end(), m_idx.begin(), m_idx.end());
|
||||
e_idx.insert(e_idx.end(), n_idx.begin(), n_idx.end());
|
||||
|
||||
EDataType result = static_cast<EDataType>(sum);
|
||||
if(ds_full_dims_host.size() == 0)
|
||||
{
|
||||
;
|
||||
}
|
||||
else if(ds_full_dims_host.size() == 1)
|
||||
{
|
||||
cde_elementwise(result,
|
||||
ck_tile::type_convert<float>(sum),
|
||||
ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)));
|
||||
}
|
||||
else if(ds_full_dims_host.size() == 2)
|
||||
{
|
||||
cde_elementwise(result,
|
||||
ck_tile::type_convert<float>(sum),
|
||||
ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)),
|
||||
ck_tile::type_convert<float>(ds_full_dims_host[1](e_idx)));
|
||||
}
|
||||
else if(ds_full_dims_host.size() == 3)
|
||||
{
|
||||
cde_elementwise(result,
|
||||
ck_tile::type_convert<float>(sum),
|
||||
ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)),
|
||||
ck_tile::type_convert<float>(ds_full_dims_host[1](e_idx)),
|
||||
ck_tile::type_convert<float>(ds_full_dims_host[2](e_idx)));
|
||||
}
|
||||
else if(ds_full_dims_host.size() == 4)
|
||||
{
|
||||
cde_elementwise(result,
|
||||
ck_tile::type_convert<float>(sum),
|
||||
ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)),
|
||||
ck_tile::type_convert<float>(ds_full_dims_host[1](e_idx)),
|
||||
ck_tile::type_convert<float>(ds_full_dims_host[2](e_idx)),
|
||||
ck_tile::type_convert<float>(ds_full_dims_host[3](e_idx)));
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported NumDTensor for reference calculation");
|
||||
}
|
||||
|
||||
e_full_dims_host_ref(e_idx) = static_cast<EDataType>(result);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
9
include/ck_tile/ops/batched_contraction.hpp
Normal file
9
include/ck_tile/ops/batched_contraction.hpp
Normal file
@@ -0,0 +1,9 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp"
|
||||
#include "ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
@@ -0,0 +1,522 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp"
|
||||
|
||||
/**
|
||||
* @file batched_contraction_kernel.hpp
|
||||
* @brief Batched Tensor Contraction Operations
|
||||
*
|
||||
* @section batched_contraction_overview What is Batched Tensor Contraction with Multiple D?
|
||||
*
|
||||
* Tensor contraction is a fundamental operation that generalizes matrix multiplication to
|
||||
* multi-dimensional tensors. It performs element-wise multiplication and summation over
|
||||
* shared dimensions
|
||||
*
|
||||
* **Beyond pure contraction, this kernel supports multiple auxiliary input tensors (D tensors)**
|
||||
* that are fused with the contraction result through configurable epilogue operations, enabling
|
||||
* efficient computation of complex tensor expressions in a single kernel launch.
|
||||
*
|
||||
* @subsection mathematical_formulation Mathematical Formulation
|
||||
*
|
||||
* For tensors A and B with arbitrary dimensionalities, the complete operation computes:
|
||||
*
|
||||
* **E[G₀,G₁,...,M₀,M₁,...,N₀,N₁,...] = epilogue_op(C, D₀, D₁, D₂, ...)**
|
||||
*
|
||||
* Where:
|
||||
* **C[G₀,G₁,...,M₀,M₁,...,N₀,N₁,...] = Σ_{K₀,K₁,...} A[G₀,G₁,...,M₀,M₁,...,K₀,K₁,...] ×
|
||||
* B[G₀,G₁,...,N₀,N₁,...,K₀,K₁,...]**
|
||||
*
|
||||
* Where:
|
||||
* - **G dimensions**: Batch dimensions (shared across A, B, and output E)
|
||||
* - **M dimensions**: Row dimensions of the output matrix (from tensor A)
|
||||
* - **N dimensions**: Column dimensions of the output matrix (from tensor B)
|
||||
* - **K dimensions**: Contraction dimensions (summed over, present in both A and B)
|
||||
*
|
||||
* @subsection why_gemm_implementation Why Tensor Contraction Can Be Implemented Using GEMM
|
||||
*
|
||||
* **Mathematical Equivalence**: Tensor contraction is fundamentally equivalent to matrix
|
||||
* multiplication when dimensions are appropriately flattened. The key insight is that the summation
|
||||
* operation over shared dimensions (K dimensions) in tensor contraction is mathematically identical
|
||||
* to the dot product computation in matrix multiplication.
|
||||
*
|
||||
* **Dimension Flattening Strategy**:
|
||||
* - **M dimensions** (from tensor A) → Flattened into matrix rows (M_total)
|
||||
* - **N dimensions** (from tensor B) → Flattened into matrix columns (N_total)
|
||||
* - **K dimensions** (contraction dims) → Flattened into inner dimension (K_total)
|
||||
* - **G dimensions** (batch dims) → Handled through batch processing
|
||||
*
|
||||
* **Mathematical Transformation**:
|
||||
* ```
|
||||
* Original: E[g,m₀,m₁,n₀,n₁] = Σ_{k₀,k₁} A[g,m₀,m₁,k₀,k₁] × B[g,n₀,n₁,k₀,k₁]
|
||||
* Flattened: E[g,M,N] = Σ_K A[g,M,K] × B[g,N,K] (where M=m₀×m₁, N=n₀×n₁, K=k₀×k₁)
|
||||
* GEMM Form: E = A × Bᵀ
|
||||
*
|
||||
* **Why This Approach Is Optimal**:
|
||||
* Rather than implementing tensor contraction from scratch, this kernel leverages the highly
|
||||
* optimized `UniversalGemmKernel` as its computational backend.
|
||||
*
|
||||
* @subsection current_limitations Current Kernel Limitations
|
||||
*
|
||||
* **Layout Restrictions:**
|
||||
* - **Row-Major Only**: All tensors must use row-major memory layout
|
||||
* - **Packed Tensors**: Only contiguous/packed tensor layouts supported
|
||||
* - **Hardcoded Strides**: stride_A = K_total, stride_B = K_total, stride_E = N_total
|
||||
* - **D Tensor Layout**: All D tensors must match E tensor layout (stride_Ds = N_total)
|
||||
*
|
||||
* **Implementation Constraints:**
|
||||
* - **Fixed Stride Calculation**: Strides are automatically calculated and cannot be customized
|
||||
* - **No Column-Major**: Column-major or custom stride patterns not supported
|
||||
* - **No Strided Access**: Non-contiguous tensor slicing not supported
|
||||
*
|
||||
* **Future Enhancements:**
|
||||
* - Support for arbitrary stride patterns
|
||||
* - Column-major and mixed layout support
|
||||
* - Non-contiguous tensor operation support
|
||||
*/
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/// @brief Host arguments for batched tensor contraction operations.
|
||||
///
|
||||
/// @par Overview
|
||||
/// This structure encapsulates all host-side arguments required for batched tensor contraction.
|
||||
/// It supports arbitrary number of batch dimensions (G), M dimensions, N dimensions, and K
|
||||
/// dimensions.
|
||||
///
|
||||
/// @par Tensor Layout Assumptions
|
||||
/// - A tensor: [G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
|
||||
/// - B tensor: [G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
|
||||
/// - D tensors: [G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] (auxiliary input tensors)
|
||||
/// - E tensor: [G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] (output tensor)
|
||||
///
|
||||
/// @tparam NumDTensor Number of D (auxiliary input) tensors. Default is 0.
|
||||
template <ck_tile::index_t NumDTensor = 0>
|
||||
struct BatchedContractionHostArgs
|
||||
{
|
||||
/// @brief Constructor for batched contraction host arguments.
|
||||
///
|
||||
/// @param a_ptr_ Pointer to input tensor A
|
||||
/// @param b_ptr_ Pointer to input tensor B
|
||||
/// @param ds_ptr_ Array of pointers to auxiliary input tensors D
|
||||
/// @param e_ptr_ Pointer to output tensor E
|
||||
/// @param k_batch_ Number of k-splits for split-K batching
|
||||
/// @param A_dims_ Dimension vector for tensor A: [G0, G1, ..., M0, M1, ..., K0, K1, ...]
|
||||
/// @param B_dims_ Dimension vector for tensor B: [G0, G1, ..., N0, N1, ..., K0, K1, ...]
|
||||
/// @param Ds_dims_ Dimension vectors for D tensors: [G0, G1, ..., M0, M1, ..., N0, N1, ...]
|
||||
/// @param E_dims_ Dimension vector for tensor E: [G0, G1, ..., M0, M1, ..., N0, N1, ...]
|
||||
/// @param A_strides_ Stride vector for tensor A: [G0, G1, ..., M0, M1, ..., K0, K1, ...]
|
||||
/// @param B_strides_ Stride vector for tensor B: [G0, G1, ..., N0, N1, ..., K0, K1, ...]
|
||||
/// @param Ds_strides_ Stride vectors for D tensors: [G0, G1, ..., M0, M1, ..., N0, N1, ...]
|
||||
/// @param E_strides_ Stride vector for tensor E: [G0, G1, ..., M0, M1, ..., N0, N1, ...]
|
||||
CK_TILE_HOST
|
||||
BatchedContractionHostArgs(
|
||||
const void* a_ptr_,
|
||||
const void* b_ptr_,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr_,
|
||||
void* e_ptr_,
|
||||
ck_tile::index_t k_batch_,
|
||||
const std::vector<ck_tile::index_t>& A_dims_, // [G0, G1, ..., M0, M1, ... , K0, K1, ...]
|
||||
const std::vector<ck_tile::index_t>& B_dims_, // [G0, G1, ..., N0, N1, ... , K0, K1, ...]
|
||||
const std::array<std::vector<ck_tile::index_t>, NumDTensor>&
|
||||
Ds_dims_, // [G0, G1, ..., M0, M1, ... , N0, N1, ...][NumDTensor]
|
||||
const std::vector<ck_tile::index_t>& E_dims_, // [G0, G1, ..., M0, M1, ... , N0, N1, ...]
|
||||
|
||||
const std::vector<ck_tile::index_t>& A_strides_, // [G0, G1, ..., M0, M1, ...,K0, K1, ...]
|
||||
const std::vector<ck_tile::index_t>& B_strides_, // [G0, G1, ..., N0, N1, ...,K0, K1, ...]
|
||||
const std::array<std::vector<ck_tile::index_t>, NumDTensor>&
|
||||
Ds_strides_, // [G0, G1, ..., M0, M1, ...,N0, N1, ...]
|
||||
const std::vector<ck_tile::index_t>&
|
||||
E_strides_) // [G0, G1, ..., M0, M1, ...,N0, N1, ...][NumDTensor]
|
||||
|
||||
: a_ptr(a_ptr_),
|
||||
b_ptr(b_ptr_),
|
||||
ds_ptr(ds_ptr_),
|
||||
e_ptr(e_ptr_),
|
||||
k_batch(k_batch_),
|
||||
A_dims(A_dims_),
|
||||
B_dims(B_dims_),
|
||||
Ds_dims(Ds_dims_),
|
||||
E_dims(E_dims_),
|
||||
A_strides(A_strides_),
|
||||
B_strides(B_strides_),
|
||||
Ds_strides(Ds_strides_),
|
||||
E_strides(E_strides_)
|
||||
{
|
||||
}
|
||||
|
||||
const void* a_ptr; ///< Pointer to input tensor A
|
||||
const void* b_ptr; ///< Pointer to input tensor B
|
||||
std::array<const void*, NumDTensor> ds_ptr; ///< Array of pointers to auxiliary input tensors D
|
||||
void* e_ptr; ///< Pointer to output tensor E
|
||||
ck_tile::index_t k_batch; ///< Number of k-splits for split-K batching
|
||||
const std::vector<ck_tile::index_t>
|
||||
A_dims; ///< Dimension vector for tensor A: [G0, G1, ..., M0, M1, ..., K0, K1, ...]
|
||||
const std::vector<ck_tile::index_t>
|
||||
B_dims; ///< Dimension vector for tensor B: [G0, G1, ..., N0, N1, ..., K0, K1, ...]
|
||||
const std::array<std::vector<ck_tile::index_t>, NumDTensor>
|
||||
Ds_dims; ///< Dimension vectors for D tensors: [G0, G1, ..., M0, M1, ..., N0, N1, ...]
|
||||
const std::vector<ck_tile::index_t>
|
||||
E_dims; ///< Dimension vector for tensor E: [G0, G1, ..., M0, M1, ..., N0, N1, ...]
|
||||
const std::vector<ck_tile::index_t>
|
||||
A_strides; ///< Stride vector for tensor A: [G0, G1, ..., M0, M1, ..., K0, K1, ...]
|
||||
const std::vector<ck_tile::index_t>
|
||||
B_strides; ///< Stride vector for tensor B: [G0, G1, ..., N0, N1, ..., K0, K1, ...]
|
||||
const std::array<std::vector<ck_tile::index_t>, NumDTensor>
|
||||
Ds_strides; ///< Stride vectors for D tensors: [G0, G1, ..., M0, M1, ..., N0, N1, ...]
|
||||
const std::vector<ck_tile::index_t>
|
||||
E_strides; ///< Stride vector for tensor E: [G0, G1, ..., M0, M1, ..., N0, N1, ...]
|
||||
};
|
||||
|
||||
/// @brief Kernel arguments for batched tensor contraction operations.
|
||||
///
|
||||
/// @tparam NumDimG Number of batch dimensions
|
||||
/// @tparam NumDimM Number of M (output row) dimensions
|
||||
/// @tparam NumDimN Number of N (output column) dimensions
|
||||
/// @tparam NumDimK Number of K (contraction) dimensions
|
||||
/// @tparam NumDTensor Number of auxiliary input D tensors. Default is 0.
|
||||
|
||||
template <ck_tile::index_t NumDimG,
|
||||
ck_tile::index_t NumDimM,
|
||||
ck_tile::index_t NumDimN,
|
||||
ck_tile::index_t NumDimK,
|
||||
ck_tile::index_t NumDTensor = 0>
|
||||
struct BatchedContractionKernelArgs
|
||||
{
|
||||
const void* a_ptr; ///< Pointer to input tensor A
|
||||
const void* b_ptr; ///< Pointer to input tensor B
|
||||
std::array<const void*, NumDTensor> ds_ptr; ///< Array of pointers to auxiliary input tensors D
|
||||
void* e_ptr; ///< Pointer to output tensor E
|
||||
ck_tile::index_t k_batch; ///< Number of k-splits for split-K batching
|
||||
|
||||
ck_tile::index_t M_dims[NumDimM]; ///< M dimension sizes: [M0, M1, M2, ..., M_{NumDimM-1}]
|
||||
ck_tile::index_t N_dims[NumDimN]; ///< N dimension sizes: [N0, N1, N2, ..., N_{NumDimN-1}]
|
||||
ck_tile::index_t K_dims[NumDimK]; ///< K dimension sizes: [K0, K1, K2, ..., K_{NumDimK-1}]
|
||||
ck_tile::index_t
|
||||
G_dims[NumDimG]; ///< G (batch) dimension sizes: [G0, G1, G2, ..., G_{NumDimG-1}]
|
||||
|
||||
// Batch strides for efficient offset calculation
|
||||
ck_tile::index_t batch_stride_A; ///< Batch stride for tensor A
|
||||
ck_tile::index_t batch_stride_B; ///< Batch stride for tensor B
|
||||
ck_tile::index_t batch_stride_E; ///< Batch stride for tensor E
|
||||
std::array<ck_tile::index_t, NumDTensor> batch_stride_Ds; ///< Batch strides for D tensors
|
||||
|
||||
ck_tile::index_t G_total; ///< Total batch size: G0 * G1 * ... * G_{NumDimG-1}
|
||||
ck_tile::index_t M_total; ///< Total M dimension: M0 * M1 * ... * M_{NumDimM-1}
|
||||
ck_tile::index_t N_total; ///< Total N dimension: N0 * N1 * ... * N_{NumDimN-1}
|
||||
ck_tile::index_t K_total; ///< Total K dimension: K0 * K1 * ... * K_{NumDimK-1}
|
||||
|
||||
ck_tile::index_t stride_A; ///< Leading dimension stride for tensor A (row-major: K_total)
|
||||
ck_tile::index_t stride_B; ///< Leading dimension stride for tensor B (row-major: K_total)
|
||||
std::array<ck_tile::index_t, NumDTensor>
|
||||
stride_Ds; ///< Leading dimension strides for D tensors (row-major: N_total)
|
||||
ck_tile::index_t stride_E; ///< Leading dimension stride for tensor E (row-major: N_total)
|
||||
};
|
||||
|
||||
/// @brief GPU kernel for batched tensor contraction operations.
|
||||
///
|
||||
/// @par Overview
|
||||
/// This kernel performs batched tensor contraction operations using the underlying
|
||||
/// UniversalGemmKernel. It supports arbitrary tensor dimensionalities (G, M, N, K) and
|
||||
/// processes multiple batch instances in parallel. Each batch performs: E =
|
||||
/// epilogue_op(contraction(A, B), D0, D1, ...).
|
||||
///
|
||||
/// @tparam Problem_ Tensor contraction problem specification defining data types and dimensions
|
||||
/// @tparam TilePartitioner_ Tile partitioning strategy for workload distribution
|
||||
/// @tparam GemmPipeline_ GEMM computation pipeline for core matrix operations
|
||||
/// @tparam EpiloguePipeline_ Epilogue pipeline for post-GEMM operations and tensor fusion
|
||||
|
||||
template <typename Problem_,
|
||||
typename TilePartitioner_,
|
||||
typename GemmPipeline_,
|
||||
typename EpiloguePipeline_>
|
||||
struct BatchedContractionKernel
|
||||
{
|
||||
// Type aliases for cleaner code and better readability
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>; ///< Tensor contraction problem specification
|
||||
using ADataType =
|
||||
ck_tile::remove_cvref_t<typename Problem::ADataType>; ///< Data type for input tensor A
|
||||
using BDataType =
|
||||
ck_tile::remove_cvref_t<typename Problem::BDataType>; ///< Data type for input tensor B
|
||||
using DsDataType =
|
||||
ck_tile::remove_cvref_t<typename Problem::DsDataType>; ///< Data types for auxiliary input
|
||||
///< tensors D
|
||||
using EDataType =
|
||||
ck_tile::remove_cvref_t<typename Problem::EDataType>; ///< Data type for output tensor E
|
||||
|
||||
// Compile-time dimension constants extracted from problem specification
|
||||
static constexpr ck_tile::index_t NumDimG = Problem::NumDimG; ///< Number of batch dimensions
|
||||
static constexpr ck_tile::index_t NumDimM =
|
||||
Problem::NumDimM; ///< Number of M (output row) dimensions
|
||||
static constexpr ck_tile::index_t NumDimN =
|
||||
Problem::NumDimN; ///< Number of N (output column) dimensions
|
||||
static constexpr ck_tile::index_t NumDimK =
|
||||
Problem::NumDimK; ///< Number of K (contraction) dimensions
|
||||
static constexpr ck_tile::index_t NumDTensor =
|
||||
Problem::NumDTensor; ///< Number of auxiliary input D tensors
|
||||
|
||||
// Pipeline and partitioning strategy types
|
||||
using TilePartitioner =
|
||||
ck_tile::remove_cvref_t<TilePartitioner_>; ///< Tile partitioning strategy for workload
|
||||
///< distribution
|
||||
using GemmPipeline = ck_tile::remove_cvref_t<GemmPipeline_>; ///< GEMM computation pipeline
|
||||
using EpiloguePipeline =
|
||||
ck_tile::remove_cvref_t<EpiloguePipeline_>; ///< Epilogue pipeline for post-GEMM operations
|
||||
|
||||
// Underlying GEMM kernel that performs the actual computation
|
||||
using UniversalGemmKernel =
|
||||
ck_tile::UniversalGemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
|
||||
|
||||
static constexpr ck_tile::index_t kBlockSize =
|
||||
UniversalGemmKernel::kBlockSize; ///< GPU block size inherited from GEMM kernel
|
||||
|
||||
using KernelArgs =
|
||||
BatchedContractionKernelArgs<NumDimG, NumDimM, NumDimN, NumDimK, NumDTensor>; ///< Kernel
|
||||
///< argument
|
||||
///< structure
|
||||
|
||||
/// @brief Returns the kernel name for debugging and profiling purposes.
|
||||
/// @return Constant string identifier for this kernel
|
||||
CK_TILE_HOST static constexpr auto GetKernelName() { return "batched_contraction_kernel"; }
|
||||
|
||||
/// @brief Validates whether the given kernel arguments are supported.
|
||||
/// @param kargs Kernel arguments to validate
|
||||
/// @return True if arguments are supported, false otherwise
|
||||
/// @details Checks underlying GEMM kernel support and ensures valid batch dimensions
|
||||
CK_TILE_HOST static constexpr bool IsSupportedArguments(const KernelArgs& kargs)
|
||||
{
|
||||
typename UniversalGemmKernel::KernelArgs gemm_kargs{{kargs.a_ptr},
|
||||
{kargs.b_ptr},
|
||||
kargs.ds_ptr,
|
||||
kargs.e_ptr,
|
||||
kargs.M_total,
|
||||
kargs.N_total,
|
||||
kargs.K_total,
|
||||
{kargs.stride_A},
|
||||
{kargs.stride_B},
|
||||
kargs.stride_Ds,
|
||||
kargs.stride_E,
|
||||
kargs.k_batch};
|
||||
|
||||
return UniversalGemmKernel::IsSupportedArgument(gemm_kargs) && kargs.G_total > 0;
|
||||
}
|
||||
|
||||
/// @brief Returns the shared memory size required by the kernel.
|
||||
/// @return Shared memory size in bytes
|
||||
/// @details Delegates to underlying GEMM kernel's shared memory requirements
|
||||
CK_TILE_HOST static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return UniversalGemmKernel::GetSmemSize();
|
||||
}
|
||||
|
||||
/// @brief Returns the GPU block size for kernel launch.
|
||||
/// @return 3D block dimensions for GPU kernel execution
|
||||
CK_TILE_HOST static constexpr auto GetBlockSize()
|
||||
{
|
||||
return dim3(UniversalGemmKernel::kBlockSize);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const KernelArgs& kargs)
|
||||
{
|
||||
return dim3(
|
||||
TilePartitioner::GridSize(kargs.M_total, kargs.N_total), kargs.G_total, kargs.k_batch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr KernelArgs
|
||||
MakeKernelArgs(const BatchedContractionHostArgs<NumDTensor>& host_args)
|
||||
{
|
||||
const auto expected_A_dims = NumDimG + NumDimM + NumDimK;
|
||||
const auto expected_B_dims = NumDimG + NumDimN + NumDimK;
|
||||
const auto expected_E_dims = NumDimG + NumDimM + NumDimN;
|
||||
|
||||
if(host_args.A_dims.size() != expected_A_dims ||
|
||||
host_args.A_strides.size() != expected_A_dims)
|
||||
{
|
||||
throw std::invalid_argument("A dimension size mismatch");
|
||||
}
|
||||
if(host_args.B_dims.size() != expected_B_dims ||
|
||||
host_args.B_strides.size() != expected_B_dims)
|
||||
{
|
||||
throw std::invalid_argument("B dimension size mismatch");
|
||||
}
|
||||
if(host_args.E_dims.size() != expected_E_dims ||
|
||||
host_args.E_strides.size() != expected_E_dims)
|
||||
{
|
||||
throw std::invalid_argument("E dimension size mismatch");
|
||||
}
|
||||
|
||||
for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
|
||||
{
|
||||
if(host_args.Ds_dims[d].size() != expected_E_dims ||
|
||||
host_args.Ds_strides[d].size() != expected_E_dims)
|
||||
{
|
||||
throw std::invalid_argument("D dimension size mismatch");
|
||||
}
|
||||
}
|
||||
|
||||
KernelArgs kargs;
|
||||
kargs.a_ptr = host_args.a_ptr;
|
||||
kargs.b_ptr = host_args.b_ptr;
|
||||
kargs.ds_ptr = host_args.ds_ptr;
|
||||
kargs.e_ptr = host_args.e_ptr;
|
||||
kargs.k_batch = host_args.k_batch;
|
||||
|
||||
// Validate and set G dimensions (must be identical across all tensors)
|
||||
for(ck_tile::index_t i = 0; i < NumDimG; ++i)
|
||||
{
|
||||
// All tensors must have same G dimensions for valid contraction
|
||||
if(host_args.A_dims[i] != host_args.B_dims[i] ||
|
||||
host_args.A_dims[i] != host_args.E_dims[i])
|
||||
{
|
||||
throw std::invalid_argument(
|
||||
"All tensors must have identical G dimensions for valid contraction");
|
||||
}
|
||||
|
||||
// Store G dimensions (same for all tensors)
|
||||
kargs.G_dims[i] = host_args.A_dims[i];
|
||||
}
|
||||
|
||||
// Set batch strides from the stride of last G dimension
|
||||
kargs.batch_stride_A = host_args.A_strides[NumDimG - 1];
|
||||
kargs.batch_stride_B = host_args.B_strides[NumDimG - 1];
|
||||
kargs.batch_stride_E = host_args.E_strides[NumDimG - 1];
|
||||
|
||||
for(ck_tile::index_t i = 0; i < NumDimM; ++i)
|
||||
{
|
||||
kargs.M_dims[i] = host_args.A_dims[NumDimG + i];
|
||||
if(kargs.M_dims[i] != host_args.E_dims[NumDimG + i])
|
||||
{
|
||||
throw std::invalid_argument("M dimension mismatch between A and E tensors");
|
||||
}
|
||||
}
|
||||
for(ck_tile::index_t i = 0; i < NumDimN; ++i)
|
||||
{
|
||||
kargs.N_dims[i] = host_args.B_dims[NumDimG + i];
|
||||
if(kargs.N_dims[i] != host_args.E_dims[NumDimG + NumDimM + i])
|
||||
{
|
||||
throw std::invalid_argument("N dimension mismatch between B and E tensors");
|
||||
}
|
||||
}
|
||||
for(ck_tile::index_t i = 0; i < NumDimK; ++i)
|
||||
{
|
||||
kargs.K_dims[i] = host_args.A_dims[NumDimG + NumDimM + i];
|
||||
if(kargs.K_dims[i] != host_args.B_dims[NumDimG + NumDimN + i])
|
||||
{
|
||||
throw std::invalid_argument("K dimension mismatch between A and B tensors");
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate total dimensions from individual dimension arrays
|
||||
kargs.G_total = 1;
|
||||
for(ck_tile::index_t i = 0; i < NumDimG; ++i)
|
||||
{
|
||||
kargs.G_total *= kargs.G_dims[i];
|
||||
}
|
||||
|
||||
kargs.M_total = 1;
|
||||
for(ck_tile::index_t i = 0; i < NumDimM; ++i)
|
||||
{
|
||||
kargs.M_total *= kargs.M_dims[i];
|
||||
}
|
||||
|
||||
kargs.N_total = 1;
|
||||
for(ck_tile::index_t i = 0; i < NumDimN; ++i)
|
||||
{
|
||||
kargs.N_total *= kargs.N_dims[i];
|
||||
}
|
||||
|
||||
kargs.K_total = 1;
|
||||
for(ck_tile::index_t i = 0; i < NumDimK; ++i)
|
||||
{
|
||||
kargs.K_total *= kargs.K_dims[i];
|
||||
}
|
||||
|
||||
kargs.stride_A = kargs.K_total;
|
||||
kargs.stride_B = kargs.K_total;
|
||||
kargs.stride_E = kargs.N_total;
|
||||
|
||||
// Validate D tensors have same G dimensions and set their batch strides
|
||||
for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
|
||||
{
|
||||
for(ck_tile::index_t i = 0; i < NumDimG; ++i)
|
||||
{
|
||||
if(host_args.Ds_dims[d][i] != host_args.A_dims[i])
|
||||
{
|
||||
throw std::invalid_argument(
|
||||
"D tensor G dimensions must match A/B/E tensor G dimensions");
|
||||
}
|
||||
}
|
||||
// Set batch stride for D tensor
|
||||
kargs.batch_stride_Ds[d] = host_args.Ds_strides[d][NumDimG - 1];
|
||||
kargs.stride_Ds[d] = kargs.N_total; // D tensors same shape as E
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(const KernelArgs& kargs) const
|
||||
{
|
||||
|
||||
const auto [iM, iN] =
|
||||
TilePartitioner{kargs.M_total, kargs.N_total}.GetOutputTileIndex(blockIdx.x);
|
||||
const ck_tile::index_t i_m =
|
||||
__builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
|
||||
const ck_tile::index_t i_n =
|
||||
__builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
const auto i_batch_flat = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
const auto i_splitk = __builtin_amdgcn_readfirstlane(blockIdx.z);
|
||||
|
||||
// Calculate batch offsets for each tensor
|
||||
const auto batch_offset_A = i_batch_flat * kargs.batch_stride_A;
|
||||
const auto batch_offset_B = i_batch_flat * kargs.batch_stride_B;
|
||||
const auto batch_offset_E = i_batch_flat * kargs.batch_stride_E;
|
||||
|
||||
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr) + batch_offset_A;
|
||||
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr) + batch_offset_B;
|
||||
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr) + batch_offset_E;
|
||||
|
||||
std::array<const void*, NumDTensor> ds_batch_ptr;
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = typename std::tuple_element<i.value, DsDataType>::type;
|
||||
const auto batch_offset_D = i_batch_flat * kargs.batch_stride_Ds[i];
|
||||
ds_batch_ptr[i] = static_cast<const DDataType*>(kargs.ds_ptr[i]) + batch_offset_D;
|
||||
});
|
||||
|
||||
typename UniversalGemmKernel::KernelArgs gemm_kargs{{a_ptr},
|
||||
{b_ptr},
|
||||
ds_batch_ptr,
|
||||
e_ptr,
|
||||
kargs.M_total,
|
||||
kargs.N_total,
|
||||
kargs.K_total,
|
||||
{kargs.stride_A},
|
||||
{kargs.stride_B},
|
||||
kargs.stride_Ds,
|
||||
kargs.stride_E,
|
||||
kargs.k_batch};
|
||||
|
||||
const typename UniversalGemmKernel::SplitKBatchOffset splitk_batch_offset(gemm_kargs,
|
||||
i_splitk);
|
||||
|
||||
const ADataType* a_ptr_final = a_ptr + splitk_batch_offset.as_k_split_offset[0];
|
||||
const BDataType* b_ptr_final = b_ptr + splitk_batch_offset.bs_k_split_offset[0];
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
UniversalGemmKernel::RunGemm({a_ptr_final},
|
||||
{b_ptr_final},
|
||||
ds_batch_ptr,
|
||||
e_ptr,
|
||||
smem_ptr,
|
||||
gemm_kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,32 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename DsDataType_,
|
||||
typename EDataType_,
|
||||
ck_tile::index_t NumDimG_,
|
||||
ck_tile::index_t NumDimM_,
|
||||
ck_tile::index_t NumDimN_,
|
||||
ck_tile::index_t NumDimK_,
|
||||
ck_tile::index_t NumDTensor_>
|
||||
struct BatchedContractionProblem
|
||||
{
|
||||
using ADataType = ck_tile::remove_cvref_t<ADataType_>;
|
||||
using BDataType = ck_tile::remove_cvref_t<BDataType_>;
|
||||
using DsDataType = ck_tile::remove_cvref_t<DsDataType_>;
|
||||
using EDataType = ck_tile::remove_cvref_t<EDataType_>;
|
||||
|
||||
static constexpr ck_tile::index_t NumDimG = NumDimG_;
|
||||
static constexpr ck_tile::index_t NumDimM = NumDimM_;
|
||||
static constexpr ck_tile::index_t NumDimN = NumDimN_;
|
||||
static constexpr ck_tile::index_t NumDimK = NumDimK_;
|
||||
static constexpr ck_tile::index_t NumDTensor = NumDTensor_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,169 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
/**
|
||||
* @file tensor_descriptor_utils.hpp
|
||||
* @brief Utility functions for creating tensor descriptors in batched contraction operations
|
||||
*
|
||||
* @details This file contains utility functions for creating tensor descriptors with flattened
|
||||
* dimensions for GEMM operations. These functions transform multi-dimensional tensors into
|
||||
* 2D matrix descriptors by removing batch dimensions and flattening the remaining dimensions.
|
||||
*
|
||||
* These utilities are currently not used in the main batched contraction kernel but are preserved
|
||||
* for future implementations that may require explicit tensor descriptor creation.
|
||||
*/
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/**
|
||||
* @brief Utility class for creating tensor descriptors in batched contraction operations
|
||||
*
|
||||
* @tparam NumDimG Number of batch dimensions
|
||||
* @tparam NumDimM Number of M (output row) dimensions
|
||||
* @tparam NumDimN Number of N (output column) dimensions
|
||||
* @tparam NumDimK Number of K (contraction) dimensions
|
||||
*/
|
||||
template <ck_tile::index_t NumDimG,
|
||||
ck_tile::index_t NumDimM,
|
||||
ck_tile::index_t NumDimN,
|
||||
ck_tile::index_t NumDimK>
|
||||
struct TensorDescriptorUtils
|
||||
{
|
||||
/// @brief Creates a tensor descriptor for input tensor A with batch dimensions removed.
|
||||
/// @param A_dims Dimension vector for tensor A: [G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
|
||||
/// @param A_strides Stride vector for tensor A: [G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
|
||||
/// @return Flattened tensor descriptor: [M_total, K_total] for GEMM computation
|
||||
/// @details Removes batch dimensions and flattens M and K dimensions for efficient GEMM
|
||||
/// execution
|
||||
CK_TILE_HOST static constexpr auto
|
||||
Make_A_GridDescriptor_M_K(const std::vector<ck_tile::index_t>& A_dims = {},
|
||||
const std::vector<ck_tile::index_t>& A_strides = {})
|
||||
{
|
||||
const auto to_tuple = [&](auto& vec, auto start, auto end) {
|
||||
return generate_tuple([&](auto i) { return vec[start + i]; }, number<end - start>{});
|
||||
};
|
||||
|
||||
// Remove G Dimensions
|
||||
const auto A_dims_M_K =
|
||||
to_tuple(A_dims, number<NumDimG>{}, number<NumDimG + NumDimM + NumDimK>{});
|
||||
const auto A_strides_M_K =
|
||||
to_tuple(A_strides, number<NumDimG>{}, number<NumDimG + NumDimM + NumDimK>{});
|
||||
|
||||
// dimension Ids for M and K
|
||||
constexpr auto A_dims_M_ids = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
|
||||
constexpr auto A_dims_K_ids =
|
||||
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimK, 1>::type{};
|
||||
|
||||
// Dimensions for M [M0, M1, ...] and K [K0, K1, ...]
|
||||
const auto dims_M = get_container_subset(A_dims_M_K, A_dims_M_ids);
|
||||
const auto dims_K = get_container_subset(A_dims_M_K, A_dims_K_ids);
|
||||
|
||||
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...] Discriptor
|
||||
const auto A_grid_desc_Ms_Ks =
|
||||
ck_tile::make_naive_tensor_descriptor(A_dims_M_K, A_strides_M_K);
|
||||
|
||||
// transformed tensor to flatten M and K dimensions [M_total = M0 * M1 * M2 * ... , K_total
|
||||
// = K0 * K1 * K2 * ...]
|
||||
const auto A_grid_desc_Mflat_Kflat = ck_tile::transform_tensor_descriptor(
|
||||
A_grid_desc_Ms_Ks,
|
||||
make_tuple(make_merge_transform(dims_M), make_merge_transform(dims_K)),
|
||||
make_tuple(A_dims_M_ids, A_dims_K_ids),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return A_grid_desc_Mflat_Kflat;
|
||||
}
|
||||
|
||||
/// @brief Creates a tensor descriptor for input tensor B with batch dimensions removed.
|
||||
/// @param B_dims Dimension vector for tensor B: [G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
|
||||
/// @param B_strides Stride vector for tensor B: [G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
|
||||
/// @return Flattened tensor descriptor: [N_total, K_total] for GEMM computation
|
||||
/// @details Removes batch dimensions and flattens N and K dimensions for efficient GEMM
|
||||
/// execution
|
||||
CK_TILE_HOST static constexpr auto
|
||||
Make_B_GridDescriptor_N_K(const std::vector<ck_tile::index_t>& B_dims = {},
|
||||
const std::vector<ck_tile::index_t>& B_strides = {})
|
||||
{
|
||||
const auto to_tuple = [&](auto& vec, auto start, auto end) {
|
||||
return generate_tuple([&](auto i) { return vec[start + i]; }, number<end - start>{});
|
||||
};
|
||||
|
||||
// Remove G Dimensions
|
||||
const auto B_dims_N_K =
|
||||
to_tuple(B_dims, number<NumDimG>{}, number<NumDimG + NumDimN + NumDimK>{});
|
||||
const auto B_strides_N_K =
|
||||
to_tuple(B_strides, number<NumDimG>{}, number<NumDimG + NumDimN + NumDimK>{});
|
||||
|
||||
// dimension Ids for N and K
|
||||
constexpr auto B_dims_N_ids = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{};
|
||||
constexpr auto B_dims_K_ids =
|
||||
typename arithmetic_sequence_gen<NumDimN, NumDimN + NumDimK, 1>::type{};
|
||||
|
||||
// Dimensions for N [N0, N1, ...] and K [K0, K1, ...]
|
||||
const auto dims_N = get_container_subset(B_dims_N_K, B_dims_N_ids);
|
||||
const auto dims_K = get_container_subset(B_dims_N_K, B_dims_K_ids);
|
||||
|
||||
// naive tensor B[N0, N1, N2, ..., K0, K1, K2...] Discriptor
|
||||
const auto B_grid_desc_Ns_Ks =
|
||||
ck_tile::make_naive_tensor_descriptor(B_dims_N_K, B_strides_N_K);
|
||||
|
||||
// transformed tensor to flatten N and K dimensions [N_total = N0 * N1 * N2 * ... , K_total
|
||||
// = K0 * K1 * K2 * ...]
|
||||
const auto B_grid_desc_Nflat_Kflat = ck_tile::transform_tensor_descriptor(
|
||||
B_grid_desc_Ns_Ks,
|
||||
make_tuple(make_merge_transform(dims_N), make_merge_transform(dims_K)),
|
||||
make_tuple(B_dims_N_ids, B_dims_K_ids),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return B_grid_desc_Nflat_Kflat;
|
||||
}
|
||||
|
||||
/// @brief Creates a tensor descriptor for output tensor E with batch dimensions removed.
|
||||
/// @param E_dims Dimension vector for tensor E: [G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
|
||||
/// @param E_strides Stride vector for tensor E: [G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
|
||||
/// @return Flattened tensor descriptor: [M_total, N_total] for GEMM computation
|
||||
/// @details Removes batch dimensions and flattens M and N dimensions for efficient GEMM
|
||||
/// execution
|
||||
CK_TILE_HOST static constexpr auto
|
||||
Make_E_GridDescriptor_M_N(const std::vector<ck_tile::index_t>& E_dims = {},
|
||||
const std::vector<ck_tile::index_t>& E_strides = {})
|
||||
{
|
||||
const auto to_tuple = [&](auto& vec, auto start, auto end) {
|
||||
return generate_tuple([&](auto i) { return vec[start + i]; }, number<end - start>{});
|
||||
};
|
||||
|
||||
// Remove G dimensions
|
||||
const auto E_dims_M_N =
|
||||
to_tuple(E_dims, number<NumDimG>{}, number<NumDimG + NumDimM + NumDimN>{});
|
||||
const auto E_strides_M_N =
|
||||
to_tuple(E_strides, number<NumDimG>{}, number<NumDimG + NumDimM + NumDimN>{});
|
||||
|
||||
// dimension Ids for M and N
|
||||
constexpr auto E_dims_M_ids = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
|
||||
constexpr auto E_dims_N_ids =
|
||||
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimN, 1>::type{};
|
||||
|
||||
// Dimensions for M and N
|
||||
const auto dims_M = get_container_subset(E_dims_M_N, E_dims_M_ids);
|
||||
const auto dims_N = get_container_subset(E_dims_M_N, E_dims_N_ids);
|
||||
|
||||
// naive tensor E[M0, M1, M2, ..., N0, N1, N2...] Discriptor
|
||||
const auto E_grid_desc_Ms_Ns =
|
||||
ck_tile::make_naive_tensor_descriptor(E_dims_M_N, E_strides_M_N);
|
||||
|
||||
// transformed tensor to flatten M and N dimensions [M_total = M0 * M1 * M2 * ... ,
|
||||
// N_total = N0 * N1 * N2 * ...]
|
||||
const auto E_grid_desc_Mflat_Nflat = ck_tile::transform_tensor_descriptor(
|
||||
E_grid_desc_Ms_Ns,
|
||||
make_tuple(make_merge_transform(dims_M), make_merge_transform(dims_N)),
|
||||
make_tuple(E_dims_M_ids, E_dims_N_ids),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return E_grid_desc_Mflat_Nflat;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user