mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Tile engine for streamk (#3157)
* [CK TILE STREAMK] Introduce initial support for tile engine in streamk GEMM.
- This commit lays the groundwork for integrating the tile engine into streamk GEMM.
It focuses on creating benchmark executables for streamk GEMM.
- Additional scripts like test_benchmark.sh and gemm_benchmark.py will be added once
the streamk implementation reaches stability.
* [CK TILE STREAMK] Enable CI to execute tile engine benchmarks for StreamK GEMM
* [CK TILE STREAMK] Refactor: Extract common utility functions.
* [CK TILE STREAMK] Revise tile engine of streamk to align with the updated implementation
* Add pre-commit
* [CK TILE STREAMK] Add 'dp_persistent' and 'reduction_strategy' in output of CK TILE STREAMK
* [CK TILE STREAMK] Fix a bug about value of 'dp_persistent' of CK TILE STREAMK
* [CK TILE STREAMK] Update Jenkinsfile
* [CK TILE Engine] Update StreamK tile engine help message
Remove default value messages as they are automatically printed
* [CK TILE Engine] Update StreamK tile engine
- Remove namespace reboot
* [CK TILE Engine] Update StreamK tile engine
- Fix merge error
[ROCm/composable_kernel commit: 30727c48fc]
This commit is contained in:
@@ -86,7 +86,7 @@ invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
|
||||
std::tuple<float, ck_tile::index_t> ave_time_and_batch;
|
||||
|
||||
if(args.reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic)
|
||||
if(reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
ave_time_and_batch = gemm<GemmConfig,
|
||||
ADataType,
|
||||
|
||||
@@ -105,13 +105,13 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
|
||||
}
|
||||
|
||||
auto reset_data_buffers = [&]() {
|
||||
if(ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
|
||||
if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
// Clear the output C tensor results after each repetition of the kernel
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
}
|
||||
else if(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction)
|
||||
else if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
// Reset sk flags to zero before each repetition of the kernel
|
||||
workspace_data.SetZero();
|
||||
|
||||
Reference in New Issue
Block a user