mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
CK_TILE: Implement two-stage split-K GEMM with workspace reduction (LWPCK-2966) (#2632)
* CK_TILE: Implement two-stage split-K GEMM with reduction - Added split-K GEMM with reduction example * comment resolutions
This commit is contained in:
committed by
GitHub
parent
e5623d3825
commit
7f14772406
@@ -213,6 +213,23 @@ struct UniversalGemmKernel
|
||||
};
|
||||
static constexpr bool PersistentKernel = has_persistent_kernel::value;
|
||||
|
||||
// Check if TilePartitioner has GetOutputOffset method with kargs and k_id
|
||||
struct has_tile_partitioner_output_offset_impl
|
||||
{
|
||||
template <typename T, typename KernelArgs>
|
||||
using has_get_output_offset_t =
|
||||
decltype(T::GetOutputOffset(std::declval<KernelArgs>(), std::declval<index_t>()));
|
||||
|
||||
static constexpr bool value = []() {
|
||||
if constexpr(is_detected<has_get_output_offset_t, TilePartitioner>{})
|
||||
return true;
|
||||
else
|
||||
return false;
|
||||
}();
|
||||
};
|
||||
static constexpr bool has_tile_partitioner_output_offset =
|
||||
has_tile_partitioner_output_offset_impl::value;
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
static constexpr auto I2 = number<2>();
|
||||
@@ -1032,7 +1049,13 @@ struct UniversalGemmKernel
|
||||
splitk_batch_offset.bs_k_split_offset[i];
|
||||
});
|
||||
|
||||
// Calculate output offset from tile partitioner and apply to output pointer
|
||||
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
|
||||
if constexpr(has_tile_partitioner_output_offset)
|
||||
{
|
||||
const index_t output_offset = TilePartitioner::GetOutputOffset(kargs, blockIdx.z);
|
||||
e_ptr += output_offset;
|
||||
}
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
@@ -1110,7 +1133,13 @@ struct UniversalGemmKernel
|
||||
splitk_batch_offset.bs_k_split_offset[i];
|
||||
});
|
||||
|
||||
// Calculate output offset from tile partitioner and apply to output pointer
|
||||
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
|
||||
if constexpr(has_tile_partitioner_output_offset)
|
||||
{
|
||||
const index_t output_offset = TilePartitioner::GetOutputOffset(kargs, k_batch);
|
||||
e_ptr += output_offset;
|
||||
}
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
|
||||
Reference in New Issue
Block a user