mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Implement grouped gemm tile loop for RDNA4 (#3304)
* feat: grouped gemm tile loop support for RDNA4
* fix: removed extra parameter from grouped gemm example instance
* fix: FP8 check incorrectly enabling FP8 on RDNA3
[ROCm/composable_kernel commit: eb041079a3]
This commit is contained in:
@@ -44,6 +44,9 @@ add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_spl
|
||||
add_example_executable(example_grouped_gemm_wmma_splitk_bf16 grouped_gemm_wmma_splitk_bf16.cpp)
|
||||
add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_splitk_bf16)
|
||||
|
||||
add_example_executable(example_grouped_gemm_multiple_d_wmma_fp16 grouped_gemm_multiple_d_wmma_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_multiple_d_wmma_fp16)
|
||||
|
||||
list(APPEND gpu_list_tf32 gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include <ck/utility/data_type.hpp>
|
||||
#include <ck/utility/tuple.hpp>
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm_multiple_d.hpp"
|
||||
|
||||
using ::ck::DeviceMem;
|
||||
using ::ck::hip_check_error;
|
||||
using ::ck::HostTensorDescriptor;
|
||||
using ::ck::Tensor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddAdd = ck::tensor_operation::element_wise::AddAdd;
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using DDataType = F16;
|
||||
using DsDataType = ck::Tuple<DDataType, DDataType>;
|
||||
using EDataType = F16;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using DLayout = Row;
|
||||
using DsLayout = ck::Tuple<DLayout, DLayout>;
|
||||
using ELayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = AddAdd;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
static constexpr int NumDs = 2;
|
||||
|
||||
using DeviceGemmInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3
|
||||
// clang-format off
|
||||
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<4, 4, 4>>;
|
||||
// clang-format on
|
||||
|
||||
#include "run_grouped_gemm_multiple_d_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }
|
||||
@@ -71,339 +71,6 @@ using DeviceGemmInstance =
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<4,4,4>>;
|
||||
// clang-format on
|
||||
|
||||
struct ProblemSize final
|
||||
{
|
||||
std::vector<ck::index_t> Ms;
|
||||
std::vector<ck::index_t> Ns;
|
||||
std::vector<ck::index_t> Ks;
|
||||
#include "run_grouped_gemm_multiple_d_example.inc"
|
||||
|
||||
std::vector<ck::index_t> stride_As;
|
||||
std::vector<ck::index_t> stride_Bs;
|
||||
std::vector<std::vector<ck::index_t>> stride_Ds;
|
||||
std::vector<ck::index_t> stride_Cs;
|
||||
|
||||
ck::index_t group_count;
|
||||
};
|
||||
|
||||
struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
};
|
||||
|
||||
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
auto group_count = problem_size.group_count;
|
||||
|
||||
using KernelArguments = ck::tensor_operation::device::GroupedGemmKernelArgument<NumDs>;
|
||||
using GemmDesc = ck::tensor_operation::device::GemmDesc;
|
||||
|
||||
// GEMM shape
|
||||
std::vector<GemmDesc> gemm_descs;
|
||||
std::vector<KernelArguments> ggemm_kargs;
|
||||
std::vector<void*> p_Cs;
|
||||
std::vector<const void*> p_As;
|
||||
std::vector<const void*> p_Bs;
|
||||
std::vector<std::array<const void*, NumDs>> p_Ds = {};
|
||||
|
||||
gemm_descs.reserve(group_count);
|
||||
ggemm_kargs.reserve(group_count);
|
||||
p_As.reserve(group_count);
|
||||
p_Bs.reserve(group_count);
|
||||
p_Ds.reserve(group_count);
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<Tensor<ADataType>> a_tensors;
|
||||
std::vector<Tensor<BDataType>> b_tensors;
|
||||
std::vector<std::array<Tensor<DDataType>, NumDs>> d_tensors;
|
||||
std::vector<Tensor<EDataType>> c_host_tensors;
|
||||
std::vector<Tensor<EDataType>> c_device_result_tensors;
|
||||
|
||||
a_tensors.reserve(group_count);
|
||||
b_tensors.reserve(group_count);
|
||||
d_tensors.reserve(group_count);
|
||||
c_host_tensors.reserve(group_count);
|
||||
c_device_result_tensors.reserve(group_count);
|
||||
|
||||
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
|
||||
|
||||
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, c_tensors_device;
|
||||
std::vector<std::vector<DeviceMemPtr>> d_tensors_device;
|
||||
|
||||
a_tensors_device.reserve(group_count);
|
||||
b_tensors_device.reserve(group_count);
|
||||
c_tensors_device.reserve(group_count);
|
||||
d_tensors_device.resize(group_count); // reserve and update vector size
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{})));
|
||||
b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{})));
|
||||
|
||||
auto d0_tensor = Tensor<DDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{}));
|
||||
auto d1_tensor = Tensor<DDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{}));
|
||||
|
||||
std::array<Tensor<DDataType>, NumDs> d_tens = {d0_tensor, d1_tensor};
|
||||
d_tensors.push_back(d_tens);
|
||||
c_host_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
|
||||
c_device_result_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
|
||||
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
|
||||
<< " b_k_n: " << b_tensors[i].mDesc
|
||||
<< " c_m_n: " << c_device_result_tensors[i].mDesc << std::endl;
|
||||
|
||||
flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i];
|
||||
num_btype += sizeof(ADataType) * a_tensors[i].GetElementSize() +
|
||||
sizeof(BDataType) * b_tensors[i].GetElementSize() +
|
||||
sizeof(DDataType) * d_tensors[i][0].GetElementSize() * NumDs +
|
||||
sizeof(EDataType) * c_device_result_tensors[i].GetElementSize();
|
||||
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
|
||||
}
|
||||
break;
|
||||
case 2:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3<DDataType>{0.0, 1.0});
|
||||
}
|
||||
break;
|
||||
default:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<DDataType, 0>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
a_tensors_device.emplace_back(
|
||||
std::make_unique<DeviceMem>(a_tensors[i].GetElementSpaceSize() * sizeof(ADataType)));
|
||||
b_tensors_device.emplace_back(
|
||||
std::make_unique<DeviceMem>(b_tensors[i].GetElementSpaceSize() * sizeof(BDataType)));
|
||||
c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
|
||||
c_device_result_tensors[i].GetElementSpaceSize() * sizeof(EDataType)));
|
||||
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors_device[i].emplace_back(std::make_unique<DeviceMem>(
|
||||
d_tensors[i][j].GetElementSpaceSize() * sizeof(DDataType)));
|
||||
}
|
||||
|
||||
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
|
||||
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors_device[i][j]->ToDevice(d_tensors[i][j].mData.data());
|
||||
}
|
||||
c_tensors_device[i]->SetZero();
|
||||
|
||||
p_As.push_back(a_tensors_device[i]->GetDeviceBuffer());
|
||||
p_Bs.push_back(b_tensors_device[i]->GetDeviceBuffer());
|
||||
p_Ds.push_back(
|
||||
{d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()});
|
||||
p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer());
|
||||
|
||||
// The device op does not have to know M problem size at lunch time.
|
||||
gemm_descs.push_back({0,
|
||||
problem_size.Ns[i],
|
||||
problem_size.Ks[i],
|
||||
problem_size.stride_As[i],
|
||||
problem_size.stride_Bs[i],
|
||||
problem_size.stride_Cs[i],
|
||||
{problem_size.stride_Cs[i], problem_size.stride_Cs[i]}});
|
||||
ggemm_kargs.push_back(
|
||||
{a_tensors_device[i]->GetDeviceBuffer(),
|
||||
b_tensors_device[i]->GetDeviceBuffer(),
|
||||
{d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()},
|
||||
c_tensors_device[i]->GetDeviceBuffer(),
|
||||
problem_size.Ms[i],
|
||||
problem_size.Ns[i],
|
||||
problem_size.Ks[i],
|
||||
problem_size.stride_As[i],
|
||||
problem_size.stride_Bs[i],
|
||||
{problem_size.stride_Cs[i], problem_size.stride_Cs[i]},
|
||||
problem_size.stride_Cs[i]});
|
||||
}
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
|
||||
// do GEMM
|
||||
auto argument = gemm.MakeArgument(
|
||||
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op);
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument));
|
||||
hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(),
|
||||
ggemm_kargs.data(),
|
||||
gemm.GetDeviceKernelArgSize(&argument),
|
||||
hipMemcpyHostToDevice));
|
||||
gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer());
|
||||
|
||||
invoker.Run(argument, StreamConfig{nullptr, false, 1});
|
||||
|
||||
bool pass = true;
|
||||
if(config.do_verification)
|
||||
{
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceGemmMultipleD<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
{
|
||||
auto karg = ggemm_kargs[i];
|
||||
auto dev_res_tensor =
|
||||
Tensor<float>(f_host_tensor_descriptor(karg.M, karg.N, karg.StrideE, ELayout{}));
|
||||
c_tensors_device[i]->FromDevice(c_device_result_tensors[i].mData.data());
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(a_tensors[i],
|
||||
b_tensors[i],
|
||||
d_tensors[i],
|
||||
c_host_tensors[i],
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
pass &= ck::utils::check_err(c_device_result_tensors[i], c_host_tensors[i]);
|
||||
}
|
||||
|
||||
std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl;
|
||||
}
|
||||
|
||||
if(config.time_kernel)
|
||||
{
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << gemm.GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
std::vector<int> argToIntArray(char* input)
|
||||
{
|
||||
std::vector<int> out;
|
||||
std::istringstream in(input);
|
||||
std::string item;
|
||||
|
||||
while(std::getline(in, item, ','))
|
||||
{
|
||||
out.push_back(std::stoi(item));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ProblemSize problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
if(argc < 10)
|
||||
{
|
||||
std::vector<ck::index_t> Ms{64, 127, 255, 129, 260, 190, 77};
|
||||
problem_size.group_count = Ms.size();
|
||||
|
||||
for(int i = 0; i < problem_size.group_count; i++)
|
||||
{
|
||||
problem_size.Ms.push_back(Ms[i]);
|
||||
problem_size.Ns.push_back(252);
|
||||
problem_size.Ks.push_back(4608);
|
||||
|
||||
problem_size.stride_As.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
|
||||
|
||||
problem_size.stride_Ds.push_back({});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
problem_size.stride_Ds[i].push_back(problem_size.Ns[i]);
|
||||
}
|
||||
}
|
||||
|
||||
std::cout
|
||||
<< "Usage:\n"
|
||||
<< "arg1: verification (0=no, 1=yes)\n"
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
|
||||
<< "arg3: time kernel (0=n0, 1=yes)\n"
|
||||
<< "arg4 to 9: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
|
||||
"64,64 64,64 128,128)\n"
|
||||
<< "... setting default values." << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
|
||||
problem_size.Ms = argToIntArray(argv[4]);
|
||||
problem_size.Ns = argToIntArray(argv[5]);
|
||||
problem_size.Ks = argToIntArray(argv[6]);
|
||||
|
||||
problem_size.stride_As = argToIntArray(argv[7]);
|
||||
problem_size.stride_Bs = argToIntArray(argv[8]);
|
||||
problem_size.stride_Cs = argToIntArray(argv[9]);
|
||||
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
problem_size.stride_Ds.push_back(problem_size.stride_Cs);
|
||||
}
|
||||
|
||||
problem_size.group_count = problem_size.Ms.size();
|
||||
}
|
||||
|
||||
return !run_grouped_gemm(problem_size, config);
|
||||
}
|
||||
int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }
|
||||
|
||||
@@ -58,11 +58,11 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
|
||||
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_CShuffleV3
|
||||
// clang-format off
|
||||
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>;
|
||||
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>;
|
||||
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -57,11 +57,11 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
|
||||
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_CShuffleV3
|
||||
// clang-format off
|
||||
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>;
|
||||
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>;
|
||||
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -323,8 +323,8 @@ bool run_grouped_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
printf("arg4: async hargs (0=n0, 1=yes)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4: async hargs (0=no, 1=yes)\n");
|
||||
printf("arg5: group count (default=16)\n");
|
||||
#if defined(EXAMPLE_USE_SPLITK)
|
||||
printf("arg6: k-batch count (default=1)\n");
|
||||
|
||||
341
example/15_grouped_gemm/run_grouped_gemm_multiple_d_example.inc
Normal file
341
example/15_grouped_gemm/run_grouped_gemm_multiple_d_example.inc
Normal file
@@ -0,0 +1,341 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
struct ProblemSize final
|
||||
{
|
||||
std::vector<ck::index_t> Ms;
|
||||
std::vector<ck::index_t> Ns;
|
||||
std::vector<ck::index_t> Ks;
|
||||
|
||||
std::vector<ck::index_t> stride_As;
|
||||
std::vector<ck::index_t> stride_Bs;
|
||||
std::vector<std::vector<ck::index_t>> stride_Ds;
|
||||
std::vector<ck::index_t> stride_Cs;
|
||||
|
||||
ck::index_t group_count;
|
||||
};
|
||||
|
||||
struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
};
|
||||
|
||||
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
auto group_count = problem_size.group_count;
|
||||
|
||||
using KernelArguments = ck::tensor_operation::device::GroupedGemmKernelArgument<NumDs>;
|
||||
using GemmDesc = ck::tensor_operation::device::GemmDesc;
|
||||
|
||||
// GEMM shape
|
||||
std::vector<GemmDesc> gemm_descs;
|
||||
std::vector<KernelArguments> ggemm_kargs;
|
||||
std::vector<void*> p_Cs;
|
||||
std::vector<const void*> p_As;
|
||||
std::vector<const void*> p_Bs;
|
||||
std::vector<std::array<const void*, NumDs>> p_Ds = {};
|
||||
|
||||
gemm_descs.reserve(group_count);
|
||||
ggemm_kargs.reserve(group_count);
|
||||
p_As.reserve(group_count);
|
||||
p_Bs.reserve(group_count);
|
||||
p_Ds.reserve(group_count);
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<Tensor<ADataType>> a_tensors;
|
||||
std::vector<Tensor<BDataType>> b_tensors;
|
||||
std::vector<std::array<Tensor<DDataType>, NumDs>> d_tensors;
|
||||
std::vector<Tensor<EDataType>> c_host_tensors;
|
||||
std::vector<Tensor<EDataType>> c_device_result_tensors;
|
||||
|
||||
a_tensors.reserve(group_count);
|
||||
b_tensors.reserve(group_count);
|
||||
d_tensors.reserve(group_count);
|
||||
c_host_tensors.reserve(group_count);
|
||||
c_device_result_tensors.reserve(group_count);
|
||||
|
||||
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
|
||||
|
||||
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, c_tensors_device;
|
||||
std::vector<std::vector<DeviceMemPtr>> d_tensors_device;
|
||||
|
||||
a_tensors_device.reserve(group_count);
|
||||
b_tensors_device.reserve(group_count);
|
||||
c_tensors_device.reserve(group_count);
|
||||
d_tensors_device.resize(group_count); // reserve and update vector size
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{})));
|
||||
b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{})));
|
||||
|
||||
auto d0_tensor = Tensor<DDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{}));
|
||||
auto d1_tensor = Tensor<DDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{}));
|
||||
|
||||
std::array<Tensor<DDataType>, NumDs> d_tens = {d0_tensor, d1_tensor};
|
||||
d_tensors.push_back(d_tens);
|
||||
c_host_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
|
||||
c_device_result_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
|
||||
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
|
||||
<< " b_k_n: " << b_tensors[i].mDesc
|
||||
<< " c_m_n: " << c_device_result_tensors[i].mDesc << std::endl;
|
||||
|
||||
flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i];
|
||||
num_btype += sizeof(ADataType) * a_tensors[i].GetElementSize() +
|
||||
sizeof(BDataType) * b_tensors[i].GetElementSize() +
|
||||
sizeof(DDataType) * d_tensors[i][0].GetElementSize() * NumDs +
|
||||
sizeof(EDataType) * c_device_result_tensors[i].GetElementSize();
|
||||
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
|
||||
}
|
||||
break;
|
||||
case 2:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3<DDataType>{0.0, 1.0});
|
||||
}
|
||||
break;
|
||||
default:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<DDataType, 0>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
a_tensors_device.emplace_back(
|
||||
std::make_unique<DeviceMem>(a_tensors[i].GetElementSpaceSize() * sizeof(ADataType)));
|
||||
b_tensors_device.emplace_back(
|
||||
std::make_unique<DeviceMem>(b_tensors[i].GetElementSpaceSize() * sizeof(BDataType)));
|
||||
c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
|
||||
c_device_result_tensors[i].GetElementSpaceSize() * sizeof(EDataType)));
|
||||
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors_device[i].emplace_back(std::make_unique<DeviceMem>(
|
||||
d_tensors[i][j].GetElementSpaceSize() * sizeof(DDataType)));
|
||||
}
|
||||
|
||||
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
|
||||
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors_device[i][j]->ToDevice(d_tensors[i][j].mData.data());
|
||||
}
|
||||
c_tensors_device[i]->SetZero();
|
||||
|
||||
p_As.push_back(a_tensors_device[i]->GetDeviceBuffer());
|
||||
p_Bs.push_back(b_tensors_device[i]->GetDeviceBuffer());
|
||||
p_Ds.push_back(
|
||||
{d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()});
|
||||
p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer());
|
||||
|
||||
// The device op does not have to know M problem size at lunch time.
|
||||
gemm_descs.push_back({0,
|
||||
problem_size.Ns[i],
|
||||
problem_size.Ks[i],
|
||||
problem_size.stride_As[i],
|
||||
problem_size.stride_Bs[i],
|
||||
problem_size.stride_Cs[i],
|
||||
{problem_size.stride_Cs[i], problem_size.stride_Cs[i]}});
|
||||
ggemm_kargs.push_back(
|
||||
{a_tensors_device[i]->GetDeviceBuffer(),
|
||||
b_tensors_device[i]->GetDeviceBuffer(),
|
||||
{d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()},
|
||||
c_tensors_device[i]->GetDeviceBuffer(),
|
||||
problem_size.Ms[i],
|
||||
problem_size.Ns[i],
|
||||
problem_size.Ks[i],
|
||||
problem_size.stride_As[i],
|
||||
problem_size.stride_Bs[i],
|
||||
{problem_size.stride_Cs[i], problem_size.stride_Cs[i]},
|
||||
problem_size.stride_Cs[i]});
|
||||
}
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
|
||||
// do GEMM
|
||||
auto argument = gemm.MakeArgument(
|
||||
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op);
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument));
|
||||
hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(),
|
||||
ggemm_kargs.data(),
|
||||
gemm.GetDeviceKernelArgSize(&argument),
|
||||
hipMemcpyHostToDevice));
|
||||
gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer());
|
||||
|
||||
invoker.Run(argument, StreamConfig{nullptr, false, 1});
|
||||
|
||||
bool pass = true;
|
||||
if(config.do_verification)
|
||||
{
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceGemmMultipleD<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
{
|
||||
auto karg = ggemm_kargs[i];
|
||||
auto dev_res_tensor =
|
||||
Tensor<float>(f_host_tensor_descriptor(karg.M, karg.N, karg.StrideE, ELayout{}));
|
||||
c_tensors_device[i]->FromDevice(c_device_result_tensors[i].mData.data());
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(a_tensors[i],
|
||||
b_tensors[i],
|
||||
d_tensors[i],
|
||||
c_host_tensors[i],
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
pass &= ck::utils::check_err(c_device_result_tensors[i], c_host_tensors[i]);
|
||||
}
|
||||
|
||||
std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl;
|
||||
}
|
||||
|
||||
if(config.time_kernel)
|
||||
{
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << gemm.GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
std::vector<int> argToIntArray(char* input)
|
||||
{
|
||||
std::vector<int> out;
|
||||
std::istringstream in(input);
|
||||
std::string item;
|
||||
|
||||
while(std::getline(in, item, ','))
|
||||
{
|
||||
out.push_back(std::stoi(item));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
bool run_grouped_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
ProblemSize problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
if(argc < 10)
|
||||
{
|
||||
std::vector<ck::index_t> Ms{64, 127, 255, 129, 260, 190, 77};
|
||||
problem_size.group_count = Ms.size();
|
||||
|
||||
for(int i = 0; i < problem_size.group_count; i++)
|
||||
{
|
||||
problem_size.Ms.push_back(Ms[i]);
|
||||
problem_size.Ns.push_back(252);
|
||||
problem_size.Ks.push_back(4608);
|
||||
|
||||
problem_size.stride_As.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
|
||||
|
||||
problem_size.stride_Ds.push_back({});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
problem_size.stride_Ds[i].push_back(problem_size.Ns[i]);
|
||||
}
|
||||
}
|
||||
|
||||
std::cout
|
||||
<< "Usage:\n"
|
||||
<< "arg1: verification (0=no, 1=yes)\n"
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
|
||||
<< "arg3: time kernel (0=n0, 1=yes)\n"
|
||||
<< "arg4 to 9: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
|
||||
"64,64 64,64 128,128)\n"
|
||||
<< "... setting default values." << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
|
||||
problem_size.Ms = argToIntArray(argv[4]);
|
||||
problem_size.Ns = argToIntArray(argv[5]);
|
||||
problem_size.Ks = argToIntArray(argv[6]);
|
||||
|
||||
problem_size.stride_As = argToIntArray(argv[7]);
|
||||
problem_size.stride_Bs = argToIntArray(argv[8]);
|
||||
problem_size.stride_Cs = argToIntArray(argv[9]);
|
||||
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
problem_size.stride_Ds.push_back(problem_size.stride_Cs);
|
||||
}
|
||||
|
||||
problem_size.group_count = problem_size.Ms.size();
|
||||
}
|
||||
|
||||
return run_grouped_gemm(problem_size, config);
|
||||
}
|
||||
@@ -151,7 +151,10 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
|
||||
static bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; }
|
||||
static bool __host__ __device__ BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
static TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
@@ -707,7 +710,10 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
|
||||
static bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; }
|
||||
__host__ __device__ static bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
static TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
|
||||
@@ -3,6 +3,11 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/stream_utility.hpp"
|
||||
|
||||
#include "device_grouped_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
@@ -43,6 +48,59 @@ struct DeviceGroupedGemmTileLoop : public DeviceGroupedGemm<ALayout,
|
||||
{
|
||||
};
|
||||
|
||||
template <ck::index_t BlockSize>
|
||||
struct TileLoopKernelConfig
|
||||
{
|
||||
// The oversubscription factor for the number of blocks that can simultaneously reside on
|
||||
// GPU.
|
||||
static constexpr int BLOCK_SUBSCRIPTION_FACTOR = 1;
|
||||
// static constexpr int BLOCK_WAVES = BlockSize / get_warp_size();
|
||||
static constexpr int CU_SIMDS = 4;
|
||||
// Assume we want to have at most 2 waves per SIMD
|
||||
// static constexpr int CU_BLOCKS = math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES);
|
||||
static int GetCuBlocks()
|
||||
{
|
||||
int BLOCK_WAVES = BlockSize / get_warp_size();
|
||||
return ck::math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES);
|
||||
}
|
||||
|
||||
template <typename KernelFunction>
|
||||
static int CalculateMaxOccupancyGridSize(const KernelFunction& kernel,
|
||||
const StreamConfig& stream_config)
|
||||
{
|
||||
// Calculate max number of workgroups that can simultaneously reside on the CU.
|
||||
int occ_num_blocks = GetKernelOccupancy(kernel);
|
||||
int cu_count = getAvailableComputeUnitCount(stream_config);
|
||||
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
std::cout << "MaxActiveBlocksPerCU: " << occ_num_blocks
|
||||
<< ", available CUs count: " << cu_count << ", occup. grid size: "
|
||||
<< ck::math::min(occ_num_blocks, GetCuBlocks()) * cu_count << std::endl;
|
||||
}
|
||||
|
||||
return cu_count * ck::math::min(occ_num_blocks, GetCuBlocks());
|
||||
}
|
||||
|
||||
template <typename KernelFunction>
|
||||
static int GetKernelOccupancy(const KernelFunction& kernel)
|
||||
{
|
||||
int occupancy = 0;
|
||||
ck::hip_check_error(
|
||||
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
|
||||
return occupancy;
|
||||
}
|
||||
|
||||
static int GetComputeUnitCount()
|
||||
{
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
ck::hip_check_error(hipGetDevice(&dev));
|
||||
ck::hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
|
||||
return dev_prop.multiProcessorCount;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
@@ -0,0 +1,689 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
#include "ck/host_utility/stream_utility.hpp"
|
||||
#include "ck/utility/loop_scheduler.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
///
|
||||
/// @brief Entry point kernel for device-wide Grouped GEMM operation.
|
||||
///
|
||||
/// @param[in] gemm_descs_const The pointer to the array of GEMM descriptor structures.
|
||||
/// @param[in] group_count The number of together processed GEMMs.
|
||||
///
|
||||
/// @tparam GridwiseGemm The specific GridwiseGEMM algorithm implementation.
|
||||
/// @tparam GemmDesc The structure holding all necessary descriptors and
|
||||
/// other data needed for grouped gemm calculation and work
|
||||
/// distribution.
|
||||
/// @tparam LocalBlock2ETileMap The structure providing mapping between workgroup ids,
|
||||
/// the data tiles to process and the output tiles.
|
||||
///
|
||||
template <typename GridwiseGemm,
|
||||
typename GemmDesc,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
index_t KPerBlock,
|
||||
typename OffsettedBlockToCTileMap,
|
||||
typename LocalBlock2ETileMap,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_grouped_gemm_multiple_d_wmma(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
|
||||
const index_t group_count,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
#if(defined(__gfx11__) || defined(__gfx12__))
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
|
||||
typename GridwiseGemm::EpilogueCShuffle>();
|
||||
__shared__ uint8_t p_shared[LDS_size];
|
||||
|
||||
const auto gemm_desc_ptr =
|
||||
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
|
||||
|
||||
constexpr auto NumDTensor = DsDataType::Size();
|
||||
index_t tile_id = get_block_1d_id();
|
||||
index_t tile_offset = 0;
|
||||
index_t group_id = -1;
|
||||
index_t group_offset = 0;
|
||||
index_t grid_size_grp = 0;
|
||||
|
||||
index_t gemm_tile_id_start = 0;
|
||||
index_t gemm_tile_id_end = 0;
|
||||
|
||||
index_t M = 0, N = 0, K = 0;
|
||||
|
||||
auto b2c_tile_map = OffsettedBlockToCTileMap(LocalBlock2ETileMap(1, 1), 1, 1);
|
||||
|
||||
do
|
||||
{
|
||||
// Find corresponding GEMM group for our tile
|
||||
while(!(tile_id >= gemm_tile_id_start && tile_id < gemm_tile_id_end) &&
|
||||
group_id < group_count)
|
||||
{
|
||||
group_offset += grid_size_grp;
|
||||
group_id++;
|
||||
|
||||
if(group_id >= group_count)
|
||||
return;
|
||||
|
||||
M = gemm_desc_ptr[group_id].M;
|
||||
N = gemm_desc_ptr[group_id].N;
|
||||
K = gemm_desc_ptr[group_id].K;
|
||||
|
||||
if(M == 0 || N == 0 || K == 0)
|
||||
{
|
||||
grid_size_grp = 0;
|
||||
continue;
|
||||
}
|
||||
|
||||
b2c_tile_map =
|
||||
OffsettedBlockToCTileMap(LocalBlock2ETileMap(M, N, 4), group_offset, tile_offset);
|
||||
grid_size_grp = b2c_tile_map.CalculateGridSize(M, N);
|
||||
|
||||
gemm_tile_id_start = group_offset;
|
||||
gemm_tile_id_end = group_offset + grid_size_grp;
|
||||
}
|
||||
|
||||
// Create A&B grid pointer containing their single tensors
|
||||
typename GridwiseGemm::AsGridPointer p_as_grid = Tuple<const ADataType*>(
|
||||
static_cast<const ADataType*>(gemm_desc_ptr[group_id].p_a_grid));
|
||||
typename GridwiseGemm::BsGridPointer p_bs_grid = Tuple<const BDataType*>(
|
||||
static_cast<const BDataType*>(gemm_desc_ptr[group_id].p_b_grid));
|
||||
|
||||
// Make a DsGridPointer instance containing all D tensors
|
||||
using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer());
|
||||
DsGridPointer p_ds_grid;
|
||||
std::array<index_t, NumDTensor> stride_Ds;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
p_ds_grid(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]);
|
||||
stride_Ds[i] = gemm_desc_ptr[group_id].StrideDs[i];
|
||||
});
|
||||
|
||||
index_t K_split = ck::math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
// Update tile offset if we have moved within group
|
||||
b2c_tile_map.UpdateTileOffset(tile_offset);
|
||||
|
||||
using Problem = typename GridwiseGemm::Problem;
|
||||
auto problem = Problem(gemm_desc_ptr[group_id].M,
|
||||
gemm_desc_ptr[group_id].N,
|
||||
gemm_desc_ptr[group_id].K,
|
||||
std::array<index_t, 1>{gemm_desc_ptr[group_id].StrideA},
|
||||
std::array<index_t, 1>{gemm_desc_ptr[group_id].StrideB},
|
||||
stride_Ds,
|
||||
gemm_desc_ptr[group_id].StrideE,
|
||||
1);
|
||||
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
constexpr TailNumber TailNum = TailNumber::Full;
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
GridwiseGemm::template Run<true, InMemoryDataOperationEnum::Set, TailNum>(
|
||||
p_as_grid,
|
||||
p_bs_grid,
|
||||
p_ds_grid,
|
||||
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
|
||||
static_cast<void*>(p_shared),
|
||||
problem,
|
||||
b2c_tile_map,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
epilogue_args);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
GridwiseGemm::template Run<false, InMemoryDataOperationEnum::Set, TailNum>(
|
||||
p_as_grid,
|
||||
p_bs_grid,
|
||||
p_ds_grid,
|
||||
static_cast<EDataType*>(gemm_desc_ptr[group_id].p_e_grid),
|
||||
static_cast<void*>(p_shared),
|
||||
problem,
|
||||
b2c_tile_map,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
epilogue_args);
|
||||
}
|
||||
}
|
||||
|
||||
tile_id += get_grid_size();
|
||||
tile_offset += get_grid_size();
|
||||
|
||||
} while(group_id < group_count);
|
||||
#else
|
||||
ignore = gemm_descs_const;
|
||||
ignore = group_count;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = cde_element_op;
|
||||
#endif // end of if (defined(__gfx11__) || defined(__gfx12__))
|
||||
}
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t AK1,
|
||||
ck::index_t BK1,
|
||||
ck::index_t MPerWmma,
|
||||
ck::index_t NPerWmma,
|
||||
ck::index_t MRepeat,
|
||||
ck::index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
index_t ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
index_t BBlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
typename ComputeTypeA = EDataType,
|
||||
typename ComputeTypeB = ComputeTypeA>
|
||||
|
||||
struct DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3
|
||||
: public DeviceGroupedGemmTileLoop<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3<
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
Tuple<ADataType>,
|
||||
Tuple<BDataType>,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
false, // PermuteA not supported by GridwiseOp.
|
||||
false>; // PermuteB not supported by DeviceGroupedGemmTileLoop base class.
|
||||
|
||||
using KernelConfig = TileLoopKernelConfig<BlockSize>;
|
||||
using KernelArguments = GroupedGemmKernelArgument<NumDTensor>;
|
||||
using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
|
||||
using OffsettedLocalBlock2ETileMap = OffsettedBlockToCTileMap2<Block2ETileMap>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(std::vector<const void*>& /* p_As */,
|
||||
std::vector<const void*>& /* p_Bs */,
|
||||
std::vector<std::array<const void*, NumDTensor>>& /* p_Ds */,
|
||||
std::vector<void*>& /* p_Es */,
|
||||
const std::vector<GemmDesc>& gemm_descs,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
int occupancy_num_blocks,
|
||||
int gpu_cu_count)
|
||||
: group_count_{static_cast<index_t>(gemm_descs.size())},
|
||||
occupancy_num_blocks_{occupancy_num_blocks},
|
||||
gpu_cu_count_{gpu_cu_count},
|
||||
gemm_descs_{gemm_descs},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op},
|
||||
tile_count_{0}
|
||||
{
|
||||
for(const auto& desc : gemm_descs)
|
||||
{
|
||||
const auto M = desc.M_;
|
||||
const auto N = desc.N_;
|
||||
const auto b2c_tile_map = Block2ETileMap(M, N);
|
||||
tile_count_ += b2c_tile_map.CalculateGridSize(M, N);
|
||||
}
|
||||
}
|
||||
|
||||
index_t group_count_;
|
||||
const void* p_dev_gemm_args_;
|
||||
int occupancy_num_blocks_;
|
||||
int gpu_cu_count_;
|
||||
const std::vector<GemmDesc>& gemm_descs_;
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
index_t tile_count_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
///
|
||||
/// @brief Launch Grouped Gemm kernel.
|
||||
///
|
||||
/// @note This function overload is using user provided device buffer for kernel
|
||||
/// arguments.
|
||||
///
|
||||
/// @param[in] arg The structure containing kernel arguments (in host
|
||||
/// memory).
|
||||
/// @param[in] dev_gemm_args The pointer to device memory with kernel arguments.
|
||||
/// @param[in] stream_config The device stream configuration.
|
||||
///
|
||||
/// @return The average kernel execution time (if time measurement is enabled.)
|
||||
///
|
||||
float Run(const Argument& arg,
|
||||
const void* dev_gemm_args,
|
||||
const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(dev_gemm_args == nullptr)
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "The gemm arguments device buffer is not allocated!" << " In " << __FILE__
|
||||
<< ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
|
||||
const auto kernel = GetKernelFunction();
|
||||
|
||||
int grid_size = KernelConfig::CalculateMaxOccupancyGridSize(kernel, stream_config);
|
||||
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
std::cout << "grid_size: " << grid_size << " tile_count: " << arg.tile_count_
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// run multiple kernels
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
cast_pointer_to_constant_address_space(dev_gemm_args),
|
||||
arg.group_count_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_);
|
||||
}
|
||||
|
||||
///
|
||||
/// @brief Launch Grouped Gemm kernel.
|
||||
///
|
||||
/// @note This function overload is using device buffers (for kernel arguments and
|
||||
/// for kernel auxiliary workspace) provided with an argument. The user should
|
||||
/// call @see GetDeviceKernelArgSize, and @see SetDeviceKernelArgs, on arg
|
||||
/// parameter to properly allocate those buffers.
|
||||
///
|
||||
/// @param[in] arg The structure containing kernel arguments (in host memory).
|
||||
/// @param[in] stream_config The device stream configuration.
|
||||
///
|
||||
/// @return The average kernel execution time (if time measurement is enabled.)
|
||||
///
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(arg.p_dev_gemm_args_ == nullptr)
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "The gemm arguments device buffer is not allocated!" << " In " << __FILE__
|
||||
<< ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
|
||||
return Run(arg, arg.p_dev_gemm_args_, stream_config);
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static auto GetKernelFunction()
|
||||
{
|
||||
const auto kernel = kernel_grouped_gemm_multiple_d_wmma<GridwiseGemm,
|
||||
KernelArguments,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
KPerBlock,
|
||||
OffsettedLocalBlock2ETileMap,
|
||||
Block2ETileMap,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer>;
|
||||
return kernel;
|
||||
}
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if constexpr(std::is_same_v<ComputeTypeA, f8_t> || std::is_same_v<ComputeTypeA, bf8_t> ||
|
||||
std::is_same_v<ComputeTypeB, f8_t> || std::is_same_v<ComputeTypeB, bf8_t>)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool supported = true;
|
||||
for(index_t i = 0; i < arg.group_count_; ++i)
|
||||
{
|
||||
std::array<const void*, NumDTensor> placeholder_p_ds_grid{};
|
||||
std::array<index_t, NumDTensor> stride_Ds;
|
||||
std::copy_n(arg.gemm_descs_[i].stride_Ds_.begin(), NumDTensor, stride_Ds.begin());
|
||||
|
||||
typename GridwiseGemm::Argument gridwise_arg(
|
||||
std::array<const void*, 1>{nullptr}, // p_a_grid,
|
||||
std::array<const void*, 1>{nullptr}, // p_b_grid,
|
||||
placeholder_p_ds_grid, // p_ds_grid,
|
||||
nullptr, // p_e_grid ,
|
||||
arg.gemm_descs_[i].M_,
|
||||
arg.gemm_descs_[i].N_,
|
||||
arg.gemm_descs_[i].K_,
|
||||
std::array<index_t, 1>{arg.gemm_descs_[i].stride_A_},
|
||||
std::array<index_t, 1>{arg.gemm_descs_[i].stride_B_},
|
||||
stride_Ds,
|
||||
arg.gemm_descs_[i].stride_C_,
|
||||
1, // KBatch
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
false);
|
||||
|
||||
bool group_arg_valid = GridwiseGemm::CheckValidity(gridwise_arg);
|
||||
supported = supported && group_arg_valid;
|
||||
|
||||
if(!group_arg_valid)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "[" << __func__ << "] group id: " << i
|
||||
<< " has invalid GridwiseGemm settings!" << std::endl;
|
||||
gridwise_arg.Print();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return supported;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static int GetKernelOccupancy()
|
||||
{
|
||||
const auto kernel = GetKernelFunction();
|
||||
return KernelConfig::GetKernelOccupancy(kernel);
|
||||
}
|
||||
|
||||
static auto MakeArgument(std::vector<const void*>& p_As,
|
||||
std::vector<const void*>& p_Bs,
|
||||
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
|
||||
std::vector<void*>& p_Es,
|
||||
std::vector<GemmDesc>& gemm_descs,
|
||||
AElementwiseOperation a_elementwise_op,
|
||||
BElementwiseOperation b_elementwise_op,
|
||||
CDEElementwiseOperation cde_elementwise_op)
|
||||
{
|
||||
int occupancy = GetKernelOccupancy();
|
||||
int num_cu = KernelConfig::GetComputeUnitCount();
|
||||
|
||||
return Argument{p_As,
|
||||
p_Bs,
|
||||
p_Ds,
|
||||
p_Es,
|
||||
gemm_descs,
|
||||
a_elementwise_op,
|
||||
b_elementwise_op,
|
||||
cde_elementwise_op,
|
||||
occupancy,
|
||||
num_cu};
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(std::vector<const void*>& p_As,
|
||||
std::vector<const void*>& p_Bs,
|
||||
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
|
||||
std::vector<void*>& p_Es,
|
||||
std::vector<GemmDesc>& gemm_descs,
|
||||
AElementwiseOperation a_elementwise_op,
|
||||
BElementwiseOperation b_elementwise_op,
|
||||
CDEElementwiseOperation cde_elementwise_op) override
|
||||
{
|
||||
int occupancy = GetKernelOccupancy();
|
||||
int num_cu = KernelConfig::GetComputeUnitCount();
|
||||
|
||||
return std::make_unique<Argument>(p_As,
|
||||
p_Bs,
|
||||
p_Ds,
|
||||
p_Es,
|
||||
gemm_descs,
|
||||
a_elementwise_op,
|
||||
b_elementwise_op,
|
||||
cde_elementwise_op,
|
||||
occupancy,
|
||||
num_cu);
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::ostringstream();
|
||||
|
||||
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
|
||||
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
|
||||
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
|
||||
{BlockGemmPipelineVersion::v1, "v1"},
|
||||
{BlockGemmPipelineVersion::v2, "v2"},
|
||||
{BlockGemmPipelineVersion::v3, "v3"},
|
||||
{BlockGemmPipelineVersion::v4, "v4"},
|
||||
{BlockGemmPipelineVersion::v5, "v5"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3"
|
||||
<< "<"
|
||||
<< std::string(ALayout::name)[0] << ","
|
||||
<< std::string(BLayout::name)[0] << ","
|
||||
<< std::string(ELayout::name)[0] << ","
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< AK1 << ", "
|
||||
<< BK1 << ", "
|
||||
<< MPerWmma << ", "
|
||||
<< NPerWmma << ", "
|
||||
<< MRepeat << ", "
|
||||
<< NRepeat << ", "
|
||||
<< ABlockTransferSrcScalarPerVector << ", "
|
||||
<< BBlockTransferSrcScalarPerVector << ", "
|
||||
<< CShuffleMRepeatPerShuffle << ", "
|
||||
<< CShuffleNRepeatPerShuffle << ", "
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer]
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
|
||||
void SetDeviceKernelArgs(Argument& arg,
|
||||
void* p_dev_kernel_args,
|
||||
const void* p_host_kernel_args) const
|
||||
{
|
||||
arg.p_dev_gemm_args_ = p_dev_kernel_args;
|
||||
hip_check_error(hipMemcpyAsync(p_dev_kernel_args,
|
||||
p_host_kernel_args,
|
||||
GetDeviceKernelArgSize(&arg),
|
||||
hipMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
virtual void SetDeviceKernelArgs(BaseArgument* p_arg,
|
||||
void* p_dev_kernel_args,
|
||||
const void* p_host_kernel_args) const override
|
||||
{
|
||||
return SetDeviceKernelArgs(
|
||||
*dynamic_cast<Argument*>(p_arg), p_dev_kernel_args, p_host_kernel_args);
|
||||
}
|
||||
|
||||
void SetDeviceKernelArgs(Argument& arg, void* p_dev_kernel_args) const
|
||||
{
|
||||
arg.p_dev_gemm_args_ = p_dev_kernel_args;
|
||||
}
|
||||
|
||||
virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override
|
||||
{
|
||||
return SetDeviceKernelArgs(*dynamic_cast<Argument*>(p_arg), p_dev_kernel_args);
|
||||
}
|
||||
|
||||
size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
|
||||
{
|
||||
return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(KernelArguments);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <optional>
|
||||
#include <sstream>
|
||||
#include <tuple>
|
||||
|
||||
@@ -26,6 +27,18 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// Dummy kernel to use as a fallback in the kernel selection logic
|
||||
// Is not used in practice, but only used in case of misconfigured parameters
|
||||
template <typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
__global__ void kernel_dummy(const void CK_CONSTANT_ADDRESS_SPACE*,
|
||||
const index_t,
|
||||
const AElementwiseOperation,
|
||||
const BElementwiseOperation,
|
||||
const CDEElementwiseOperation)
|
||||
{
|
||||
}
|
||||
///
|
||||
/// @brief Entry point kernel for device-wide Grouped GEMM operation.
|
||||
///
|
||||
@@ -528,6 +541,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
|
||||
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
|
||||
|
||||
using KernelConfig = TileLoopKernelConfig<BlockSize>;
|
||||
using KernelArguments = GroupedGemmKernelArgument<NumDTensor>;
|
||||
using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
|
||||
using OffsettedLocalBlock2ETileMap = OffsettedBlockToCTileMap2<Block2ETileMap>;
|
||||
@@ -574,22 +588,6 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
index_t tile_count_;
|
||||
};
|
||||
|
||||
struct KernelConfig
|
||||
{
|
||||
// The oversubscription factor for the number of blocks that can simultaneously reside on
|
||||
// GPU.
|
||||
static constexpr int BLOCK_SUBSCRIPTION_FACTOR = 1;
|
||||
// static constexpr int BLOCK_WAVES = BlockSize / get_warp_size();
|
||||
static constexpr int CU_SIMDS = 4;
|
||||
// Assume we want to have at most 2 waves per SIMD
|
||||
// static constexpr int CU_BLOCKS = math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES);
|
||||
static int GetCuBlocks()
|
||||
{
|
||||
int BLOCK_WAVES = BlockSize / get_warp_size();
|
||||
return math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES);
|
||||
}
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
@@ -666,58 +664,17 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
const void* dev_gemm_args,
|
||||
const StreamConfig& stream_config) const
|
||||
{
|
||||
const auto kernel = kernel_grouped_gemm_multiple_d_xdl<GridwiseGemm,
|
||||
KernelArguments,
|
||||
GemmSpec,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
KPerBlock,
|
||||
OffsettedLocalBlock2ETileMap,
|
||||
Block2ETileMap,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer>;
|
||||
const auto kernel = GetKernelFunction<GridwiseGemm>();
|
||||
return LaunchKernel(kernel, arg, dev_gemm_args, stream_config);
|
||||
}
|
||||
|
||||
template <typename KernelFunction>
|
||||
int CalculateMaxOccupancyGridSize(const KernelFunction& kernel,
|
||||
const StreamConfig& stream_config) const
|
||||
{
|
||||
// Calculate max number of workgroups that can simultaneously reside on the CU.
|
||||
int occ_num_blocks = 0;
|
||||
size_t dyn_shared_mem_per_blk = 0;
|
||||
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&occ_num_blocks, kernel, BlockSize, dyn_shared_mem_per_blk));
|
||||
|
||||
int cu_count = getAvailableComputeUnitCount(stream_config);
|
||||
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
std::cout << "MaxActiveBlocksPerCU: " << occ_num_blocks
|
||||
<< ", available CUs count: " << cu_count << ", occup. grid size: "
|
||||
<< ck::math::min(occ_num_blocks, KernelConfig::GetCuBlocks()) * cu_count
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
return cu_count * ck::math::min(occ_num_blocks, KernelConfig::GetCuBlocks());
|
||||
}
|
||||
|
||||
template <typename KernelFunction>
|
||||
float LaunchKernel(const KernelFunction& kernel,
|
||||
const Argument& arg,
|
||||
const void* dev_gemm_args,
|
||||
const StreamConfig& stream_config) const
|
||||
{
|
||||
int grid_size = CalculateMaxOccupancyGridSize(kernel, stream_config);
|
||||
int grid_size = KernelConfig::CalculateMaxOccupancyGridSize(kernel, stream_config);
|
||||
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
@@ -835,65 +792,60 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static int GetKernelOccupancy()
|
||||
template <typename GridwiseGemm>
|
||||
static auto GetKernelFunction()
|
||||
{
|
||||
const auto kernel = kernel_grouped_gemm_multiple_d_xdl<GridwiseGemm,
|
||||
KernelArguments,
|
||||
GemmSpec,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
KPerBlock,
|
||||
OffsettedLocalBlock2ETileMap,
|
||||
Block2ETileMap,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer>;
|
||||
return kernel;
|
||||
}
|
||||
|
||||
static auto GetKernelFunction()
|
||||
{
|
||||
int occupancy = 0;
|
||||
if(get_warp_size() == 64)
|
||||
{
|
||||
if constexpr(NXdlPerWave64 > 0)
|
||||
{
|
||||
const auto kernel = kernel_grouped_gemm_multiple_d_xdl<GridwiseGemm64,
|
||||
KernelArguments,
|
||||
GemmSpec,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
KPerBlock,
|
||||
OffsettedLocalBlock2ETileMap,
|
||||
Block2ETileMap,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer>;
|
||||
hip_check_error(
|
||||
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
|
||||
const auto kernel = GetKernelFunction<GridwiseGemm64>();
|
||||
return kernel;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
if constexpr(NXdlPerWave32 > 0)
|
||||
{
|
||||
const auto kernel = kernel_grouped_gemm_multiple_d_xdl<GridwiseGemm32,
|
||||
KernelArguments,
|
||||
GemmSpec,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
KPerBlock,
|
||||
OffsettedLocalBlock2ETileMap,
|
||||
Block2ETileMap,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer>;
|
||||
hip_check_error(
|
||||
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
|
||||
const auto kernel = GetKernelFunction<GridwiseGemm32>();
|
||||
return kernel;
|
||||
}
|
||||
}
|
||||
return occupancy;
|
||||
|
||||
// This is here to handle the case where MXdlPerWave/NxdPerWave is too small
|
||||
// This is caught by IsSupportedArgument(), but as GetKernelFunction is sometimes called
|
||||
// before we need a fallback kernel to return here.
|
||||
return kernel_dummy<AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation>;
|
||||
}
|
||||
|
||||
static int GetKernelOccupancy()
|
||||
{
|
||||
const auto kernel = GetKernelFunction();
|
||||
return KernelConfig::GetKernelOccupancy(kernel);
|
||||
}
|
||||
|
||||
static auto MakeArgument(std::vector<const void*>& p_As,
|
||||
@@ -906,13 +858,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
CDEElementwiseOperation cde_elementwise_op)
|
||||
{
|
||||
int occupancy = GetKernelOccupancy();
|
||||
int num_cu;
|
||||
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
hip_check_error(hipGetDevice(&dev));
|
||||
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
|
||||
num_cu = dev_prop.multiProcessorCount;
|
||||
int num_cu = KernelConfig::GetComputeUnitCount();
|
||||
|
||||
return Argument{p_As,
|
||||
p_Bs,
|
||||
@@ -937,13 +883,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
CDEElementwiseOperation cde_elementwise_op) override
|
||||
{
|
||||
int occupancy = GetKernelOccupancy();
|
||||
int num_cu;
|
||||
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
hip_check_error(hipGetDevice(&dev));
|
||||
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
|
||||
num_cu = dev_prop.multiProcessorCount;
|
||||
int num_cu = KernelConfig::GetComputeUnitCount();
|
||||
|
||||
return std::make_unique<Argument>(p_As,
|
||||
p_Bs,
|
||||
|
||||
@@ -126,7 +126,6 @@ template <typename ALayout,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::index_t NumGemmKPrefetchStage,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
@@ -158,9 +157,7 @@ template <typename ALayout,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
typename ComputeTypeA = EDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
bool PermuteA = false,
|
||||
bool PermuteB = false>
|
||||
typename ComputeTypeB = ComputeTypeA>
|
||||
struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
@@ -231,8 +228,8 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
false, // PermuteA not supported by DeviceBatchedGemm base class.
|
||||
false>; // PermuteB not supported by DeviceBatchedGemm base class.
|
||||
false, // PermuteA not supported by GridwiseOp
|
||||
false>; // PermuteB not supported by DeviceGroupedGemm base class
|
||||
|
||||
using CGridDesc_M_N =
|
||||
remove_cvref_t<decltype(GridwiseGemm::template MakeDEGridDescriptor_M_N<ELayout>(
|
||||
@@ -779,7 +776,7 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayou
|
||||
{BlockGemmPipelineVersion::v5, "v5"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGroupedGemm_WmmaSplitK"
|
||||
str << "DeviceGroupedGemm_Wmma_CShuffleV3"
|
||||
<< "<"
|
||||
<< std::string(ALayout::name)[0] << ","
|
||||
<< std::string(BLayout::name)[0] << ","
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/quantization_operation.hpp"
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -236,8 +237,9 @@ struct MultiplyAdd
|
||||
const half_t& d0,
|
||||
const half_t& d1) const
|
||||
{
|
||||
const half_t y = type_convert<half_t>(c) * d0 + d1;
|
||||
e = y;
|
||||
const half_t y =
|
||||
type_convert<half_t>(c * type_convert<float>(d0) + type_convert<float>(d1));
|
||||
e = y;
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, float, bhalf_t, bhalf_t>(bhalf_t& e,
|
||||
@@ -245,8 +247,9 @@ struct MultiplyAdd
|
||||
const bhalf_t& d0,
|
||||
const bhalf_t& d1) const
|
||||
{
|
||||
const bhalf_t y = type_convert<bhalf_t>(c) * d0 + d1;
|
||||
e = y;
|
||||
const bhalf_t y =
|
||||
type_convert<bhalf_t>(c * type_convert<float>(d0) + type_convert<float>(d1));
|
||||
e = y;
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ void operator()<float, float, half_t, half_t>(float& e,
|
||||
|
||||
@@ -334,14 +334,14 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
struct Problem
|
||||
{
|
||||
__host__ Problem() = default;
|
||||
__host__ Problem(index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
std::array<index_t, NumATensor> StrideAs_,
|
||||
std::array<index_t, NumBTensor> StrideBs_,
|
||||
std::array<index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideE_,
|
||||
index_t KBatch_)
|
||||
__host__ __device__ Problem(index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
std::array<index_t, NumATensor> StrideAs_,
|
||||
std::array<index_t, NumBTensor> StrideBs_,
|
||||
std::array<index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideE_,
|
||||
index_t KBatch_)
|
||||
: M{M_},
|
||||
N{N_},
|
||||
K{K_},
|
||||
|
||||
@@ -351,64 +351,65 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
|
||||
// Calculate grid size taking into account splitk (KBatch)
|
||||
// 2D grid (x,z)
|
||||
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
|
||||
__host__ __device__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
|
||||
{
|
||||
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch);
|
||||
}
|
||||
|
||||
// Calculate grid size taking into account splitk (KBatch) and multiple groups (Batch)
|
||||
// 3D grid (x,y,z)
|
||||
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch)
|
||||
__host__ __device__ static auto
|
||||
CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch)
|
||||
{
|
||||
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), KBatch, Batch);
|
||||
}
|
||||
|
||||
__host__ static auto CalculateMPadded(index_t M)
|
||||
__host__ __device__ static auto CalculateMPadded(index_t M)
|
||||
{
|
||||
return math::integer_least_multiple(M, MPerBlock);
|
||||
}
|
||||
|
||||
__host__ static auto CalculateNPadded(index_t N)
|
||||
__host__ __device__ static auto CalculateNPadded(index_t N)
|
||||
{
|
||||
return math::integer_least_multiple(N, NPerBlock);
|
||||
}
|
||||
|
||||
__host__ static auto CalculateKPadded(index_t K)
|
||||
__host__ __device__ static auto CalculateKPadded(index_t K)
|
||||
{
|
||||
return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
|
||||
}
|
||||
|
||||
__host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
|
||||
__host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
|
||||
{
|
||||
auto K_t = K_Batch * KPerBlock;
|
||||
return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
|
||||
}
|
||||
|
||||
__host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
|
||||
__host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
|
||||
{
|
||||
auto K_t = K_Batch * KPerBlock;
|
||||
return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
|
||||
}
|
||||
|
||||
__host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
|
||||
__host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
|
||||
{
|
||||
auto K_t = K_Batch * KPerBlock;
|
||||
return (K + K_t - 1) / K_t * KPerBlock;
|
||||
}
|
||||
|
||||
__host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
|
||||
__host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
|
||||
{
|
||||
constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
|
||||
auto K_t = K_Batch * KReadVec;
|
||||
return (K + K_t - 1) / K_t * KReadVec;
|
||||
}
|
||||
|
||||
__host__ static auto CalculateMBlock(index_t M)
|
||||
__host__ __device__ static auto CalculateMBlock(index_t M)
|
||||
{
|
||||
return math::integer_divide_ceil(M, MPerBlock);
|
||||
}
|
||||
|
||||
__host__ static auto CalculateNBlock(index_t N)
|
||||
__host__ __device__ static auto CalculateNBlock(index_t N)
|
||||
{
|
||||
return math::integer_divide_ceil(N, NPerBlock);
|
||||
}
|
||||
@@ -963,14 +964,14 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
|
||||
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
|
||||
{
|
||||
const index_t num_loop = K / KPerBlock;
|
||||
|
||||
return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
|
||||
}
|
||||
|
||||
__host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
|
||||
__host__ __device__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
|
||||
{
|
||||
const index_t num_loop = K / KPerBlock;
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "ck/utility/sequence.hpp"
|
||||
#include "ck/utility/type.hpp"
|
||||
#include "ck/utility/enable_if.hpp"
|
||||
#include <tuple>
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -220,4 +221,49 @@ constexpr Tuple<Args&...> tie(Args&... args) noexcept
|
||||
return {args...};
|
||||
}
|
||||
|
||||
//
|
||||
// tuple_map: Map tuple with a different type
|
||||
// e.g. tuple_map<Wrapper, Tuple<T1, T2, T3>> becomes Tuple<Wrapper<T1>, Wrapper<T2>, Wrapper<T3>>
|
||||
//
|
||||
template <template <typename> class Wrapper, typename Tuple>
|
||||
struct tuple_map;
|
||||
|
||||
template <template <typename> class Wrapper, typename... Ts>
|
||||
struct tuple_map<Wrapper, Tuple<Ts...>>
|
||||
{
|
||||
using type = Tuple<Wrapper<Ts>...>;
|
||||
};
|
||||
|
||||
template <template <typename> class Wrapper, typename Tuple>
|
||||
using tuple_map_t = typename tuple_map<Wrapper, Tuple>::type;
|
||||
|
||||
//
|
||||
// tuple_element_or: helper to access type element of a tuple by index, with the option to default
|
||||
// to a type if the index is out of range of the tuple size
|
||||
//
|
||||
namespace detail {
|
||||
|
||||
// Base template (will be specialized on the boolean)
|
||||
template <ck::index_t N, typename Tuple, typename Default, bool InRange = (N < Tuple::Size())>
|
||||
struct tuple_element_or_impl;
|
||||
|
||||
// Specialization for the in-range case: use tuple_element_t
|
||||
template <ck::index_t N, typename Tuple, typename Default>
|
||||
struct tuple_element_or_impl<N, Tuple, Default, true>
|
||||
{
|
||||
using type = tuple_element_t<N, Tuple>;
|
||||
};
|
||||
|
||||
// Specialization for the out-of-range case: use Default
|
||||
template <ck::index_t N, typename Tuple, typename Default>
|
||||
struct tuple_element_or_impl<N, Tuple, Default, false>
|
||||
{
|
||||
using type = Default;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
// User-facing alias
|
||||
template <ck::index_t N, typename Tuple, typename Default>
|
||||
using tuple_element_or_t = typename detail::tuple_element_or_impl<N, Tuple, Default>::type;
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -69,7 +69,7 @@ void add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_nk_mn_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif // CK_ENABLE_FP16
|
||||
#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8) && defined(__gfx12__)
|
||||
#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8) && defined(CK_USE_WMMA_FP8)
|
||||
void add_device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
|
||||
Row,
|
||||
@@ -572,7 +572,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
}
|
||||
}
|
||||
#endif // CK_ENABLE_FP16
|
||||
#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8) && defined(__gfx12__)
|
||||
#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8) && defined(CK_USE_WMMA_FP8)
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> &&
|
||||
is_same_v<EDataType, half_t>)
|
||||
{
|
||||
|
||||
@@ -55,16 +55,15 @@ template <typename T,
|
||||
typename BElementOp,
|
||||
typename CDEElementOp,
|
||||
enable_if_t<sizeof(T) == 2, bool> = false>
|
||||
using device_grouped_gemm_wmma_universal_km_kn_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Col, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Col, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Col, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>
|
||||
using device_grouped_gemm_wmma_universal_km_kn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//##############################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Col, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Col, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Col, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>
|
||||
// clang`-format on
|
||||
>;
|
||||
|
||||
@@ -79,15 +78,15 @@ template <typename T,
|
||||
enable_if_t<sizeof(T) == 2, bool> = false>
|
||||
using device_grouped_gemm_wmma_universal_km_nk_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Col, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Col, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Col, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>
|
||||
// clang-format on
|
||||
>;
|
||||
//##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//##############################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Col, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Col, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Col, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
// Instances for 2 byte datatypes in RRR layout with ADataType = BDataType = EDataType
|
||||
template <typename T,
|
||||
@@ -98,18 +97,17 @@ template <typename T,
|
||||
typename BElementOp,
|
||||
typename CDEElementOp,
|
||||
enable_if_t<sizeof(T) == 2, bool> = false>
|
||||
using device_grouped_gemm_wmma_universal_mk_kn_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>
|
||||
// clang-format on
|
||||
>;
|
||||
using device_grouped_gemm_wmma_universal_mk_kn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//##############################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
// Instances for 2 byte datatypes in RCR layout with ADataType = BDataType = EDataType
|
||||
template <typename T,
|
||||
@@ -120,18 +118,17 @@ template <typename T,
|
||||
typename BElementOp,
|
||||
typename CDEElementOp,
|
||||
enable_if_t<sizeof(T) == 2, bool> = false>
|
||||
using device_grouped_gemm_wmma_universal_mk_nk_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>
|
||||
// clang-format on
|
||||
>;
|
||||
using device_grouped_gemm_wmma_universal_mk_nk_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//##############################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
// List of instance variants to add (pipeline/scheduler/padding combinations)
|
||||
// Some are disabled now, can be re-enabled if needed
|
||||
|
||||
@@ -19,6 +19,7 @@ namespace instance {
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
// fp16_output
|
||||
#ifdef CK_USE_XDL
|
||||
void add_device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
@@ -45,6 +46,34 @@ void add_device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_nk_mn_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_USE_WMMA
|
||||
void add_device_grouped_gemm_wmma_tile_loop_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
Empty_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_wmma_tile_loop_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Col,
|
||||
Empty_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
@@ -89,12 +118,22 @@ struct DeviceOperationInstanceFactory<
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
#ifdef CK_USE_XDL
|
||||
add_device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
#endif
|
||||
#ifdef CK_USE_WMMA
|
||||
add_device_grouped_gemm_wmma_tile_loop_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
#endif
|
||||
}
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
#ifdef CK_USE_XDL
|
||||
add_device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
#endif
|
||||
#ifdef CK_USE_WMMA
|
||||
add_device_grouped_gemm_wmma_tile_loop_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
#include <cstdlib>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/utility/loop_scheduler.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using I8 = int8_t;
|
||||
using F8 = ck::f8_t;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using Empty_Tuple = ck::Tuple<>;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
|
||||
|
||||
using CShuffleDataType = F32;
|
||||
using AccDataType = F32;
|
||||
using ELayout = Row;
|
||||
|
||||
static constexpr auto PipelineV1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto PipelineV3 = BlockGemmPipelineVersion::v3;
|
||||
static constexpr auto IntrawaveScheduler = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto InterwaveScheduler = BlockGemmPipelineScheduler::Interwave;
|
||||
static constexpr auto GemmKPadding = device::GemmSpecialization::KPadding;
|
||||
static constexpr auto GemmMNPadding = device::GemmSpecialization::MNPadding;
|
||||
static constexpr auto GemmMNKPadding = device::GemmSpecialization::MNKPadding;
|
||||
static constexpr auto GemmDefault = device::GemmSpecialization::Default;
|
||||
|
||||
// Instances for 2 byte * 1 byte datatypes in RRR layout, with EDataType = ADataType
|
||||
// HACK: CBlockTransfer_ScalarPerVector_NRepeat elements should depend on the amount and data types
|
||||
// in the D tensors. In practice, D tensors are 2 bytes and there's never more than two. So this
|
||||
// works, but isn't very robust.
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename DsLayout,
|
||||
device::GemmSpecialization GemmSpec,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
typename AElementOp,
|
||||
typename BElementOp,
|
||||
typename CDEElementOp,
|
||||
enable_if_t<sizeof(ADataType) == 2, bool> = false,
|
||||
enable_if_t<sizeof(BDataType) == 1, bool> = false>
|
||||
using device_grouped_gemm_tile_loop_multiply_wmma_mk_kn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#################################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//#################################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//#################################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, ADataType, DsDataType, ADataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
// DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, ADataType, DsDataType, ADataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, ADataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, BlkGemmPipeSched, BlkGemmPipelineVer>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
static constexpr device::GemmSpecialization GemmSpecVariants[] = {GemmDefault, GemmMNKPadding};
|
||||
|
||||
// Helper function to add a list of layout instances for instances with matching A/B/E data types
|
||||
// for all supported padding/scheduler/pipeline version combinations
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
template <typename ADataType_inner,
|
||||
typename BDataType_inner,
|
||||
typename DsDataTyper_inner,
|
||||
typename DsLayout_inner,
|
||||
device::GemmSpecialization GemmSpec,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
typename AElementOp,
|
||||
typename BElementOp,
|
||||
typename CDEElementOp>
|
||||
typename LayoutInstances,
|
||||
typename AElementOp, // NOTE: element-wise op parameters as last so that they can be
|
||||
typename BElementOp, // inferred from the vector argument
|
||||
typename CDEElementOp>
|
||||
void add_device_grouped_gemm_tile_loop_multiply_wmma_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
ADataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
{
|
||||
static_for<0, std::size(GemmSpecVariants), 1>{}([&](auto i) {
|
||||
constexpr auto GemmSpec = GemmSpecVariants[i];
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
LayoutInstances<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
DsLayout,
|
||||
GemmSpec,
|
||||
IntrawaveScheduler,
|
||||
PipelineV1,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>{});
|
||||
add_device_operation_instances(instances,
|
||||
LayoutInstances<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
DsLayout,
|
||||
GemmSpec,
|
||||
InterwaveScheduler,
|
||||
PipelineV1,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>{});
|
||||
add_device_operation_instances(instances,
|
||||
LayoutInstances<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
DsLayout,
|
||||
GemmSpec,
|
||||
IntrawaveScheduler,
|
||||
PipelineV3,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>{});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,215 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/utility/loop_scheduler.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F8 = ck::f8_t;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using Empty_Tuple = ck::Tuple<>;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
|
||||
|
||||
using AccDataType = F32;
|
||||
using DsDataType = Empty_Tuple;
|
||||
|
||||
using DsLayout = Empty_Tuple;
|
||||
using ELayout = Row;
|
||||
|
||||
static constexpr auto PipelineV1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto PipelineV3 = BlockGemmPipelineVersion::v3;
|
||||
static constexpr auto IntrawaveScheduler = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto InterwaveScheduler = BlockGemmPipelineScheduler::Interwave;
|
||||
static constexpr auto GemmMNKPadding = device::GemmSpecialization::MNKPadding;
|
||||
static constexpr auto GemmDefault = device::GemmSpecialization::Default;
|
||||
|
||||
// Instances for 2 byte datatypes in CRR layout with ADataType = BDataType = EDataType
|
||||
template <typename T,
|
||||
device::GemmSpecialization GemmSpec,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
typename AElementOp,
|
||||
typename BElementOp,
|
||||
typename CDEElementOp,
|
||||
enable_if_t<sizeof(T) == 2, bool> = false>
|
||||
using device_grouped_gemm_tile_loop_wmma_km_kn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#################################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//#################################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//#################################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3< Col, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8>, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3< Col, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, S<8>, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3< Col, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8>, BlkGemmPipeSched, BlkGemmPipelineVer>
|
||||
// clang`-format on
|
||||
>;
|
||||
|
||||
// Instances for 2 byte datatypes in CCR layout with ADataType = BDataType = EDataType
|
||||
template <typename T,
|
||||
device::GemmSpecialization GemmSpec,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
typename AElementOp,
|
||||
typename BElementOp,
|
||||
typename CDEElementOp,
|
||||
enable_if_t<sizeof(T) == 2, bool> = false>
|
||||
using device_grouped_gemm_tile_loop_wmma_km_nk_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#################################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//#################################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//#################################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3< Col, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8>, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3< Col, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, S<8>, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3< Col, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8>, BlkGemmPipeSched, BlkGemmPipelineVer>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
// Instances for 2 byte datatypes in RRR layout with ADataType = BDataType = EDataType
|
||||
template <typename T,
|
||||
device::GemmSpecialization GemmSpec,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
typename AElementOp,
|
||||
typename BElementOp,
|
||||
typename CDEElementOp,
|
||||
enable_if_t<sizeof(T) == 2, bool> = false>
|
||||
using device_grouped_gemm_tile_loop_wmma_mk_kn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#################################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//#################################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//#################################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8>, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, S<8>, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8>, BlkGemmPipeSched, BlkGemmPipelineVer>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
// Instances for 2 byte datatypes in RCR layout with ADataType = BDataType = EDataType
|
||||
template <typename T,
|
||||
device::GemmSpecialization GemmSpec,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
typename AElementOp,
|
||||
typename BElementOp,
|
||||
typename CDEElementOp,
|
||||
enable_if_t<sizeof(T) == 2, bool> = false>
|
||||
using device_grouped_gemm_tile_loop_wmma_mk_nk_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#################################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//#################################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//#################################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8>, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, S<8>, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8>, BlkGemmPipeSched, BlkGemmPipelineVer>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
// Helper function to add a list of layout instances for instances with matching A/B/E data types
|
||||
// for all supported padding/scheduler/pipeline version combinations
|
||||
template <typename T,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
template <typename T2,
|
||||
device::GemmSpecialization GemmSpec,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
typename AElementOp,
|
||||
typename BElementOp,
|
||||
typename CDEElementOp>
|
||||
typename LayoutInstances,
|
||||
typename AElementOp, // NOTE: element-wise op parameters as last so that they can be
|
||||
typename BElementOp, // inferred from the vector argument
|
||||
typename CDEElementOp>
|
||||
void add_device_grouped_gemm_tile_loop_wmma_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
T,
|
||||
T,
|
||||
DsDataType,
|
||||
T,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
LayoutInstances<T,
|
||||
GemmDefault,
|
||||
IntrawaveScheduler,
|
||||
PipelineV1,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>{});
|
||||
add_device_operation_instances(instances,
|
||||
LayoutInstances<T,
|
||||
GemmDefault,
|
||||
InterwaveScheduler,
|
||||
PipelineV1,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>{});
|
||||
add_device_operation_instances(instances,
|
||||
LayoutInstances<T,
|
||||
GemmDefault,
|
||||
IntrawaveScheduler,
|
||||
PipelineV3,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>{});
|
||||
add_device_operation_instances(instances,
|
||||
LayoutInstances<T,
|
||||
GemmMNKPadding,
|
||||
IntrawaveScheduler,
|
||||
PipelineV1,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>{});
|
||||
add_device_operation_instances(instances,
|
||||
LayoutInstances<T,
|
||||
GemmMNKPadding,
|
||||
InterwaveScheduler,
|
||||
PipelineV1,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>{});
|
||||
add_device_operation_instances(instances,
|
||||
LayoutInstances<T,
|
||||
GemmMNKPadding,
|
||||
IntrawaveScheduler,
|
||||
PipelineV3,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -17,6 +17,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
@@ -172,6 +173,21 @@ void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Multiply>>>& instances);
|
||||
#endif // CK_USE_XDL
|
||||
#if defined(CK_USE_WMMA)
|
||||
void add_device_grouped_gemm_wmma_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
BF16,
|
||||
I8,
|
||||
BF16_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Multiply>>>& instances);
|
||||
#endif
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
@@ -216,6 +232,7 @@ struct DeviceOperationInstanceFactory<
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
#if defined(CK_USE_XDL)
|
||||
add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instances(
|
||||
@@ -240,12 +257,18 @@ struct DeviceOperationInstanceFactory<
|
||||
op_ptrs);
|
||||
add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_mem_v2_kpadding_instances(
|
||||
op_ptrs);
|
||||
#endif // CK_USE_XDL
|
||||
#if defined(CK_USE_WMMA)
|
||||
add_device_grouped_gemm_wmma_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_instances(
|
||||
op_ptrs);
|
||||
#endif // CK_USE_WMMA
|
||||
}
|
||||
}
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
void add_device_grouped_gemm_xdl_tile_loop_multiply_fastgelu_bf16_i8_bf16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
@@ -258,7 +281,21 @@ void add_device_grouped_gemm_xdl_tile_loop_multiply_fastgelu_bf16_i8_bf16_mk_kn_
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyFastGelu>>>& instances);
|
||||
|
||||
#endif
|
||||
#if defined(CK_USE_WMMA)
|
||||
void add_device_grouped_gemm_wmma_tile_loop_multiply_fastgelu_bf16_i8_bf16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
BF16,
|
||||
I8,
|
||||
BF16_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyFastGelu>>>& instances);
|
||||
#endif
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename D0Layout,
|
||||
@@ -302,14 +339,21 @@ struct DeviceOperationInstanceFactory<
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
#if defined(CK_USE_XDL)
|
||||
add_device_grouped_gemm_xdl_tile_loop_multiply_fastgelu_bf16_i8_bf16_mk_kn_mn_instances(
|
||||
op_ptrs);
|
||||
#endif
|
||||
#if defined(CK_USE_WMMA)
|
||||
add_device_grouped_gemm_wmma_tile_loop_multiply_fastgelu_bf16_i8_bf16_mk_kn_mn_instances(
|
||||
op_ptrs);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
void add_device_grouped_gemm_xdl_tile_loop_multiply_bias_bf16_i8_bf16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
@@ -322,6 +366,21 @@ void add_device_grouped_gemm_xdl_tile_loop_multiply_bias_bf16_i8_bf16_mk_kn_mn_i
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyAdd>>>& instances);
|
||||
#endif
|
||||
#if defined(CK_USE_WMMA)
|
||||
void add_device_grouped_gemm_wmma_tile_loop_multiply_bias_bf16_i8_bf16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
BF16,
|
||||
I8,
|
||||
BF16_BF16_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyAdd>>>& instances);
|
||||
#endif
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
@@ -368,14 +427,20 @@ struct DeviceOperationInstanceFactory<
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
#if defined(CK_USE_XDL)
|
||||
add_device_grouped_gemm_xdl_tile_loop_multiply_bias_bf16_i8_bf16_mk_kn_mn_instances(
|
||||
op_ptrs);
|
||||
#endif
|
||||
#if defined(CK_USE_WMMA)
|
||||
add_device_grouped_gemm_wmma_tile_loop_multiply_bias_bf16_i8_bf16_mk_kn_mn_instances(
|
||||
op_ptrs);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
void add_device_grouped_gemm_xdl_tile_loop_multiply_bias_fastgelu_bf16_i8_bf16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
@@ -388,6 +453,21 @@ void add_device_grouped_gemm_xdl_tile_loop_multiply_bias_fastgelu_bf16_i8_bf16_m
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyAddFastGelu>>>& instances);
|
||||
#endif
|
||||
#if defined(CK_USE_WMMA)
|
||||
void add_device_grouped_gemm_wmma_tile_loop_multiply_bias_fastgelu_bf16_i8_bf16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
BF16,
|
||||
I8,
|
||||
BF16_BF16_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyAddFastGelu>>>& instances);
|
||||
#endif
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
@@ -434,8 +514,14 @@ struct DeviceOperationInstanceFactory<
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
#if defined(CK_USE_XDL)
|
||||
add_device_grouped_gemm_xdl_tile_loop_multiply_bias_fastgelu_bf16_i8_bf16_mk_kn_mn_instances(
|
||||
op_ptrs);
|
||||
#endif
|
||||
#if defined(CK_USE_WMMA)
|
||||
add_device_grouped_gemm_wmma_tile_loop_multiply_bias_fastgelu_bf16_i8_bf16_mk_kn_mn_instances(
|
||||
op_ptrs);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
return op_ptrs;
|
||||
|
||||
@@ -36,7 +36,7 @@ add_instance_library(device_grouped_gemm_instance
|
||||
device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_bf16_bf16_mk_nk_mn_instance.cpp
|
||||
device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instance.cpp
|
||||
device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_i8_bf16_mk_nk_mn_instance.cpp
|
||||
|
||||
|
||||
device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instance.cpp
|
||||
device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instance.cpp
|
||||
|
||||
|
||||
@@ -21,18 +21,17 @@ template <device::GemmSpecialization GemmSpec,
|
||||
typename AElementOp,
|
||||
typename BElementOp,
|
||||
typename CDEElementOp>
|
||||
using device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>
|
||||
// clang-format on
|
||||
>;
|
||||
using device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//##############################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
|
||||
|
||||
@@ -21,18 +21,17 @@ template <device::GemmSpecialization GemmSpec,
|
||||
typename AElementOp,
|
||||
typename BElementOp,
|
||||
typename CDEElementOp>
|
||||
using device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>
|
||||
// clang-format on
|
||||
>;
|
||||
using device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//##############################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# ONLY XDL_KERNELS
|
||||
# ONLY XDL_AND_WMMA_KERNELS
|
||||
set(GROUPED_GEMM_TILE_LOOP_INSTANCES)
|
||||
|
||||
|
||||
list(APPEND GROUPED_GEMM_TILE_LOOP_INSTANCES
|
||||
device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_kn_mn_instance.cpp
|
||||
device_grouped_gemm_xdl_tile_loop_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
device_grouped_gemm_wmma_tile_loop_f16_f16_f16_mk_kn_mn_instance.cpp
|
||||
device_grouped_gemm_wmma_tile_loop_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
|
||||
device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp
|
||||
device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_default_instance.cpp
|
||||
device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instance.cpp
|
||||
@@ -24,6 +27,11 @@ list(APPEND GROUPED_GEMM_TILE_LOOP_INSTANCES
|
||||
device_grouped_gemm_xdl_tile_loop_multiply_bias_bf16_i8_bf16_mk_kn_mn_instance.cpp
|
||||
device_grouped_gemm_xdl_tile_loop_multiply_bias_fastgelu_bf16_i8_bf16_mk_kn_mn_instance.cpp
|
||||
device_grouped_gemm_xdl_tile_loop_multiply_fastgelu_bf16_i8_bf16_mk_kn_mn_instance.cpp
|
||||
|
||||
device_grouped_gemm_wmma_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_instance.cpp
|
||||
device_grouped_gemm_wmma_tile_loop_multiply_bias_bf16_i8_bf16_mk_kn_mn_instance.cpp
|
||||
device_grouped_gemm_wmma_tile_loop_multiply_bias_fastgelu_bf16_i8_bf16_mk_kn_mn_instance.cpp
|
||||
device_grouped_gemm_wmma_tile_loop_multiply_fastgelu_bf16_i8_bf16_mk_kn_mn_instance.cpp
|
||||
)
|
||||
|
||||
add_instance_library(device_grouped_gemm_tile_loop_instance ${GROUPED_GEMM_TILE_LOOP_INSTANCES})
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_tile_loop_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
|
||||
void add_device_grouped_gemm_wmma_tile_loop_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
DsLayout,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
DsDataType,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
|
||||
add_device_grouped_gemm_tile_loop_wmma_instances<
|
||||
F16,
|
||||
Row,
|
||||
Row,
|
||||
device_grouped_gemm_tile_loop_wmma_mk_kn_mn_instances>(instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,41 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_tile_loop_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
|
||||
void add_device_grouped_gemm_wmma_tile_loop_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Col,
|
||||
DsLayout,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
DsDataType,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
|
||||
add_device_grouped_gemm_tile_loop_wmma_instances<
|
||||
F16,
|
||||
Row,
|
||||
Col,
|
||||
device_grouped_gemm_tile_loop_wmma_mk_nk_mn_instances>(instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,48 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_tile_loop_multiply_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using DsDataType = ck::Tuple<BF16>;
|
||||
using DsLayout = ck::Tuple<Row>;
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = Multiply;
|
||||
|
||||
void add_device_grouped_gemm_wmma_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
DsLayout,
|
||||
Row,
|
||||
BF16,
|
||||
I8,
|
||||
DsDataType,
|
||||
BF16,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
{
|
||||
|
||||
add_device_grouped_gemm_tile_loop_multiply_wmma_instances<
|
||||
BF16,
|
||||
I8,
|
||||
DsDataType,
|
||||
Row,
|
||||
Row,
|
||||
DsLayout,
|
||||
device_grouped_gemm_tile_loop_multiply_wmma_mk_kn_mn_instances>(instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,48 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_tile_loop_multiply_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using DsDataType = ck::Tuple<BF16, BF16>;
|
||||
using DsLayout = ck::Tuple<Row, Row>;
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = MultiplyAdd;
|
||||
|
||||
void add_device_grouped_gemm_wmma_tile_loop_multiply_bias_bf16_i8_bf16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
DsLayout,
|
||||
Row,
|
||||
BF16,
|
||||
I8,
|
||||
DsDataType,
|
||||
BF16,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
{
|
||||
|
||||
add_device_grouped_gemm_tile_loop_multiply_wmma_instances<
|
||||
BF16,
|
||||
I8,
|
||||
DsDataType,
|
||||
Row,
|
||||
Row,
|
||||
DsLayout,
|
||||
device_grouped_gemm_tile_loop_multiply_wmma_mk_kn_mn_instances>(instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,48 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_tile_loop_multiply_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using DsDataType = ck::Tuple<BF16, BF16>;
|
||||
using DsLayout = ck::Tuple<Row, Row>;
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = MultiplyAddFastGelu;
|
||||
|
||||
void add_device_grouped_gemm_wmma_tile_loop_multiply_bias_fastgelu_bf16_i8_bf16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
DsLayout,
|
||||
Row,
|
||||
BF16,
|
||||
I8,
|
||||
DsDataType,
|
||||
BF16,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
{
|
||||
|
||||
add_device_grouped_gemm_tile_loop_multiply_wmma_instances<
|
||||
BF16,
|
||||
I8,
|
||||
DsDataType,
|
||||
Row,
|
||||
Row,
|
||||
DsLayout,
|
||||
device_grouped_gemm_tile_loop_multiply_wmma_mk_kn_mn_instances>(instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,48 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_tile_loop_multiply_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using DsDataType = ck::Tuple<BF16>;
|
||||
using DsLayout = ck::Tuple<Row>;
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = MultiplyFastGelu;
|
||||
|
||||
void add_device_grouped_gemm_wmma_tile_loop_multiply_fastgelu_bf16_i8_bf16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
|
||||
Row,
|
||||
DsLayout,
|
||||
Row,
|
||||
BF16,
|
||||
I8,
|
||||
DsDataType,
|
||||
BF16,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
{
|
||||
|
||||
add_device_grouped_gemm_tile_loop_multiply_wmma_instances<
|
||||
BF16,
|
||||
I8,
|
||||
DsDataType,
|
||||
Row,
|
||||
Row,
|
||||
DsLayout,
|
||||
device_grouped_gemm_tile_loop_multiply_wmma_mk_kn_mn_instances>(instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -6,20 +6,9 @@
|
||||
#include <iomanip>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/env.hpp"
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multiply.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "profile_grouped_gemm_tile_loop_generic_impl.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
@@ -47,300 +36,36 @@ bool profile_grouped_gemm_multiply_tile_loop_impl(int do_verification,
|
||||
int n_warmup = 10,
|
||||
int n_iter = 50)
|
||||
{
|
||||
using CDataType = EDataType;
|
||||
bool pass = true;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
std::size_t group_count = Ms.size();
|
||||
|
||||
if(!(group_count == Ns.size() && group_count == Ks.size() && group_count == StrideAs.size() &&
|
||||
group_count == StrideBs.size() && group_count == StrideEs.size()))
|
||||
std::vector<std::array<int, 1>> stride_ds;
|
||||
for(size_t i = 0; i < StrideDs.size(); ++i)
|
||||
{
|
||||
throw std::runtime_error("wrong! inconsistent M/N/Ks, StrideA/B/Cs size\n");
|
||||
stride_ds.emplace_back(std::array<int, 1>{StrideDs[i]});
|
||||
}
|
||||
|
||||
std::vector<Tensor<ADataType>> a_m_k;
|
||||
std::vector<Tensor<BDataType>> b_k_n;
|
||||
std::vector<Tensor<DDataType>> d_m_n;
|
||||
std::vector<Tensor<CDataType>> e_m_n_host_results;
|
||||
std::vector<Tensor<CDataType>> e_m_n_device_results;
|
||||
|
||||
for(std::size_t i = 0; i < group_count; i++)
|
||||
{
|
||||
a_m_k.push_back(
|
||||
Tensor<ADataType>(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{})));
|
||||
b_k_n.push_back(
|
||||
Tensor<BDataType>(f_host_tensor_descriptor(Ks[i], Ns[i], StrideBs[i], BLayout{})));
|
||||
d_m_n.push_back(
|
||||
Tensor<DDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideDs[i], DLayout{})));
|
||||
e_m_n_device_results.push_back(
|
||||
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideEs[i], ELayout{})));
|
||||
e_m_n_host_results.push_back(
|
||||
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideEs[i], ELayout{})));
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n["
|
||||
<< i << "]:" << b_k_n[i].mDesc << ", e_m_n_device_results[" << i
|
||||
<< "]:" << e_m_n_device_results[i].mDesc << std::endl;
|
||||
}
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5, 5}(a_m_k[i]);
|
||||
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5, 5}(b_k_n[i]);
|
||||
ck::utils::FillUniformDistributionIntegerValue<DDataType>{-5, 5}(d_m_n[i]);
|
||||
break;
|
||||
case 2:
|
||||
ck::utils::FillUniformDistribution<ADataType>{.0, 1.}(a_m_k[i]);
|
||||
ck::utils::FillUniformDistribution<BDataType>{-0.5, 0.5}(b_k_n[i]);
|
||||
ck::utils::FillUniformDistribution<DDataType>{-0.5, 0.5}(d_m_n[i]);
|
||||
break;
|
||||
default:
|
||||
ck::utils::FillConstant<ADataType>{1}(a_m_k[i]);
|
||||
ck::utils::FillConstant<BDataType>{1}(b_k_n[i]);
|
||||
ck::utils::FillConstant<DDataType>{1}(d_m_n[i]);
|
||||
}
|
||||
}
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CDEElementOp = ck::tensor_operation::element_wise::Multiply;
|
||||
|
||||
const auto a_element_op = AElementOp{};
|
||||
const auto b_element_op = BElementOp{};
|
||||
const auto c_element_op = CElementOp{};
|
||||
const auto cde_element_op = CDEElementOp{};
|
||||
|
||||
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
|
||||
std::vector<DeviceMemPtr> a_device_buf, b_device_buf, d_device_buf, e_device_buf;
|
||||
|
||||
a_device_buf.reserve(group_count);
|
||||
b_device_buf.reserve(group_count);
|
||||
d_device_buf.reserve(group_count);
|
||||
e_device_buf.reserve(group_count);
|
||||
|
||||
std::vector<const void*> p_a, p_b, p_d;
|
||||
constexpr ck::index_t NumDTensor = 1;
|
||||
auto p_ds = std::vector<std::array<const void*, NumDTensor>>{};
|
||||
std::vector<void*> p_e;
|
||||
|
||||
p_a.reserve(group_count);
|
||||
p_b.reserve(group_count);
|
||||
p_ds.reserve(group_count);
|
||||
p_e.reserve(group_count);
|
||||
|
||||
using KernelArguments = ck::tensor_operation::device::GroupedGemmKernelArgument<NumDTensor>;
|
||||
|
||||
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
|
||||
std::vector<KernelArguments> gemm_kargs;
|
||||
|
||||
gemm_descs.reserve(group_count);
|
||||
gemm_kargs.reserve(group_count);
|
||||
|
||||
for(std::size_t i = 0; i < group_count; i++)
|
||||
{
|
||||
a_device_buf.emplace_back(
|
||||
std::make_unique<DeviceMem>(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpaceSize()));
|
||||
b_device_buf.emplace_back(
|
||||
std::make_unique<DeviceMem>(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpaceSize()));
|
||||
d_device_buf.emplace_back(
|
||||
std::make_unique<DeviceMem>(sizeof(DDataType) * d_m_n[i].mDesc.GetElementSpaceSize()));
|
||||
e_device_buf.emplace_back(std::make_unique<DeviceMem>(
|
||||
sizeof(CDataType) * e_m_n_device_results[i].mDesc.GetElementSpaceSize()));
|
||||
|
||||
a_device_buf[i]->ToDevice(a_m_k[i].mData.data());
|
||||
b_device_buf[i]->ToDevice(b_k_n[i].mData.data());
|
||||
d_device_buf[i]->ToDevice(d_m_n[i].mData.data());
|
||||
e_device_buf[i]->SetZero();
|
||||
|
||||
p_a.push_back(a_device_buf[i]->GetDeviceBuffer());
|
||||
p_b.push_back(b_device_buf[i]->GetDeviceBuffer());
|
||||
p_ds.push_back({d_device_buf[i]->GetDeviceBuffer()});
|
||||
p_e.push_back(e_device_buf[i]->GetDeviceBuffer());
|
||||
|
||||
gemm_descs.push_back(
|
||||
{0, Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideEs[i], {StrideDs[i]}});
|
||||
gemm_kargs.push_back({a_device_buf[i]->GetDeviceBuffer(),
|
||||
b_device_buf[i]->GetDeviceBuffer(),
|
||||
{d_device_buf[i]->GetDeviceBuffer()},
|
||||
e_device_buf[i]->GetDeviceBuffer(),
|
||||
Ms[i],
|
||||
Ns[i],
|
||||
Ks[i],
|
||||
StrideAs[i],
|
||||
StrideBs[i],
|
||||
{StrideDs[i]},
|
||||
StrideEs[i]});
|
||||
}
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmTileLoop<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<DLayout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<DDataType>,
|
||||
EDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
if(op_ptrs.size() <= 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! no device GEMM instance found");
|
||||
}
|
||||
|
||||
std::string best_gemm_name;
|
||||
float best_ave_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
{
|
||||
Tensor<CDataType> c_m_n({Ms[i], Ns[i]});
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_m_k[i], b_k_n[i], c_m_n, a_element_op, b_element_op, c_element_op);
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
for(int m = 0; m < Ms[i]; ++m)
|
||||
{
|
||||
for(int n = 0; n < Ns[i]; ++n)
|
||||
{
|
||||
cde_element_op(e_m_n_host_results[i](m, n), c_m_n(m, n), d_m_n[i](m, n));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// profile device GEMM instances
|
||||
for(auto& gemm_ptr : op_ptrs)
|
||||
{
|
||||
auto argument_ptr =
|
||||
gemm_ptr->MakeArgumentPointer(p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
gemm_descs,
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
cde_element_op);
|
||||
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
|
||||
std::string gemm_name = gemm_ptr->GetTypeString();
|
||||
|
||||
DeviceMem gemm_arg_dev_mem(gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get()));
|
||||
hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(),
|
||||
gemm_kargs.data(),
|
||||
gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get()),
|
||||
hipMemcpyHostToDevice));
|
||||
gemm_ptr->SetDeviceKernelArgs(argument_ptr.get(), gemm_arg_dev_mem.GetDeviceBuffer());
|
||||
|
||||
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false, 0, n_warmup, n_iter});
|
||||
if(do_verification)
|
||||
{
|
||||
bool instance_pass = true;
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
{
|
||||
e_device_buf[i]->FromDevice(e_m_n_device_results[i].mData.data());
|
||||
instance_pass = instance_pass && ck::utils::check_err(e_m_n_device_results[i],
|
||||
e_m_n_host_results[i]);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "a : ", a_m_k[i].mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(std::cout << "b: ", b_k_n[i].mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "e_device: ", e_m_n_device_results[i].mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "e_host : ", e_m_n_host_results[i].mData, ",")
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Instance: " << gemm_name << " verification "
|
||||
<< (instance_pass ? "SUCCEED" : "FAILED") << std::endl;
|
||||
|
||||
pass = pass && instance_pass;
|
||||
}
|
||||
|
||||
if(time_kernel)
|
||||
{
|
||||
float ave_time = invoker_ptr->Run(
|
||||
argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter});
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
{
|
||||
flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i];
|
||||
|
||||
num_btype += sizeof(ADataType) * Ms[i] * Ks[i] +
|
||||
sizeof(BDataType) * Ks[i] * Ns[i] +
|
||||
sizeof(EDataType) * Ms[i] * Ns[i] + // D matrix
|
||||
sizeof(EDataType) * Ms[i] * Ns[i];
|
||||
}
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops
|
||||
<< " TFlops, " << gb_per_sec << " GB/s, " << gemm_name << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
best_gemm_name = gemm_name;
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
if(time_kernel)
|
||||
{
|
||||
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
|
||||
<< best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
return profile_grouped_gemm_tile_loop_generic_impl<
|
||||
ADataType,
|
||||
BDataType,
|
||||
Tuple<DDataType>,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
Tuple<DLayout>,
|
||||
ELayout,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
ck::tensor_operation::element_wise::Multiply>(do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
Ms,
|
||||
Ns,
|
||||
Ks,
|
||||
StrideAs,
|
||||
StrideBs,
|
||||
stride_ds,
|
||||
StrideEs,
|
||||
n_warmup,
|
||||
n_iter);
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
|
||||
@@ -0,0 +1,436 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iomanip>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/env.hpp"
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop_multiply.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm_multiple_d.hpp"
|
||||
#include "ck/utility/integral_constant.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck/utility/tuple_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
template <class F, std::size_t... I>
|
||||
constexpr auto make_array_from_fn_impl(F&& f, std::index_sequence<I...>)
|
||||
{
|
||||
using T = std::decay_t<decltype(f(std::integral_constant<std::size_t, 0>{}))>;
|
||||
return std::array<T, sizeof...(I)>{f(std::integral_constant<std::size_t, I>{})...};
|
||||
}
|
||||
|
||||
template <std::size_t N, class F>
|
||||
constexpr auto make_array_from_fn(F&& f)
|
||||
{
|
||||
return make_array_from_fn_impl(std::forward<F>(f), std::make_index_sequence<N>{});
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename AElementOp = PassThrough,
|
||||
typename BElementOp = PassThrough,
|
||||
typename CDEElementOp = PassThrough>
|
||||
bool profile_grouped_gemm_tile_loop_generic_impl(
|
||||
int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
const std::vector<int>& Ms,
|
||||
const std::vector<int>& Ns,
|
||||
const std::vector<int>& Ks,
|
||||
const std::vector<int>& StrideAs,
|
||||
const std::vector<int>& StrideBs,
|
||||
const std::vector<std::array<int, DsDataType::Size()>>& StrideDs,
|
||||
const std::vector<int>& StrideEs,
|
||||
int n_warmup = 10,
|
||||
int n_iter = 50)
|
||||
{
|
||||
using AccDataType = float;
|
||||
constexpr ck::index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsistent NumDTensor");
|
||||
|
||||
bool pass = true;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
std::size_t group_count = Ms.size();
|
||||
|
||||
if(!(group_count == Ns.size() && group_count == Ks.size() && group_count == StrideAs.size() &&
|
||||
group_count == StrideBs.size() &&
|
||||
((StrideDs.size() == 0 && NumDTensor == 0) || group_count == StrideDs.size()) &&
|
||||
group_count == StrideEs.size()))
|
||||
{
|
||||
throw std::runtime_error("wrong! inconsistent M/N/Ks, StrideA/B/D/Es size\n");
|
||||
}
|
||||
|
||||
std::vector<Tensor<ADataType>> a_m_k;
|
||||
std::vector<Tensor<BDataType>> b_k_n;
|
||||
std::vector<tuple_map_t<Tensor, DsDataType>> d_m_n;
|
||||
std::vector<Tensor<EDataType>> e_m_n_host_results;
|
||||
std::vector<Tensor<EDataType>> e_m_n_device_results;
|
||||
|
||||
for(std::size_t i = 0; i < group_count; i++)
|
||||
{
|
||||
a_m_k.push_back(
|
||||
Tensor<ADataType>(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{})));
|
||||
b_k_n.push_back(
|
||||
Tensor<BDataType>(f_host_tensor_descriptor(Ks[i], Ns[i], StrideBs[i], BLayout{})));
|
||||
|
||||
auto d_tensors = ck::generate_tuple(
|
||||
[&](auto j) {
|
||||
using DDataType = tuple_element_t<j, DsDataType>;
|
||||
|
||||
return Tensor<DDataType>(f_host_tensor_descriptor(
|
||||
Ms[i], Ns[i], StrideDs[i][j], tuple_element_t<j, DsLayout>{}));
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
d_m_n.emplace_back(d_tensors);
|
||||
|
||||
e_m_n_device_results.push_back(
|
||||
Tensor<EDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideEs[i], ELayout{})));
|
||||
e_m_n_host_results.push_back(
|
||||
Tensor<EDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideEs[i], ELayout{})));
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n["
|
||||
<< i << "]:" << b_k_n[i].mDesc << ", e_m_n_device_results[" << i
|
||||
<< "]:" << e_m_n_device_results[i].mDesc << std::endl;
|
||||
}
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_m_k[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_k_n[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
static_for<0, NumDTensor, 1>{}([&](auto j) -> void {
|
||||
d_m_n[i](j).GenerateTensorValue(
|
||||
GeneratorTensor_2<tuple_element_t<j, DsDataType>>{-5, 5});
|
||||
});
|
||||
break;
|
||||
case 2:
|
||||
a_m_k[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
static_for<0, NumDTensor, 1>{}([&](auto j) -> void {
|
||||
d_m_n[i](j).GenerateTensorValue(
|
||||
GeneratorTensor_3<tuple_element_t<j, DsDataType>>{-0.5, 0.5});
|
||||
});
|
||||
break;
|
||||
default:
|
||||
ck::utils::FillConstant<ADataType>{1}(a_m_k[i]);
|
||||
ck::utils::FillConstant<BDataType>{1}(b_k_n[i]);
|
||||
static_for<0, NumDTensor, 1>{}([&](auto j) -> void {
|
||||
ck::utils::FillConstant<tuple_element_t<j, DsDataType>>{1}(d_m_n[i](j));
|
||||
});
|
||||
}
|
||||
}
|
||||
const auto a_element_op = AElementOp{};
|
||||
const auto b_element_op = BElementOp{};
|
||||
const auto cde_element_op = CDEElementOp{};
|
||||
|
||||
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
|
||||
std::vector<DeviceMemPtr> a_device_buf, b_device_buf, e_device_buf;
|
||||
std::vector<std::array<DeviceMemPtr, NumDTensor>> d_device_bufs;
|
||||
|
||||
a_device_buf.reserve(group_count);
|
||||
b_device_buf.reserve(group_count);
|
||||
d_device_bufs.reserve(group_count);
|
||||
e_device_buf.reserve(group_count);
|
||||
|
||||
std::vector<const void*> p_a, p_b;
|
||||
std::vector<std::array<const void*, NumDTensor>> p_ds;
|
||||
std::vector<void*> p_e;
|
||||
|
||||
p_a.reserve(group_count);
|
||||
p_b.reserve(group_count);
|
||||
p_ds.reserve(group_count);
|
||||
p_e.reserve(group_count);
|
||||
|
||||
using KernelArguments = ck::tensor_operation::device::GroupedGemmKernelArgument<NumDTensor>;
|
||||
|
||||
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
|
||||
std::vector<KernelArguments> gemm_kargs;
|
||||
|
||||
gemm_descs.reserve(group_count);
|
||||
gemm_kargs.reserve(group_count);
|
||||
|
||||
for(std::size_t i = 0; i < group_count; i++)
|
||||
{
|
||||
a_device_buf.emplace_back(
|
||||
std::make_unique<DeviceMem>(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpaceSize()));
|
||||
b_device_buf.emplace_back(
|
||||
std::make_unique<DeviceMem>(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpaceSize()));
|
||||
|
||||
if constexpr(NumDTensor > 0)
|
||||
{
|
||||
d_device_bufs.emplace_back(make_array_from_fn<NumDTensor>([&](auto j) {
|
||||
return std::make_unique<DeviceMem>(
|
||||
sizeof(tuple_element_t<j, DsDataType>) *
|
||||
d_m_n[i][ck::integral_constant<index_t, j>{}].mDesc.GetElementSpaceSize());
|
||||
}));
|
||||
}
|
||||
|
||||
e_device_buf.emplace_back(std::make_unique<DeviceMem>(
|
||||
sizeof(EDataType) * e_m_n_device_results[i].mDesc.GetElementSpaceSize()));
|
||||
|
||||
a_device_buf[i]->ToDevice(a_m_k[i].mData.data());
|
||||
b_device_buf[i]->ToDevice(b_k_n[i].mData.data());
|
||||
|
||||
static_for<0, NumDTensor, 1>{}(
|
||||
[&](auto j) -> void { d_device_bufs[i][j]->ToDevice(d_m_n[i][j].mData.data()); });
|
||||
|
||||
e_device_buf[i]->SetZero();
|
||||
|
||||
p_a.push_back(a_device_buf[i]->GetDeviceBuffer());
|
||||
p_b.push_back(b_device_buf[i]->GetDeviceBuffer());
|
||||
|
||||
std::array<const void*, NumDTensor> p_d;
|
||||
static_for<0, NumDTensor, 1>{}(
|
||||
[&](auto j) -> void { p_d[j] = d_device_bufs[i][j]->GetDeviceBuffer(); });
|
||||
|
||||
p_ds.push_back(p_d);
|
||||
|
||||
p_e.push_back(e_device_buf[i]->GetDeviceBuffer());
|
||||
|
||||
gemm_descs.push_back({Ms[i],
|
||||
Ns[i],
|
||||
Ks[i],
|
||||
StrideAs[i],
|
||||
StrideBs[i],
|
||||
StrideEs[i],
|
||||
std::vector<int>(StrideDs[i].begin(), StrideDs[i].end())});
|
||||
gemm_kargs.push_back({a_device_buf[i]->GetDeviceBuffer(),
|
||||
b_device_buf[i]->GetDeviceBuffer(),
|
||||
p_d,
|
||||
e_device_buf[i]->GetDeviceBuffer(),
|
||||
Ms[i],
|
||||
Ns[i],
|
||||
Ks[i],
|
||||
StrideAs[i],
|
||||
StrideBs[i],
|
||||
StrideDs[i],
|
||||
StrideEs[i]});
|
||||
}
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmTileLoop<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
if(op_ptrs.size() <= 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! no device GEMM instance found");
|
||||
}
|
||||
|
||||
std::string best_gemm_name;
|
||||
float best_ave_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
{
|
||||
if constexpr(NumDTensor > 0)
|
||||
{
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceGemmMultipleD<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
// HACK: reference GEMM expects D tensors as std::array
|
||||
// This limits D tensors to all have the same data type
|
||||
using DDataType = tuple_element_t<0, DsDataType>;
|
||||
std::array<Tensor<DDataType>, NumDTensor> d_tensors =
|
||||
make_array_from_fn<NumDTensor>(
|
||||
[&](auto j) { return d_m_n[i][ck::integral_constant<index_t, j>{}]; });
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
auto ref_argument = ref_gemm.MakeArgument(a_m_k[i],
|
||||
b_k_n[i],
|
||||
d_tensors,
|
||||
e_m_n_host_results[i],
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
ref_invoker.Run(ref_argument);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
auto ref_argument = ref_gemm.MakeArgument(a_m_k[i],
|
||||
b_k_n[i],
|
||||
e_m_n_host_results[i],
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
ref_invoker.Run(ref_argument);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// profile device GEMM instances
|
||||
for(auto& gemm_ptr : op_ptrs)
|
||||
{
|
||||
auto argument_ptr = gemm_ptr->MakeArgumentPointer(
|
||||
p_a, p_b, p_ds, p_e, gemm_descs, a_element_op, b_element_op, cde_element_op);
|
||||
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
|
||||
std::string gemm_name = gemm_ptr->GetTypeString();
|
||||
|
||||
DeviceMem gemm_arg_dev_mem(gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get()));
|
||||
ck::hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(),
|
||||
gemm_kargs.data(),
|
||||
gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get()),
|
||||
hipMemcpyHostToDevice));
|
||||
gemm_ptr->SetDeviceKernelArgs(argument_ptr.get(), gemm_arg_dev_mem.GetDeviceBuffer());
|
||||
|
||||
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false, 0, n_warmup, n_iter});
|
||||
if(do_verification)
|
||||
{
|
||||
bool instance_pass = true;
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
{
|
||||
e_device_buf[i]->FromDevice(e_m_n_device_results[i].mData.data());
|
||||
instance_pass = instance_pass && ck::utils::check_err(e_m_n_device_results[i],
|
||||
e_m_n_host_results[i]);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "a : ", a_m_k[i].mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(std::cout << "b: ", b_k_n[i].mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "e_device: ", e_m_n_device_results[i].mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "e_host : ", e_m_n_host_results[i].mData, ",")
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Instance: " << gemm_name << " verification "
|
||||
<< (instance_pass ? "SUCCEED" : "FAILED") << std::endl;
|
||||
|
||||
pass = pass && instance_pass;
|
||||
}
|
||||
|
||||
if(time_kernel)
|
||||
{
|
||||
float ave_time = invoker_ptr->Run(
|
||||
argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter});
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
{
|
||||
flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i];
|
||||
|
||||
num_btype += sizeof(ADataType) * Ms[i] * Ks[i] +
|
||||
sizeof(BDataType) * Ks[i] * Ns[i] +
|
||||
sizeof(EDataType) * Ms[i] * Ns[i];
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto j) -> void {
|
||||
num_btype +=
|
||||
sizeof(tuple_element_t<j, DsDataType>) * Ms[i] * Ns[i]; // D matrix
|
||||
});
|
||||
}
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops
|
||||
<< " TFlops, " << gb_per_sec << " GB/s, " << gemm_name << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
best_gemm_name = gemm_name;
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
if(time_kernel)
|
||||
{
|
||||
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
|
||||
<< best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace ck
|
||||
@@ -6,20 +6,9 @@
|
||||
#include <iomanip>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/env.hpp"
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_tile_loop.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "profile_grouped_gemm_tile_loop_generic_impl.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
@@ -44,277 +33,30 @@ bool profile_grouped_gemm_tile_loop_impl(int do_verification,
|
||||
int n_warmup = 10,
|
||||
int n_iter = 50)
|
||||
{
|
||||
bool pass = true;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
std::size_t group_count = Ms.size();
|
||||
|
||||
if(!(group_count == Ns.size() && group_count == Ks.size() && group_count == StrideAs.size() &&
|
||||
group_count == StrideBs.size() && group_count == StrideCs.size()))
|
||||
{
|
||||
throw std::runtime_error("wrong! inconsistent M/N/Ks, StrideA/B/Cs size\n");
|
||||
}
|
||||
|
||||
std::vector<Tensor<ADataType>> a_m_k;
|
||||
std::vector<Tensor<BDataType>> b_k_n;
|
||||
std::vector<Tensor<CDataType>> c_m_n_host_results;
|
||||
std::vector<Tensor<CDataType>> c_m_n_device_results;
|
||||
|
||||
for(std::size_t i = 0; i < group_count; i++)
|
||||
{
|
||||
a_m_k.push_back(
|
||||
Tensor<ADataType>(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{})));
|
||||
b_k_n.push_back(
|
||||
Tensor<BDataType>(f_host_tensor_descriptor(Ks[i], Ns[i], StrideBs[i], BLayout{})));
|
||||
c_m_n_device_results.push_back(
|
||||
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
|
||||
c_m_n_host_results.push_back(
|
||||
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n["
|
||||
<< i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i
|
||||
<< "]:" << c_m_n_device_results[i].mDesc << std::endl;
|
||||
}
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5, 5}(a_m_k[i]);
|
||||
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5, 5}(b_k_n[i]);
|
||||
break;
|
||||
case 2:
|
||||
ck::utils::FillUniformDistribution<ADataType>{.0, 1.}(a_m_k[i]);
|
||||
ck::utils::FillUniformDistribution<BDataType>{-0.5, 0.5}(b_k_n[i]);
|
||||
break;
|
||||
default:
|
||||
ck::utils::FillConstant<ADataType>{1}(a_m_k[i]);
|
||||
ck::utils::FillConstant<BDataType>{1}(b_k_n[i]);
|
||||
}
|
||||
}
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
const auto a_element_op = AElementOp{};
|
||||
const auto b_element_op = BElementOp{};
|
||||
const auto c_element_op = CElementOp{};
|
||||
|
||||
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
|
||||
std::vector<DeviceMemPtr> a_device_buf, b_device_buf, c_device_buf;
|
||||
|
||||
a_device_buf.reserve(group_count);
|
||||
b_device_buf.reserve(group_count);
|
||||
c_device_buf.reserve(group_count);
|
||||
|
||||
std::vector<const void*> p_a, p_b;
|
||||
std::vector<void*> p_c;
|
||||
|
||||
p_a.reserve(group_count);
|
||||
p_b.reserve(group_count);
|
||||
p_c.reserve(group_count);
|
||||
|
||||
using KernelArguments = ck::tensor_operation::device::GroupedGemmKernelArgument<>;
|
||||
|
||||
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
|
||||
std::vector<KernelArguments> gemm_kargs;
|
||||
|
||||
gemm_descs.reserve(group_count);
|
||||
gemm_kargs.reserve(group_count);
|
||||
|
||||
for(std::size_t i = 0; i < group_count; i++)
|
||||
{
|
||||
a_device_buf.emplace_back(
|
||||
std::make_unique<DeviceMem>(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpaceSize()));
|
||||
b_device_buf.emplace_back(
|
||||
std::make_unique<DeviceMem>(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpaceSize()));
|
||||
c_device_buf.emplace_back(std::make_unique<DeviceMem>(
|
||||
sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSpaceSize()));
|
||||
|
||||
a_device_buf[i]->ToDevice(a_m_k[i].mData.data());
|
||||
b_device_buf[i]->ToDevice(b_k_n[i].mData.data());
|
||||
c_device_buf[i]->SetZero();
|
||||
|
||||
p_a.push_back(a_device_buf[i]->GetDeviceBuffer());
|
||||
p_b.push_back(b_device_buf[i]->GetDeviceBuffer());
|
||||
p_c.push_back(c_device_buf[i]->GetDeviceBuffer());
|
||||
|
||||
gemm_descs.push_back({0, Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}});
|
||||
gemm_kargs.push_back({a_device_buf[i]->GetDeviceBuffer(),
|
||||
b_device_buf[i]->GetDeviceBuffer(),
|
||||
{},
|
||||
c_device_buf[i]->GetDeviceBuffer(),
|
||||
Ms[i],
|
||||
Ns[i],
|
||||
Ks[i],
|
||||
StrideAs[i],
|
||||
StrideBs[i],
|
||||
{},
|
||||
StrideCs[i]});
|
||||
}
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmTileLoop<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<>,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<>,
|
||||
CDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
if(op_ptrs.size() <= 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! no device GEMM instance found");
|
||||
}
|
||||
|
||||
std::string best_gemm_name;
|
||||
float best_ave_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
|
||||
auto p_ds = std::vector<std::array<const void*, 0>>{};
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
{
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
auto ref_argument = ref_gemm.MakeArgument(a_m_k[i],
|
||||
b_k_n[i],
|
||||
c_m_n_host_results[i],
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
ref_invoker.Run(ref_argument);
|
||||
}
|
||||
}
|
||||
|
||||
// profile device GEMM instances
|
||||
for(auto& gemm_ptr : op_ptrs)
|
||||
{
|
||||
auto argument_ptr =
|
||||
gemm_ptr->MakeArgumentPointer(p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_c,
|
||||
gemm_descs,
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
ck::tensor_operation::element_wise::PassThrough{});
|
||||
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
|
||||
std::string gemm_name = gemm_ptr->GetTypeString();
|
||||
|
||||
DeviceMem gemm_arg_dev_mem(gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get()));
|
||||
hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(),
|
||||
gemm_kargs.data(),
|
||||
gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get()),
|
||||
hipMemcpyHostToDevice));
|
||||
gemm_ptr->SetDeviceKernelArgs(argument_ptr.get(), gemm_arg_dev_mem.GetDeviceBuffer());
|
||||
|
||||
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false, 0, n_warmup, n_iter});
|
||||
if(do_verification)
|
||||
{
|
||||
bool instance_pass = true;
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
{
|
||||
c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data());
|
||||
instance_pass = instance_pass && ck::utils::check_err(c_m_n_device_results[i],
|
||||
c_m_n_host_results[i]);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "a : ", a_m_k[i].mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(std::cout << "b: ", b_k_n[i].mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "c_device: ", c_m_n_device_results[i].mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "c_host : ", c_m_n_host_results[i].mData, ",")
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Instance: " << gemm_name << " verification "
|
||||
<< (instance_pass ? "SUCCEED" : "FAILED") << std::endl;
|
||||
|
||||
pass = pass && instance_pass;
|
||||
}
|
||||
|
||||
if(time_kernel)
|
||||
{
|
||||
float ave_time = invoker_ptr->Run(
|
||||
argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter});
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
{
|
||||
flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i];
|
||||
|
||||
num_btype += sizeof(ADataType) * Ms[i] * Ks[i] +
|
||||
sizeof(BDataType) * Ks[i] * Ns[i] +
|
||||
sizeof(CDataType) * Ms[i] * Ns[i];
|
||||
}
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops
|
||||
<< " TFlops, " << gb_per_sec << " GB/s, " << gemm_name << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
best_gemm_name = gemm_name;
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
if(time_kernel)
|
||||
{
|
||||
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
|
||||
<< best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
return profile_grouped_gemm_tile_loop_generic_impl<ADataType,
|
||||
BDataType,
|
||||
Tuple<>,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
Tuple<>,
|
||||
CLayout,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
Ms,
|
||||
Ns,
|
||||
Ks,
|
||||
StrideAs,
|
||||
StrideBs,
|
||||
std::vector<std::array<int, 0>>{},
|
||||
StrideCs,
|
||||
n_warmup,
|
||||
n_iter);
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
|
||||
@@ -278,6 +278,7 @@ add_subdirectory(batched_gemm_softmax_gemm)
|
||||
add_subdirectory(batched_gemm_softmax_gemm_permute)
|
||||
add_subdirectory(batched_gemm_b_scale)
|
||||
add_subdirectory(grouped_gemm)
|
||||
add_subdirectory(grouped_gemm_tile_loop)
|
||||
add_subdirectory(reduce)
|
||||
add_subdirectory(convnd_fwd)
|
||||
add_subdirectory(convnd_bwd_data)
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
@@ -34,10 +34,10 @@ class TestGroupedGemm : public ck::test::TestGroupedGemm<Tuple, true>
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F16, AElementOp, BElementOp, CDEElementOp>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F16, AElementOp, BElementOp, CDEElementOp>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F16, AElementOp, BElementOp, CDEElementOp>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F16, AElementOp, BElementOp, CDEElementOp>
|
||||
ck::Tuple< Row, Row, Row, F16, F16, F16, AElementOp, BElementOp, CDEElementOp>,
|
||||
ck::Tuple< Row, Col, Row, F16, F16, F16, AElementOp, BElementOp, CDEElementOp>,
|
||||
ck::Tuple< Col, Row, Row, F16, F16, F16, AElementOp, BElementOp, CDEElementOp>,
|
||||
ck::Tuple< Col, Col, Row, F16, F16, F16, AElementOp, BElementOp, CDEElementOp>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test_grouped_gemm_util.hpp"
|
||||
@@ -31,7 +32,7 @@ class TestGroupedGemm : public ck::test::TestGroupedGemm<Tuple>
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
// The old XDL tests didn't fail if instances were not supported, so we want to keep that
|
||||
// behaviour When compiling WMMA instances and WMMA is supported, then we'll fail if a
|
||||
// behaviour. When compiling WMMA instances and WMMA is supported, then we'll fail if a
|
||||
// specific case is not supported
|
||||
this->fail_if_no_supported_instances_ =
|
||||
ck::is_gfx11_supported() || ck::is_gfx12_supported();
|
||||
@@ -44,28 +45,31 @@ using KernelTypes = ::testing::Types<
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
// WWMA only. No reason to not have it for XDL, but the instance was not defined and it was not in the original test.
|
||||
std::tuple< Col, Col, Row, BF16, BF16, BF16>,
|
||||
ck::Tuple< Col, Col, Row, BF16, BF16, BF16>,
|
||||
#endif
|
||||
|
||||
#if defined(CK_USE_XDL) && defined(__gfx9__)
|
||||
#if defined(CK_USE_XDL) && !defined(CK_USE_WMMA)
|
||||
// XDL only at the moment, instances for WMMA not defined
|
||||
std::tuple< Row, Row, Row, BF16, I8, BF16>,
|
||||
std::tuple< Row, Col, Row, BF16, I8, BF16>,
|
||||
// (And XDL instances don't run on gfx11/12, so we conditionally keep them out)
|
||||
ck::Tuple< Row, Row, Row, BF16, I8, BF16>,
|
||||
ck::Tuple< Row, Col, Row, BF16, I8, BF16>,
|
||||
#endif
|
||||
|
||||
#if (defined(CK_USE_XDL) && (defined(__gfx9__) || defined(__gfx12__))) || (defined(CK_USE_WMMA) && defined(__gfx12__))
|
||||
std::tuple< Row, Row, Row, F8, F16, F16>,
|
||||
std::tuple< Row, Row, Row, F16, F8, F16>,
|
||||
#if CK_USE_OCP_FP8 || CK_USE_FNUZ_FP8 || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_WMMA_FP8)
|
||||
// FP8 instances. Unfortunately CK_ENABLE_FP8 is always defined when not explicitly disabled, even if FP8 is
|
||||
// not supported for any included architecture.
|
||||
ck::Tuple< Row, Row, Row, F8, F16, F16>,
|
||||
ck::Tuple< Row, Row, Row, F16, F8, F16>,
|
||||
#endif
|
||||
|
||||
std::tuple< Row, Row, Row, F16, F16, F16>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F16>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F16>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F16>,
|
||||
ck::Tuple< Row, Row, Row, F16, F16, F16>,
|
||||
ck::Tuple< Row, Col, Row, F16, F16, F16>,
|
||||
ck::Tuple< Col, Row, Row, F16, F16, F16>,
|
||||
ck::Tuple< Col, Col, Row, F16, F16, F16>,
|
||||
|
||||
std::tuple< Row, Row, Row, BF16, BF16, BF16>,
|
||||
std::tuple< Row, Col, Row, BF16, BF16, BF16>,
|
||||
std::tuple< Col, Row, Row, BF16, BF16, BF16>
|
||||
ck::Tuple< Row, Row, Row, BF16, BF16, BF16>,
|
||||
ck::Tuple< Row, Col, Row, BF16, BF16, BF16>,
|
||||
ck::Tuple< Col, Row, Row, BF16, BF16, BF16>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -23,55 +23,18 @@ extern ck::index_t instance_index;
|
||||
namespace ck {
|
||||
namespace test {
|
||||
|
||||
template <typename Range>
|
||||
std::string serialize_range(const Range& range)
|
||||
{
|
||||
std::stringstream ss;
|
||||
for(auto& r : range)
|
||||
{
|
||||
ss << r << ", ";
|
||||
}
|
||||
std::string str = ss.str();
|
||||
return std::string(str.begin(), str.end() - 2);
|
||||
}
|
||||
|
||||
// Helper primary template (will be specialized on the boolean)
|
||||
template <std::size_t N,
|
||||
typename Tuple,
|
||||
typename Default,
|
||||
bool InRange = (N < std::tuple_size_v<std::remove_reference_t<Tuple>>)>
|
||||
struct tuple_element_or_impl;
|
||||
|
||||
// Specialization for the in-range case: use std::tuple_element_t
|
||||
template <std::size_t N, typename Tuple, typename Default>
|
||||
struct tuple_element_or_impl<N, Tuple, Default, true>
|
||||
{
|
||||
using type = std::tuple_element_t<N, std::remove_reference_t<Tuple>>;
|
||||
};
|
||||
|
||||
// Specialization for the out-of-range case: use Default
|
||||
template <std::size_t N, typename Tuple, typename Default>
|
||||
struct tuple_element_or_impl<N, Tuple, Default, false>
|
||||
{
|
||||
using type = Default;
|
||||
};
|
||||
|
||||
// User-facing alias
|
||||
template <std::size_t N, typename Tuple, typename Default>
|
||||
using tuple_element_or_t = typename tuple_element_or_impl<N, Tuple, Default>::type;
|
||||
|
||||
template <typename Tuple, bool FailIfNoSupportedInstances = false>
|
||||
class TestGroupedGemm : public testing::Test
|
||||
{
|
||||
protected:
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using ELayout = std::tuple_element_t<2, Tuple>;
|
||||
using ADataType = std::tuple_element_t<3, Tuple>;
|
||||
using BDataType = std::tuple_element_t<4, Tuple>;
|
||||
using EDataType = std::tuple_element_t<5, Tuple>;
|
||||
using ALayout = tuple_element_t<0, Tuple>;
|
||||
using BLayout = tuple_element_t<1, Tuple>;
|
||||
using ELayout = tuple_element_t<2, Tuple>;
|
||||
using ADataType = tuple_element_t<3, Tuple>;
|
||||
using BDataType = tuple_element_t<4, Tuple>;
|
||||
using EDataType = tuple_element_t<5, Tuple>;
|
||||
using AElementOp = tuple_element_or_t<6, Tuple, PassThrough>;
|
||||
using BElementOp = tuple_element_or_t<7, Tuple, PassThrough>;
|
||||
using CDEElementOp = tuple_element_or_t<8, Tuple, PassThrough>;
|
||||
|
||||
18
test/grouped_gemm_tile_loop/CMakeLists.txt
Normal file
18
test/grouped_gemm_tile_loop/CMakeLists.txt
Normal file
@@ -0,0 +1,18 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
add_custom_target(test_grouped_gemm_tile_loop)
|
||||
|
||||
if (CK_USE_XDL OR CK_USE_WMMA)
|
||||
add_gtest_executable(test_grouped_gemm_tile_loop_vanilla test_grouped_gemm_tile_loop.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_grouped_gemm_tile_loop_vanilla PRIVATE utility device_grouped_gemm_tile_loop_instance)
|
||||
add_dependencies(test_grouped_gemm_tile_loop test_grouped_gemm_tile_loop_vanilla)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_grouped_gemm_tile_loop_multiply test_grouped_gemm_tile_loop_multiply.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_grouped_gemm_tile_loop_multiply PRIVATE utility device_grouped_gemm_tile_loop_instance)
|
||||
add_dependencies(test_grouped_gemm_tile_loop test_grouped_gemm_tile_loop_multiply)
|
||||
endif()
|
||||
endif()
|
||||
52
test/grouped_gemm_tile_loop/test_grouped_gemm_tile_loop.cpp
Normal file
52
test/grouped_gemm_tile_loop/test_grouped_gemm_tile_loop.cpp
Normal file
@@ -0,0 +1,52 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test_grouped_gemm_tile_loop_util.hpp"
|
||||
|
||||
ck::index_t param_mask = 0xffffff;
|
||||
ck::index_t instance_index = -1;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F8 = ck::f8_t;
|
||||
using I8 = int8_t;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedGemmTileLoop : public ck::test::TestGroupedGemmTileLoop<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
ck::Tuple<Row, Row, ck::Tuple<>, Row, F16, F16, ck::Tuple<>, F16>,
|
||||
ck::Tuple<Row, Col, ck::Tuple<>, Row, F16, F16, ck::Tuple<>, F16>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestGroupedGemmTileLoop, KernelTypes);
|
||||
|
||||
#include "test_grouped_gemm_tile_loop_ut_cases.inc"
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
if(argc == 1) {}
|
||||
else if(argc == 3)
|
||||
{
|
||||
param_mask = strtol(argv[1], nullptr, 0);
|
||||
instance_index = atoi(argv[2]);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Usage of " << argv[0] << std::endl;
|
||||
std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl;
|
||||
}
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "example/68_gemm_add/common.hpp"
|
||||
#include "test_grouped_gemm_tile_loop_util.hpp"
|
||||
|
||||
ck::index_t param_mask = 0xffffff;
|
||||
ck::index_t instance_index = -1;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F8 = ck::f8_t;
|
||||
using I8 = int8_t;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Multiply = ck::tensor_operation::element_wise::Multiply;
|
||||
using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
|
||||
using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu;
|
||||
using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedGemmTileLoop : public ck::test::TestGroupedGemmTileLoop<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
ck::Tuple<Row, Row, ck::Tuple<Row>, Row, BF16, I8, ck::Tuple<BF16>, BF16, PassThrough, PassThrough, Multiply>,
|
||||
ck::Tuple<Row, Row, ck::Tuple<Row, Row>, Row, BF16, I8, ck::Tuple<BF16, BF16>, BF16, PassThrough, PassThrough, MultiplyAdd>,
|
||||
ck::Tuple<Row, Row, ck::Tuple<Row, Row>, Row, BF16, I8, ck::Tuple<BF16, BF16>, BF16, PassThrough, PassThrough, MultiplyAddFastGelu>,
|
||||
ck::Tuple<Row, Row, ck::Tuple<Row>, Row, BF16, I8, ck::Tuple<BF16>, BF16, PassThrough, PassThrough, MultiplyFastGelu>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestGroupedGemmTileLoop, KernelTypes);
|
||||
|
||||
#include "test_grouped_gemm_tile_loop_ut_cases.inc"
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
if(argc == 1) {}
|
||||
else if(argc == 3)
|
||||
{
|
||||
param_mask = strtol(argv[1], nullptr, 0);
|
||||
instance_index = atoi(argv[2]);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Usage of " << argv[0] << std::endl;
|
||||
std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl;
|
||||
}
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TestGroupedGemmTileLoop, TinyCases)
|
||||
{
|
||||
const std::vector<int> Ms{2, 1};
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 544;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
|
||||
this->Run(Ms, Ns, Ks);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedGemmTileLoop, SmallCases)
|
||||
{
|
||||
const std::vector<int> Ms{2, 1, 3, 4, 5};
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 544;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
|
||||
this->Run(Ms, Ns, Ks);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedGemmTileLoop, MidCases)
|
||||
{
|
||||
const std::vector<int> Ms{167, 183, 177, 153, 139, 204};
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 544;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
|
||||
this->Run(Ms, Ns, Ks);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedGemmTileLoop, Regular)
|
||||
{
|
||||
const std::vector<int> Ms{64, 128, 256};
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 320;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
|
||||
this->Run(Ms, Ns, Ks);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedGemmTileLoop, MNKPadded)
|
||||
{
|
||||
const std::vector<int> Ms{127, 150, 188, 210};
|
||||
constexpr int N = 136;
|
||||
constexpr int K = 280;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
|
||||
this->Run(Ms, Ns, Ks);
|
||||
}
|
||||
173
test/grouped_gemm_tile_loop/test_grouped_gemm_tile_loop_util.hpp
Normal file
173
test/grouped_gemm_tile_loop/test_grouped_gemm_tile_loop_util.hpp
Normal file
@@ -0,0 +1,173 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
#include "example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp"
|
||||
#include "profiler/profile_grouped_gemm_tile_loop_generic_impl.hpp"
|
||||
|
||||
extern ck::index_t param_mask;
|
||||
extern ck::index_t instance_index;
|
||||
|
||||
namespace ck {
|
||||
namespace test {
|
||||
|
||||
template <typename Tuple, bool FailIfNoSupportedInstances = false>
|
||||
class TestGroupedGemmTileLoop : public testing::Test
|
||||
{
|
||||
protected:
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ALayout = tuple_element_t<0, Tuple>;
|
||||
using BLayout = tuple_element_t<1, Tuple>;
|
||||
using DsLayout = tuple_element_t<2, Tuple>;
|
||||
using ELayout = tuple_element_t<3, Tuple>;
|
||||
using ADataType = tuple_element_t<4, Tuple>;
|
||||
using BDataType = tuple_element_t<5, Tuple>;
|
||||
using DsDataType = tuple_element_t<6, Tuple>;
|
||||
using EDataType = tuple_element_t<7, Tuple>;
|
||||
using AElementOp = tuple_element_or_t<8, Tuple, PassThrough>;
|
||||
using BElementOp = tuple_element_or_t<9, Tuple, PassThrough>;
|
||||
using CDEElementOp = tuple_element_or_t<10, Tuple, PassThrough>;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
static constexpr auto NumDTensor = DsLayout::Size();
|
||||
|
||||
public:
|
||||
static constexpr bool verify_ = true;
|
||||
static constexpr int init_method_ = 1; // integer value initialization
|
||||
static constexpr bool log_ = false;
|
||||
static constexpr bool bench_ = false; // measure kernel performance
|
||||
static constexpr int n_warmup_ = 0;
|
||||
static constexpr int n_iter_ = 1;
|
||||
|
||||
bool fail_if_no_supported_instances_ = FailIfNoSupportedInstances;
|
||||
|
||||
private:
|
||||
template <typename Layout>
|
||||
void SetStrides(std::vector<int>& strides,
|
||||
const std::vector<int>& rows,
|
||||
const std::vector<int>& cols) const
|
||||
{
|
||||
if(std::is_same_v<Layout, Row>)
|
||||
{
|
||||
for(const auto c : cols)
|
||||
{
|
||||
strides.emplace_back(c);
|
||||
}
|
||||
}
|
||||
else if(std::is_same_v<Layout, Col>)
|
||||
{
|
||||
for(const auto r : rows)
|
||||
{
|
||||
strides.emplace_back(r);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
void Run(const std::vector<int>& Ms,
|
||||
const std::vector<int>& Ns,
|
||||
const std::vector<int>& Ks,
|
||||
const std::vector<int>& StrideAs = {},
|
||||
const std::vector<int>& StrideBs = {},
|
||||
const std::vector<std::array<int, NumDTensor>>& StrideDs = {},
|
||||
const std::vector<int>& StrideEs = {})
|
||||
{
|
||||
std::vector<int> stride_as = StrideAs;
|
||||
std::vector<int> stride_bs = StrideBs;
|
||||
std::vector<std::array<int, NumDTensor>> stride_ds = StrideDs;
|
||||
std::vector<int> stride_es = StrideEs;
|
||||
|
||||
if(stride_as.empty())
|
||||
{
|
||||
SetStrides<ALayout>(stride_as, Ms, Ks);
|
||||
}
|
||||
if(stride_bs.empty())
|
||||
{
|
||||
SetStrides<BLayout>(stride_bs, Ks, Ns);
|
||||
}
|
||||
|
||||
if(stride_ds.empty())
|
||||
{
|
||||
for(size_t group = 0; group < Ms.size(); ++group)
|
||||
{
|
||||
std::array<int, NumDTensor> d_strides;
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = tuple_element_t<i, DsLayout>;
|
||||
|
||||
if(std::is_same_v<DLayout, Row>)
|
||||
{
|
||||
d_strides[i] = Ns[group];
|
||||
}
|
||||
else if(std::is_same_v<DLayout, Col>)
|
||||
{
|
||||
d_strides[i] = Ms[group];
|
||||
}
|
||||
});
|
||||
|
||||
stride_ds.emplace_back(d_strides);
|
||||
}
|
||||
}
|
||||
|
||||
if(stride_es.empty())
|
||||
{
|
||||
SetStrides<ELayout>(stride_es, Ms, Ns);
|
||||
}
|
||||
|
||||
RunSingle(Ms, Ns, Ks, stride_as, stride_bs, stride_ds, stride_es);
|
||||
}
|
||||
|
||||
void RunSingle(const std::vector<int>& Ms,
|
||||
const std::vector<int>& Ns,
|
||||
const std::vector<int>& Ks,
|
||||
const std::vector<int>& StrideAs,
|
||||
const std::vector<int>& StrideBs,
|
||||
const std::vector<std::array<int, NumDTensor>>& StrideDs,
|
||||
const std::vector<int>& StrideEs)
|
||||
{
|
||||
bool pass =
|
||||
ck::profiler::profile_grouped_gemm_tile_loop_generic_impl<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>(verify_,
|
||||
init_method_,
|
||||
log_,
|
||||
bench_,
|
||||
Ms,
|
||||
Ns,
|
||||
Ks,
|
||||
StrideAs,
|
||||
StrideBs,
|
||||
StrideDs,
|
||||
StrideEs,
|
||||
n_warmup_,
|
||||
n_iter_);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace test
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user