mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Ck tile grouped GEMM example (#1713)
* Ck-tile, impl. grouped gemm * Workspace is allocated by user, and is passed to the function * Prepare test to new api design * Unify GemTransKernelArgs, removing N0 param * Add 1 to dim3 in paritioner * Typo: gem - > gemm --------- Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
2
example/ck_tile/17_grouped_gemm/CMakeLists.txt
Normal file
2
example/ck_tile/17_grouped_gemm/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
add_executable(tile_example_grouped_gemm EXCLUDE_FROM_ALL grouped_gemm.cpp)
|
||||
|
||||
25
example/ck_tile/17_grouped_gemm/README.md
Normal file
25
example/ck_tile/17_grouped_gemm/README.md
Normal file
@@ -0,0 +1,25 @@
|
||||
# Grouped CShuffle GEMM
|
||||
|
||||
This folder contains example for Grouped GEMM using ck_tile tile-programming implementation. Currently, it only supports the basic feature of the CK Tile GEMM, but creates the placeholders for the future support on different GEMM pipeline and different GEMM modules. In the near future, we will gradually migrate all the GEMM features from old CK to CK Tile.
|
||||
|
||||
## build
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch>
|
||||
# The basic pipeline method on the gemm calculation
|
||||
make tile_example_grouped_gemm -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_grouped_gemm`
|
||||
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-a_layout Tensor A layout (default:R)
|
||||
-b_layout Tensor B layout (default:R)
|
||||
-c_layout Tensor C layout (default:R)
|
||||
-v 0. No validation, 1. Validation on CPU
|
||||
-warmup number of iterations before benchmark the kernel (default:10)
|
||||
-repeat number of iterations to benchmark the kernel (default:100)
|
||||
```
|
||||
151
example/ck_tile/17_grouped_gemm/grouped_gemm.cpp
Normal file
151
example/ck_tile/17_grouped_gemm/grouped_gemm.cpp
Normal file
@@ -0,0 +1,151 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <memory>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "grouped_gemm.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
struct GroupedGemmKernelParam
|
||||
{
|
||||
static const bool kPadM = false;
|
||||
static const bool kPadN = false;
|
||||
static const bool kPadK = false;
|
||||
static const bool kTilePermute = false;
|
||||
|
||||
static const ck_tile::index_t kOutputRank = 2;
|
||||
|
||||
static const int kBlockPerCu = 1;
|
||||
static const ck_tile::index_t M_Tile = 128;
|
||||
static const ck_tile::index_t N_Tile = 128;
|
||||
static const ck_tile::index_t K_Tile = 32;
|
||||
|
||||
static const ck_tile::index_t M_Warp = 2;
|
||||
static const ck_tile::index_t N_Warp = 2;
|
||||
static const ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static const ck_tile::index_t M_Warp_Tile = 32;
|
||||
static const ck_tile::index_t N_Warp_Tile = 32;
|
||||
static const ck_tile::index_t K_Warp_Tile = 8;
|
||||
};
|
||||
|
||||
using CodegenGemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<GroupedGemmKernelParam::M_Tile,
|
||||
GroupedGemmKernelParam::N_Tile,
|
||||
GroupedGemmKernelParam::K_Tile>,
|
||||
ck_tile::sequence<GroupedGemmKernelParam::M_Warp,
|
||||
GroupedGemmKernelParam::N_Warp,
|
||||
GroupedGemmKernelParam::K_Warp>,
|
||||
ck_tile::sequence<GroupedGemmKernelParam::M_Warp_Tile,
|
||||
GroupedGemmKernelParam::N_Warp_Tile,
|
||||
GroupedGemmKernelParam::K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
|
||||
|
||||
template <typename CLayout>
|
||||
using GemmEpilogue = std::conditional_t<
|
||||
std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>,
|
||||
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType,
|
||||
CDataType,
|
||||
GroupedGemmKernelParam::kPadM,
|
||||
GroupedGemmKernelParam::kPadN,
|
||||
GroupedGemmKernelParam::kTilePermute,
|
||||
GroupedGemmKernelParam::kOutputRank,
|
||||
1,
|
||||
0,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock>>,
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<AccDataType,
|
||||
CDataType,
|
||||
GroupedGemmKernelParam::kPadM,
|
||||
GroupedGemmKernelParam::kPadN>>>;
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
using CodegenGemmTraits = ck_tile::TileGemmTraits<GroupedGemmKernelParam::kPadM,
|
||||
GroupedGemmKernelParam::kPadN,
|
||||
GroupedGemmKernelParam::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
using CodegenPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenGemmShape,
|
||||
CodegenGemmTraits<ALayout, BLayout, CLayout>>;
|
||||
|
||||
using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy;
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
using CodegenGemmPipeline =
|
||||
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem<ALayout, BLayout, CLayout>,
|
||||
CodegenGemmPolicy>;
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner,
|
||||
CodegenGemmPipeline<ALayout, BLayout, CLayout>,
|
||||
GemmEpilogue<CLayout>>;
|
||||
}; // namespace
|
||||
|
||||
std::size_t GetWorkspaceSize(const std::vector<grouped_gemm_kargs>& gemm_descs)
|
||||
{
|
||||
return ::Kernel<std::nullptr_t, std::nullptr_t, std::nullptr_t>::GetWorkSpaceSize(gemm_descs);
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
void* p_workspace_)
|
||||
{
|
||||
using GroupedGemmKernel = ::Kernel<ALayout, BLayout, CLayout>;
|
||||
|
||||
auto arguments = GroupedGemmKernel::MakeKargs(gemm_descs);
|
||||
|
||||
const dim3 grids = GroupedGemmKernel::GridSize(gemm_descs);
|
||||
constexpr dim3 blocks = GroupedGemmKernel::BlockSize();
|
||||
|
||||
ck_tile::hip_check_error(hipMemcpyWithStream(
|
||||
p_workspace_,
|
||||
arguments.data(),
|
||||
arguments.size() * sizeof(typename GroupedGemmKernel::GemmTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:"
|
||||
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
float ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, GroupedGemmKernelParam::kBlockPerCu>(
|
||||
GroupedGemmKernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(p_workspace_),
|
||||
gemm_descs.size()));
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
#include "run_grouped_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }
|
||||
53
example/ck_tile/17_grouped_gemm/grouped_gemm.hpp
Normal file
53
example/ck_tile/17_grouped_gemm/grouped_gemm.hpp
Normal file
@@ -0,0 +1,53 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
|
||||
|
||||
template <typename DataType>
|
||||
struct GemmBasicTypeConfig;
|
||||
|
||||
template <>
|
||||
struct GemmBasicTypeConfig<ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using CDataType = ck_tile::half_t;
|
||||
using AccDataType = float;
|
||||
};
|
||||
|
||||
using Types = GemmBasicTypeConfig<ck_tile::half_t>;
|
||||
|
||||
// Specific type aliases for easy access
|
||||
using ADataType = Types::ADataType;
|
||||
using BDataType = Types::BDataType;
|
||||
using AccDataType = Types::AccDataType;
|
||||
using CDataType = Types::CDataType;
|
||||
|
||||
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
.insert("b_layout", "R", "B tensor data layout - Row by default")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default")
|
||||
.insert("validate", "1", "0. No validation, 1. Validation on CPU")
|
||||
.insert("warmup", "10", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("group_count", "16", "group count");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceSize(const std::vector<grouped_gemm_kargs>& gemm_descs);
|
||||
|
||||
float grouped_gemm_calc(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
void* p_workspace_);
|
||||
191
example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc
Normal file
191
example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc
Normal file
@@ -0,0 +1,191 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
float invoke_gemm(int n_warmup,
|
||||
int n_repeat,
|
||||
int group_count,
|
||||
const std::vector<grouped_gemm_kargs>& args)
|
||||
{
|
||||
|
||||
ck_tile::DeviceMem gemm_workspace;
|
||||
gemm_workspace.Realloc(GetWorkspaceSize(args));
|
||||
|
||||
float ave_time = grouped_gemm<ALayout, BLayout, CLayout>(
|
||||
args,
|
||||
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat},
|
||||
gemm_workspace.GetDeviceBuffer());
|
||||
|
||||
std::string op_name{"Grouped Gemm"};
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
for(int j = 0; j < group_count; ++j)
|
||||
{
|
||||
flop += std::size_t(2) * args[j].M * args[j].N * args[j].K;
|
||||
|
||||
num_btype += sizeof(ADataType) * args[j].M * args[j].K +
|
||||
sizeof(BDataType) * args[j].K * args[j].N +
|
||||
sizeof(CDataType) * args[j].M * args[j].N;
|
||||
}
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
int run_grouped_gemm_example_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
};
|
||||
|
||||
const int group_count = arg_parser.get_int("group_count");
|
||||
const int repeat = arg_parser.get_int("repeat");
|
||||
const int warmup = arg_parser.get_int("warmup");
|
||||
|
||||
std::vector<ck_tile::index_t> Ms;
|
||||
std::vector<ck_tile::index_t> Ns;
|
||||
std::vector<ck_tile::index_t> Ks;
|
||||
std::vector<ck_tile::index_t> stride_As;
|
||||
std::vector<ck_tile::index_t> stride_Bs;
|
||||
std::vector<ck_tile::index_t> stride_Cs;
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
Ms.push_back(256 + 256 * i);
|
||||
Ns.push_back(128 + 128 * i);
|
||||
Ks.push_back(128 + 64 * i);
|
||||
|
||||
stride_As.push_back(Ks[i]);
|
||||
stride_Bs.push_back(Ks[i]);
|
||||
stride_Cs.push_back(Ns[i]);
|
||||
}
|
||||
|
||||
std::vector<ck_tile::HostTensor<ADataType>> a_m_k_tensors;
|
||||
std::vector<ck_tile::HostTensor<BDataType>> b_k_n_tensors;
|
||||
std::vector<ck_tile::HostTensor<CDataType>> c_m_n_tensors;
|
||||
|
||||
a_m_k_tensors.reserve(group_count);
|
||||
b_k_n_tensors.reserve(group_count);
|
||||
c_m_n_tensors.reserve(group_count);
|
||||
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> a_m_k_dev_buf;
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> b_k_n_dev_buf;
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> c_m_n_dev_buf;
|
||||
|
||||
a_m_k_dev_buf.reserve(group_count);
|
||||
b_k_n_dev_buf.reserve(group_count);
|
||||
c_m_n_dev_buf.reserve(group_count);
|
||||
|
||||
std::vector<grouped_gemm_kargs> gemm_descs;
|
||||
gemm_descs.reserve(group_count);
|
||||
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
const ck_tile::index_t M = Ms[i];
|
||||
const ck_tile::index_t N = Ns[i];
|
||||
const ck_tile::index_t K = Ks[i];
|
||||
|
||||
stride_As[i] = f_get_default_stride(M, N, stride_As[i], a_layout);
|
||||
stride_Bs[i] = f_get_default_stride(K, N, stride_Bs[i], b_layout);
|
||||
stride_Cs[i] = f_get_default_stride(M, N, stride_Cs[i], CLayout{});
|
||||
|
||||
a_m_k_tensors.push_back(
|
||||
ck_tile::HostTensor<ADataType>(f_host_tensor_descriptor(M, K, stride_As[i], a_layout)));
|
||||
b_k_n_tensors.push_back(
|
||||
ck_tile::HostTensor<BDataType>(f_host_tensor_descriptor(K, N, stride_Bs[i], b_layout)));
|
||||
c_m_n_tensors.push_back(ck_tile::HostTensor<CDataType>(
|
||||
f_host_tensor_descriptor(M, N, stride_Cs[i], CLayout{})));
|
||||
|
||||
std::cout << "gemm[" << i << "]"
|
||||
<< " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc
|
||||
<< " c_m_n: " << c_m_n_tensors[i].mDesc << std::endl;
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n_tensors[i]);
|
||||
|
||||
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
a_m_k_tensors[i].get_element_space_size_in_bytes()));
|
||||
b_k_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
b_k_n_tensors[i].get_element_space_size_in_bytes()));
|
||||
c_m_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
c_m_n_tensors[i].get_element_space_size_in_bytes()));
|
||||
|
||||
a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data());
|
||||
b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data());
|
||||
c_m_n_dev_buf[i]->SetZero();
|
||||
c_m_n_tensors[i].SetZero();
|
||||
|
||||
const void* p_a = a_m_k_dev_buf[i]->GetDeviceBuffer();
|
||||
const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer();
|
||||
void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer();
|
||||
|
||||
gemm_descs.push_back({p_a, p_b, p_c, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
|
||||
}
|
||||
|
||||
invoke_gemm<ALayout, BLayout, CLayout>(warmup, repeat, group_count, gemm_descs);
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
c_m_n_dev_buf[i]->FromDevice(c_m_n_tensors[i].data());
|
||||
}
|
||||
|
||||
bool pass{true};
|
||||
if(arg_parser.get_int("validate"))
|
||||
{
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
f_host_tensor_descriptor(Ms[i], Ns[i], stride_Cs[i], CLayout{}));
|
||||
c_m_n_host_ref.SetZero();
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
|
||||
pass &= ck_tile::check_err(c_m_n_tensors[i], c_m_n_host_ref);
|
||||
}
|
||||
std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int run_grouped_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
const std::string a_layout = arg_parser.get_str("a_layout");
|
||||
const std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "R" && b_layout == "R")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
}
|
||||
}
|
||||
38
example/ck_tile/17_grouped_gemm/utils.hpp
Normal file
38
example/ck_tile/17_grouped_gemm/utils.hpp
Normal file
@@ -0,0 +1,38 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
template <typename TLayout>
|
||||
constexpr auto
|
||||
f_host_tensor_descriptor(std::size_t row, std::size_t col, std::size_t stride, TLayout layout)
|
||||
{
|
||||
using namespace ck_tile::literals;
|
||||
|
||||
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
}
|
||||
template <typename TLayout>
|
||||
constexpr auto
|
||||
f_get_default_stride(std::size_t row, std::size_t col, std::size_t stride, TLayout layout)
|
||||
{
|
||||
if(stride == 0)
|
||||
{
|
||||
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return col;
|
||||
}
|
||||
else
|
||||
{
|
||||
return row;
|
||||
}
|
||||
}
|
||||
else
|
||||
return stride;
|
||||
}
|
||||
@@ -16,3 +16,4 @@ add_subdirectory(13_moe_sorting)
|
||||
add_subdirectory(14_moe_smoothquant)
|
||||
add_subdirectory(15_fused_moe)
|
||||
add_subdirectory(16_batched_gemm)
|
||||
add_subdirectory(17_grouped_gemm)
|
||||
|
||||
37
include/ck_tile/core/utility/amd_address_space.hpp
Normal file
37
include/ck_tile/core/utility/amd_address_space.hpp
Normal file
@@ -0,0 +1,37 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
|
||||
// Address Space for AMDGCN
|
||||
// https://llvm.org/docs/AMDGPUUsage.html#address-space
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
#define CK_CONSTANT_ADDRESS_SPACE __attribute__((address_space(4)))
|
||||
|
||||
template <typename T>
|
||||
__device__ T* cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE* p)
|
||||
{
|
||||
// cast a pointer in "Constant" address space (4) to "Generic" address space (0)
|
||||
// only c-style pointer cast seems be able to be compiled
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wold-style-cast"
|
||||
return (T*)p; // NOLINT(old-style-cast)
|
||||
#pragma clang diagnostic pop
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE* cast_pointer_to_constant_address_space(T* p)
|
||||
{
|
||||
// cast a pointer in "Generic" address space (0) to "Constant" address space (4)
|
||||
// only c-style pointer cast seems be able to be compiled
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wold-style-cast"
|
||||
return (T CK_CONSTANT_ADDRESS_SPACE*)p; // NOLINT(old-style-cast)
|
||||
#pragma clang diagnostic pop
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -25,6 +25,7 @@
|
||||
#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
|
||||
|
||||
@@ -35,4 +35,40 @@ struct GemmTilePartitioner
|
||||
return make_tuple(iM, iN);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename BlockGemmShape_>
|
||||
struct GemmTile1DPartitioner
|
||||
{
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N)
|
||||
{
|
||||
index_t GridDimX = (M + MPerBlock - 1) / MPerBlock;
|
||||
index_t GridDimY = (N + NPerBlock - 1) / NPerBlock;
|
||||
return dim3(GridDimX * GridDimY, 1, 1);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetNBlock(index_t N)
|
||||
{
|
||||
return integer_divide_ceil(N, NPerBlock);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetLoopNum(index_t K)
|
||||
{
|
||||
return integer_divide_ceil(K, KPerBlock);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto operator()(index_t blockOffset, index_t NBlockSize)
|
||||
{
|
||||
index_t iM = __builtin_amdgcn_readfirstlane((blockIdx.x - blockOffset) /
|
||||
GetNBlock(NBlockSize) * MPerBlock);
|
||||
index_t iN = __builtin_amdgcn_readfirstlane((blockIdx.x - blockOffset) %
|
||||
GetNBlock(NBlockSize) * NPerBlock);
|
||||
return make_tuple(iM, iN);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
310
include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp
Normal file
310
include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp
Normal file
@@ -0,0 +1,310 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/literals.hpp"
|
||||
#include "ck_tile/core/utility/amd_address_space.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct GroupedGemmHostArgs
|
||||
{
|
||||
const void* a_ptr;
|
||||
const void* b_ptr;
|
||||
void* c_ptr;
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t stride_A;
|
||||
index_t stride_B;
|
||||
index_t stride_C;
|
||||
};
|
||||
|
||||
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
|
||||
struct GroupedGemmKernel
|
||||
{
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
|
||||
|
||||
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
struct GemmTransKernelArg
|
||||
{
|
||||
GroupedGemmHostArgs group_karg;
|
||||
ck_tile::index_t block_start;
|
||||
ck_tile::index_t block_end;
|
||||
|
||||
GemmTransKernelArg() = default;
|
||||
GemmTransKernelArg(GroupedGemmHostArgs&& karg, index_t bl_start, index_t bl_end)
|
||||
: group_karg{karg}, block_start{bl_start}, block_end{bl_end}
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
__host__ static size_t GetWorkSpaceSize(const std::vector<GroupedGemmHostArgs>& gemm_descs)
|
||||
{
|
||||
return gemm_descs.size() * sizeof(GemmTransKernelArg);
|
||||
}
|
||||
|
||||
__host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
|
||||
|
||||
using Hargs = GroupedGemmHostArgs;
|
||||
|
||||
__host__ static constexpr auto GridSize(const std::vector<Hargs>& gemm_descs)
|
||||
{
|
||||
index_t grid_size = 0;
|
||||
for(const auto& it_desc : gemm_descs)
|
||||
{
|
||||
const auto dim3 = TilePartitioner::GridSize(it_desc.M, it_desc.N);
|
||||
grid_size += dim3.x * dim3.y * 1;
|
||||
}
|
||||
return dim3(grid_size, 1, 1);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static auto MakeKargs(const std::vector<Hargs>& gemm_descs)
|
||||
{
|
||||
std::vector<GemmTransKernelArg> gemm_kernel_args_;
|
||||
index_t group_count = ck_tile::type_convert<ck_tile::index_t>(gemm_descs.size());
|
||||
index_t grid_size = 0;
|
||||
gemm_kernel_args_.reserve(group_count);
|
||||
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); ++i)
|
||||
{
|
||||
const index_t M = gemm_descs[i].M;
|
||||
const index_t N = gemm_descs[i].N;
|
||||
const index_t K = gemm_descs[i].K;
|
||||
|
||||
if(M == 0 || N == 0 || K == 0)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
const index_t stride_a = gemm_descs[i].stride_A;
|
||||
const index_t stride_b = gemm_descs[i].stride_B;
|
||||
const index_t stride_c = gemm_descs[i].stride_C;
|
||||
|
||||
const auto dim3 = TilePartitioner::GridSize(M, N);
|
||||
const index_t grid_size_grp = dim3.x * 1 * 1;
|
||||
|
||||
const index_t block_start = grid_size;
|
||||
const index_t block_end = grid_size + grid_size_grp;
|
||||
|
||||
grid_size += grid_size_grp;
|
||||
|
||||
auto karg = GroupedGemmHostArgs{type_convert<const ADataType*>(gemm_descs[i].a_ptr),
|
||||
type_convert<const BDataType*>(gemm_descs[i].b_ptr),
|
||||
type_convert<CDataType*>(gemm_descs[i].c_ptr),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c};
|
||||
|
||||
gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end);
|
||||
}
|
||||
|
||||
return gemm_kernel_args_;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void Run(const Hargs& kargs, const index_t block_start) const
|
||||
{
|
||||
const auto [i_m, i_n] = TilePartitioner{}(block_start, kargs.N);
|
||||
// options
|
||||
const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr);
|
||||
const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr);
|
||||
// Convert pointers to tensor views
|
||||
auto a_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_start,
|
||||
make_tuple(kargs.M, kargs.K),
|
||||
make_tuple(kargs.stride_A, 1),
|
||||
number<GemmPipeline::VectorSizeA>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_start,
|
||||
make_tuple(kargs.M, kargs.K),
|
||||
make_tuple(1, kargs.stride_A),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
auto b_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_start,
|
||||
make_tuple(kargs.N, kargs.K),
|
||||
make_tuple(1, kargs.stride_B),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_start,
|
||||
make_tuple(kargs.N, kargs.K),
|
||||
make_tuple(kargs.stride_B, 1),
|
||||
number<GemmPipeline::VectorSizeB>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
auto a_pad_view = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<GemmPipeline::kPadM, false>{});
|
||||
}
|
||||
}();
|
||||
// clang-format on
|
||||
|
||||
auto a_block_window = make_tile_window(
|
||||
a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, 0});
|
||||
|
||||
auto b_pad_view = [&]() {
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return pad_tensor_view(b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<GemmPipeline::kPadN, false>{});
|
||||
}
|
||||
}();
|
||||
|
||||
auto b_block_window = make_tile_window(
|
||||
b_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
{i_n, 0});
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K);
|
||||
|
||||
// Run GEMM cooperatively by whole wokrgroup.
|
||||
auto c_block_tile =
|
||||
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr);
|
||||
|
||||
CDataType* c_start = static_cast<CDataType*>(kargs.c_ptr);
|
||||
auto c_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
c_start,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.stride_C, 1),
|
||||
number<GemmPipeline::VectorSizeC>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
c_start,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(1, kargs.stride_C),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
auto c_pad_view = [&]() {
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<GemmPipeline::kPadM, false>{});
|
||||
}
|
||||
}();
|
||||
auto CBlockWindow_pad = make_tile_window(
|
||||
c_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
|
||||
EpiloguePipeline{}(CBlockWindow_pad, c_block_tile);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
|
||||
int group_count) const
|
||||
{
|
||||
const index_t block_id = ck_tile::get_block_1d_id();
|
||||
const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg*>(
|
||||
cast_pointer_to_generic_address_space(gemm_descs_const));
|
||||
|
||||
index_t left = 0;
|
||||
index_t right = group_count;
|
||||
index_t group_id = index_t((left + right) / 2);
|
||||
|
||||
while((!(block_id >= gemm_desc_ptr[group_id].block_start &&
|
||||
block_id < gemm_desc_ptr[group_id].block_end)) &&
|
||||
left <= right)
|
||||
{
|
||||
if(block_id < gemm_desc_ptr[group_id].block_start)
|
||||
{
|
||||
right = group_id;
|
||||
}
|
||||
else
|
||||
{
|
||||
left = group_id;
|
||||
}
|
||||
group_id = index_t((left + right) / 2);
|
||||
}
|
||||
|
||||
Run(gemm_desc_ptr[group_id].group_karg, gemm_desc_ptr[group_id].block_start);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,3 +1,4 @@
|
||||
add_subdirectory(image_to_column)
|
||||
add_subdirectory(gemm)
|
||||
add_subdirectory(batched_gemm)
|
||||
add_subdirectory(grouped_gemm)
|
||||
|
||||
4
test/ck_tile/grouped_gemm/CMakeLists.txt
Normal file
4
test/ck_tile/grouped_gemm/CMakeLists.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
# Currently ck_tile is only built on gfx9
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm test_grouped_gemm.cpp)
|
||||
endif()
|
||||
29
test/ck_tile/grouped_gemm/test_grouped_gemm.cpp
Normal file
29
test/ck_tile/grouped_gemm/test_grouped_gemm.cpp
Normal file
@@ -0,0 +1,29 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_grouped_gemm_util.hpp"
|
||||
|
||||
using F16 = ck_tile::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16>,
|
||||
//std::tuple< Col, Row, Row, F16, F16, F32, F16>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16>//,
|
||||
//std::tuple< Col, Col, Row, F16, F16, F32, F16>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGroupedGemm, KernelTypes);
|
||||
|
||||
#include "test_grouped_gemm_ut_cases.inc"
|
||||
25
test/ck_tile/grouped_gemm/test_grouped_gemm_ut_cases.inc
Normal file
25
test/ck_tile/grouped_gemm/test_grouped_gemm_ut_cases.inc
Normal file
@@ -0,0 +1,25 @@
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TestCkTileGroupedGemm, Basic)
|
||||
{
|
||||
const int group_count = 16;
|
||||
std::vector<int> Ms;
|
||||
std::vector<int> Ns;
|
||||
std::vector<int> Ks;
|
||||
std::vector<int> stride_As;
|
||||
std::vector<int> stride_Bs;
|
||||
std::vector<int> stride_Cs;
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
Ms.push_back(256 + 256 * i);
|
||||
Ns.push_back(128 + 128 * i);
|
||||
Ks.push_back(128 + 64 * i);
|
||||
|
||||
stride_As.push_back(Ks[i]);
|
||||
stride_Bs.push_back(Ks[i]);
|
||||
stride_Cs.push_back(Ns[i]);
|
||||
}
|
||||
|
||||
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, group_count);
|
||||
}
|
||||
282
test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp
Normal file
282
test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp
Normal file
@@ -0,0 +1,282 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#pragma once
|
||||
|
||||
#include <sstream>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileGroupedGemm : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using CLayout = std::tuple_element_t<2, Tuple>;
|
||||
using ADataType = std::tuple_element_t<3, Tuple>;
|
||||
using BDataType = std::tuple_element_t<4, Tuple>;
|
||||
using AccDataType = std::tuple_element_t<5, Tuple>;
|
||||
using CDataType = std::tuple_element_t<6, Tuple>;
|
||||
|
||||
struct GroupedGemKernelParam
|
||||
{
|
||||
static const bool kPadM = false;
|
||||
static const bool kPadN = false;
|
||||
static const bool kPadK = false;
|
||||
static const bool kTilePermute = false;
|
||||
|
||||
static const ck_tile::index_t kOutputRank = 2;
|
||||
|
||||
static const int kBlockPerCu = 1;
|
||||
static const ck_tile::index_t M_Tile = 128;
|
||||
static const ck_tile::index_t N_Tile = 128;
|
||||
static const ck_tile::index_t K_Tile = 32;
|
||||
|
||||
static const ck_tile::index_t M_Warp = 2;
|
||||
static const ck_tile::index_t N_Warp = 2;
|
||||
static const ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static const ck_tile::index_t M_Warp_Tile = 32;
|
||||
static const ck_tile::index_t N_Warp_Tile = 32;
|
||||
static const ck_tile::index_t K_Warp_Tile = 8;
|
||||
};
|
||||
|
||||
using CodegenGemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<GroupedGemKernelParam::M_Tile,
|
||||
GroupedGemKernelParam::N_Tile,
|
||||
GroupedGemKernelParam::K_Tile>,
|
||||
ck_tile::sequence<GroupedGemKernelParam::M_Warp,
|
||||
GroupedGemKernelParam::N_Warp,
|
||||
GroupedGemKernelParam::K_Warp>,
|
||||
ck_tile::sequence<GroupedGemKernelParam::M_Warp_Tile,
|
||||
GroupedGemKernelParam::N_Warp_Tile,
|
||||
GroupedGemKernelParam::K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
|
||||
|
||||
template <typename CLayout>
|
||||
using GemmEpilogue =
|
||||
std::conditional_t<std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>,
|
||||
ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<AccDataType,
|
||||
CDataType,
|
||||
GroupedGemKernelParam::kPadM,
|
||||
GroupedGemKernelParam::kPadN,
|
||||
GroupedGemKernelParam::kTilePermute,
|
||||
GroupedGemKernelParam::kOutputRank,
|
||||
1,
|
||||
0,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock>>,
|
||||
ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<AccDataType,
|
||||
CDataType,
|
||||
GroupedGemKernelParam::kPadM,
|
||||
GroupedGemKernelParam::kPadN>>>;
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
using CodegenGemmTraits = ck_tile::TileGemmTraits<GroupedGemKernelParam::kPadM,
|
||||
GroupedGemKernelParam::kPadN,
|
||||
GroupedGemKernelParam::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
using CodegenPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenGemmShape,
|
||||
CodegenGemmTraits<ALayout, BLayout, CLayout>>;
|
||||
|
||||
using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy;
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
using CodegenGemmPipeline =
|
||||
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem<ALayout, BLayout, CLayout>,
|
||||
CodegenGemmPolicy>;
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner,
|
||||
CodegenGemmPipeline<ALayout, BLayout, CLayout>,
|
||||
GemmEpilogue<CLayout>>;
|
||||
|
||||
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
|
||||
std::size_t GetWorkspaceSize(const std::vector<grouped_gemm_kargs>& gemm_descs)
|
||||
{
|
||||
return Kernel<std::nullptr_t, std::nullptr_t, std::nullptr_t>::GetWorkSpaceSize(gemm_descs);
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
void invoke_grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
void* p_workspace_)
|
||||
{
|
||||
using GroupedGemmKernel = Kernel<ALayout, BLayout, CLayout>;
|
||||
|
||||
auto arguments = GroupedGemmKernel::MakeKargs(gemm_descs);
|
||||
|
||||
const dim3 grids = GroupedGemmKernel::GridSize(gemm_descs);
|
||||
constexpr dim3 blocks = GroupedGemmKernel::BlockSize();
|
||||
|
||||
ck_tile::hip_check_error(hipMemcpyWithStream(
|
||||
p_workspace_,
|
||||
arguments.data(),
|
||||
arguments.size() * sizeof(typename GroupedGemmKernel::GemmTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:"
|
||||
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, GroupedGemKernelParam::kBlockPerCu>(
|
||||
GroupedGemmKernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(p_workspace_),
|
||||
gemm_descs.size()));
|
||||
}
|
||||
|
||||
public:
|
||||
void Run(const std::vector<int>& Ms,
|
||||
const std::vector<int>& Ns,
|
||||
const std::vector<int>& Ks,
|
||||
std::vector<int>& stride_As,
|
||||
std::vector<int>& stride_Bs,
|
||||
std::vector<int>& stride_Cs,
|
||||
const int group_count = 16)
|
||||
{
|
||||
using namespace ck_tile::literals;
|
||||
auto f_host_tensor_descriptor = [](std::size_t row,
|
||||
std::size_t col,
|
||||
std::size_t stride,
|
||||
auto layout) {
|
||||
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(stride == 0)
|
||||
{
|
||||
if constexpr(std::is_same_v<decltype(layout),
|
||||
ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return col;
|
||||
}
|
||||
else
|
||||
{
|
||||
return row;
|
||||
}
|
||||
}
|
||||
else
|
||||
return stride;
|
||||
};
|
||||
|
||||
std::vector<ck_tile::HostTensor<ADataType>> a_m_k_tensors;
|
||||
std::vector<ck_tile::HostTensor<BDataType>> b_k_n_tensors;
|
||||
std::vector<ck_tile::HostTensor<CDataType>> c_m_n_tensors;
|
||||
|
||||
a_m_k_tensors.reserve(group_count);
|
||||
b_k_n_tensors.reserve(group_count);
|
||||
c_m_n_tensors.reserve(group_count);
|
||||
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> a_m_k_dev_buf;
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> b_k_n_dev_buf;
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> c_m_n_dev_buf;
|
||||
|
||||
a_m_k_dev_buf.reserve(group_count);
|
||||
b_k_n_dev_buf.reserve(group_count);
|
||||
c_m_n_dev_buf.reserve(group_count);
|
||||
|
||||
std::vector<grouped_gemm_kargs> gemm_descs;
|
||||
gemm_descs.reserve(group_count);
|
||||
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
const ck_tile::index_t M = Ms[i];
|
||||
const ck_tile::index_t N = Ns[i];
|
||||
const ck_tile::index_t K = Ks[i];
|
||||
|
||||
stride_As[i] = f_get_default_stride(M, N, stride_As[i], ALayout{});
|
||||
stride_Bs[i] = f_get_default_stride(K, N, stride_Bs[i], BLayout{});
|
||||
stride_Cs[i] = f_get_default_stride(M, N, stride_Cs[i], CLayout{});
|
||||
|
||||
a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
|
||||
f_host_tensor_descriptor(M, K, stride_As[i], ALayout{})));
|
||||
b_k_n_tensors.push_back(ck_tile::HostTensor<BDataType>(
|
||||
f_host_tensor_descriptor(K, N, stride_Bs[i], BLayout{})));
|
||||
c_m_n_tensors.push_back(ck_tile::HostTensor<CDataType>(
|
||||
f_host_tensor_descriptor(M, N, stride_Cs[i], CLayout{})));
|
||||
|
||||
std::cout << "gemm[" << i << "]"
|
||||
<< " a_m_k: " << a_m_k_tensors[i].mDesc
|
||||
<< " b_k_n: " << b_k_n_tensors[i].mDesc
|
||||
<< " c_m_n: " << c_m_n_tensors[i].mDesc << std::endl;
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n_tensors[i]);
|
||||
|
||||
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
a_m_k_tensors[i].get_element_space_size_in_bytes()));
|
||||
b_k_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
b_k_n_tensors[i].get_element_space_size_in_bytes()));
|
||||
c_m_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
c_m_n_tensors[i].get_element_space_size_in_bytes()));
|
||||
|
||||
a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data());
|
||||
b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data());
|
||||
c_m_n_dev_buf[i]->SetZero();
|
||||
c_m_n_tensors[i].SetZero();
|
||||
|
||||
const void* p_a = a_m_k_dev_buf[i]->GetDeviceBuffer();
|
||||
const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer();
|
||||
void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer();
|
||||
|
||||
gemm_descs.push_back(
|
||||
{p_a, p_b, p_c, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem gemm_workspace;
|
||||
gemm_workspace.Realloc(GetWorkspaceSize(gemm_descs));
|
||||
|
||||
invoke_grouped_gemm<ALayout, BLayout, CLayout>(
|
||||
gemm_descs, ck_tile::stream_config{nullptr, false}, gemm_workspace.GetDeviceBuffer());
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
c_m_n_dev_buf[i]->FromDevice(c_m_n_tensors[i].data());
|
||||
}
|
||||
|
||||
bool pass{true};
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
f_host_tensor_descriptor(Ms[i], Ns[i], stride_Cs[i], CLayout{}));
|
||||
c_m_n_host_ref.SetZero();
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
|
||||
pass &= ck_tile::check_err(c_m_n_tensors[i], c_m_n_host_ref);
|
||||
}
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
Reference in New Issue
Block a user