mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
CK_TILE: Implement two-stage split-K GEMM with workspace reduction (LWPCK-2966) (#2632)
* CK_TILE: Implement two-stage split-K GEMM with reduction - Added split-K GEMM with reduction example * comment resolutions
This commit is contained in:
committed by
GitHub
parent
e5623d3825
commit
7f14772406
@@ -1,6 +1,7 @@
|
||||
add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp)
|
||||
add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp)
|
||||
add_executable(tile_example_gemm_weight_preshuffle EXCLUDE_FROM_ALL gemm_weight_preshuffle.cpp)
|
||||
add_executable(tile_example_gemm_reduce EXCLUDE_FROM_ALL gemm_splitk_two_stage_reduce.cpp)
|
||||
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
|
||||
set(EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS)
|
||||
if(CK_USE_OCP_FP8)
|
||||
@@ -14,3 +15,4 @@ list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS "SHELL: -mllvm -greedy-rev
|
||||
target_compile_options(tile_example_gemm_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_gemm_universal PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_gemm_weight_preshuffle PRIVATE ${EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_gemm_reduce PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
1009
example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp
Normal file
1009
example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp
Normal file
File diff suppressed because it is too large
Load Diff
@@ -213,6 +213,23 @@ struct UniversalGemmKernel
|
||||
};
|
||||
static constexpr bool PersistentKernel = has_persistent_kernel::value;
|
||||
|
||||
// Check if TilePartitioner has GetOutputOffset method with kargs and k_id
|
||||
struct has_tile_partitioner_output_offset_impl
|
||||
{
|
||||
template <typename T, typename KernelArgs>
|
||||
using has_get_output_offset_t =
|
||||
decltype(T::GetOutputOffset(std::declval<KernelArgs>(), std::declval<index_t>()));
|
||||
|
||||
static constexpr bool value = []() {
|
||||
if constexpr(is_detected<has_get_output_offset_t, TilePartitioner>{})
|
||||
return true;
|
||||
else
|
||||
return false;
|
||||
}();
|
||||
};
|
||||
static constexpr bool has_tile_partitioner_output_offset =
|
||||
has_tile_partitioner_output_offset_impl::value;
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
static constexpr auto I2 = number<2>();
|
||||
@@ -1032,7 +1049,13 @@ struct UniversalGemmKernel
|
||||
splitk_batch_offset.bs_k_split_offset[i];
|
||||
});
|
||||
|
||||
// Calculate output offset from tile partitioner and apply to output pointer
|
||||
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
|
||||
if constexpr(has_tile_partitioner_output_offset)
|
||||
{
|
||||
const index_t output_offset = TilePartitioner::GetOutputOffset(kargs, blockIdx.z);
|
||||
e_ptr += output_offset;
|
||||
}
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
@@ -1110,7 +1133,13 @@ struct UniversalGemmKernel
|
||||
splitk_batch_offset.bs_k_split_offset[i];
|
||||
});
|
||||
|
||||
// Calculate output offset from tile partitioner and apply to output pointer
|
||||
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
|
||||
if constexpr(has_tile_partitioner_output_offset)
|
||||
{
|
||||
const index_t output_offset = TilePartitioner::GetOutputOffset(kargs, k_batch);
|
||||
e_ptr += output_offset;
|
||||
}
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
|
||||
Reference in New Issue
Block a user