Add multi-dimensional non-contiguous stride support to batched contraction, num_d = 0

This commit is contained in:
Mohsen Saffari
2025-10-20 13:15:39 +00:00
parent 2ecb0bfb3e
commit b8b56d5cc6

View File

@@ -352,6 +352,63 @@ struct BatchedContractionKernel
TilePartitioner::GridSize(kargs.M_total, kargs.N_total), kargs.G_total, kargs.k_batch);
}
/// @brief Custom RunGemm implementation using tensor descriptors for multi-dimensional
/// non-contiguous support
/// @details This function creates tensor views from descriptors and runs GEMM pipeline,
/// similar to UniversalGemmKernel::RunGemm but with descriptor-based tensor views
CK_TILE_DEVICE static void
RunGemm(const ADataType* a_ptr,
const BDataType* b_ptr,
[[maybe_unused]] const std::array<const void*, NumDTensor>& ds_ptr,
EDataType* e_ptr,
void* smem_ptr,
const KernelArgs& kargs,
const index_t i_m,
const index_t i_n)
{
// Create tensor views from descriptors (handles multi-dimensional strides)
auto a_tensor_view =
make_tensor_view<address_space_enum::global>(a_ptr, kargs.a_grid_desc_m_k);
auto b_tensor_view =
make_tensor_view<address_space_enum::global>(b_ptr, kargs.b_grid_desc_n_k);
auto e_tensor_view =
make_tensor_view<address_space_enum::global>(e_ptr, kargs.e_grid_desc_m_n);
// Create tile windows for this block's work
auto a_block_window = make_tile_window(
a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
{i_m, 0});
auto b_block_window = make_tile_window(
b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
{i_n, 0});
auto e_block_window = make_tile_window(
e_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
// Calculate number of K loops
const index_t num_loop =
__builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(kargs.K_total));
// Run GEMM Pipeline (same as UniversalGemmKernel, but with descriptor-based windows)
using AElementWise = remove_cvref_t<typename GemmPipeline::AElementWise>;
using BElementWise = remove_cvref_t<typename GemmPipeline::BElementWise>;
const auto& c_block_tile = GemmPipeline{}(
a_block_window, AElementWise{}, b_block_window, BElementWise{}, num_loop, smem_ptr);
// Create empty D windows tuple (for NumDTensor=0 case)
// TODO: Create actual D windows from descriptors when NumDTensor > 0
auto ds_block_windows = make_tuple();
// Run Epilogue Pipeline (same as UniversalGemmKernel)
EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_windows, smem_ptr);
}
CK_TILE_HOST static constexpr KernelArgs
MakeKernelArgs(const BatchedContractionHostArgs<NumDTensor>& host_args)
{
@@ -503,8 +560,8 @@ struct BatchedContractionKernel
const ck_tile::index_t i_n =
__builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
const auto i_batch_flat = __builtin_amdgcn_readfirstlane(blockIdx.y);
const auto i_splitk = __builtin_amdgcn_readfirstlane(blockIdx.z);
const auto i_batch_flat = __builtin_amdgcn_readfirstlane(blockIdx.y);
[[maybe_unused]] const auto i_splitk = __builtin_amdgcn_readfirstlane(blockIdx.z);
// Calculate batch offsets for each tensor
const auto batch_offset_A = i_batch_flat * kargs.batch_stride_A;
@@ -522,35 +579,46 @@ struct BatchedContractionKernel
ds_batch_ptr[i] = static_cast<const DDataType*>(kargs.ds_ptr[i]) + batch_offset_D;
});
typename UniversalGemmKernel::KernelArgs gemm_kargs{{a_ptr},
{b_ptr},
ds_batch_ptr,
e_ptr,
kargs.M_total,
kargs.N_total,
kargs.K_total,
{kargs.stride_A},
{kargs.stride_B},
kargs.stride_Ds,
kargs.stride_E,
kargs.k_batch};
const typename UniversalGemmKernel::SplitKBatchOffset splitk_batch_offset(gemm_kargs,
i_splitk);
const ADataType* a_ptr_final = a_ptr + splitk_batch_offset.as_k_split_offset[0];
const BDataType* b_ptr_final = b_ptr + splitk_batch_offset.bs_k_split_offset[0];
// Allocate shared memory
__shared__ char smem_ptr[GetSmemSize()];
UniversalGemmKernel::RunGemm({a_ptr_final},
{b_ptr_final},
ds_batch_ptr,
e_ptr,
smem_ptr,
gemm_kargs,
splitk_batch_offset,
i_m,
i_n);
if constexpr(NumDTensor > 0)
{
// OLD PATH: Use UniversalGemmKernel for cases with D tensors
typename UniversalGemmKernel::KernelArgs gemm_kargs{{a_ptr},
{b_ptr},
ds_batch_ptr,
e_ptr,
kargs.M_total,
kargs.N_total,
kargs.K_total,
{kargs.stride_A},
{kargs.stride_B},
kargs.stride_Ds,
kargs.stride_E,
kargs.k_batch};
const typename UniversalGemmKernel::SplitKBatchOffset splitk_batch_offset(gemm_kargs,
i_splitk);
const ADataType* a_ptr_final = a_ptr + splitk_batch_offset.as_k_split_offset[0];
const BDataType* b_ptr_final = b_ptr + splitk_batch_offset.bs_k_split_offset[0];
UniversalGemmKernel::RunGemm({a_ptr_final},
{b_ptr_final},
ds_batch_ptr,
e_ptr,
smem_ptr,
gemm_kargs,
splitk_batch_offset,
i_m,
i_n);
}
else
{
// NEW PATH: Use descriptor-based RunGemm for num_d=0
RunGemm(a_ptr, b_ptr, ds_batch_ptr, e_ptr, smem_ptr, kargs, i_m, i_n);
}
}
};