mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
Add complete multi-dimensional stride support via descriptors
This commit is contained in:
@@ -219,9 +219,8 @@ float batched_contraction(const ck_tile::BatchedContractionHostArgs<DsDataType::
|
|||||||
HANDLE_CASE(2, 1, 1, 1);
|
HANDLE_CASE(2, 1, 1, 1);
|
||||||
HANDLE_CASE(2, 2, 2, 1);
|
HANDLE_CASE(2, 2, 2, 1);
|
||||||
HANDLE_CASE(1, 2, 1, 1);
|
HANDLE_CASE(1, 2, 1, 1);
|
||||||
HANDLE_CASE(1, 1, 1, 2);
|
HANDLE_CASE(2, 1, 1, 1);
|
||||||
HANDLE_CASE(2, 2, 2, 2);
|
HANDLE_CASE(2, 2, 2, 2);
|
||||||
HANDLE_CASE(4, 4, 4, 4);
|
|
||||||
|
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"Unsupported dimension combination: G=" + std::to_string(num_g_dims) +
|
"Unsupported dimension combination: G=" + std::to_string(num_g_dims) +
|
||||||
|
|||||||
@@ -232,7 +232,8 @@ struct BatchedContractionKernelArgs
|
|||||||
AGridDesc_M_K_ a_grid_desc_m_k; ///< Tensor descriptor for A[M, K] with actual strides
|
AGridDesc_M_K_ a_grid_desc_m_k; ///< Tensor descriptor for A[M, K] with actual strides
|
||||||
BGridDesc_N_K_ b_grid_desc_n_k; ///< Tensor descriptor for B[N, K] with actual strides
|
BGridDesc_N_K_ b_grid_desc_n_k; ///< Tensor descriptor for B[N, K] with actual strides
|
||||||
EGridDesc_M_N_ e_grid_desc_m_n; ///< Tensor descriptor for E[M, N] with actual strides
|
EGridDesc_M_N_ e_grid_desc_m_n; ///< Tensor descriptor for E[M, N] with actual strides
|
||||||
// TODO: Add D descriptors array when needed
|
std::array<EGridDesc_M_N_, NumDTensor>
|
||||||
|
ds_grid_desc_m_n; ///< Descriptors for D tensors (same shape as E, independent strides)
|
||||||
};
|
};
|
||||||
|
|
||||||
/// @brief GPU kernel for batched tensor contraction operations.
|
/// @brief GPU kernel for batched tensor contraction operations.
|
||||||
@@ -356,15 +357,14 @@ struct BatchedContractionKernel
|
|||||||
/// non-contiguous support
|
/// non-contiguous support
|
||||||
/// @details This function creates tensor views from descriptors and runs GEMM pipeline,
|
/// @details This function creates tensor views from descriptors and runs GEMM pipeline,
|
||||||
/// similar to UniversalGemmKernel::RunGemm but with descriptor-based tensor views
|
/// similar to UniversalGemmKernel::RunGemm but with descriptor-based tensor views
|
||||||
CK_TILE_DEVICE static void
|
CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
|
||||||
RunGemm(const ADataType* a_ptr,
|
const BDataType* b_ptr,
|
||||||
const BDataType* b_ptr,
|
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||||
[[maybe_unused]] const std::array<const void*, NumDTensor>& ds_ptr,
|
EDataType* e_ptr,
|
||||||
EDataType* e_ptr,
|
void* smem_ptr,
|
||||||
void* smem_ptr,
|
const KernelArgs& kargs,
|
||||||
const KernelArgs& kargs,
|
const index_t i_m,
|
||||||
const index_t i_m,
|
const index_t i_n)
|
||||||
const index_t i_n)
|
|
||||||
{
|
{
|
||||||
// Create tensor views from descriptors (handles multi-dimensional strides)
|
// Create tensor views from descriptors (handles multi-dimensional strides)
|
||||||
auto a_tensor_view =
|
auto a_tensor_view =
|
||||||
@@ -401,11 +401,23 @@ struct BatchedContractionKernel
|
|||||||
const auto& c_block_tile = GemmPipeline{}(
|
const auto& c_block_tile = GemmPipeline{}(
|
||||||
a_block_window, AElementWise{}, b_block_window, BElementWise{}, num_loop, smem_ptr);
|
a_block_window, AElementWise{}, b_block_window, BElementWise{}, num_loop, smem_ptr);
|
||||||
|
|
||||||
// Create empty D windows tuple (for NumDTensor=0 case)
|
// Create D windows from descriptors (for each D tensor)
|
||||||
// TODO: Create actual D windows from descriptors when NumDTensor > 0
|
auto ds_block_windows = generate_tuple(
|
||||||
auto ds_block_windows = make_tuple();
|
[&](auto i) {
|
||||||
|
using DDataType = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
|
||||||
|
const DDataType* d_ptr = static_cast<const DDataType*>(ds_ptr[i]);
|
||||||
|
|
||||||
// Run Epilogue Pipeline (same as UniversalGemmKernel)
|
auto d_tensor_view =
|
||||||
|
make_tensor_view<address_space_enum::global>(d_ptr, kargs.ds_grid_desc_m_n[i]);
|
||||||
|
|
||||||
|
return make_tile_window(d_tensor_view,
|
||||||
|
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||||
|
number<TilePartitioner::NPerBlock>{}),
|
||||||
|
{i_m, i_n});
|
||||||
|
},
|
||||||
|
number<NumDTensor>{});
|
||||||
|
|
||||||
|
// Run Epilogue Pipeline with descriptor-based D windows
|
||||||
EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_windows, smem_ptr);
|
EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_windows, smem_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -526,6 +538,13 @@ struct BatchedContractionKernel
|
|||||||
kargs.e_grid_desc_m_n =
|
kargs.e_grid_desc_m_n =
|
||||||
DescriptorUtils::Make_E_GridDescriptor_M_N(host_args.E_dims, host_args.E_strides);
|
DescriptorUtils::Make_E_GridDescriptor_M_N(host_args.E_dims, host_args.E_strides);
|
||||||
|
|
||||||
|
// Create D descriptors with their own strides (same shape as E, independent strides)
|
||||||
|
for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
|
||||||
|
{
|
||||||
|
kargs.ds_grid_desc_m_n[d] = DescriptorUtils::Make_E_GridDescriptor_M_N(
|
||||||
|
host_args.Ds_dims[d], host_args.Ds_strides[d]);
|
||||||
|
}
|
||||||
|
|
||||||
// Keep simple strides for backward compatibility
|
// Keep simple strides for backward compatibility
|
||||||
kargs.stride_A = kargs.K_total;
|
kargs.stride_A = kargs.K_total;
|
||||||
kargs.stride_B = kargs.K_total;
|
kargs.stride_B = kargs.K_total;
|
||||||
@@ -582,43 +601,39 @@ struct BatchedContractionKernel
|
|||||||
// Allocate shared memory
|
// Allocate shared memory
|
||||||
__shared__ char smem_ptr[GetSmemSize()];
|
__shared__ char smem_ptr[GetSmemSize()];
|
||||||
|
|
||||||
if constexpr(NumDTensor > 0)
|
#if 0 // OLD PATH: UniversalGemmKernel (kept for reference, can be deleted after full validation)
|
||||||
{
|
typename UniversalGemmKernel::KernelArgs gemm_kargs{{a_ptr},
|
||||||
// OLD PATH: Use UniversalGemmKernel for cases with D tensors
|
{b_ptr},
|
||||||
typename UniversalGemmKernel::KernelArgs gemm_kargs{{a_ptr},
|
ds_batch_ptr,
|
||||||
{b_ptr},
|
e_ptr,
|
||||||
ds_batch_ptr,
|
kargs.M_total,
|
||||||
e_ptr,
|
kargs.N_total,
|
||||||
kargs.M_total,
|
kargs.K_total,
|
||||||
kargs.N_total,
|
{kargs.stride_A},
|
||||||
kargs.K_total,
|
{kargs.stride_B},
|
||||||
{kargs.stride_A},
|
kargs.stride_Ds,
|
||||||
{kargs.stride_B},
|
kargs.stride_E,
|
||||||
kargs.stride_Ds,
|
kargs.k_batch};
|
||||||
kargs.stride_E,
|
|
||||||
kargs.k_batch};
|
|
||||||
|
|
||||||
const typename UniversalGemmKernel::SplitKBatchOffset splitk_batch_offset(gemm_kargs,
|
const typename UniversalGemmKernel::SplitKBatchOffset splitk_batch_offset(gemm_kargs,
|
||||||
i_splitk);
|
i_splitk);
|
||||||
|
|
||||||
const ADataType* a_ptr_final = a_ptr + splitk_batch_offset.as_k_split_offset[0];
|
const ADataType* a_ptr_final = a_ptr + splitk_batch_offset.as_k_split_offset[0];
|
||||||
const BDataType* b_ptr_final = b_ptr + splitk_batch_offset.bs_k_split_offset[0];
|
const BDataType* b_ptr_final = b_ptr + splitk_batch_offset.bs_k_split_offset[0];
|
||||||
|
|
||||||
UniversalGemmKernel::RunGemm({a_ptr_final},
|
UniversalGemmKernel::RunGemm({a_ptr_final},
|
||||||
{b_ptr_final},
|
{b_ptr_final},
|
||||||
ds_batch_ptr,
|
ds_batch_ptr,
|
||||||
e_ptr,
|
e_ptr,
|
||||||
smem_ptr,
|
smem_ptr,
|
||||||
gemm_kargs,
|
gemm_kargs,
|
||||||
splitk_batch_offset,
|
splitk_batch_offset,
|
||||||
i_m,
|
i_m,
|
||||||
i_n);
|
i_n);
|
||||||
}
|
#else // NEW PATH: Descriptor-based RunGemm
|
||||||
else
|
// custom descriptor-based RunGemm with full multi-dimensional stride support
|
||||||
{
|
RunGemm(a_ptr, b_ptr, ds_batch_ptr, e_ptr, smem_ptr, kargs, i_m, i_n);
|
||||||
// NEW PATH: Use descriptor-based RunGemm for num_d=0
|
#endif
|
||||||
RunGemm(a_ptr, b_ptr, ds_batch_ptr, e_ptr, smem_ptr, kargs, i_m, i_n);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user