Add complete multi-dimensional stride support via descriptors

This commit is contained in:
Mohsen Saffari
2025-10-20 14:43:32 +00:00
parent b8b56d5cc6
commit bbfe4501fa
2 changed files with 64 additions and 50 deletions

View File

@@ -219,9 +219,8 @@ float batched_contraction(const ck_tile::BatchedContractionHostArgs<DsDataType::
HANDLE_CASE(2, 1, 1, 1); HANDLE_CASE(2, 1, 1, 1);
HANDLE_CASE(2, 2, 2, 1); HANDLE_CASE(2, 2, 2, 1);
HANDLE_CASE(1, 2, 1, 1); HANDLE_CASE(1, 2, 1, 1);
HANDLE_CASE(1, 1, 1, 2); HANDLE_CASE(2, 1, 1, 1);
HANDLE_CASE(2, 2, 2, 2); HANDLE_CASE(2, 2, 2, 2);
HANDLE_CASE(4, 4, 4, 4);
throw std::runtime_error( throw std::runtime_error(
"Unsupported dimension combination: G=" + std::to_string(num_g_dims) + "Unsupported dimension combination: G=" + std::to_string(num_g_dims) +

View File

@@ -232,7 +232,8 @@ struct BatchedContractionKernelArgs
AGridDesc_M_K_ a_grid_desc_m_k; ///< Tensor descriptor for A[M, K] with actual strides AGridDesc_M_K_ a_grid_desc_m_k; ///< Tensor descriptor for A[M, K] with actual strides
BGridDesc_N_K_ b_grid_desc_n_k; ///< Tensor descriptor for B[N, K] with actual strides BGridDesc_N_K_ b_grid_desc_n_k; ///< Tensor descriptor for B[N, K] with actual strides
EGridDesc_M_N_ e_grid_desc_m_n; ///< Tensor descriptor for E[M, N] with actual strides EGridDesc_M_N_ e_grid_desc_m_n; ///< Tensor descriptor for E[M, N] with actual strides
// TODO: Add D descriptors array when needed std::array<EGridDesc_M_N_, NumDTensor>
ds_grid_desc_m_n; ///< Descriptors for D tensors (same shape as E, independent strides)
}; };
/// @brief GPU kernel for batched tensor contraction operations. /// @brief GPU kernel for batched tensor contraction operations.
@@ -356,15 +357,14 @@ struct BatchedContractionKernel
/// non-contiguous support /// non-contiguous support
/// @details This function creates tensor views from descriptors and runs GEMM pipeline, /// @details This function creates tensor views from descriptors and runs GEMM pipeline,
/// similar to UniversalGemmKernel::RunGemm but with descriptor-based tensor views /// similar to UniversalGemmKernel::RunGemm but with descriptor-based tensor views
CK_TILE_DEVICE static void CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
RunGemm(const ADataType* a_ptr, const BDataType* b_ptr,
const BDataType* b_ptr, const std::array<const void*, NumDTensor>& ds_ptr,
[[maybe_unused]] const std::array<const void*, NumDTensor>& ds_ptr, EDataType* e_ptr,
EDataType* e_ptr, void* smem_ptr,
void* smem_ptr, const KernelArgs& kargs,
const KernelArgs& kargs, const index_t i_m,
const index_t i_m, const index_t i_n)
const index_t i_n)
{ {
// Create tensor views from descriptors (handles multi-dimensional strides) // Create tensor views from descriptors (handles multi-dimensional strides)
auto a_tensor_view = auto a_tensor_view =
@@ -401,11 +401,23 @@ struct BatchedContractionKernel
const auto& c_block_tile = GemmPipeline{}( const auto& c_block_tile = GemmPipeline{}(
a_block_window, AElementWise{}, b_block_window, BElementWise{}, num_loop, smem_ptr); a_block_window, AElementWise{}, b_block_window, BElementWise{}, num_loop, smem_ptr);
// Create empty D windows tuple (for NumDTensor=0 case) // Create D windows from descriptors (for each D tensor)
// TODO: Create actual D windows from descriptors when NumDTensor > 0 auto ds_block_windows = generate_tuple(
auto ds_block_windows = make_tuple(); [&](auto i) {
using DDataType = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
const DDataType* d_ptr = static_cast<const DDataType*>(ds_ptr[i]);
// Run Epilogue Pipeline (same as UniversalGemmKernel) auto d_tensor_view =
make_tensor_view<address_space_enum::global>(d_ptr, kargs.ds_grid_desc_m_n[i]);
return make_tile_window(d_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
},
number<NumDTensor>{});
// Run Epilogue Pipeline with descriptor-based D windows
EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_windows, smem_ptr); EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_windows, smem_ptr);
} }
@@ -526,6 +538,13 @@ struct BatchedContractionKernel
kargs.e_grid_desc_m_n = kargs.e_grid_desc_m_n =
DescriptorUtils::Make_E_GridDescriptor_M_N(host_args.E_dims, host_args.E_strides); DescriptorUtils::Make_E_GridDescriptor_M_N(host_args.E_dims, host_args.E_strides);
// Create D descriptors with their own strides (same shape as E, independent strides)
for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
{
kargs.ds_grid_desc_m_n[d] = DescriptorUtils::Make_E_GridDescriptor_M_N(
host_args.Ds_dims[d], host_args.Ds_strides[d]);
}
// Keep simple strides for backward compatibility // Keep simple strides for backward compatibility
kargs.stride_A = kargs.K_total; kargs.stride_A = kargs.K_total;
kargs.stride_B = kargs.K_total; kargs.stride_B = kargs.K_total;
@@ -582,43 +601,39 @@ struct BatchedContractionKernel
// Allocate shared memory // Allocate shared memory
__shared__ char smem_ptr[GetSmemSize()]; __shared__ char smem_ptr[GetSmemSize()];
if constexpr(NumDTensor > 0) #if 0 // OLD PATH: UniversalGemmKernel (kept for reference, can be deleted after full validation)
{ typename UniversalGemmKernel::KernelArgs gemm_kargs{{a_ptr},
// OLD PATH: Use UniversalGemmKernel for cases with D tensors {b_ptr},
typename UniversalGemmKernel::KernelArgs gemm_kargs{{a_ptr}, ds_batch_ptr,
{b_ptr}, e_ptr,
ds_batch_ptr, kargs.M_total,
e_ptr, kargs.N_total,
kargs.M_total, kargs.K_total,
kargs.N_total, {kargs.stride_A},
kargs.K_total, {kargs.stride_B},
{kargs.stride_A}, kargs.stride_Ds,
{kargs.stride_B}, kargs.stride_E,
kargs.stride_Ds, kargs.k_batch};
kargs.stride_E,
kargs.k_batch};
const typename UniversalGemmKernel::SplitKBatchOffset splitk_batch_offset(gemm_kargs, const typename UniversalGemmKernel::SplitKBatchOffset splitk_batch_offset(gemm_kargs,
i_splitk); i_splitk);
const ADataType* a_ptr_final = a_ptr + splitk_batch_offset.as_k_split_offset[0]; const ADataType* a_ptr_final = a_ptr + splitk_batch_offset.as_k_split_offset[0];
const BDataType* b_ptr_final = b_ptr + splitk_batch_offset.bs_k_split_offset[0]; const BDataType* b_ptr_final = b_ptr + splitk_batch_offset.bs_k_split_offset[0];
UniversalGemmKernel::RunGemm({a_ptr_final}, UniversalGemmKernel::RunGemm({a_ptr_final},
{b_ptr_final}, {b_ptr_final},
ds_batch_ptr, ds_batch_ptr,
e_ptr, e_ptr,
smem_ptr, smem_ptr,
gemm_kargs, gemm_kargs,
splitk_batch_offset, splitk_batch_offset,
i_m, i_m,
i_n); i_n);
} #else // NEW PATH: Descriptor-based RunGemm
else // custom descriptor-based RunGemm with full multi-dimensional stride support
{ RunGemm(a_ptr, b_ptr, ds_batch_ptr, e_ptr, smem_ptr, kargs, i_m, i_n);
// NEW PATH: Use descriptor-based RunGemm for num_d=0 #endif
RunGemm(a_ptr, b_ptr, ds_batch_ptr, e_ptr, smem_ptr, kargs, i_m, i_n);
}
} }
}; };