mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
Add multi-dimensional non-contiguous stride support to batched contraction, num_d = 0
This commit is contained in:
@@ -352,6 +352,63 @@ struct BatchedContractionKernel
|
||||
TilePartitioner::GridSize(kargs.M_total, kargs.N_total), kargs.G_total, kargs.k_batch);
|
||||
}
|
||||
|
||||
/// @brief Custom RunGemm implementation using tensor descriptors for multi-dimensional
|
||||
/// non-contiguous support
|
||||
/// @details This function creates tensor views from descriptors and runs GEMM pipeline,
|
||||
/// similar to UniversalGemmKernel::RunGemm but with descriptor-based tensor views
|
||||
CK_TILE_DEVICE static void
|
||||
RunGemm(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
[[maybe_unused]] const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
void* smem_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const index_t i_m,
|
||||
const index_t i_n)
|
||||
{
|
||||
// Create tensor views from descriptors (handles multi-dimensional strides)
|
||||
auto a_tensor_view =
|
||||
make_tensor_view<address_space_enum::global>(a_ptr, kargs.a_grid_desc_m_k);
|
||||
auto b_tensor_view =
|
||||
make_tensor_view<address_space_enum::global>(b_ptr, kargs.b_grid_desc_n_k);
|
||||
auto e_tensor_view =
|
||||
make_tensor_view<address_space_enum::global>(e_ptr, kargs.e_grid_desc_m_n);
|
||||
|
||||
// Create tile windows for this block's work
|
||||
auto a_block_window = make_tile_window(
|
||||
a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, 0});
|
||||
|
||||
auto b_block_window = make_tile_window(
|
||||
b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
{i_n, 0});
|
||||
|
||||
auto e_block_window = make_tile_window(
|
||||
e_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
|
||||
// Calculate number of K loops
|
||||
const index_t num_loop =
|
||||
__builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(kargs.K_total));
|
||||
|
||||
// Run GEMM Pipeline (same as UniversalGemmKernel, but with descriptor-based windows)
|
||||
using AElementWise = remove_cvref_t<typename GemmPipeline::AElementWise>;
|
||||
using BElementWise = remove_cvref_t<typename GemmPipeline::BElementWise>;
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}(
|
||||
a_block_window, AElementWise{}, b_block_window, BElementWise{}, num_loop, smem_ptr);
|
||||
|
||||
// Create empty D windows tuple (for NumDTensor=0 case)
|
||||
// TODO: Create actual D windows from descriptors when NumDTensor > 0
|
||||
auto ds_block_windows = make_tuple();
|
||||
|
||||
// Run Epilogue Pipeline (same as UniversalGemmKernel)
|
||||
EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_windows, smem_ptr);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr KernelArgs
|
||||
MakeKernelArgs(const BatchedContractionHostArgs<NumDTensor>& host_args)
|
||||
{
|
||||
@@ -503,8 +560,8 @@ struct BatchedContractionKernel
|
||||
const ck_tile::index_t i_n =
|
||||
__builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
const auto i_batch_flat = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
const auto i_splitk = __builtin_amdgcn_readfirstlane(blockIdx.z);
|
||||
const auto i_batch_flat = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
[[maybe_unused]] const auto i_splitk = __builtin_amdgcn_readfirstlane(blockIdx.z);
|
||||
|
||||
// Calculate batch offsets for each tensor
|
||||
const auto batch_offset_A = i_batch_flat * kargs.batch_stride_A;
|
||||
@@ -522,35 +579,46 @@ struct BatchedContractionKernel
|
||||
ds_batch_ptr[i] = static_cast<const DDataType*>(kargs.ds_ptr[i]) + batch_offset_D;
|
||||
});
|
||||
|
||||
typename UniversalGemmKernel::KernelArgs gemm_kargs{{a_ptr},
|
||||
{b_ptr},
|
||||
ds_batch_ptr,
|
||||
e_ptr,
|
||||
kargs.M_total,
|
||||
kargs.N_total,
|
||||
kargs.K_total,
|
||||
{kargs.stride_A},
|
||||
{kargs.stride_B},
|
||||
kargs.stride_Ds,
|
||||
kargs.stride_E,
|
||||
kargs.k_batch};
|
||||
|
||||
const typename UniversalGemmKernel::SplitKBatchOffset splitk_batch_offset(gemm_kargs,
|
||||
i_splitk);
|
||||
|
||||
const ADataType* a_ptr_final = a_ptr + splitk_batch_offset.as_k_split_offset[0];
|
||||
const BDataType* b_ptr_final = b_ptr + splitk_batch_offset.bs_k_split_offset[0];
|
||||
// Allocate shared memory
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
UniversalGemmKernel::RunGemm({a_ptr_final},
|
||||
{b_ptr_final},
|
||||
ds_batch_ptr,
|
||||
e_ptr,
|
||||
smem_ptr,
|
||||
gemm_kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
if constexpr(NumDTensor > 0)
|
||||
{
|
||||
// OLD PATH: Use UniversalGemmKernel for cases with D tensors
|
||||
typename UniversalGemmKernel::KernelArgs gemm_kargs{{a_ptr},
|
||||
{b_ptr},
|
||||
ds_batch_ptr,
|
||||
e_ptr,
|
||||
kargs.M_total,
|
||||
kargs.N_total,
|
||||
kargs.K_total,
|
||||
{kargs.stride_A},
|
||||
{kargs.stride_B},
|
||||
kargs.stride_Ds,
|
||||
kargs.stride_E,
|
||||
kargs.k_batch};
|
||||
|
||||
const typename UniversalGemmKernel::SplitKBatchOffset splitk_batch_offset(gemm_kargs,
|
||||
i_splitk);
|
||||
|
||||
const ADataType* a_ptr_final = a_ptr + splitk_batch_offset.as_k_split_offset[0];
|
||||
const BDataType* b_ptr_final = b_ptr + splitk_batch_offset.bs_k_split_offset[0];
|
||||
|
||||
UniversalGemmKernel::RunGemm({a_ptr_final},
|
||||
{b_ptr_final},
|
||||
ds_batch_ptr,
|
||||
e_ptr,
|
||||
smem_ptr,
|
||||
gemm_kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
// NEW PATH: Use descriptor-based RunGemm for num_d=0
|
||||
RunGemm(a_ptr, b_ptr, ds_batch_ptr, e_ptr, smem_ptr, kargs, i_m, i_n);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user