Implement device_grouped_gemm_fixed_nk_bias for RDNA4 (#4340)

## Proposed changes

Summary:

- Modified implementation for grouped_gemm_fixed_nk_bias
- FP16 WMMA examples
- WMMA instances
- Profiler for grouped_gemm_fixed_nk_bias
- Add WMMA instances to existing tests

**This PR depends on PR https://github.com/ROCm/rocm-libraries/pull/4299
and should be merged after it.
Only the last 6 commits are in the scope of this PR.**

## Checklist

Please put an `x` into the boxes that apply. You can also fill these out
after creating the PR. If you're not sure, please don't hesitate to ask.

- [x] I have added tests relevant to the introduced functionality, and
the unit tests are passing locally
- [x] I have added the test to REGRESSION_TESTS list defined at the top
of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more
than 30 seconds to run.
- [x] I have added inline documentation which enables the maintainers
with understanding the motivation
- [x] I have removed the stale documentation which is no longer relevant
after this pull request
- [ ] (If this change is user-facing) I have added release notes which
provide the end users with a brief summary of the improvement from this
pull request
- [x] I have run `clang-format` on all changed files
- [ ] Any dependent changes have been merged

## Discussion

If this is a relatively large or complex change, feel free to start a
discussion by explaining why you chose the solution you did and what
alternatives you considered

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
Yung-sheng Tu
2026-02-26 01:28:09 +01:00
committed by GitHub
parent bea67cb520
commit 00853e2bd2
11 changed files with 1514 additions and 40 deletions

View File

@@ -50,6 +50,9 @@ add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_multiple
add_example_executable(example_grouped_gemm_wmma_fixed_nk_fp16 grouped_gemm_wmma_fixed_nk_fp16.cpp)
add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_fixed_nk_fp16)
add_example_executable(example_grouped_gemm_wmma_fixed_nk_bias_fp16 grouped_gemm_wmma_fixed_nk_bias_fp16.cpp)
add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_fixed_nk_bias_fp16)
list(APPEND gpu_list_tf32 gfx942 gfx950)
set(target 0)

View File

@@ -0,0 +1,406 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/ck.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/stream_config.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/tuple.hpp"
#include <array>
#include <cstddef>
#include <iostream>
#include <memory>
#include <stdexcept>
#include <string>
#include <type_traits>
#include <vector>
using ck::DeviceMem;
using ck::hip_check_error;
using ck::HostTensorDescriptor;
using ck::Tensor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using SplitKAdd = ck::tensor_operation::element_wise::SplitKAdd;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = F16;
using DsDataType = ck::Tuple<D0DataType>;
using EDataType = F16;
using ALayout = Row;
using BLayout = Row;
using D0Layout = Row;
using DsLayout = ck::Tuple<D0Layout>;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = SplitKAdd;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_Fixed_Nk
// clang-format off
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 128, 32, 128, 32, 8, 8, 16, 16, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>;
// clang-format on
struct ProblemSize final
{
std::vector<ck::index_t> Ms;
std::vector<ck::index_t> Ns;
std::vector<ck::index_t> Ks;
std::vector<ck::index_t> stride_As;
std::vector<ck::index_t> stride_Bs;
std::vector<ck::index_t> stride_Ds;
std::vector<ck::index_t> stride_Es;
ck::index_t group_count;
};
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
int k_batch = 1;
};
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
{
int group_count = problem_size.group_count;
// GEMM shape
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
std::vector<const void*> p_a, p_b;
std::vector<std::array<const void*, 1>> p_ds = {};
std::vector<void*> p_e;
gemm_descs.reserve(group_count);
int sum_of_m = 0;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), Row>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
std::vector<Tensor<ADataType>> a_tensors;
std::vector<Tensor<BDataType>> b_tensors;
std::vector<Tensor<D0DataType>> d0_tensors;
std::vector<Tensor<EDataType>> e_host_tensors;
std::vector<Tensor<EDataType>> e_device_tensors;
a_tensors.reserve(group_count);
b_tensors.reserve(group_count);
d0_tensors.reserve(group_count);
e_host_tensors.reserve(group_count);
e_device_tensors.reserve(group_count);
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, d0_tensors_device,
e_tensors_device;
a_tensors_device.reserve(group_count);
b_tensors_device.reserve(group_count);
d0_tensors_device.reserve(group_count);
e_tensors_device.reserve(group_count);
std::size_t flop = 0, num_btype = 0;
for(int i = 0; i < group_count; ++i)
{
sum_of_m += problem_size.Ms[i];
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{})));
b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{})));
d0_tensors.push_back(Tensor<D0DataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Ds[i], D0Layout{})));
e_host_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Es[i], ELayout{})));
e_device_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Es[i], ELayout{})));
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
<< " b_k_n: " << b_tensors[i].mDesc << " d_m_n: " << d0_tensors[i].mDesc
<< " e_m_n: " << e_device_tensors[i].mDesc << std::endl;
flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i];
num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() +
sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() +
sizeof(D0DataType) * d0_tensors[i].mDesc.GetElementSize() +
sizeof(EDataType) * e_device_tensors[i].mDesc.GetElementSize();
switch(config.init_method)
{
case 0: break;
case 1:
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d0_tensors[i].GenerateTensorValue(GeneratorTensor_2<D0DataType>{-5, 5});
break;
case 2:
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d0_tensors[i].GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
break;
default:
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d0_tensors[i].GenerateTensorValue(GeneratorTensor_2<D0DataType>{-5, 5});
}
}
using GroupedGemmKernelArgument = ck::tensor_operation::device::GroupedGemmKernelArgument<1>;
std::vector<GroupedGemmKernelArgument> grouped_gemm_kernel_args_;
grouped_gemm_kernel_args_.reserve(group_count);
for(int i = 0; i < group_count; ++i)
{
a_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(ADataType) * problem_size.Ms[i] * problem_size.Ks[i]));
b_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(BDataType) * problem_size.Ns[i] * problem_size.Ks[i]));
d0_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(D0DataType) * problem_size.Ms[i] * problem_size.Ns[i]));
e_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(EDataType) * problem_size.Ms[i] * problem_size.Ns[i]));
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data(),
a_tensors[i].mDesc.GetElementSpaceSize() * sizeof(ADataType));
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data(),
b_tensors[i].mDesc.GetElementSpaceSize() * sizeof(BDataType));
d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data(),
d0_tensors[i].mDesc.GetElementSpaceSize() *
sizeof(D0DataType));
e_tensors_device[i]->SetZero();
p_a.push_back(a_tensors_device[i]->GetDeviceBuffer());
p_b.push_back(b_tensors_device[i]->GetDeviceBuffer());
p_ds.push_back(std::array<const void*, 1>{d0_tensors_device[i]->GetDeviceBuffer()});
p_e.push_back(e_tensors_device[i]->GetDeviceBuffer());
gemm_descs.push_back({sum_of_m,
problem_size.Ns[i],
problem_size.Ks[i],
problem_size.stride_As[i],
problem_size.stride_Bs[i],
problem_size.stride_Es[i],
{problem_size.stride_Ds[i]}});
grouped_gemm_kernel_args_.push_back(
{a_tensors_device[i]->GetDeviceBuffer(),
b_tensors_device[i]->GetDeviceBuffer(),
std::array<const void*, 1>{d0_tensors_device[i]->GetDeviceBuffer()},
e_tensors_device[i]->GetDeviceBuffer(),
problem_size.Ms[i],
problem_size.Ns[i],
problem_size.Ks[i],
problem_size.stride_As[i],
problem_size.stride_Bs[i],
std::array<ck::index_t, 1>{problem_size.stride_Ds[i]},
problem_size.stride_Es[i]});
}
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
// do GEMM
auto argument = gemm.MakeArgument(
p_a, p_b, p_ds, p_e, gemm_descs, a_element_op, b_element_op, cde_element_op);
DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument));
DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer());
hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(),
grouped_gemm_kernel_args_.data(),
gemm.GetDeviceKernelArgSize(&argument),
hipMemcpyHostToDevice));
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
gemm.SetDeviceKernelArgs(&argument, gemm_arg_dev_mem.GetDeviceBuffer());
gemm.SetKBatch(argument, config.k_batch);
invoker.Run(argument, StreamConfig{nullptr, false});
if(config.time_kernel)
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
}
bool pass = true;
if(config.do_verification)
{
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
EDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
for(std::size_t i = 0; i < gemm_descs.size(); ++i)
{
e_tensors_device[i]->FromDevice(e_device_tensors[i].mData.data(),
e_device_tensors[i].mDesc.GetElementSize() *
sizeof(EDataType));
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_tensors[i],
b_tensors[i],
e_host_tensors[i],
a_element_op,
b_element_op,
PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < problem_size.Ms[i]; ++m)
{
for(int n = 0; n < problem_size.Ns[i]; ++n)
{
cde_element_op(
e_host_tensors[i](m, n), e_host_tensors[i](m, n), d0_tensors[i](m, n));
}
}
pass &= ck::utils::check_err(e_device_tensors[i], e_host_tensors[i]);
}
std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl;
}
return pass;
}
int main(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
problem_size.group_count = 16;
if(argc == 1)
{
// use default cases
}
else if(argc == 5)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
config.k_batch = std::stoi(argv[4]);
}
else if(argc == 6)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
config.k_batch = std::stoi(argv[4]);
problem_size.group_count = std::stoi(argv[5]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4: k_batch (>0)\n");
printf("arg5: group count (default=16)");
exit(0);
}
// Lambda to get stride based on layout
auto get_stride = [](auto layout, auto row_dim, auto col_dim) {
if constexpr(std::is_same_v<decltype(layout), Row>)
{
return col_dim;
}
else
{
return row_dim;
}
};
for(int i = 0; i < problem_size.group_count; ++i)
{
problem_size.Ms.push_back(256 + 256 * i);
problem_size.Ns.push_back(512);
problem_size.Ks.push_back(512);
problem_size.stride_As.push_back(
get_stride(ALayout{}, problem_size.Ms[i], problem_size.Ks[i]));
problem_size.stride_Bs.push_back(
get_stride(BLayout{}, problem_size.Ks[i], problem_size.Ns[i]));
problem_size.stride_Ds.push_back(
get_stride(D0Layout{}, problem_size.Ms[i], problem_size.Ns[i]));
problem_size.stride_Es.push_back(
get_stride(ELayout{}, problem_size.Ms[i], problem_size.Ns[i]));
}
return !run_grouped_gemm(problem_size, config);
}

View File

@@ -7,22 +7,55 @@
#include <sstream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/utility/env.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/env.hpp"
#include "ck/utility/scheduler_enum.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace element_wise {
struct SplitKAdd
{
static constexpr const char* name = "SplitKAdd";
__host__ __device__ void set_kbatch(const index_t& id, const index_t& total)
{
kbatch_id = id;
KBatch = total;
}
template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const
{
if(kbatch_id == KBatch - 1)
{
add_op(y, x0, x1);
}
else
{
passthrough_op(y, x0);
}
}
private:
index_t kbatch_id = 0;
index_t KBatch = 1;
static constexpr auto add_op = Add{};
static constexpr auto passthrough_op = PassThrough{};
};
} // namespace element_wise
namespace device {
template <typename GridwiseGemm,
@@ -110,6 +143,21 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
while(id_local < local_grid_size)
{
const auto block_2_etile_map =
GroupedGemmBlock2ETileMap(local_b2c_tile_map, group_start, id_off);
const auto tile_index =
block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
const index_t kbatch_id = __builtin_amdgcn_readfirstlane(tile_index[Number<0>{}]);
auto c_element_op_copy(c_element_op);
if constexpr(std::is_same_v<decltype(c_element_op_copy),
ck::tensor_operation::element_wise::SplitKAdd>)
{
c_element_op_copy.set_kbatch(kbatch_id, k_batch_);
}
KernelArgument kernel_arg{std::array<const void*, 1>{gemmTransKernelArg.p_a_grid},
std::array<const void*, 1>{gemmTransKernelArg.p_b_grid},
gemmTransKernelArg.p_ds_grid,
@@ -124,15 +172,9 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
k_batch_,
a_element_op,
b_element_op,
c_element_op,
c_element_op_copy,
false};
const auto block_2_etile_map =
GroupedGemmBlock2ETileMap(local_b2c_tile_map, group_start, id_off);
const auto tile_index =
block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
const auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(kernel_arg, tile_index[Number<0>{}]);
@@ -813,7 +855,9 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK<ALayout
}
if constexpr(!std::is_same_v<CDEElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough>)
ck::tensor_operation::element_wise::PassThrough> &&
!std::is_same_v<CDEElementwiseOperation,
ck::tensor_operation::element_wise::SplitKAdd>)
{
if(arg.k_batch_ > 1)
{

View File

@@ -8,27 +8,28 @@
#include <ostream>
#endif
#include "ck/utility/env.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles_preshuffle.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_direct_store.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_welford_wmma.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_welford_wmma.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_direct_store.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles_preshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles_interleave.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/env.hpp"
namespace ck {

View File

@@ -3,13 +3,14 @@
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp"
#include "ck/utility/type.hpp"
#include <memory>
#include <vector>
namespace ck {
namespace tensor_operation {
@@ -70,20 +71,52 @@ void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instances(
PassThrough,
Add>>>& instances);
void add_device_grouped_gemm_wmma_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances(
std::vector<
std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Row,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
ck::tensor_operation::element_wise::SplitKAdd>>>&
instances);
void add_device_grouped_gemm_wmma_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(
std::vector<
std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Col,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
ck::tensor_operation::element_wise::SplitKAdd>>>&
instances);
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGroupedGemmFixedNK<ALayout,
BLayout,
Row_Tuple,
DsLayout,
ELayout,
ADataType,
BDataType,
F32_Tuple,
DsDataType,
EDataType,
PassThrough,
PassThrough,
@@ -91,11 +124,11 @@ struct DeviceOperationInstanceFactory<
{
using DeviceOp = DeviceGroupedGemmFixedNK<ALayout,
BLayout,
Row_Tuple,
DsLayout,
ELayout,
ADataType,
BDataType,
F32_Tuple,
DsDataType,
EDataType,
PassThrough,
PassThrough,
@@ -112,28 +145,96 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
#if defined(CK_USE_XDL)
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
#endif
}
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
#if defined(CK_USE_XDL)
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
#endif
}
}
// fp32_output
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, float>)
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, float>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
#ifdef CK_USE_XDL
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instances(op_ptrs);
#endif
}
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
#ifdef CK_USE_XDL
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instances(op_ptrs);
#endif
}
}
return op_ptrs;
}
};
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedGemmFixedNK<
ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
PassThrough,
PassThrough,
ck::tensor_operation::element_wise::SplitKAdd>>
{
using DeviceOp = DeviceGroupedGemmFixedNK<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
PassThrough,
PassThrough,
ck::tensor_operation::element_wise::SplitKAdd>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
// fp16_output
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instances(op_ptrs);
#if defined(CK_USE_WMMA)
add_device_grouped_gemm_wmma_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
#endif
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
#if defined(CK_USE_WMMA)
add_device_grouped_gemm_wmma_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
#endif
}
}
return op_ptrs;

View File

@@ -1,11 +1,14 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
add_instance_library(device_grouped_gemm_bias_instance
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instance.cpp
device_grouped_gemm_wmma_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_wmma_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instance.cpp
)

View File

@@ -0,0 +1,95 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/tuple.hpp"
#include <memory>
#include <tuple>
#include <vector>
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = F16;
using DsDataType = ck::Tuple<D0DataType>;
using EDataType = F16;
using ALayout = Row;
using BLayout = Row;
using D0Layout = Row;
using DsLayout = ck::Tuple<D0Layout>;
using ELayout = Row;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using SplitKAdd = ck::tensor_operation::element_wise::SplitKAdd;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = SplitKAdd;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using device_grouped_gemm_wmma_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances = std::tuple<
// clang-format off
//#############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
//#############################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
//#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Wmma_Fixed_Nk<ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedGemm_Wmma_Fixed_Nk<ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedGemm_Wmma_Fixed_Nk<ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 64, 32, 8, 2, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedGemm_Wmma_Fixed_Nk<ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedGemm_Wmma_Fixed_Nk<ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 64, 128, 32, 8, 2, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Wmma_Fixed_Nk<ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 64, 128, 32, 8, 8, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Wmma_Fixed_Nk<ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 128, 64, 32, 8, 2, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Wmma_Fixed_Nk<ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 128, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Wmma_Fixed_Nk<ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 64, 128, 32, 8, 2, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGroupedGemm_Wmma_Fixed_Nk<ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 64, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>
// clang-format on
>;
void add_device_grouped_gemm_wmma_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementOp,
BElementOp,
CDEElementOp>>>& instances)
{
add_device_operation_instances(
instances, device_grouped_gemm_wmma_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,98 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/tuple.hpp"
#include <memory>
#include <tuple>
#include <vector>
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = F16;
using DsDataType = ck::Tuple<D0DataType>;
using EDataType = F16;
using ALayout = Row;
using BLayout = Col;
using D0Layout = Row;
using DsLayout = ck::Tuple<D0Layout>;
using ELayout = Row;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using SplitKAdd = ck::tensor_operation::element_wise::SplitKAdd;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = SplitKAdd;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using device_grouped_gemm_wmma_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances = std::tuple<
// clang-format off
//#############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
//#############################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
//#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Wmma_Fixed_Nk<ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Wmma_Fixed_Nk<ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Wmma_Fixed_Nk<ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedGemm_Wmma_Fixed_Nk<ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 64, 128, 64, 8, 8, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Wmma_Fixed_Nk<ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 128, 128, 64, 8, 8, 16, 16, 4, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGroupedGemm_Wmma_Fixed_Nk<ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Wmma_Fixed_Nk<ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 64, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGroupedGemm_Wmma_Fixed_Nk<ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 128, 32, 64, 8, 8, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Wmma_Fixed_Nk<ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGroupedGemm_Wmma_Fixed_Nk<ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 256, 64, 8, 8, 16, 16, 2, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Wmma_Fixed_Nk<ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 64, 64, 64, 64, 8, 8, 16, 16, 4, 2, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedGemm_Wmma_Fixed_Nk<ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 2, 2, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGroupedGemm_Wmma_Fixed_Nk<ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
// clang-format on
>;
void add_device_grouped_gemm_wmma_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementOp,
BElementOp,
CDEElementOp>>>& instances)
{
add_device_operation_instances(
instances, device_grouped_gemm_wmma_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,411 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/ck.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_bias.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/stream_config.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/utility/env.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/type.hpp"
#include <array>
#include <cstddef>
#include <iomanip>
#include <iostream>
#include <memory>
#include <stdexcept>
#include <string>
#include <vector>
namespace ck {
namespace profiler {
template <typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout>
bool profile_grouped_gemm_fixed_nk_bias_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideDs,
const std::vector<int>& StrideEs,
const std::vector<int>& kbatches = {1},
int n_warmup = 1,
int n_iter = 10)
{
bool pass = true;
using ComputeDataType = ADataType;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz}, layout);
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride}, layout);
}
};
std::size_t group_count = Ms.size();
if(!(group_count == Ns.size() && group_count == Ks.size() && group_count == StrideAs.size() &&
group_count == StrideBs.size() && group_count == StrideDs.size() &&
group_count == StrideEs.size()))
{
throw std::runtime_error("wrong! inconsistent M/N/Ks, StrideA/B/Cs size\n");
}
using D0DataType = remove_cvref_t<ck::tuple_element_t<Number<0>{}, DsDataType>>;
std::vector<Tensor<ADataType>> a_tensors;
std::vector<Tensor<BDataType>> b_tensors;
std::vector<Tensor<D0DataType>> d0_tensors;
std::vector<Tensor<EDataType>> e_host_tensors;
std::vector<Tensor<EDataType>> e_device_tensors;
a_tensors.reserve(group_count);
b_tensors.reserve(group_count);
d0_tensors.reserve(group_count);
e_host_tensors.reserve(group_count);
e_device_tensors.reserve(group_count);
double max_abs_in_val = 0.f;
int sum_of_m = 0;
using D0Layout = remove_cvref_t<ck::tuple_element_t<Number<0>{}, DsLayout>>;
for(std::size_t i = 0; i < group_count; ++i)
{
sum_of_m += Ms[i];
a_tensors.push_back(
Tensor<ADataType>(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{})));
b_tensors.push_back(
Tensor<BDataType>(f_host_tensor_descriptor(Ks[i], Ns[i], StrideBs[i], BLayout{})));
d0_tensors.push_back(
Tensor<D0DataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideDs[i], D0Layout{})));
e_host_tensors.push_back(
Tensor<EDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideEs[i], ELayout{})));
e_device_tensors.push_back(
Tensor<EDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideEs[i], ELayout{})));
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_tensors[i].mDesc
<< ", b_k_n[" << i << "]:" << b_tensors[i].mDesc << ", d_m_n[" << i
<< "]:" << d0_tensors[i].mDesc << ", e_m_n_device_results[" << i
<< "]:" << e_device_tensors[i].mDesc << std::endl;
}
switch(init_method)
{
case 0: break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_tensors[i]);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_tensors[i]);
max_abs_in_val = 10.f;
break;
default:
ck::utils::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_tensors[i]);
ck::utils::FillUniformDistribution<BDataType>{-0.5f, 0.5f}(b_tensors[i]);
max_abs_in_val = 1.0f;
}
ck::utils::FillUniformDistribution<D0DataType>{-0.5, 0.5}(d0_tensors[i]);
}
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CDEElementOp = ck::tensor_operation::element_wise::SplitKAdd;
constexpr auto a_element_op = AElementOp{};
constexpr auto b_element_op = BElementOp{};
constexpr auto cde_element_op = CDEElementOp{};
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, d0_tensors_device,
e_tensors_device;
a_tensors_device.reserve(group_count);
b_tensors_device.reserve(group_count);
d0_tensors_device.reserve(group_count);
e_tensors_device.reserve(group_count);
std::vector<const void*> p_a, p_b;
std::vector<std::array<const void*, 1>> p_ds;
std::vector<void*> p_e;
p_a.reserve(group_count);
p_b.reserve(group_count);
p_ds.reserve(group_count);
p_e.reserve(group_count);
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
gemm_descs.reserve(group_count);
std::vector<ck::tensor_operation::device::GroupedGemmKernelArgument<1>>
grouped_gemm_kernel_args_;
grouped_gemm_kernel_args_.reserve(group_count);
for(std::size_t i = 0; i < group_count; ++i)
{
a_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(ADataType) * a_tensors[i].mDesc.GetElementSpaceSize()));
b_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(BDataType) * b_tensors[i].mDesc.GetElementSpaceSize()));
d0_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(D0DataType) * d0_tensors[i].mDesc.GetElementSpaceSize()));
e_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(EDataType) * e_device_tensors[i].mDesc.GetElementSpaceSize()));
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data());
gemm_descs.push_back(
{sum_of_m, Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideEs[i], {StrideDs[i]}});
p_a.push_back(a_tensors_device[i]->GetDeviceBuffer());
p_b.push_back(b_tensors_device[i]->GetDeviceBuffer());
p_ds.push_back(std::array<const void*, 1>{d0_tensors_device[i]->GetDeviceBuffer()});
p_e.push_back(e_tensors_device[i]->GetDeviceBuffer());
grouped_gemm_kernel_args_.push_back(
{a_tensors_device[i]->GetDeviceBuffer(),
b_tensors_device[i]->GetDeviceBuffer(),
std::array<const void*, 1>{d0_tensors_device[i]->GetDeviceBuffer()},
e_tensors_device[i]->GetDeviceBuffer(),
Ms[i],
Ns[i],
Ks[i],
StrideAs[i],
StrideBs[i],
std::array<ck::index_t, 1>{StrideDs[i]},
StrideEs[i]});
}
using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmFixedNK<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementOp,
BElementOp,
CDEElementOp>;
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
if(op_ptrs.size() <= 0)
{
std::cerr << "Skip! no device GEMM instance found" << std::endl;
return true;
}
std::string best_gemm_name;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
float best_kbatch = 0;
if(do_verification)
{
for(std::size_t i = 0; i < gemm_descs.size(); ++i)
{
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<
ADataType,
BDataType,
EDataType,
AccDataType,
AElementOp,
BElementOp,
ck::tensor_operation::element_wise::PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument =
ref_gemm.MakeArgument(a_tensors[i],
b_tensors[i],
e_host_tensors[i],
a_element_op,
b_element_op,
ck::tensor_operation::element_wise::PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < Ms[i]; ++m)
{
for(int n = 0; n < Ns[i]; ++n)
{
cde_element_op(
e_host_tensors[i](m, n), e_host_tensors[i](m, n), d0_tensors[i](m, n));
}
}
}
}
// profile device GEMM instances
for(auto& gemm_ptr : op_ptrs)
{
auto argument_ptr = gemm_ptr->MakeArgumentPointer(
p_a, p_b, p_ds, p_e, gemm_descs, a_element_op, b_element_op, cde_element_op);
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
DeviceMem gemm_desc_workspace(gemm_ptr->GetWorkSpaceSize(argument_ptr.get()));
DeviceMem grouped_gemm_kernel_args_dev(
gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get()));
hipGetErrorString(hipMemcpy(grouped_gemm_kernel_args_dev.GetDeviceBuffer(),
grouped_gemm_kernel_args_.data(),
gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get()),
hipMemcpyHostToDevice));
gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer());
gemm_ptr->SetDeviceKernelArgs(argument_ptr.get(),
grouped_gemm_kernel_args_dev.GetDeviceBuffer());
std::string gemm_name = gemm_ptr->GetTypeString();
for(std::size_t j = 0; j < kbatches.size(); ++j)
{
auto kbatch_curr = kbatches[j];
gemm_ptr->SetKBatch(argument_ptr.get(), kbatch_curr);
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{
for(std::size_t i = 0; i < gemm_descs.size(); ++i)
{
e_tensors_device[i]->SetZero();
}
invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr, false, 0, n_warmup, n_iter});
if(do_verification)
{
bool instance_pass = true;
for(std::size_t i = 0; i < gemm_descs.size(); ++i)
{
e_tensors_device[i]->FromDevice(e_device_tensors[i].mData.data());
auto atol = ck::utils::get_absolute_threshold<ComputeDataType, EDataType>(
max_abs_in_val, gemm_descs[i].K_);
auto rtol = ck::utils::get_relative_threshold<ComputeDataType, EDataType>(
gemm_descs[i].K_);
instance_pass =
instance_pass && ck::utils::check_err(e_device_tensors[i],
e_host_tensors[i],
"Error: Incorrect results!",
rtol,
atol);
if(do_log)
{
LogRangeAsType<float>(std::cout << "a : ", a_tensors[i].mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_tensors[i].mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "d0: ", d0_tensors[i].mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "e_device: ", e_device_tensors[i].mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "e_host : ", e_host_tensors[i].mData, ",")
<< std::endl;
}
}
std::cout << "Instance: " << gemm_name << "; KBatch: " << kbatch_curr << " "
<< (instance_pass ? "SUCCEED" : "FAILED") << std::endl;
pass = pass && instance_pass;
}
float ave_time = invoker_ptr->Run(
argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter});
if(time_kernel)
{
std::size_t flop = 0, num_btype = 0;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i];
num_btype +=
sizeof(ADataType) * Ms[i] * Ks[i] + sizeof(BDataType) * Ks[i] * Ns[i] +
sizeof(D0DataType) * Ms[i] * Ns[i] + sizeof(EDataType) * Ms[i] * Ns[i];
}
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops
<< " TFlops, " << gb_per_sec << " GB/s, " << gemm_name << ", KBatch "
<< kbatch_curr << std::endl;
if(tflops > best_tflops)
{
best_gemm_name = gemm_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
best_kbatch = kbatch_curr;
}
}
}
else
{
std::cout << "Instance: " << gemm_name
<< ", does not support this GEMM problem (KBatch: " << kbatch_curr << ")"
<< std::endl;
}
}
}
if(time_kernel)
{
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_gemm_name << ", KBatch = " << best_kbatch
<< std::endl;
}
return pass;
}
} // namespace profiler
} // namespace ck

View File

@@ -25,6 +25,12 @@ if (CK_USE_XDL OR CK_USE_WMMA)
add_dependencies(test_grouped_gemm test_grouped_gemm_fixed_nk)
endif()
add_gtest_executable(test_grouped_gemm_fixed_nk_bias test_grouped_gemm_fixed_nk_bias.cpp)
if(result EQUAL 0)
target_link_libraries(test_grouped_gemm_fixed_nk_bias PRIVATE utility device_grouped_gemm_bias_instance)
add_dependencies(test_grouped_gemm test_grouped_gemm_fixed_nk_bias)
endif()
add_gtest_executable(test_grouped_gemm_multi_abd_fixed_nk test_grouped_gemm_multi_abd_fixed_nk.cpp)
if(result EQUAL 0)
target_link_libraries(test_grouped_gemm_multi_abd_fixed_nk PRIVATE utility device_grouped_gemm_fixed_nk_multi_abd_instance)

View File

@@ -0,0 +1,306 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/ck.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/number.hpp"
#include "ck/utility/tuple.hpp"
#include "profiler/profile_grouped_gemm_fixed_nk_bias_impl.hpp"
#include "gtest/gtest.h"
#include <iostream>
#include <tuple>
#include <type_traits>
#include <vector>
static ck::index_t param_mask = 0xffffff;
static ck::index_t instance_index = -1;
using FP32 = float;
using FP16 = ck::half_t;
using BF16 = ck::bhalf_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using SplitKAdd = ck::tensor_operation::element_wise::SplitKAdd;
using CDEElementOp = SplitKAdd;
// clang-format off
using KernelTypes = ::testing::Types<
std::tuple<FP16, FP16, ck::Tuple<FP16>, FP16, Row, Row, ck::Tuple<Row>, Row>,
std::tuple<FP16, FP16, ck::Tuple<FP16>, FP16, Row, Col, ck::Tuple<Row>, Row>
>;
// clang-format on
template <typename Tuple>
class TestGroupedGemmFixedNKBias : public testing::Test
{
protected:
using ADataType = std::tuple_element_t<0, Tuple>;
using BDataType = std::tuple_element_t<1, Tuple>;
using DsDataType = std::tuple_element_t<2, Tuple>;
using EDataType = std::tuple_element_t<3, Tuple>;
using ALayout = std::tuple_element_t<4, Tuple>;
using BLayout = std::tuple_element_t<5, Tuple>;
using DsLayout = std::tuple_element_t<6, Tuple>;
using ELayout = std::tuple_element_t<7, Tuple>;
using AccDataType = FP32;
public:
static constexpr bool verify_ = true;
static constexpr int init_method_ = 1; // integer value initialization
static constexpr bool log_ = false;
static constexpr bool bench_ = false; // measure kernel performance
static constexpr int n_warmup_ = 0;
static constexpr int n_iter_ = 1;
std::vector<int> k_batches_;
bool IsSplitKSupported()
{
// gfx11 does not support split-K due to missing atomic add for fp16/bf16
// Technically, we could still use split-K for fp32, but we currently don't have
// instances for it so we disable it entirely
constexpr bool require_16bit_atomic_add =
std::is_same_v<EDataType, FP16> || std::is_same_v<EDataType, BF16>;
bool missing_atomic_add = require_16bit_atomic_add && ck::is_gfx11_supported();
// CDE element operators are not supported in combination with split K
constexpr bool has_cde_element_operator =
!std::is_same_v<CDEElementOp, PassThrough> && !std::is_same_v<CDEElementOp, SplitKAdd>;
return !missing_atomic_add && !has_cde_element_operator;
}
void SetUp() override
{
if(!IsSplitKSupported())
{
k_batches_ = {1};
}
else
{
k_batches_ = {1, 2, 3, 4, 8};
}
}
private:
template <typename Layout>
void SetStrides(std::vector<int>& strides,
const std::vector<int>& rows,
const std::vector<int>& cols) const
{
if(std::is_same_v<Layout, Row>)
{
for(const auto c : cols)
{
strides.emplace_back(c);
}
}
else if(std::is_same_v<Layout, Col>)
{
for(const auto r : rows)
{
strides.emplace_back(r);
}
}
}
template <typename Layouts>
void SetTupleStrides(std::vector<int>& strides,
const std::vector<int>& rows,
const std::vector<int>& cols) const
{
if constexpr(Layouts::Size() > 0)
{
// As of now multi ABD implementation supports only tensors with matching layouts.
using Layout = ck::remove_cvref_t<ck::tuple_element_t<ck::Number<0>{}, Layouts>>;
SetStrides<Layout>(strides, rows, cols);
}
}
public:
void Run(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs = {},
const std::vector<int>& StrideBs = {},
const std::vector<int>& StrideDs = {},
const std::vector<int>& StrideEs = {})
{
std::vector<int> stride_as = StrideAs;
std::vector<int> stride_bs = StrideBs;
std::vector<int> stride_ds = StrideDs;
std::vector<int> stride_es = StrideEs;
if(stride_as.empty())
{
SetStrides<ALayout>(stride_as, Ms, Ks);
}
if(stride_bs.empty())
{
SetStrides<BLayout>(stride_bs, Ks, Ns);
}
if(stride_ds.empty())
{
SetTupleStrides<DsLayout>(stride_ds, Ms, Ns);
}
if(stride_es.empty())
{
SetStrides<ELayout>(stride_es, Ms, Ns);
}
std::vector<int> k_batches;
for(size_t i = 0; i < k_batches_.size(); i++)
{
if(param_mask & (1 << i))
{
k_batches.push_back(k_batches_[i]);
}
}
RunSingle(Ms, Ns, Ks, stride_as, stride_bs, stride_ds, stride_es, k_batches);
}
void RunSingle(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideDs,
const std::vector<int>& StrideEs,
const std::vector<int>& kbatches)
{
bool pass = ck::profiler::profile_grouped_gemm_fixed_nk_bias_impl<ADataType,
BDataType,
DsDataType,
EDataType,
AccDataType,
ALayout,
BLayout,
DsLayout,
ELayout>(verify_,
init_method_,
log_,
bench_,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideDs,
StrideEs,
kbatches,
n_warmup_,
n_iter_);
EXPECT_TRUE(pass);
}
};
TYPED_TEST_SUITE(TestGroupedGemmFixedNKBias, KernelTypes);
TYPED_TEST(TestGroupedGemmFixedNKBias, TinyCases)
{
const std::vector<int> Ms{2, 1};
constexpr int N = 768;
constexpr int K = 544;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
this->Run(Ms, Ns, Ks);
}
TYPED_TEST(TestGroupedGemmFixedNKBias, SmallCases)
{
const std::vector<int> Ms{2, 1, 3, 4, 5};
constexpr int N = 768;
constexpr int K = 544;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
this->Run(Ms, Ns, Ks);
}
TYPED_TEST(TestGroupedGemmFixedNKBias, MidCases)
{
const std::vector<int> Ms{167, 183, 177, 153, 139, 204};
constexpr int N = 768;
constexpr int K = 544;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
this->Run(Ms, Ns, Ks);
}
TYPED_TEST(TestGroupedGemmFixedNKBias, Regular)
{
const std::vector<int> Ms{64, 128, 256};
constexpr int N = 768;
constexpr int K = 320;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
this->Run(Ms, Ns, Ks);
}
TYPED_TEST(TestGroupedGemmFixedNKBias, MNKPadded)
{
const std::vector<int> Ms{127, 150, 188, 210};
constexpr int N = 136;
constexpr int K = 280;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
this->Run(Ms, Ns, Ks);
}
TYPED_TEST(TestGroupedGemmFixedNKBias, TestLargeKBatch)
{
// In some cases Split K is not supported. Running this test would fail since no instance will
// be supported, so we skip the test
if(!this->IsSplitKSupported())
{
GTEST_SKIP() << "Split-K not supported for for the current configuration (FP16/BF16 on "
"GFX11, or using CDE element-wise operation)";
}
const std::vector<int> Ms{188, 210};
constexpr int N = 768;
constexpr int K = 4096;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
this->k_batches_ = {32, 64};
this->Run(Ms, Ns, Ks);
}
int main(int argc, char** argv)
{
testing::InitGoogleTest(&argc, argv);
if(argc == 1) {}
else if(argc == 3)
{
param_mask = strtol(argv[1], nullptr, 0);
instance_index = atoi(argv[2]);
}
else
{
std::cout << "Usage of " << argv[0] << std::endl;
std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl;
}
return RUN_ALL_TESTS();
}