mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 19:09:59 +00:00
CK-Tile Grouped GEMM refactor and post PR fixes (#1756)
* Grouped gemm simple code refactor
* Offset invoker
* Invoke generic Run, and replace name of parrtitioner variable
* Tests fix type
* Removed namespaces
* Add template param to avoid implicit cast
* Remove generic function
* Constant value
* underline enum to int16_t
* Generalize partitioner function
* Remove whitespaces
* Rename function
* Using support
* Clang-format
* Clang-format
* Fn-partitioner description fn
* Typo
* Typo 2
* Better description
* Better description
* Refactor after review
* Use ctr instead of set fn
* Inovke ctr and typo
* Comments
* Remove unnecessary comment
* Review, remove modulo
[ROCm/composable_kernel commit: 3c93d3c444]
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
@@ -49,7 +49,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTilePartitioner<CodegenGemmShape>;
|
||||
using TilePartitioner = ck_tile::GemmTile2DPartitioner<CodegenGemmShape>;
|
||||
|
||||
using GemmEpilogue = std::conditional_t<
|
||||
CShuffleEpilogue,
|
||||
@@ -61,8 +61,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
kOutputRank,
|
||||
1,
|
||||
0,
|
||||
TilePartitioner::kM,
|
||||
TilePartitioner::kN>>,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock>>,
|
||||
ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>;
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>;
|
||||
using TilePartitioner = ck_tile::GemmTile2DPartitioner<GemmShape>;
|
||||
|
||||
using GemmEpilogue = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
@@ -51,7 +51,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTilePartitioner<CodegenGemmShape>;
|
||||
using TilePartitioner = ck_tile::GemmTile2DPartitioner<CodegenGemmShape>;
|
||||
|
||||
using GemmEpilogue = std::conditional_t<
|
||||
CShuffleEpilogue,
|
||||
@@ -63,8 +63,8 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
|
||||
kOutputRank,
|
||||
1,
|
||||
0,
|
||||
TilePartitioner::kM,
|
||||
TilePartitioner::kN>>,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock>>,
|
||||
ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>;
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "grouped_gemm.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -102,7 +101,7 @@ using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner,
|
||||
GemmEpilogue<CLayout>>;
|
||||
}; // namespace
|
||||
|
||||
std::size_t GetWorkspaceSize(const std::vector<grouped_gemm_kargs>& gemm_descs)
|
||||
std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
|
||||
{
|
||||
return ::Kernel<std::nullptr_t, std::nullptr_t, std::nullptr_t>::GetWorkSpaceSize(gemm_descs);
|
||||
}
|
||||
|
||||
@@ -52,8 +52,8 @@ auto create_args(int argc, char* argv[])
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceSize(const std::vector<grouped_gemm_kargs>& gemm_descs);
|
||||
std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs);
|
||||
|
||||
float grouped_gemm_calc(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
void* p_workspace_);
|
||||
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
void* p_workspace_);
|
||||
|
||||
@@ -31,7 +31,7 @@ float invoke_gemm(int n_warmup,
|
||||
{
|
||||
|
||||
ck_tile::DeviceMem gemm_workspace;
|
||||
gemm_workspace.Realloc(GetWorkspaceSize(args));
|
||||
gemm_workspace.Realloc(get_workspace_size(args));
|
||||
|
||||
float ave_time = grouped_gemm<ALayout, BLayout, CLayout>(
|
||||
args,
|
||||
@@ -128,16 +128,16 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
const ck_tile::index_t N = Ns[i];
|
||||
const ck_tile::index_t K = Ks[i];
|
||||
|
||||
stride_As[i] = f_get_default_stride(M, N, stride_As[i], a_layout);
|
||||
stride_Bs[i] = f_get_default_stride(K, N, stride_Bs[i], b_layout);
|
||||
stride_Cs[i] = f_get_default_stride(M, N, stride_Cs[i], CLayout{});
|
||||
stride_As[i] = ck_tile::get_default_stride(M, N, stride_As[i], a_layout);
|
||||
stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], b_layout);
|
||||
stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], CLayout{});
|
||||
|
||||
a_m_k_tensors.push_back(
|
||||
ck_tile::HostTensor<ADataType>(f_host_tensor_descriptor(M, K, stride_As[i], a_layout)));
|
||||
b_k_n_tensors.push_back(
|
||||
ck_tile::HostTensor<BDataType>(f_host_tensor_descriptor(K, N, stride_Bs[i], b_layout)));
|
||||
a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_As[i], a_layout)));
|
||||
b_k_n_tensors.push_back(ck_tile::HostTensor<BDataType>(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], b_layout)));
|
||||
c_m_n_tensors.push_back(ck_tile::HostTensor<CDataType>(
|
||||
f_host_tensor_descriptor(M, N, stride_Cs[i], CLayout{})));
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], CLayout{})));
|
||||
|
||||
std::cout << "gemm[" << i << "]"
|
||||
<< " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc
|
||||
@@ -178,7 +178,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
f_host_tensor_descriptor(Ms[i], Ns[i], stride_Cs[i], CLayout{}));
|
||||
ck_tile::host_tensor_descriptor(Ms[i], Ns[i], stride_Cs[i], CLayout{}));
|
||||
c_m_n_host_ref.SetZero();
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
template <typename TLayout>
|
||||
constexpr auto
|
||||
f_host_tensor_descriptor(std::size_t row, std::size_t col, std::size_t stride, TLayout layout)
|
||||
{
|
||||
using namespace ck_tile::literals;
|
||||
|
||||
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
}
|
||||
template <typename TLayout>
|
||||
constexpr auto
|
||||
f_get_default_stride(std::size_t row, std::size_t col, std::size_t stride, TLayout layout)
|
||||
{
|
||||
if(stride == 0)
|
||||
{
|
||||
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return col;
|
||||
}
|
||||
else
|
||||
{
|
||||
return row;
|
||||
}
|
||||
}
|
||||
else
|
||||
return stride;
|
||||
}
|
||||
Reference in New Issue
Block a user