mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
Add stride-aware reference for batched contraction with independent D tensor layouts
This commit is contained in:
@@ -44,14 +44,14 @@ struct ExtractDValues<DDataType, NumDTensor, std::index_sequence<Is...>>
|
||||
{
|
||||
template <typename EDataType, typename AccDataType, typename CDEElementWise>
|
||||
CK_TILE_HOST static void
|
||||
apply_at_offset(EDataType& result,
|
||||
AccDataType sum,
|
||||
const CDEElementWise& cde_elementwise,
|
||||
const std::array<ck_tile::HostTensor<DDataType>, NumDTensor>& ds_tensors,
|
||||
std::size_t offset)
|
||||
apply_at_offsets(EDataType& result,
|
||||
AccDataType sum,
|
||||
const CDEElementWise& cde_elementwise,
|
||||
const std::array<ck_tile::HostTensor<DDataType>, NumDTensor>& ds_tensors,
|
||||
const std::array<std::size_t, NumDTensor>& d_offsets)
|
||||
{
|
||||
ApplyCDEElementWise<EDataType, AccDataType, CDEElementWise>::apply(
|
||||
result, sum, cde_elementwise, ds_tensors[Is].mData[offset]...);
|
||||
result, sum, cde_elementwise, ds_tensors[Is].mData[d_offsets[Is]]...);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -86,6 +86,13 @@ void compute_reference_batched_contraction(
|
||||
const auto b_strides = b_full_dims.get_strides();
|
||||
const auto e_strides = e_full_dims_host_ref.get_strides();
|
||||
|
||||
// Extract D tensor strides
|
||||
std::array<std::vector<std::size_t>, NumDTensor> ds_strides;
|
||||
for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
|
||||
{
|
||||
ds_strides[d] = ds_full_dims_host[d].get_strides();
|
||||
}
|
||||
|
||||
const ck_tile::index_t num_g_dims = G_dims.size();
|
||||
const ck_tile::index_t num_m_dims = M_dims.size();
|
||||
const ck_tile::index_t num_n_dims = N_dims.size();
|
||||
@@ -188,6 +195,41 @@ void compute_reference_batched_contraction(
|
||||
return offset;
|
||||
};
|
||||
|
||||
// Helper to compute D tensor offset (D tensors have same shape as E: [G, M, N])
|
||||
auto compute_d_offset = [&](ck_tile::index_t g_flat,
|
||||
ck_tile::index_t m_flat,
|
||||
ck_tile::index_t n_flat,
|
||||
ck_tile::index_t d_idx) -> std::size_t {
|
||||
std::size_t offset = 0;
|
||||
const auto& d_strides = ds_strides[d_idx];
|
||||
|
||||
// Decode G dimensions
|
||||
ck_tile::index_t temp = g_flat;
|
||||
for(ck_tile::index_t i = num_g_dims - 1; i >= 0; --i)
|
||||
{
|
||||
offset += (temp % G_dims[i]) * d_strides[i];
|
||||
temp /= G_dims[i];
|
||||
}
|
||||
|
||||
// Decode M dimensions
|
||||
temp = m_flat;
|
||||
for(ck_tile::index_t i = num_m_dims - 1; i >= 0; --i)
|
||||
{
|
||||
offset += (temp % M_dims[i]) * d_strides[num_g_dims + i];
|
||||
temp /= M_dims[i];
|
||||
}
|
||||
|
||||
// Decode N dimensions
|
||||
temp = n_flat;
|
||||
for(ck_tile::index_t i = num_n_dims - 1; i >= 0; --i)
|
||||
{
|
||||
offset += (temp % N_dims[i]) * d_strides[num_g_dims + num_m_dims + i];
|
||||
temp /= N_dims[i];
|
||||
}
|
||||
|
||||
return offset;
|
||||
};
|
||||
|
||||
// Parallel computation over G and M dimensions
|
||||
auto f_gm = [&](auto g_flat, auto m_flat) {
|
||||
for(ck_tile::index_t n_flat = 0; n_flat < N_total; ++n_flat)
|
||||
@@ -208,10 +250,17 @@ void compute_reference_batched_contraction(
|
||||
// Compute output offset using strides
|
||||
const std::size_t e_offset = compute_e_offset(g_flat, m_flat, n_flat);
|
||||
|
||||
// Compute individual D tensor offsets using their respective strides
|
||||
std::array<std::size_t, NumDTensor> d_offsets;
|
||||
for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
|
||||
{
|
||||
d_offsets[d] = compute_d_offset(g_flat, m_flat, n_flat, d);
|
||||
}
|
||||
|
||||
// Apply elementwise operation with D tensors using compile-time dispatch
|
||||
EDataType result = static_cast<EDataType>(sum);
|
||||
ExtractDValues<DDataType, NumDTensor>::apply_at_offset(
|
||||
result, sum, cde_elementwise, ds_full_dims_host, e_offset);
|
||||
ExtractDValues<DDataType, NumDTensor>::apply_at_offsets(
|
||||
result, sum, cde_elementwise, ds_full_dims_host, d_offsets);
|
||||
|
||||
// Store result using stride-aware indexing
|
||||
e_full_dims_host_ref.mData[e_offset] = static_cast<EDataType>(result);
|
||||
|
||||
Reference in New Issue
Block a user