CK_TILE: Implement two-stage split-K GEMM with workspace reduction (LWPCK-2966) (#2632)

* CK_TILE: Implement two-stage split-K GEMM with reduction

- Added split-K GEMM with reduction example

* comment resolutions
This commit is contained in:
Yashvardhan Agarwal
2025-08-14 11:18:52 +03:00
committed by GitHub
parent e5623d3825
commit 7f14772406
3 changed files with 1040 additions and 0 deletions

View File

@@ -1,6 +1,7 @@
add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp)
add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp)
add_executable(tile_example_gemm_weight_preshuffle EXCLUDE_FROM_ALL gemm_weight_preshuffle.cpp)
add_executable(tile_example_gemm_reduce EXCLUDE_FROM_ALL gemm_splitk_two_stage_reduce.cpp)
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
set(EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS)
if(CK_USE_OCP_FP8)
@@ -14,3 +15,4 @@ list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS "SHELL: -mllvm -greedy-rev
target_compile_options(tile_example_gemm_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(tile_example_gemm_universal PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(tile_example_gemm_weight_preshuffle PRIVATE ${EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS})
target_compile_options(tile_example_gemm_reduce PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})

File diff suppressed because it is too large Load Diff

View File

@@ -213,6 +213,23 @@ struct UniversalGemmKernel
};
static constexpr bool PersistentKernel = has_persistent_kernel::value;
// Check if TilePartitioner has GetOutputOffset method with kargs and k_id
struct has_tile_partitioner_output_offset_impl
{
template <typename T, typename KernelArgs>
using has_get_output_offset_t =
decltype(T::GetOutputOffset(std::declval<KernelArgs>(), std::declval<index_t>()));
static constexpr bool value = []() {
if constexpr(is_detected<has_get_output_offset_t, TilePartitioner>{})
return true;
else
return false;
}();
};
static constexpr bool has_tile_partitioner_output_offset =
has_tile_partitioner_output_offset_impl::value;
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
static constexpr auto I2 = number<2>();
@@ -1032,7 +1049,13 @@ struct UniversalGemmKernel
splitk_batch_offset.bs_k_split_offset[i];
});
// Calculate output offset from tile partitioner and apply to output pointer
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
if constexpr(has_tile_partitioner_output_offset)
{
const index_t output_offset = TilePartitioner::GetOutputOffset(kargs, blockIdx.z);
e_ptr += output_offset;
}
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
@@ -1110,7 +1133,13 @@ struct UniversalGemmKernel
splitk_batch_offset.bs_k_split_offset[i];
});
// Calculate output offset from tile partitioner and apply to output pointer
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
if constexpr(has_tile_partitioner_output_offset)
{
const index_t output_offset = TilePartitioner::GetOutputOffset(kargs, k_batch);
e_ptr += output_offset;
}
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];