mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
[CK TILE] GEMM and Batched GEMM SplitK support (#1724)
* [CK TILE] Add split K support in GEMM * Updates * Fixes * rebase * fix * Fix * fixes * support for batched gemm
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -35,21 +35,39 @@ struct Default2DEpilogue
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool IsOutputTransposed() { return false; }
|
||||
|
||||
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
|
||||
// how do we fix this ?
|
||||
template <typename ODramWindowTmp, typename OAccTile>
|
||||
template <typename ODramWindowTmp,
|
||||
typename OAccTile,
|
||||
memory_operation_enum out_memory_data_op = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile)
|
||||
{
|
||||
|
||||
// TODO: this is ugly
|
||||
if constexpr(UseRawStore && (kPadM || kPadN))
|
||||
{
|
||||
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
if constexpr(out_memory_data_op == memory_operation_enum::set)
|
||||
{
|
||||
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
}
|
||||
buffer_store_fence();
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
if constexpr(out_memory_data_op == memory_operation_enum::set)
|
||||
{
|
||||
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user