173 implement device grouped gemm fixed nk for rdna4 (#4299)

## Proposed changes

This PR adds an RDNA4 implementation of the device_grouped_gemm_fixed_nk
instance library using for WMMA.

The implementation is based on the existing
DeviceGroupedGemm_Xdl_Fixed_NK design and reuses the same high-level
structure, but replaces the XDL kernel with a WMMA-based one. It uses
the GridwiseGemm_wmma_cshuffle_v3 kernel.

At this stage, the focus is functional correctness and compatibility,
not performance tuning.

## Technical Details

- Device struct for grouped gemm fixed NK
- Example code for the WMMA version
- Unit tests for both new wmma implementation and the reference XDL code
(previously missing)
- Generic ck profiler interface with the purpose of calling unit tests.

## Checklist

Please put an into the boxes that apply. You can also fill these out
after creating the PR. If you're not sure, please don't hesitate to ask.

- [x] I have added tests relevant to the introduced functionality, and
the unit tests are passing locally
- [x] I have added the test to REGRESSION_TESTS list defined at the top
of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more
than 30 seconds to run.
- [ ] I have added inline documentation which enables the maintainers
with understanding the motivation
- [ ] I have removed the stale documentation which is no longer relevant
after this pull request
- [x] (If this change is user-facing) I have added release notes which
provide the end users with a brief summary of the improvement from this
pull request
- [x] I have run  on all changed files
- [x] Any dependent changes have been merged

## Discussion

If this is a relatively large or complex change, feel free to start a
discussion by explaining why you chose the solution you did and what
alternatives you considered



---
🔁 Imported from
[ROCm/composable_kernel#3668](https://github.com/ROCm/composable_kernel/pull/3668)
🧑‍💻 Originally authored by @bidlekm

---------

Co-authored-by: Marton Bidlek <marton.bidlek@streamhpc.com>
Co-authored-by: Erwin Terpstra <erwin.terpstra@streamhpc.com>
Co-authored-by: bidlekm <bidlekmarton@gmail.com>
Co-authored-by: assistant-librarian[bot] <assistant-librarian[bot]@users.noreply.github.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
This commit is contained in:
assistant-librarian[bot]
2026-02-19 09:13:05 +01:00
committed by GitHub
parent 33f83b99ed
commit 9c8d3a39ac
32 changed files with 2819 additions and 163 deletions

View File

@@ -47,6 +47,10 @@ add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_spl
add_example_executable(example_grouped_gemm_multiple_d_wmma_fp16 grouped_gemm_multiple_d_wmma_fp16.cpp)
add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_multiple_d_wmma_fp16)
add_example_executable(example_grouped_gemm_wmma_fixed_nk_fp16 grouped_gemm_wmma_fixed_nk_fp16.cpp)
add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_fixed_nk_fp16)
list(APPEND gpu_list_tf32 gfx942 gfx950)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)

View File

@@ -0,0 +1,351 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
using ::ck::DeviceMem;
using ::ck::hip_check_error;
using ::ck::HostTensorDescriptor;
using ::ck::Tensor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using DsDataType = ck::Tuple<>;
using EDataType = F16;
using ALayout = Row;
using BLayout = Col;
using DsLayout = ck::Tuple<>;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_Fixed_Nk
// clang-format off
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>;
// clang-format on
struct ProblemSize final
{
std::vector<ck::index_t> Ms;
std::vector<ck::index_t> Ns;
std::vector<ck::index_t> Ks;
std::vector<ck::index_t> stride_As;
std::vector<ck::index_t> stride_Bs;
std::vector<ck::index_t> stride_Cs;
ck::index_t group_count;
};
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
int k_batch = 1;
bool time_kernel = false;
};
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
{
auto group_count = problem_size.group_count;
// GEMM shape
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
std::vector<void*> p_Cs;
gemm_descs.reserve(group_count);
int sum_of_m = 0;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
std::vector<Tensor<ADataType>> a_tensors;
std::vector<Tensor<BDataType>> b_tensors;
std::vector<Tensor<EDataType>> c_host_tensors;
std::vector<Tensor<EDataType>> c_device_tensors;
a_tensors.reserve(group_count);
b_tensors.reserve(group_count);
c_host_tensors.reserve(group_count);
c_device_tensors.reserve(group_count);
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, c_tensors_device;
a_tensors_device.reserve(group_count);
b_tensors_device.reserve(group_count);
c_tensors_device.reserve(group_count);
std::size_t flop = 0, num_btype = 0;
for(int i = 0; i < group_count; i++)
{
sum_of_m += problem_size.Ms[i];
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{})));
b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{})));
c_host_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
c_device_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
<< " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc
<< std::endl;
flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i];
num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() +
sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() +
sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize();
switch(config.init_method)
{
case 0: break;
case 1:
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
case 2:
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
default:
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
}
}
std::cout << "Sum of M: " << sum_of_m << std::endl;
using GroupedGemmKernelArgument = ck::tensor_operation::device::GroupedGemmKernelArgument<>;
std::vector<GroupedGemmKernelArgument> grouped_gemm_kernel_args_;
grouped_gemm_kernel_args_.reserve(group_count);
for(int i = 0; i < group_count; i++)
{
a_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(ADataType) * problem_size.Ms[i] * problem_size.Ks[i]));
b_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(BDataType) * problem_size.Ns[i] * problem_size.Ks[i]));
c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(EDataType) * problem_size.Ms[i] * problem_size.Ns[i]));
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data(),
a_tensors[i].mDesc.GetElementSpaceSize() * sizeof(ADataType));
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data(),
b_tensors[i].mDesc.GetElementSpaceSize() * sizeof(BDataType));
c_tensors_device[i]->SetZero();
p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer());
gemm_descs.push_back({sum_of_m,
problem_size.Ns[i],
problem_size.Ks[i],
1,
problem_size.stride_Bs[i],
1,
{}});
grouped_gemm_kernel_args_.push_back({a_tensors_device[i]->GetDeviceBuffer(),
b_tensors_device[i]->GetDeviceBuffer(),
{},
c_tensors_device[i]->GetDeviceBuffer(),
problem_size.Ms[i],
problem_size.Ns[i],
problem_size.Ks[i],
problem_size.stride_As[i],
problem_size.stride_Bs[i],
{},
problem_size.stride_Cs[i]});
}
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CDEElementOp{};
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
std::vector<const void*> p_As = {};
std::vector<const void*> p_Bs = {};
std::vector<std::array<const void*, 0>> p_Ds = {};
// do GEMM
auto argument = gemm.MakeArgument(
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, c_element_op);
DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument));
DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer());
hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(),
grouped_gemm_kernel_args_.data(),
gemm.GetDeviceKernelArgSize(&argument),
hipMemcpyHostToDevice));
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
gemm.SetDeviceKernelArgs(&argument, gemm_arg_dev_mem.GetDeviceBuffer());
gemm.SetKBatch(argument, config.k_batch);
invoker.Run(argument, StreamConfig{nullptr, false});
if(config.time_kernel)
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
}
bool pass = true;
if(config.do_verification)
{
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
EDataType,
AccDataType,
AElementOp,
BElementOp,
CDEElementOp>;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(),
c_device_tensors[i].mDesc.GetElementSize() *
sizeof(EDataType));
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_tensors[i],
b_tensors[i],
c_host_tensors[i],
a_element_op,
b_element_op,
c_element_op);
ref_invoker.Run(ref_argument);
pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]);
}
// Copy device tensors back to host
for(std::size_t i = 0; i < c_device_tensors.size(); i++)
{
c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(),
c_device_tensors[i].mDesc.GetElementSize() *
sizeof(EDataType));
}
}
std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl;
return pass;
}
int main(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
problem_size.group_count = 16;
if(argc == 5)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
config.k_batch = std::stoi(argv[4]);
}
else if(argc == 6)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
config.k_batch = std::stoi(argv[4]);
problem_size.group_count = std::stoi(argv[5]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4: k_batch (> 0)\n");
printf("arg5: group count (default=16)");
exit(0);
}
for(int i = 0; i < problem_size.group_count; i++)
{
problem_size.Ms.push_back(128 + i * 128);
problem_size.Ns.push_back(1024);
problem_size.Ks.push_back(1024);
problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
}
return !run_grouped_gemm(problem_size, config);
}

View File

@@ -291,6 +291,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
}
}
std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl;
return pass;
}
@@ -329,7 +330,7 @@ int main(int argc, char* argv[])
for(int i = 0; i < problem_size.group_count; i++)
{
problem_size.Ms.push_back(128 + rand() % 128);
problem_size.Ms.push_back(128 + i * 128);
problem_size.Ns.push_back(1024);
problem_size.Ks.push_back(1024);

View File

@@ -272,6 +272,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
ComputeDataType>(c_device_tensors[i], c_host_tensors[i]);
#endif
}
std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl;
}
if(config.time_kernel)

File diff suppressed because it is too large Load Diff

View File

@@ -775,8 +775,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
c_thread_buf,
block_work_idx[I0],
block_work_idx[I1],
block_work_idx[I2],
p_shared,
p_ds_grid,
p_e_grid,

View File

@@ -344,6 +344,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization;
static constexpr index_t APackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<LDSTypeA>, pk_i4_t>)
return 2;
@@ -627,8 +629,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
const std::array<index_t, NumATensor>& StrideAs,
const index_t AK0)
{
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
constexpr bool padM = GemmSpec == GemmSpecialization::MKPadding ||
constexpr bool padM = GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding;
@@ -696,8 +697,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
const std::array<index_t, NumBTensor>& StrideBs,
const index_t BK0)
{
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
constexpr bool padN = GemmSpec == GemmSpecialization::NKPadding ||
constexpr bool padN = GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding;
@@ -794,7 +794,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
// TODO: Investigate why this path is not used in the original
// gridwise_gemm_xdl_cshuffle_v3.hpp
#if 0
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
@@ -1033,6 +1032,49 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
}
}
__host__ __device__ static constexpr bool
CheckValidity(const index_t M,
const index_t N,
const index_t K,
const index_t StrideA,
const index_t StrideB,
const std::array<index_t, NumDTensor> StrideDs,
const index_t StrideE,
const index_t KBatch)
{
ignore = StrideDs;
const auto M_padded = CalculateMPadded(M);
const auto N_padded = CalculateMPadded(N);
const auto K_padded = CalculateKPadded(K, KBatch);
const auto e_grid_desc_m_n =
MakeDEGridDescriptor_M_N<ELayout>(M, M_padded, N, N_padded, StrideE);
const index_t AK0 = CalculateAK0Padded(K, KBatch);
const index_t BK0 = CalculateBK0Padded(K, KBatch);
const auto a_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
M, M_padded, K, K_padded, std::array<index_t, 1>{StrideA}, AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(
K, K_padded, N, N_padded, std::array<index_t, 1>{StrideB}, BK0);
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
const auto& a_desc = a_grid_desc_ak0_m_ak1.At(I0);
const auto& b_desc = b_grid_desc_bk0_n_bk1.At(I0);
if(!(a_desc.GetElementSpaceSize() * sizeof(LDSTypeA) <= TwoGB &&
b_desc.GetElementSpaceSize() * sizeof(LDSTypeB) <= TwoGB &&
e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
{
return false;
}
return true;
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Argument>
__host__ static constexpr bool CheckValidity(const Argument& karg,
@@ -1089,9 +1131,11 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
<< karg.K << " " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__ << std::endl;
std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! "
"K_Batch:"
<< karg.KBatch << " " << "K0PerBlock:" << KPerBlock << " "
<< "K1:" << AK1Number << " " << "K:" << karg.K << " " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
}
return false;
}

View File

@@ -0,0 +1,268 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/utility/scheduler_enum.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F8 = ck::f8_t;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using I8 = int8_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Empty_Tuple = ck::Tuple<>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using AccDataType = F32;
using DsDataType = Empty_Tuple;
using DsLayout = Empty_Tuple;
using ELayout = Row;
static constexpr auto PipelineV1 = BlockGemmPipelineVersion::v1;
static constexpr auto PipelineV3 = BlockGemmPipelineVersion::v3;
static constexpr auto IntrawaveScheduler = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto InterwaveScheduler = BlockGemmPipelineScheduler::Interwave;
static constexpr auto GemmMNKPadding = device::GemmSpecialization::MNKPadding;
static constexpr auto GemmDefault = device::GemmSpecialization::Default;
// Instances for 2 byte datatypes in RRR layout with ADataType = BDataType = EDataType
template <typename T,
device::GemmSpecialization GemmSpec,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer,
typename AElementOp,
typename BElementOp,
typename CDEElementOp,
enable_if_t<sizeof(T) == 2, bool> = false>
using device_grouped_gemm_wmma_fixed_nk_mk_kn_mn_instances = std::tuple<
// clang-format off
//#############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
//#############################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
//#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>
// clang-format on
>;
// Instances for 2 byte datatypes in RCR layout with ADataType = BDataType = EDataType
template <typename T,
device::GemmSpecialization GemmSpec,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer,
typename AElementOp,
typename BElementOp,
typename CDEElementOp,
enable_if_t<sizeof(T) == 2, bool> = false>
using device_grouped_gemm_wmma_fixed_nk_mk_nk_mn_instances = std::tuple<
// clang-format off
//#############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
//#############################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
//#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>
// clang-format on
>;
template <typename TA,
typename TB,
device::GemmSpecialization GemmSpec,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer,
typename AElementOp,
typename BElementOp,
typename CDEElementOp>
using device_grouped_gemm_wmma_fixed_nk_mk_kn_mn_irregular_instances = std::tuple<
// clang-format off
//##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
//##############################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, TA, TB, AccDataType, AccDataType, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, TA, TB, AccDataType, AccDataType, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, TA, TB, AccDataType, AccDataType, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>
// clang-format on
>;
template <typename TA,
typename TB,
device::GemmSpecialization GemmSpec,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer,
typename AElementOp,
typename BElementOp,
typename CDEElementOp>
using device_grouped_gemm_wmma_fixed_nk_mk_nk_mn_irregular_instances = std::tuple<
// clang-format off
//##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
//##############################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, TA, TB, AccDataType, AccDataType, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, TA, TB, AccDataType, AccDataType, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, TA, TB, AccDataType, AccDataType, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>
// clang-format on
>;
// List of instance variants to add (pipeline/scheduler/padding combinations)
// Some are disabled now, can be re-enabled if needed
using InstanceVariant =
ck::Tuple<device::GemmSpecialization, BlockGemmPipelineScheduler, BlockGemmPipelineVersion>;
static constexpr InstanceVariant InstanceVariants[] = {
make_tuple(GemmDefault, IntrawaveScheduler, PipelineV1),
make_tuple(GemmDefault, IntrawaveScheduler, PipelineV3),
make_tuple(GemmMNKPadding, IntrawaveScheduler, PipelineV1)};
template <typename ALayout,
typename BLayout,
template <device::GemmSpecialization GemmSpec,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer,
typename AElementOp,
typename BElementOp,
typename CDEElementOp>
typename LayoutInstances,
typename ADataType,
typename BDataType,
typename EDataType,
typename AElementOp,
typename BElementOp,
typename CDEElementOp>
void add_device_grouped_gemm_wmma_fixed_nk_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementOp,
BElementOp,
CDEElementOp>>>& instances)
{
static_for<0, std::size(InstanceVariants), 1>{}([&](auto i) {
constexpr auto instance = InstanceVariants[i];
add_device_operation_instances(instances,
LayoutInstances<instance.At(Number<0>{}),
instance.At(Number<1>{}),
instance.At(Number<2>{}),
AElementOp,
BElementOp,
CDEElementOp>{});
});
}
template <typename T,
typename ALayout,
typename BLayout,
template <typename T2,
device::GemmSpecialization GemmSpec,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer,
typename AElementOp,
typename BElementOp,
typename CDEElementOp>
typename LayoutInstances,
typename AElementOp,
typename BElementOp,
typename CDEElementOp>
void add_device_grouped_gemm_wmma_fixed_nk_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<ALayout,
BLayout,
DsLayout,
ELayout,
T,
T,
DsDataType,
T,
AElementOp,
BElementOp,
CDEElementOp>>>& instances)
{
static_for<0, std::size(InstanceVariants), 1>{}([&](auto i) {
constexpr auto instance = InstanceVariants[i];
add_device_operation_instances(instances,
LayoutInstances<T,
instance.At(Number<0>{}),
instance.At(Number<1>{}),
instance.At(Number<2>{}),
AElementOp,
BElementOp,
CDEElementOp>{});
});
}
template <typename TA,
typename TB,
typename ALayout,
typename BLayout,
template <typename T2,
typename T3,
device::GemmSpecialization GemmSpec,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer,
typename AElementOp,
typename BElementOp,
typename CDEElementOp>
typename LayoutInstances,
typename AElementOp,
typename BElementOp,
typename CDEElementOp>
void add_device_grouped_gemm_wmma_fixed_nk_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<ALayout,
BLayout,
DsLayout,
ELayout,
TA,
TB,
DsDataType,
TA,
AElementOp,
BElementOp,
CDEElementOp>>>& instances)
{
static_for<0, std::size(InstanceVariants), 1>{}([&](auto i) {
constexpr auto instance = InstanceVariants[i];
add_device_operation_instances(instances,
LayoutInstances<TA,
TB,
instance.At(Number<0>{}),
instance.At(Number<1>{}),
instance.At(Number<2>{}),
AElementOp,
BElementOp,
CDEElementOp>{});
});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -17,6 +17,7 @@ namespace device {
namespace instance {
// fp16_output
#if defined(CK_USE_XDL)
void add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Row,
@@ -96,9 +97,91 @@ void add_device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#if defined(CK_USE_WMMA)
void add_device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Row,
Empty_Tuple,
Row,
F16,
F8,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Col,
Empty_Tuple,
Row,
F16,
F8,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Row,
Empty_Tuple,
Row,
F16,
I8,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Col,
Empty_Tuple,
Row,
F16,
I8,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
// bf16_inputA i8_inputB
#if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8)
#if defined(CK_USE_XDL)
void add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Row,
@@ -125,9 +208,38 @@ void add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instances(
PassThrough,
PassThrough>>>& instances);
#endif
#if defined(CK_USE_WMMA)
void add_device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Col,
Empty_Tuple,
Row,
BF16,
I8,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Row,
Empty_Tuple,
Row,
BF16,
I8,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#endif
// bf16_inputA bf16_inputB
#if defined(CK_ENABLE_BF16)
#if defined(CK_USE_XDL)
void add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Row,
@@ -153,6 +265,34 @@ void add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#if defined(CK_USE_WMMA)
void add_device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Row,
Empty_Tuple,
Row,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Col,
Empty_Tuple,
Row,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#endif // CK_ENABLE_BF16
template <typename ALayout,
@@ -190,6 +330,7 @@ struct DeviceOperationInstanceFactory<
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#if defined(CK_USE_XDL)
// fp16_output
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, half_t>)
@@ -273,6 +414,93 @@ struct DeviceOperationInstanceFactory<
}
}
#endif // CK_ENABLE_BF16
#endif // CK_USE_XDL
#if defined(CK_USE_WMMA)
// fp16_output
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
}
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
}
}
// fp8_input
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> &&
is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_kn_mn_instances(op_ptrs);
}
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_nk_mn_instances(op_ptrs);
}
}
// i8_input
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, int8_t> &&
is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_kn_mn_instances(op_ptrs);
}
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_nk_mn_instances(op_ptrs);
}
}
// bf16_i8_input
#if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8)
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, int8_t> &&
is_same_v<EDataType, bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_kn_mn_instances(op_ptrs);
}
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_nk_mn_instances(op_ptrs);
}
}
#endif
// bf16_inputA bf16_inputB
#if defined(CK_ENABLE_BF16)
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, bhalf_t> &&
is_same_v<EDataType, bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances(op_ptrs);
}
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances(op_ptrs);
}
}
#endif // CK_ENABLE_BF16
#endif // CK_USE_WMMA
return op_ptrs;
}

View File

@@ -1,18 +1,31 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
set(GROUPED_GEMM_FIXED_NK_INSTANCES)
list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_f16_fp8_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_f16_fp8_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp)
list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES
device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_f16_fp8_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_f16_fp8_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp
device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instance.cpp
device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp
device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_kn_mn_instance.cpp
device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_nk_mn_instance.cpp
device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp
)
add_instance_library(device_grouped_gemm_fixed_nk_instance ${GROUPED_GEMM_FIXED_NK_INSTANCES})

View File

@@ -0,0 +1,37 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Row,
DsLayout,
Row,
BF16,
BF16,
DsDataType,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_grouped_gemm_wmma_fixed_nk_instances<
BF16,
Row,
Row,
device_grouped_gemm_wmma_fixed_nk_mk_kn_mn_instances>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,37 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Col,
DsLayout,
Row,
BF16,
BF16,
DsDataType,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_grouped_gemm_wmma_fixed_nk_instances<
BF16,
Row,
Col,
device_grouped_gemm_wmma_fixed_nk_mk_nk_mn_instances>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,38 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Row,
DsLayout,
Row,
BF16,
I8,
DsDataType,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_grouped_gemm_wmma_fixed_nk_irregular_instances<
BF16,
I8,
Row,
Row,
device_grouped_gemm_wmma_fixed_nk_mk_kn_mn_irregular_instances>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,38 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Col,
DsLayout,
Row,
BF16,
I8,
DsDataType,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_grouped_gemm_wmma_fixed_nk_irregular_instances<
BF16,
I8,
Row,
Col,
device_grouped_gemm_wmma_fixed_nk_mk_nk_mn_irregular_instances>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,37 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Row,
DsLayout,
Row,
F16,
F16,
DsDataType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_grouped_gemm_wmma_fixed_nk_instances<
F16,
Row,
Row,
device_grouped_gemm_wmma_fixed_nk_mk_kn_mn_instances>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,37 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Col,
DsLayout,
Row,
F16,
F16,
DsDataType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_grouped_gemm_wmma_fixed_nk_instances<
F16,
Row,
Col,
device_grouped_gemm_wmma_fixed_nk_mk_nk_mn_instances>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,38 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Row,
DsLayout,
Row,
F16,
F8,
DsDataType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_grouped_gemm_wmma_fixed_nk_irregular_instances<
F16,
F8,
Row,
Row,
device_grouped_gemm_wmma_fixed_nk_mk_kn_mn_irregular_instances>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,38 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Col,
DsLayout,
Row,
F16,
F8,
DsDataType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_grouped_gemm_wmma_fixed_nk_irregular_instances<
F16,
F8,
Row,
Col,
device_grouped_gemm_wmma_fixed_nk_mk_nk_mn_irregular_instances>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,38 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Row,
DsLayout,
Row,
F16,
I8,
DsDataType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_grouped_gemm_wmma_fixed_nk_irregular_instances<
F16,
I8,
Row,
Row,
device_grouped_gemm_wmma_fixed_nk_mk_kn_mn_irregular_instances>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,38 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Col,
DsLayout,
Row,
F16,
I8,
DsDataType,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_grouped_gemm_wmma_fixed_nk_irregular_instances<
F16,
I8,
Row,
Col,
device_grouped_gemm_wmma_fixed_nk_mk_nk_mn_irregular_instances>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -30,7 +30,8 @@ using DsLayout = ck::Tuple<>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// static constexpr auto GemmMNKPadding =
// ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_irregular_tile_instances = std::tuple<
// clang-format off
@@ -38,16 +39,16 @@ using device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_irregular_tile_insta
//############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 16,16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 16,16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>
// clang-format on
>;

View File

@@ -30,27 +30,28 @@ using DsLayout = ck::Tuple<>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// static constexpr auto GemmMNKPadding =
// ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_irregular_tile_instances = std::tuple<
// clang-format off
//############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 8, 32, 32, 2, 4, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 64, 8, 8, 32, 32, 4, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 256, 64, 8, 8, 32, 32, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
// //############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// //############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// //############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// //############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F16, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 8, 32, 32, 2, 4, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F16, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F16, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F16, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F16, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 64, 8, 8, 32, 32, 4, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F16, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F16, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F16, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F16, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F16, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 256, 64, 8, 8, 32, 32, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F16, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F16, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F16, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
// clang-format on
>;

View File

@@ -31,7 +31,8 @@ using DsLayout = ck::Tuple<>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// static constexpr auto GemmMNKPadding =
// ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_kn_mn_irregular_tile_instances = std::tuple<
// clang-format off
@@ -39,16 +40,16 @@ using device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_kn_mn_irregular_tile_instan
//############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 16,16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 16,16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>
// clang-format on
>;

View File

@@ -31,27 +31,28 @@ using DsLayout = ck::Tuple<>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// static constexpr auto GemmMNKPadding =
// ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_nk_mn_irregular_tile_instances = std::tuple<
// clang-format off
//############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 8, 32, 32, 2, 4, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 64, 8, 8, 32, 32, 4, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 256, 64, 8, 8, 32, 32, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
// //############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// //############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// //############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// //############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 8, 32, 32, 2, 4, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 64, 8, 8, 32, 32, 4, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 256, 64, 8, 8, 32, 32, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
// clang-format on
>;

View File

@@ -31,7 +31,8 @@ using DsLayout = ck::Tuple<>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// static constexpr auto GemmMNKPadding =
// ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_irregular_tile_instances = std::tuple<
// clang-format off
@@ -39,16 +40,16 @@ using device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_irregular_tile_instan
//############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 16,16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 16,16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>
// clang-format on
>;

View File

@@ -31,27 +31,28 @@ using DsLayout = ck::Tuple<>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// static constexpr auto GemmMNKPadding =
// ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_irregular_tile_instances = std::tuple<
// clang-format off
//############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 8, 32, 32, 2, 4, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 64, 8, 8, 32, 32, 4, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 256, 64, 8, 8, 32, 32, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
// //############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// //############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// //############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// //############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, I8, F32, F16, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 8, 32, 32, 2, 4, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, I8, F32, F16, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, I8, F32, F16, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, I8, F32, F16, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, I8, F32, F16, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 64, 8, 8, 32, 32, 4, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, I8, F32, F16, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, I8, F32, F16, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, I8, F32, F16, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, I8, F32, F16, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, I8, F32, F16, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 256, 64, 8, 8, 32, 32, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, I8, F32, F16, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, I8, F32, F16, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
// DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, I8, F32, F16, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
// clang-format on
>;

View File

@@ -46,7 +46,8 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification,
int n_warmup = 1,
int n_iter = 10)
{
bool pass = true;
bool pass = true;
using ComputeDataType = ADataType;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
@@ -54,11 +55,11 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification,
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
return HostTensorDescriptor({row, col}, {stride, 1_uz}, layout);
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
return HostTensorDescriptor({row, col}, {1_uz, stride}, layout);
}
};
@@ -74,8 +75,8 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification,
std::vector<Tensor<BDataType>> b_k_n;
std::vector<Tensor<CDataType>> c_m_n_host_results;
std::vector<Tensor<CDataType>> c_m_n_device_results;
int sum_of_m = 0;
double max_abs_in_val = 0.f;
int sum_of_m = 0;
for(std::size_t i = 0; i < group_count; i++)
{
sum_of_m += Ms[i];
@@ -95,17 +96,18 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification,
<< i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i
<< "]:" << c_m_n_device_results[i].mDesc << std::endl;
}
std::size_t num_thread = 1;
switch(init_method)
{
case 0: break;
case 1:
a_m_k[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}, num_thread);
b_k_n[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread);
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k[i]);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n[i]);
max_abs_in_val = 5.f;
break;
default:
a_m_k[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}, num_thread);
b_k_n[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
ck::utils::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_m_k[i]);
ck::utils::FillUniformDistribution<BDataType>{-0.5f, 0.5f}(b_k_n[i]);
max_abs_in_val = 1.0f;
}
}
@@ -282,23 +284,18 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification,
bool instance_pass = true;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data());
auto atol = ck::utils::get_absolute_threshold<ComputeDataType, CDataType>(
max_abs_in_val, gemm_descs[i].K_);
auto rtol = ck::utils::get_relative_threshold<ComputeDataType, CDataType>(
gemm_descs[i].K_);
if(std::is_same_v<CDataType, ck::half_t> && kbatch_curr > 1)
{
instance_pass =
instance_pass && ck::utils::check_err(c_m_n_device_results[i],
c_m_n_host_results[i],
"Error: Incorrect results!",
0.06);
}
else
{
instance_pass =
instance_pass && ck::utils::check_err(c_m_n_device_results[i],
c_m_n_host_results[i]);
}
instance_pass =
instance_pass && ck::utils::check_err(c_m_n_device_results[i],
c_m_n_host_results[i],
"Error: Incorrect results!",
rtol,
atol);
if(do_log)
{
@@ -315,7 +312,7 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification,
}
}
std::cout << "Instance: " << gemm_name << " verification "
std::cout << "Instance: " << gemm_name << "; KBatch: " << kbatch_curr << " "
<< (instance_pass ? "SUCCEED" : "FAILED") << std::endl;
pass = pass && instance_pass;
@@ -355,7 +352,8 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification,
}
else
{
std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem"
std::cout << "Instance: " << gemm_name
<< ", does not support this GEMM problem (KBatch: " << kbatch_curr << ")"
<< std::endl;
}
}

View File

@@ -305,4 +305,4 @@ endforeach()
message(VERBOSE "ckProfiler libs: ${PROFILER_LIBS}")
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE ${PROFILER_LIBS})
rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler)
rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler)

View File

@@ -18,6 +18,12 @@ if (CK_USE_XDL OR CK_USE_WMMA)
target_link_libraries(test_grouped_gemm_fastgelu PRIVATE utility device_grouped_gemm_fastgelu_instance)
add_dependencies(test_grouped_gemm test_grouped_gemm_fastgelu)
endif()
add_gtest_executable(test_grouped_gemm_fixed_nk test_grouped_gemm_fixed_nk.cpp)
if(result EQUAL 0)
target_link_libraries(test_grouped_gemm_fixed_nk PRIVATE utility device_grouped_gemm_fixed_nk_instance)
add_dependencies(test_grouped_gemm test_grouped_gemm_fixed_nk)
endif()
endif()
add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp)

View File

@@ -0,0 +1,84 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <tuple>
#include <vector>
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/tuple.hpp"
#include "gtest/gtest.h"
#include "test_grouped_gemm_util.hpp"
ck::index_t param_mask = 0xffffff;
ck::index_t instance_index = -1;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F8 = ck::f8_t;
using I8 = int8_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <typename Tuple>
class TestGroupedGemm
: public ck::test::TestGroupedGemm<Tuple, false, ck::test::FixedNKGroupedGemmProfiler>
{
public:
void SetUp() override
{
ck::test::TestGroupedGemm<Tuple, false, ck::test::FixedNKGroupedGemmProfiler>::SetUp();
#if defined(CK_USE_WMMA)
// The old XDL tests didn't fail if instances were not supported, so we want to keep that
// behaviour When compiling WMMA instances and WMMA is supported, then we'll fail if a
// specific case is not supported
this->fail_if_no_supported_instances_ =
ck::is_gfx11_supported() || ck::is_gfx12_supported();
#endif
}
};
// clang-format off
using KernelTypes = ::testing::Types<
#if CK_USE_OCP_FP8 || CK_USE_FNUZ_FP8 || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || \
defined(CK_USE_WMMA_FP8)
ck::Tuple<Row, Row, Row, F16, F8, F16>,
ck::Tuple<Row, Col, Row, F16, F8, F16>,
#endif
ck::Tuple<Row, Row, Row, F16, F16, F16>,
ck::Tuple<Row, Col, Row, F16, F16, F16>,
ck::Tuple<Row, Row, Row, BF16, BF16, BF16>,
ck::Tuple<Row, Col, Row, BF16, BF16, BF16>,
ck::Tuple<Row, Row, Row, BF16, I8, BF16>,
ck::Tuple<Row, Col, Row, BF16, I8, BF16>,
ck::Tuple<Row, Row, Row, F16, I8, F16>,
ck::Tuple<Row, Col, Row, F16, I8, F16>
>;
// clang-format on
TYPED_TEST_SUITE(TestGroupedGemm, KernelTypes);
#include "test_grouped_gemm_fixed_nk_cases.inc"
int main(int argc, char** argv)
{
testing::InitGoogleTest(&argc, argv);
if(argc == 1) {}
else if(argc == 3)
{
param_mask = strtol(argv[1], nullptr, 0);
instance_index = atoi(argv[2]);
}
else
{
std::cout << "Usage of " << argv[0] << std::endl;
std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl;
}
return RUN_ALL_TESTS();
}

View File

@@ -0,0 +1,84 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
TYPED_TEST(TestGroupedGemm, TinyCases)
{
const std::vector<int> Ms{2, 1};
constexpr int N = 768;
constexpr int K = 544;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
this->Run(Ms, Ns, Ks);
}
TYPED_TEST(TestGroupedGemm, SmallCases)
{
const std::vector<int> Ms{2, 1, 3, 4, 5};
constexpr int N = 768;
constexpr int K = 544;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
this->Run(Ms, Ns, Ks);
}
TYPED_TEST(TestGroupedGemm, MidCases)
{
const std::vector<int> Ms{167, 183, 177, 153, 139, 204};
constexpr int N = 768;
constexpr int K = 544;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
this->Run(Ms, Ns, Ks);
}
TYPED_TEST(TestGroupedGemm, Regular)
{
const std::vector<int> Ms{64, 128, 256};
constexpr int N = 768;
constexpr int K = 320;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
this->Run(Ms, Ns, Ks);
}
TYPED_TEST(TestGroupedGemm, MNKPadded)
{
const std::vector<int> Ms{127, 150, 188, 210};
constexpr int N = 136;
constexpr int K = 280;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
this->Run(Ms, Ns, Ks);
}
TYPED_TEST(TestGroupedGemm, TestLargeKBatch)
{
// In some cases Split K is not supported. Running this test would fail since no instance will
// be supported, so we skip the test
if(!this->IsSplitKSupported())
GTEST_SKIP() << "Split-K not supported for for the current configuration (FP16/BF16 on "
"GFX11, or using CDE element-wise operation)";
const std::vector<int> Ms{188, 210};
constexpr int N = 768;
constexpr int K = 4096;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
this->k_batches_ = {32, 64};
this->Run(Ms, Ns, Ks);
}

View File

@@ -16,6 +16,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "profiler/profile_grouped_gemm_impl.hpp"
#include "profiler/profile_grouped_gemm_fixed_nk_impl.hpp"
extern ck::index_t param_mask;
extern ck::index_t instance_index;
@@ -23,7 +24,124 @@ extern ck::index_t instance_index;
namespace ck {
namespace test {
template <typename Tuple, bool FailIfNoSupportedInstances = false>
struct DefaultGroupedGemmProfiler
{
template <typename ADataType,
typename BDataType,
typename EDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename ELayout,
typename AElementOp,
typename BElementOp,
typename CDEElementOp>
static bool Run(bool verify,
int init_method,
bool log,
bool bench,
const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
const std::vector<int>& kbatches,
int n_warmup,
int n_iter,
int instance_index,
bool fail_if_no_supported_instances)
{
return ck::profiler::profile_grouped_gemm_impl<ADataType,
BDataType,
EDataType,
AccDataType,
ALayout,
BLayout,
ELayout,
AElementOp,
BElementOp,
CDEElementOp>(
verify,
init_method,
log,
bench,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatches,
n_warmup,
n_iter,
instance_index,
fail_if_no_supported_instances);
}
};
struct FixedNKGroupedGemmProfiler
{
template <typename ADataType,
typename BDataType,
typename EDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename CLayout>
static bool Run(bool verify,
int init_method,
bool log,
bool bench,
const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
const std::vector<int>& kbatches,
int n_warmup,
int n_iter,
int /*instance_index*/,
bool /*fail_if_no_supported_instances*/)
{
bool pass = true;
for(int kbatch : kbatches)
{
try
{
pass &= ck::profiler::profile_grouped_gemm_fixed_nk_impl<ADataType,
BDataType,
EDataType,
AccDataType,
ALayout,
BLayout,
CLayout>(verify,
init_method,
log,
bench,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
}
catch(const std::exception& e)
{
std::cerr << e.what() << std::endl;
}
}
return pass;
}
};
template <typename Tuple,
bool FailIfNoSupportedInstances = false,
typename Profiler = ck::test::DefaultGroupedGemmProfiler>
class TestGroupedGemm : public testing::Test
{
protected:
@@ -76,7 +194,7 @@ class TestGroupedGemm : public testing::Test
}
else
{
k_batches_ = {1, 2, 3, 5, 8};
k_batches_ = {1, 2, 3, 4, 8};
}
}
@@ -146,31 +264,61 @@ class TestGroupedGemm : public testing::Test
const std::vector<int>& StrideCs,
const std::vector<int>& kbatches)
{
bool pass =
ck::profiler::profile_grouped_gemm_impl<ADataType,
BDataType,
EDataType,
float,
ALayout,
BLayout,
ELayout,
AElementOp,
BElementOp,
CDEElementOp>(verify_,
init_method_,
log_,
bench_,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatches,
n_warmup_,
n_iter_,
instance_index,
fail_if_no_supported_instances_);
bool pass = false;
using AccDataType = float;
if constexpr(std::is_same_v<Profiler, FixedNKGroupedGemmProfiler>)
{
pass = Profiler::template Run<ADataType,
BDataType,
EDataType,
AccDataType,
ALayout,
BLayout,
ELayout>(verify_,
init_method_,
log_,
bench_,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatches,
n_warmup_,
n_iter_,
instance_index,
fail_if_no_supported_instances_);
}
else
{
pass = Profiler::template Run<ADataType,
BDataType,
EDataType,
AccDataType,
ALayout,
BLayout,
ELayout,
AElementOp,
BElementOp,
CDEElementOp>(verify_,
init_method_,
log_,
bench_,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatches,
n_warmup_,
n_iter_,
instance_index,
fail_if_no_supported_instances_);
}
EXPECT_TRUE(pass);
}
};