mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
Refactore the compute reference batched contraction to manage stride-aware calculation and some code cleanings
This commit is contained in:
@@ -48,30 +48,34 @@ void print_help(const char* program_name)
|
|||||||
std::cout << "Batched Tensor Contraction with element-wise fusion\n";
|
std::cout << "Batched Tensor Contraction with element-wise fusion\n";
|
||||||
std::cout << "E[G,M,N] = element_wise_op(contraction(A[G,M,K], B[G,N,K]), D0, D1, ...)\n";
|
std::cout << "E[G,M,N] = element_wise_op(contraction(A[G,M,K], B[G,N,K]), D0, D1, ...)\n";
|
||||||
std::cout << "(Supports multiple D tensors with configurable element-wise operations)\n\n";
|
std::cout << "(Supports multiple D tensors with configurable element-wise operations)\n\n";
|
||||||
|
|
||||||
std::cout << "Usage: " << program_name << " [OPTIONS]\n\n";
|
std::cout << "Usage: " << program_name << " [OPTIONS]\n\n";
|
||||||
|
|
||||||
std::cout << "Dimension Arguments (comma-separated, no spaces):\n";
|
std::cout << "Dimension Arguments (comma-separated, no spaces):\n";
|
||||||
std::cout << " -g_dims=<dims> Batch dimensions (default: \"1,2\")\n";
|
std::cout << " -g_dims=<dims> Batch dimensions (default: \"1,2\")\n";
|
||||||
std::cout << " -m_dims=<dims> M (row) dimensions (default: \"4,256\")\n";
|
std::cout << " -m_dims=<dims> M (row) dimensions (default: \"4,256\")\n";
|
||||||
std::cout << " -n_dims=<dims> N (column) dimensions (default: \"16,128\")\n";
|
std::cout << " -n_dims=<dims> N (column) dimensions (default: \"16,128\")\n";
|
||||||
std::cout << " -k_dims=<dims> K (contract) dims (default: \"64\")\n\n";
|
std::cout << " -k_dims=<dims> K (contract) dims (default: \"64\")\n\n";
|
||||||
|
|
||||||
std::cout << "Layout Arguments:\n";
|
std::cout << "Layout Arguments:\n";
|
||||||
std::cout << " -a_layout=<R|C> A tensor layout (R=Row-major, C=Column-major, default: \"R\")\n";
|
std::cout
|
||||||
|
<< " -a_layout=<R|C> A tensor layout (R=Row-major, C=Column-major, default: \"R\")\n";
|
||||||
std::cout << " -b_layout=<R|C> B tensor layout (default: \"C\")\n";
|
std::cout << " -b_layout=<R|C> B tensor layout (default: \"C\")\n";
|
||||||
std::cout << " -e_layout=<R|C> E tensor layout (default: \"R\")\n\n";
|
std::cout << " -e_layout=<R|C> E tensor layout (default: \"R\")\n\n";
|
||||||
|
|
||||||
std::cout << "Examples:\n";
|
std::cout << "Examples:\n";
|
||||||
std::cout << " Single batch (12 batches of 256×128):\n";
|
std::cout << " Single batch (12 batches of 256×128):\n";
|
||||||
std::cout << " " << program_name << " -g_dims=\"12\" -m_dims=\"256\" -n_dims=\"128\" -k_dims=\"64\"\n\n";
|
std::cout << " " << program_name
|
||||||
|
<< " -g_dims=\"12\" -m_dims=\"256\" -n_dims=\"128\" -k_dims=\"64\"\n\n";
|
||||||
|
|
||||||
std::cout << " 2D batch grid (2×3=6 batches):\n";
|
std::cout << " 2D batch grid (2×3=6 batches):\n";
|
||||||
std::cout << " " << program_name << " -g_dims=\"2,3\" -m_dims=\"128\" -n_dims=\"128\" -k_dims=\"64\"\n\n";
|
std::cout << " " << program_name
|
||||||
|
<< " -g_dims=\"2,3\" -m_dims=\"128\" -n_dims=\"128\" -k_dims=\"64\"\n\n";
|
||||||
|
|
||||||
std::cout << " Multi-dimensional (flattened to M=128, N=128, K=128):\n";
|
std::cout << " Multi-dimensional (flattened to M=128, N=128, K=128):\n";
|
||||||
std::cout << " " << program_name << " -g_dims=\"4\" -m_dims=\"8,16\" -n_dims=\"32,4\" -k_dims=\"16,8\"\n\n";
|
std::cout << " " << program_name
|
||||||
|
<< " -g_dims=\"4\" -m_dims=\"8,16\" -n_dims=\"32,4\" -k_dims=\"16,8\"\n\n";
|
||||||
|
|
||||||
std::cout << "Other Options:\n";
|
std::cout << "Other Options:\n";
|
||||||
std::cout << " -v=<0|1> Validation (0=off, 1=on, default: 1)\n";
|
std::cout << " -v=<0|1> Validation (0=off, 1=on, default: 1)\n";
|
||||||
std::cout << " -split_k=<int> Split-K value (default: 1)\n";
|
std::cout << " -split_k=<int> Split-K value (default: 1)\n";
|
||||||
@@ -93,7 +97,7 @@ auto create_args(int argc, char* argv[])
|
|||||||
std::exit(0);
|
std::exit(0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ck_tile::ArgParser arg_parser;
|
ck_tile::ArgParser arg_parser;
|
||||||
arg_parser.insert("m_dims", "4,256", "M dimensions separated by comma (e.g., '16,32' for 2D M)")
|
arg_parser.insert("m_dims", "4,256", "M dimensions separated by comma (e.g., '16,32' for 2D M)")
|
||||||
.insert("n_dims", "16,128", "N dimensions separated by comma (e.g., '32,32' for 2D N)")
|
.insert("n_dims", "16,128", "N dimensions separated by comma (e.g., '32,32' for 2D N)")
|
||||||
|
|||||||
@@ -201,11 +201,14 @@ int run_batched_contraction_example_with_layouts(
|
|||||||
ck_tile::HostTensor<::BDataType> b_full_dims_host(b_desc);
|
ck_tile::HostTensor<::BDataType> b_full_dims_host(b_desc);
|
||||||
ck_tile::HostTensor<::EDataType> e_full_dims_host(e_desc);
|
ck_tile::HostTensor<::EDataType> e_full_dims_host(e_desc);
|
||||||
|
|
||||||
std::vector<ck_tile::HostTensor<::DDataType>> ds_full_dims_host;
|
// Helper to construct array of HostTensors using index_sequence
|
||||||
for(int d = 0; d < NumDTensor; ++d)
|
auto make_ds_host_tensors = []<std::size_t... Is>(const auto& descs,
|
||||||
{
|
std::index_sequence<Is...>) {
|
||||||
ds_full_dims_host.emplace_back(ck_tile::HostTensor<::DDataType>(ds_descs[d]));
|
return std::array<ck_tile::HostTensor<::DDataType>, sizeof...(Is)>{
|
||||||
}
|
ck_tile::HostTensor<::DDataType>(descs[Is])...};
|
||||||
|
};
|
||||||
|
|
||||||
|
auto ds_full_dims_host = make_ds_host_tensors(ds_descs, std::make_index_sequence<NumDTensor>{});
|
||||||
|
|
||||||
ck_tile::FillUniformDistribution<::ADataType>{-5.f, 5.f, std::nullopt}(a_full_dims_host);
|
ck_tile::FillUniformDistribution<::ADataType>{-5.f, 5.f, std::nullopt}(a_full_dims_host);
|
||||||
ck_tile::FillUniformDistribution<::BDataType>{-5.f, 5.f, std::nullopt}(b_full_dims_host);
|
ck_tile::FillUniformDistribution<::BDataType>{-5.f, 5.f, std::nullopt}(b_full_dims_host);
|
||||||
@@ -316,12 +319,13 @@ int run_batched_contraction_example_with_layouts(
|
|||||||
|
|
||||||
auto start_time = std::chrono::high_resolution_clock::now();
|
auto start_time = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
calculate_reference_flat_indexing<ADataType,
|
compute_reference_batched_contraction<ADataType,
|
||||||
BDataType,
|
BDataType,
|
||||||
DDataType,
|
DDataType,
|
||||||
EDataType,
|
EDataType,
|
||||||
AccDataType,
|
AccDataType,
|
||||||
CDEElementWise>(a_full_dims_host,
|
CDEElementWise,
|
||||||
|
NumDTensor>(a_full_dims_host,
|
||||||
b_full_dims_host,
|
b_full_dims_host,
|
||||||
ds_full_dims_host,
|
ds_full_dims_host,
|
||||||
e_full_dims_host_ref,
|
e_full_dims_host_ref,
|
||||||
@@ -329,7 +333,11 @@ int run_batched_contraction_example_with_layouts(
|
|||||||
M_total,
|
M_total,
|
||||||
N_total,
|
N_total,
|
||||||
K_total,
|
K_total,
|
||||||
CDEElementWise{});
|
CDEElementWise{},
|
||||||
|
G_dims,
|
||||||
|
M_dims,
|
||||||
|
N_dims,
|
||||||
|
K_dims);
|
||||||
|
|
||||||
auto end_time = std::chrono::high_resolution_clock::now();
|
auto end_time = std::chrono::high_resolution_clock::now();
|
||||||
auto duration =
|
auto duration =
|
||||||
|
|||||||
@@ -11,110 +11,210 @@
|
|||||||
|
|
||||||
namespace ck_tile {
|
namespace ck_tile {
|
||||||
|
|
||||||
|
// Helper to apply elementwise operation with variable number of D tensors
|
||||||
|
template <typename EDataType, typename AccDataType, typename CDEElementWise>
|
||||||
|
struct ApplyCDEElementWise
|
||||||
|
{
|
||||||
|
template <typename... DValues>
|
||||||
|
CK_TILE_HOST_DEVICE static void apply(EDataType& result,
|
||||||
|
AccDataType sum,
|
||||||
|
const CDEElementWise& cde_elementwise,
|
||||||
|
DValues... d_vals)
|
||||||
|
{
|
||||||
|
if constexpr(sizeof...(DValues) == 0)
|
||||||
|
{
|
||||||
|
result = static_cast<EDataType>(sum);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
cde_elementwise(
|
||||||
|
result, ck_tile::type_convert<float>(sum), ck_tile::type_convert<float>(d_vals)...);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Helper to extract D values at a given offset using index sequence
|
||||||
|
template <typename DDataType,
|
||||||
|
ck_tile::index_t NumDTensor,
|
||||||
|
typename Indices = std::make_index_sequence<NumDTensor>>
|
||||||
|
struct ExtractDValues;
|
||||||
|
|
||||||
|
template <typename DDataType, ck_tile::index_t NumDTensor, std::size_t... Is>
|
||||||
|
struct ExtractDValues<DDataType, NumDTensor, std::index_sequence<Is...>>
|
||||||
|
{
|
||||||
|
template <typename EDataType, typename AccDataType, typename CDEElementWise>
|
||||||
|
CK_TILE_HOST static void
|
||||||
|
apply_at_offset(EDataType& result,
|
||||||
|
AccDataType sum,
|
||||||
|
const CDEElementWise& cde_elementwise,
|
||||||
|
const std::array<ck_tile::HostTensor<DDataType>, NumDTensor>& ds_tensors,
|
||||||
|
std::size_t offset)
|
||||||
|
{
|
||||||
|
ApplyCDEElementWise<EDataType, AccDataType, CDEElementWise>::apply(
|
||||||
|
result, sum, cde_elementwise, ds_tensors[Is].mData[offset]...);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename ADataType,
|
template <typename ADataType,
|
||||||
typename BDataType,
|
typename BDataType,
|
||||||
typename DDataType,
|
typename DDataType,
|
||||||
typename EDataType,
|
typename EDataType,
|
||||||
typename AccDataType,
|
typename AccDataType,
|
||||||
typename CDEElementWise>
|
typename CDEElementWise,
|
||||||
|
ck_tile::index_t NumDTensor>
|
||||||
|
|
||||||
void calculate_reference_flat_indexing(
|
void compute_reference_batched_contraction(
|
||||||
const ck_tile::HostTensor<ADataType>& a_full_dims,
|
const ck_tile::HostTensor<ADataType>& a_full_dims,
|
||||||
const ck_tile::HostTensor<BDataType>& b_full_dims,
|
const ck_tile::HostTensor<BDataType>& b_full_dims,
|
||||||
const std::vector<ck_tile::HostTensor<DDataType>>& ds_full_dims_host,
|
const std::array<ck_tile::HostTensor<DDataType>, NumDTensor>& ds_full_dims_host,
|
||||||
ck_tile::HostTensor<EDataType>& e_full_dims_host_ref,
|
ck_tile::HostTensor<EDataType>& e_full_dims_host_ref,
|
||||||
ck_tile::index_t G_total,
|
ck_tile::index_t G_total,
|
||||||
ck_tile::index_t M_total,
|
ck_tile::index_t M_total,
|
||||||
ck_tile::index_t N_total,
|
ck_tile::index_t N_total,
|
||||||
ck_tile::index_t K_total,
|
ck_tile::index_t K_total,
|
||||||
const CDEElementWise& cde_elementwise)
|
const CDEElementWise& cde_elementwise,
|
||||||
|
const std::vector<ck_tile::index_t>& G_dims,
|
||||||
|
const std::vector<ck_tile::index_t>& M_dims,
|
||||||
|
const std::vector<ck_tile::index_t>& N_dims,
|
||||||
|
const std::vector<ck_tile::index_t>& K_dims)
|
||||||
{
|
{
|
||||||
std::cout << "Calculating reference using optimized flat indexing with parallel processing..."
|
std::cout << "Calculating reference using stride-aware indexing with parallel processing..."
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
|
|
||||||
// Parallel computation over G and M dimensions using pattern from reference_batched_gemm.hpp
|
// Extract stride information from tensor descriptors
|
||||||
|
const auto a_strides = a_full_dims.get_strides();
|
||||||
|
const auto b_strides = b_full_dims.get_strides();
|
||||||
|
const auto e_strides = e_full_dims_host_ref.get_strides();
|
||||||
|
|
||||||
|
const ck_tile::index_t num_g_dims = G_dims.size();
|
||||||
|
const ck_tile::index_t num_m_dims = M_dims.size();
|
||||||
|
const ck_tile::index_t num_n_dims = N_dims.size();
|
||||||
|
const ck_tile::index_t num_k_dims = K_dims.size();
|
||||||
|
|
||||||
|
// Helper lambda to compute linear index from flat indices using strides
|
||||||
|
auto compute_a_offset = [&](ck_tile::index_t g_flat,
|
||||||
|
ck_tile::index_t m_flat,
|
||||||
|
ck_tile::index_t k_flat) -> std::size_t {
|
||||||
|
std::size_t offset = 0;
|
||||||
|
|
||||||
|
// Decode G dimensions
|
||||||
|
ck_tile::index_t temp = g_flat;
|
||||||
|
for(ck_tile::index_t i = num_g_dims - 1; i >= 0; --i)
|
||||||
|
{
|
||||||
|
offset += (temp % G_dims[i]) * a_strides[i];
|
||||||
|
temp /= G_dims[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode M dimensions
|
||||||
|
temp = m_flat;
|
||||||
|
for(ck_tile::index_t i = num_m_dims - 1; i >= 0; --i)
|
||||||
|
{
|
||||||
|
offset += (temp % M_dims[i]) * a_strides[num_g_dims + i];
|
||||||
|
temp /= M_dims[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode K dimensions
|
||||||
|
temp = k_flat;
|
||||||
|
for(ck_tile::index_t i = num_k_dims - 1; i >= 0; --i)
|
||||||
|
{
|
||||||
|
offset += (temp % K_dims[i]) * a_strides[num_g_dims + num_m_dims + i];
|
||||||
|
temp /= K_dims[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
return offset;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto compute_b_offset = [&](ck_tile::index_t g_flat,
|
||||||
|
ck_tile::index_t n_flat,
|
||||||
|
ck_tile::index_t k_flat) -> std::size_t {
|
||||||
|
std::size_t offset = 0;
|
||||||
|
|
||||||
|
// Decode G dimensions
|
||||||
|
ck_tile::index_t temp = g_flat;
|
||||||
|
for(ck_tile::index_t i = num_g_dims - 1; i >= 0; --i)
|
||||||
|
{
|
||||||
|
offset += (temp % G_dims[i]) * b_strides[i];
|
||||||
|
temp /= G_dims[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode N dimensions
|
||||||
|
temp = n_flat;
|
||||||
|
for(ck_tile::index_t i = num_n_dims - 1; i >= 0; --i)
|
||||||
|
{
|
||||||
|
offset += (temp % N_dims[i]) * b_strides[num_g_dims + i];
|
||||||
|
temp /= N_dims[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode K dimensions
|
||||||
|
temp = k_flat;
|
||||||
|
for(ck_tile::index_t i = num_k_dims - 1; i >= 0; --i)
|
||||||
|
{
|
||||||
|
offset += (temp % K_dims[i]) * b_strides[num_g_dims + num_n_dims + i];
|
||||||
|
temp /= K_dims[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
return offset;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto compute_e_offset = [&](ck_tile::index_t g_flat,
|
||||||
|
ck_tile::index_t m_flat,
|
||||||
|
ck_tile::index_t n_flat) -> std::size_t {
|
||||||
|
std::size_t offset = 0;
|
||||||
|
|
||||||
|
// Decode G dimensions
|
||||||
|
ck_tile::index_t temp = g_flat;
|
||||||
|
for(ck_tile::index_t i = num_g_dims - 1; i >= 0; --i)
|
||||||
|
{
|
||||||
|
offset += (temp % G_dims[i]) * e_strides[i];
|
||||||
|
temp /= G_dims[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode M dimensions
|
||||||
|
temp = m_flat;
|
||||||
|
for(ck_tile::index_t i = num_m_dims - 1; i >= 0; --i)
|
||||||
|
{
|
||||||
|
offset += (temp % M_dims[i]) * e_strides[num_g_dims + i];
|
||||||
|
temp /= M_dims[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode N dimensions
|
||||||
|
temp = n_flat;
|
||||||
|
for(ck_tile::index_t i = num_n_dims - 1; i >= 0; --i)
|
||||||
|
{
|
||||||
|
offset += (temp % N_dims[i]) * e_strides[num_g_dims + num_m_dims + i];
|
||||||
|
temp /= N_dims[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
return offset;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Parallel computation over G and M dimensions
|
||||||
auto f_gm = [&](auto g_flat, auto m_flat) {
|
auto f_gm = [&](auto g_flat, auto m_flat) {
|
||||||
for(ck_tile::index_t n_flat = 0; n_flat < N_total; ++n_flat)
|
for(ck_tile::index_t n_flat = 0; n_flat < N_total; ++n_flat)
|
||||||
{
|
{
|
||||||
AccDataType sum = 0;
|
AccDataType sum = 0;
|
||||||
|
|
||||||
// Compute dot product over K dimension
|
// Compute dot product over K dimension using stride-aware indexing
|
||||||
for(ck_tile::index_t k_flat = 0; k_flat < K_total; ++k_flat)
|
for(ck_tile::index_t k_flat = 0; k_flat < K_total; ++k_flat)
|
||||||
{
|
{
|
||||||
auto a_val =
|
const std::size_t a_offset = compute_a_offset(g_flat, m_flat, k_flat);
|
||||||
a_full_dims.mData[g_flat * M_total * K_total + m_flat * K_total + k_flat];
|
const std::size_t b_offset = compute_b_offset(g_flat, n_flat, k_flat);
|
||||||
auto b_val =
|
|
||||||
b_full_dims.mData[g_flat * N_total * K_total + n_flat * K_total + k_flat];
|
auto a_val = a_full_dims.mData[a_offset];
|
||||||
|
auto b_val = b_full_dims.mData[b_offset];
|
||||||
sum += static_cast<AccDataType>(a_val) * static_cast<AccDataType>(b_val);
|
sum += static_cast<AccDataType>(a_val) * static_cast<AccDataType>(b_val);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply elementwise operation with D tensors
|
// Compute output offset using strides
|
||||||
EDataType result = static_cast<EDataType>(sum);
|
const std::size_t e_offset = compute_e_offset(g_flat, m_flat, n_flat);
|
||||||
if(ds_full_dims_host.size() == 0)
|
|
||||||
{
|
|
||||||
;
|
|
||||||
}
|
|
||||||
else if(ds_full_dims_host.size() == 1)
|
|
||||||
{
|
|
||||||
cde_elementwise(result,
|
|
||||||
ck_tile::type_convert<float>(sum),
|
|
||||||
ck_tile::type_convert<float>(
|
|
||||||
ds_full_dims_host[0].mData[g_flat * M_total * N_total +
|
|
||||||
m_flat * N_total + n_flat]));
|
|
||||||
}
|
|
||||||
else if(ds_full_dims_host.size() == 2)
|
|
||||||
{
|
|
||||||
cde_elementwise(
|
|
||||||
result,
|
|
||||||
ck_tile::type_convert<float>(sum),
|
|
||||||
ck_tile::type_convert<float>(
|
|
||||||
ds_full_dims_host[0]
|
|
||||||
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
|
|
||||||
ck_tile::type_convert<float>(
|
|
||||||
ds_full_dims_host[1]
|
|
||||||
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
|
|
||||||
}
|
|
||||||
else if(ds_full_dims_host.size() == 3)
|
|
||||||
{
|
|
||||||
cde_elementwise(
|
|
||||||
result,
|
|
||||||
ck_tile::type_convert<float>(sum),
|
|
||||||
ck_tile::type_convert<float>(
|
|
||||||
ds_full_dims_host[0]
|
|
||||||
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
|
|
||||||
ck_tile::type_convert<float>(
|
|
||||||
ds_full_dims_host[1]
|
|
||||||
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
|
|
||||||
ck_tile::type_convert<float>(
|
|
||||||
ds_full_dims_host[2]
|
|
||||||
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
|
|
||||||
}
|
|
||||||
else if(ds_full_dims_host.size() == 4)
|
|
||||||
{
|
|
||||||
cde_elementwise(
|
|
||||||
result,
|
|
||||||
ck_tile::type_convert<float>(sum),
|
|
||||||
ck_tile::type_convert<float>(
|
|
||||||
ds_full_dims_host[0]
|
|
||||||
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
|
|
||||||
ck_tile::type_convert<float>(
|
|
||||||
ds_full_dims_host[1]
|
|
||||||
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
|
|
||||||
ck_tile::type_convert<float>(
|
|
||||||
ds_full_dims_host[2]
|
|
||||||
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
|
|
||||||
ck_tile::type_convert<float>(
|
|
||||||
ds_full_dims_host[3]
|
|
||||||
.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
throw std::runtime_error("Unsupported NumDTensor for reference calculation");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store result
|
// Apply elementwise operation with D tensors using compile-time dispatch
|
||||||
e_full_dims_host_ref.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat] =
|
EDataType result = static_cast<EDataType>(sum);
|
||||||
static_cast<EDataType>(result);
|
ExtractDValues<DDataType, NumDTensor>::apply_at_offset(
|
||||||
|
result, sum, cde_elementwise, ds_full_dims_host, e_offset);
|
||||||
|
|
||||||
|
// Store result using stride-aware indexing
|
||||||
|
e_full_dims_host_ref.mData[e_offset] = static_cast<EDataType>(result);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -123,143 +223,4 @@ void calculate_reference_flat_indexing(
|
|||||||
make_ParallelTensorFunctor(f_gm, G_total, M_total)(std::thread::hardware_concurrency());
|
make_ParallelTensorFunctor(f_gm, G_total, M_total)(std::thread::hardware_concurrency());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename ADataType,
|
|
||||||
typename BDataType,
|
|
||||||
typename DDataType,
|
|
||||||
typename EDataType,
|
|
||||||
typename AccDataType,
|
|
||||||
typename CDEElementWise>
|
|
||||||
void calculate_reference_multi_dimensional(
|
|
||||||
const HostTensor<ADataType>& a_full_dims,
|
|
||||||
const HostTensor<BDataType>& b_full_dims,
|
|
||||||
const std::vector<HostTensor<DDataType>>& ds_full_dims_host,
|
|
||||||
HostTensor<EDataType>& e_full_dims_host_ref,
|
|
||||||
const std::vector<index_t>& G_dims,
|
|
||||||
const std::vector<index_t>& M_dims,
|
|
||||||
const std::vector<index_t>& N_dims,
|
|
||||||
const std::vector<index_t>& K_dims,
|
|
||||||
const std::vector<index_t>& A_dims,
|
|
||||||
const std::vector<index_t>& B_dims,
|
|
||||||
const std::vector<index_t>& E_dims,
|
|
||||||
const CDEElementWise& cde_elementwise)
|
|
||||||
{
|
|
||||||
std::cout << "Calculating reference using multi-dimensional indexing..." << std::endl;
|
|
||||||
|
|
||||||
std::vector<std::size_t> g_idx(G_dims.size());
|
|
||||||
std::vector<std::size_t> m_idx(M_dims.size());
|
|
||||||
std::vector<std::size_t> n_idx(N_dims.size());
|
|
||||||
std::vector<std::size_t> k_idx(K_dims.size());
|
|
||||||
std::vector<std::size_t> a_idx, b_idx, e_idx;
|
|
||||||
|
|
||||||
a_idx.reserve(A_dims.size());
|
|
||||||
b_idx.reserve(B_dims.size());
|
|
||||||
e_idx.reserve(E_dims.size());
|
|
||||||
|
|
||||||
for(ck_tile::index_t g_flat = 0; g_flat < calculate_total_elements(G_dims); ++g_flat)
|
|
||||||
{
|
|
||||||
ck_tile::index_t temp = g_flat;
|
|
||||||
for(int i = G_dims.size() - 1; i >= 0; --i)
|
|
||||||
{
|
|
||||||
g_idx[i] = temp % G_dims[i];
|
|
||||||
temp /= G_dims[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
for(ck_tile::index_t m_flat = 0; m_flat < calculate_total_elements(M_dims); ++m_flat)
|
|
||||||
{
|
|
||||||
temp = m_flat;
|
|
||||||
for(int i = M_dims.size() - 1; i >= 0; --i)
|
|
||||||
{
|
|
||||||
m_idx[i] = temp % M_dims[i];
|
|
||||||
temp /= M_dims[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
for(ck_tile::index_t n_flat = 0; n_flat < calculate_total_elements(N_dims); ++n_flat)
|
|
||||||
{
|
|
||||||
temp = n_flat;
|
|
||||||
for(int i = N_dims.size() - 1; i >= 0; --i)
|
|
||||||
{
|
|
||||||
n_idx[i] = temp % N_dims[i];
|
|
||||||
temp /= N_dims[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
AccDataType sum = 0;
|
|
||||||
|
|
||||||
for(ck_tile::index_t k_flat = 0; k_flat < calculate_total_elements(K_dims);
|
|
||||||
++k_flat)
|
|
||||||
{
|
|
||||||
temp = k_flat;
|
|
||||||
for(int i = K_dims.size() - 1; i >= 0; --i)
|
|
||||||
{
|
|
||||||
k_idx[i] = temp % K_dims[i];
|
|
||||||
temp /= K_dims[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
a_idx.clear();
|
|
||||||
b_idx.clear();
|
|
||||||
|
|
||||||
a_idx.insert(a_idx.end(), g_idx.begin(), g_idx.end());
|
|
||||||
a_idx.insert(a_idx.end(), m_idx.begin(), m_idx.end());
|
|
||||||
a_idx.insert(a_idx.end(), k_idx.begin(), k_idx.end());
|
|
||||||
|
|
||||||
b_idx.insert(b_idx.end(), g_idx.begin(), g_idx.end());
|
|
||||||
b_idx.insert(b_idx.end(), n_idx.begin(), n_idx.end());
|
|
||||||
b_idx.insert(b_idx.end(), k_idx.begin(), k_idx.end());
|
|
||||||
|
|
||||||
auto a_val = a_full_dims(a_idx);
|
|
||||||
auto b_val = b_full_dims(b_idx);
|
|
||||||
|
|
||||||
sum += static_cast<AccDataType>(a_val) * static_cast<AccDataType>(b_val);
|
|
||||||
}
|
|
||||||
|
|
||||||
e_idx.clear();
|
|
||||||
e_idx.insert(e_idx.end(), g_idx.begin(), g_idx.end());
|
|
||||||
e_idx.insert(e_idx.end(), m_idx.begin(), m_idx.end());
|
|
||||||
e_idx.insert(e_idx.end(), n_idx.begin(), n_idx.end());
|
|
||||||
|
|
||||||
EDataType result = static_cast<EDataType>(sum);
|
|
||||||
if(ds_full_dims_host.size() == 0)
|
|
||||||
{
|
|
||||||
;
|
|
||||||
}
|
|
||||||
else if(ds_full_dims_host.size() == 1)
|
|
||||||
{
|
|
||||||
cde_elementwise(result,
|
|
||||||
ck_tile::type_convert<float>(sum),
|
|
||||||
ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)));
|
|
||||||
}
|
|
||||||
else if(ds_full_dims_host.size() == 2)
|
|
||||||
{
|
|
||||||
cde_elementwise(result,
|
|
||||||
ck_tile::type_convert<float>(sum),
|
|
||||||
ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)),
|
|
||||||
ck_tile::type_convert<float>(ds_full_dims_host[1](e_idx)));
|
|
||||||
}
|
|
||||||
else if(ds_full_dims_host.size() == 3)
|
|
||||||
{
|
|
||||||
cde_elementwise(result,
|
|
||||||
ck_tile::type_convert<float>(sum),
|
|
||||||
ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)),
|
|
||||||
ck_tile::type_convert<float>(ds_full_dims_host[1](e_idx)),
|
|
||||||
ck_tile::type_convert<float>(ds_full_dims_host[2](e_idx)));
|
|
||||||
}
|
|
||||||
else if(ds_full_dims_host.size() == 4)
|
|
||||||
{
|
|
||||||
cde_elementwise(result,
|
|
||||||
ck_tile::type_convert<float>(sum),
|
|
||||||
ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)),
|
|
||||||
ck_tile::type_convert<float>(ds_full_dims_host[1](e_idx)),
|
|
||||||
ck_tile::type_convert<float>(ds_full_dims_host[2](e_idx)),
|
|
||||||
ck_tile::type_convert<float>(ds_full_dims_host[3](e_idx)));
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
throw std::runtime_error("Unsupported NumDTensor for reference calculation");
|
|
||||||
}
|
|
||||||
|
|
||||||
e_full_dims_host_ref(e_idx) = static_cast<EDataType>(result);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace ck_tile
|
} // namespace ck_tile
|
||||||
|
|||||||
Reference in New Issue
Block a user