mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
Add complete multi-dimensional stride support via descriptors
This commit is contained in:
@@ -232,7 +232,8 @@ struct BatchedContractionKernelArgs
|
||||
AGridDesc_M_K_ a_grid_desc_m_k; ///< Tensor descriptor for A[M, K] with actual strides
|
||||
BGridDesc_N_K_ b_grid_desc_n_k; ///< Tensor descriptor for B[N, K] with actual strides
|
||||
EGridDesc_M_N_ e_grid_desc_m_n; ///< Tensor descriptor for E[M, N] with actual strides
|
||||
// TODO: Add D descriptors array when needed
|
||||
std::array<EGridDesc_M_N_, NumDTensor>
|
||||
ds_grid_desc_m_n; ///< Descriptors for D tensors (same shape as E, independent strides)
|
||||
};
|
||||
|
||||
/// @brief GPU kernel for batched tensor contraction operations.
|
||||
@@ -356,15 +357,14 @@ struct BatchedContractionKernel
|
||||
/// non-contiguous support
|
||||
/// @details This function creates tensor views from descriptors and runs GEMM pipeline,
|
||||
/// similar to UniversalGemmKernel::RunGemm but with descriptor-based tensor views
|
||||
CK_TILE_DEVICE static void
|
||||
RunGemm(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
[[maybe_unused]] const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
void* smem_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const index_t i_m,
|
||||
const index_t i_n)
|
||||
CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
void* smem_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const index_t i_m,
|
||||
const index_t i_n)
|
||||
{
|
||||
// Create tensor views from descriptors (handles multi-dimensional strides)
|
||||
auto a_tensor_view =
|
||||
@@ -401,11 +401,23 @@ struct BatchedContractionKernel
|
||||
const auto& c_block_tile = GemmPipeline{}(
|
||||
a_block_window, AElementWise{}, b_block_window, BElementWise{}, num_loop, smem_ptr);
|
||||
|
||||
// Create empty D windows tuple (for NumDTensor=0 case)
|
||||
// TODO: Create actual D windows from descriptors when NumDTensor > 0
|
||||
auto ds_block_windows = make_tuple();
|
||||
// Create D windows from descriptors (for each D tensor)
|
||||
auto ds_block_windows = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DDataType = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
|
||||
const DDataType* d_ptr = static_cast<const DDataType*>(ds_ptr[i]);
|
||||
|
||||
// Run Epilogue Pipeline (same as UniversalGemmKernel)
|
||||
auto d_tensor_view =
|
||||
make_tensor_view<address_space_enum::global>(d_ptr, kargs.ds_grid_desc_m_n[i]);
|
||||
|
||||
return make_tile_window(d_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
// Run Epilogue Pipeline with descriptor-based D windows
|
||||
EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_windows, smem_ptr);
|
||||
}
|
||||
|
||||
@@ -526,6 +538,13 @@ struct BatchedContractionKernel
|
||||
kargs.e_grid_desc_m_n =
|
||||
DescriptorUtils::Make_E_GridDescriptor_M_N(host_args.E_dims, host_args.E_strides);
|
||||
|
||||
// Create D descriptors with their own strides (same shape as E, independent strides)
|
||||
for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
|
||||
{
|
||||
kargs.ds_grid_desc_m_n[d] = DescriptorUtils::Make_E_GridDescriptor_M_N(
|
||||
host_args.Ds_dims[d], host_args.Ds_strides[d]);
|
||||
}
|
||||
|
||||
// Keep simple strides for backward compatibility
|
||||
kargs.stride_A = kargs.K_total;
|
||||
kargs.stride_B = kargs.K_total;
|
||||
@@ -582,43 +601,39 @@ struct BatchedContractionKernel
|
||||
// Allocate shared memory
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
if constexpr(NumDTensor > 0)
|
||||
{
|
||||
// OLD PATH: Use UniversalGemmKernel for cases with D tensors
|
||||
typename UniversalGemmKernel::KernelArgs gemm_kargs{{a_ptr},
|
||||
{b_ptr},
|
||||
ds_batch_ptr,
|
||||
e_ptr,
|
||||
kargs.M_total,
|
||||
kargs.N_total,
|
||||
kargs.K_total,
|
||||
{kargs.stride_A},
|
||||
{kargs.stride_B},
|
||||
kargs.stride_Ds,
|
||||
kargs.stride_E,
|
||||
kargs.k_batch};
|
||||
#if 0 // OLD PATH: UniversalGemmKernel (kept for reference, can be deleted after full validation)
|
||||
typename UniversalGemmKernel::KernelArgs gemm_kargs{{a_ptr},
|
||||
{b_ptr},
|
||||
ds_batch_ptr,
|
||||
e_ptr,
|
||||
kargs.M_total,
|
||||
kargs.N_total,
|
||||
kargs.K_total,
|
||||
{kargs.stride_A},
|
||||
{kargs.stride_B},
|
||||
kargs.stride_Ds,
|
||||
kargs.stride_E,
|
||||
kargs.k_batch};
|
||||
|
||||
const typename UniversalGemmKernel::SplitKBatchOffset splitk_batch_offset(gemm_kargs,
|
||||
i_splitk);
|
||||
const typename UniversalGemmKernel::SplitKBatchOffset splitk_batch_offset(gemm_kargs,
|
||||
i_splitk);
|
||||
|
||||
const ADataType* a_ptr_final = a_ptr + splitk_batch_offset.as_k_split_offset[0];
|
||||
const BDataType* b_ptr_final = b_ptr + splitk_batch_offset.bs_k_split_offset[0];
|
||||
const ADataType* a_ptr_final = a_ptr + splitk_batch_offset.as_k_split_offset[0];
|
||||
const BDataType* b_ptr_final = b_ptr + splitk_batch_offset.bs_k_split_offset[0];
|
||||
|
||||
UniversalGemmKernel::RunGemm({a_ptr_final},
|
||||
{b_ptr_final},
|
||||
ds_batch_ptr,
|
||||
e_ptr,
|
||||
smem_ptr,
|
||||
gemm_kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
// NEW PATH: Use descriptor-based RunGemm for num_d=0
|
||||
RunGemm(a_ptr, b_ptr, ds_batch_ptr, e_ptr, smem_ptr, kargs, i_m, i_n);
|
||||
}
|
||||
UniversalGemmKernel::RunGemm({a_ptr_final},
|
||||
{b_ptr_final},
|
||||
ds_batch_ptr,
|
||||
e_ptr,
|
||||
smem_ptr,
|
||||
gemm_kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
#else // NEW PATH: Descriptor-based RunGemm
|
||||
// custom descriptor-based RunGemm with full multi-dimensional stride support
|
||||
RunGemm(a_ptr, b_ptr, ds_batch_ptr, e_ptr, smem_ptr, kargs, i_m, i_n);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user