mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
This reverts commit1a8bd3d34b. [ROCm/composable_kernel commit:569640dc70]
This commit is contained in:
@@ -15,8 +15,6 @@
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp"
|
||||
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
|
||||
using ::ck::hip_check_error;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
|
||||
@@ -8,11 +8,3 @@ add_example_dependencies(example_grouped_gemm_xdl_multi_abd example_grouped_gemm
|
||||
|
||||
add_example_executable(example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8 grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp)
|
||||
add_example_dependencies(example_grouped_gemm_xdl_multi_abd example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8)
|
||||
|
||||
add_custom_target(example_grouped_gemm_wmma_multi_abd)
|
||||
|
||||
add_example_executable(example_grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16 grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_gemm_wmma_multi_abd example_grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16)
|
||||
|
||||
add_example_executable(example_grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8 grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp)
|
||||
add_example_dependencies(example_grouped_gemm_wmma_multi_abd example_grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8)
|
||||
@@ -1,400 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
|
||||
using ::ck::DeviceMem;
|
||||
using ::ck::hip_check_error;
|
||||
using ::ck::HostTensorDescriptor;
|
||||
using ::ck::Tensor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using I8 = int8_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using Bypass = ck::tensor_layout::BypassLayoutVerification;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Add = ck::tensor_operation::element_wise::Add;
|
||||
|
||||
using A0DataType = BF16;
|
||||
using AsDataType = ck::Tuple<A0DataType>;
|
||||
using B0DataType = I8;
|
||||
using B1DataType = BF16;
|
||||
using BsDataType = ck::Tuple<B0DataType, B1DataType>;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = BF16;
|
||||
using D0DataType = BF16;
|
||||
using DsDataType = ck::Tuple<D0DataType>;
|
||||
using EDataType = BF16;
|
||||
|
||||
using A0Layout = Row;
|
||||
using AsLayout = ck::Tuple<A0Layout>;
|
||||
using B0Layout = Col;
|
||||
using B1Layout = B0Layout;
|
||||
using BsLayout = ck::Tuple<B0Layout, B1Layout>;
|
||||
using DsLayout = ck::Tuple<Row>;
|
||||
using ELayout = Row;
|
||||
|
||||
using Multiply = ck::tensor_operation::element_wise::Multiply;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = Multiply;
|
||||
using CDEElementOp = AddFastGelu;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK
|
||||
// clang-format off
|
||||
///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 128, 32, 128, 32, 8, 8, 16, 16, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>;
|
||||
|
||||
// clang-format on
|
||||
|
||||
struct ProblemSize final
|
||||
{
|
||||
std::vector<ck::index_t> Ms;
|
||||
std::vector<ck::index_t> Ns;
|
||||
std::vector<ck::index_t> Ks;
|
||||
|
||||
std::vector<ck::index_t> stride_As;
|
||||
std::vector<ck::index_t> stride_Bs;
|
||||
std::vector<ck::index_t> stride_Cs;
|
||||
|
||||
ck::index_t group_count;
|
||||
};
|
||||
|
||||
struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
int k_batch = 1;
|
||||
};
|
||||
|
||||
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
auto group_count = problem_size.group_count;
|
||||
|
||||
// GEMM shape
|
||||
std::vector<ck::tensor_operation::device::GemmMultiABDDesc> gemm_descs;
|
||||
|
||||
gemm_descs.reserve(group_count);
|
||||
|
||||
int sum_of_m = 0;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{});
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<Tensor<A0DataType>> a0_tensors;
|
||||
std::vector<Tensor<B1DataType>> b_tensors;
|
||||
std::vector<Tensor<B0DataType>> b0_tensors;
|
||||
std::vector<Tensor<B1DataType>> b1_tensors;
|
||||
std::vector<Tensor<D0DataType>> d0_tensors;
|
||||
std::vector<Tensor<EDataType>> c_host_tensors;
|
||||
std::vector<Tensor<EDataType>> c_device_tensors;
|
||||
|
||||
a0_tensors.reserve(group_count);
|
||||
b_tensors.reserve(group_count);
|
||||
b0_tensors.reserve(group_count);
|
||||
b1_tensors.reserve(group_count);
|
||||
d0_tensors.reserve(group_count);
|
||||
c_host_tensors.reserve(group_count);
|
||||
c_device_tensors.reserve(group_count);
|
||||
|
||||
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
|
||||
|
||||
std::vector<DeviceMemPtr> a0_tensors_device, b0_tensors_device, b1_tensors_device,
|
||||
d0_tensors_device, c_tensors_device;
|
||||
|
||||
a0_tensors_device.reserve(group_count);
|
||||
b0_tensors_device.reserve(group_count);
|
||||
b1_tensors_device.reserve(group_count);
|
||||
d0_tensors_device.reserve(group_count);
|
||||
c_tensors_device.reserve(group_count);
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
sum_of_m += problem_size.Ms[i];
|
||||
|
||||
a0_tensors.push_back(Tensor<A0DataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], A0Layout{})));
|
||||
|
||||
b_tensors.push_back(Tensor<B1DataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], B0Layout{})));
|
||||
b0_tensors.push_back(Tensor<B0DataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], B0Layout{})));
|
||||
b1_tensors.push_back(Tensor<B1DataType>(
|
||||
f_host_tensor_descriptor(problem_size.Ks[i], problem_size.Ns[i], 0, B1Layout{})));
|
||||
|
||||
d0_tensors.push_back(Tensor<D0DataType>(
|
||||
f_host_tensor_descriptor(problem_size.Ms[i], problem_size.Ns[i], 0, ELayout{})));
|
||||
|
||||
c_host_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
|
||||
c_device_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
|
||||
|
||||
std::cout << "gemm[" << i << "] a_m_k: " << a0_tensors[i].mDesc
|
||||
<< " b_k_n: " << b0_tensors[i].mDesc << " d_m_n: " << d0_tensors[i].mDesc
|
||||
<< " c_m_n: " << c_device_tensors[i].mDesc << std::endl;
|
||||
|
||||
flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i];
|
||||
num_btype += sizeof(A0DataType) * a0_tensors[i].mDesc.GetElementSize() +
|
||||
sizeof(B0DataType) * b0_tensors[i].mDesc.GetElementSize() +
|
||||
sizeof(B1DataType) * b1_tensors[i].mDesc.GetElementSize() +
|
||||
sizeof(D0DataType) * d0_tensors[i].mDesc.GetElementSize() +
|
||||
sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize();
|
||||
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_tensors[i].GenerateTensorValue(GeneratorTensor_2<A0DataType>{-5, 5});
|
||||
b0_tensors[i].GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
|
||||
b1_tensors[i].GenerateTensorValue(GeneratorTensor_2<B1DataType>{0, 5});
|
||||
break;
|
||||
case 2:
|
||||
a0_tensors[i].GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_tensors[i].GenerateTensorValue(GeneratorTensor_3<B0DataType>{-5, 5});
|
||||
b1_tensors[i].GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
|
||||
break;
|
||||
default:
|
||||
a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<A0DataType, 0>{});
|
||||
b0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
|
||||
b1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<B1DataType, 1>{});
|
||||
}
|
||||
|
||||
d0_tensors[i].GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
|
||||
}
|
||||
|
||||
constexpr ck::index_t NumATensor = 1;
|
||||
constexpr ck::index_t NumBTensor = 2;
|
||||
constexpr ck::index_t NumDTensor = 1;
|
||||
|
||||
using GroupedGemmKernelArgument = ck::tensor_operation::device::
|
||||
GroupedGemmMultiABDKernelArgument<NumATensor, NumBTensor, NumDTensor>;
|
||||
|
||||
std::vector<GroupedGemmKernelArgument> grouped_gemm_kernel_args_;
|
||||
grouped_gemm_kernel_args_.reserve(group_count);
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
a0_tensors_device.emplace_back(std::make_unique<DeviceMem>(
|
||||
sizeof(A0DataType) * problem_size.Ms[i] * problem_size.Ks[i]));
|
||||
|
||||
b0_tensors_device.emplace_back(std::make_unique<DeviceMem>(
|
||||
sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i]));
|
||||
|
||||
b1_tensors_device.emplace_back(std::make_unique<DeviceMem>(
|
||||
sizeof(B1DataType) * problem_size.Ns[i] * problem_size.Ks[i]));
|
||||
|
||||
d0_tensors_device.emplace_back(
|
||||
std::make_unique<DeviceMem>(sizeof(D0DataType) * problem_size.Ns[i]));
|
||||
|
||||
c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
|
||||
sizeof(EDataType) * problem_size.Ms[i] * problem_size.Ns[i]));
|
||||
|
||||
a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data());
|
||||
b0_tensors_device[i]->ToDevice(b0_tensors[i].mData.data());
|
||||
b1_tensors_device[i]->ToDevice(b1_tensors[i].mData.data());
|
||||
d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data());
|
||||
c_tensors_device[i]->SetZero();
|
||||
|
||||
gemm_descs.push_back(
|
||||
{sum_of_m, problem_size.Ns[i], problem_size.Ks[i], {1}, {1, 1}, {0}, 1});
|
||||
|
||||
grouped_gemm_kernel_args_.push_back(
|
||||
{std::array<const void*, NumATensor>{a0_tensors_device[i]->GetDeviceBuffer()},
|
||||
std::array<const void*, NumBTensor>{b0_tensors_device[i]->GetDeviceBuffer(),
|
||||
b1_tensors_device[i]->GetDeviceBuffer()},
|
||||
std::array<const void*, NumDTensor>{d0_tensors_device[i]->GetDeviceBuffer()},
|
||||
c_tensors_device[i]->GetDeviceBuffer(),
|
||||
problem_size.Ms[i],
|
||||
problem_size.Ns[i],
|
||||
problem_size.Ks[i],
|
||||
std::array<ck::index_t, NumATensor>{problem_size.stride_As[i]},
|
||||
std::array<ck::index_t, NumBTensor>{problem_size.stride_Bs[i], 0},
|
||||
std::array<ck::index_t, NumDTensor>{0},
|
||||
problem_size.stride_Cs[i]});
|
||||
}
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
|
||||
std::vector<std::array<const void*, NumATensor>> p_As = {};
|
||||
std::vector<std::array<const void*, NumBTensor>> p_Bs = {};
|
||||
std::vector<std::array<const void*, NumDTensor>> p_Ds = {};
|
||||
std::vector<void*> p_Cs = {};
|
||||
|
||||
// do GEMM
|
||||
auto argument = gemm.MakeArgument(p_As, p_Bs, p_Ds, p_Cs, gemm_descs);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument));
|
||||
gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer());
|
||||
|
||||
DeviceMem gemm_kernel_args_dev(gemm.GetDeviceKernelArgSize(&argument));
|
||||
hip_check_error(hipMemcpy(gemm_kernel_args_dev.GetDeviceBuffer(),
|
||||
grouped_gemm_kernel_args_.data(),
|
||||
gemm.GetDeviceKernelArgSize(&argument),
|
||||
hipMemcpyHostToDevice));
|
||||
|
||||
gemm.SetDeviceKernelArgs(argument, gemm_kernel_args_dev.GetDeviceBuffer());
|
||||
gemm.SetKBatch(argument, config.k_batch);
|
||||
|
||||
gemm.SetElementwiseOps(argument, a_element_op, b_element_op, cde_element_op);
|
||||
|
||||
invoker.Run(&argument, StreamConfig{nullptr, false});
|
||||
|
||||
if(config.time_kernel)
|
||||
{
|
||||
float ave_time = invoker.Run(&argument, StreamConfig{nullptr, config.time_kernel});
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << gemm.GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
if(config.do_verification)
|
||||
{
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataType,
|
||||
B1DataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
{
|
||||
for(int n = 0; n < problem_size.Ns[i]; ++n)
|
||||
{
|
||||
for(int k = 0; k < problem_size.Ks[i]; ++k)
|
||||
{
|
||||
b_element_op(b_tensors[i](k, n), b0_tensors[i](k, n), b1_tensors[i](k, n));
|
||||
}
|
||||
}
|
||||
|
||||
c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(),
|
||||
c_device_tensors[i].mDesc.GetElementSize() *
|
||||
sizeof(EDataType));
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(a0_tensors[i],
|
||||
b_tensors[i],
|
||||
c_host_tensors[i],
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
for(int m = 0; m < problem_size.Ms[i]; ++m)
|
||||
{
|
||||
for(int n = 0; n < problem_size.Ns[i]; ++n)
|
||||
{
|
||||
cde_element_op(
|
||||
c_host_tensors[i](m, n), c_host_tensors[i](m, n), d0_tensors[i](m, n));
|
||||
}
|
||||
}
|
||||
|
||||
pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]);
|
||||
}
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ProblemSize problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
problem_size.group_count = 16;
|
||||
|
||||
for(int i = 0; i < problem_size.group_count; i++)
|
||||
{
|
||||
problem_size.Ms.push_back(32 + rand() % 32);
|
||||
problem_size.Ns.push_back(1024);
|
||||
problem_size.Ks.push_back(512);
|
||||
|
||||
problem_size.stride_As.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
|
||||
}
|
||||
|
||||
if(argc == 5)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
config.k_batch = std::stoi(argv[4]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4: k_batch (>0)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
return !run_grouped_gemm(problem_size, config);
|
||||
}
|
||||
@@ -1,396 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
|
||||
|
||||
#include "ck/utility/scheduler_enum.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
|
||||
using ::ck::DeviceMem;
|
||||
using ::ck::hip_check_error;
|
||||
using ::ck::HostTensorDescriptor;
|
||||
using ::ck::Tensor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using Bypass = ck::tensor_layout::BypassLayoutVerification;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Add = ck::tensor_operation::element_wise::Add;
|
||||
using Scale = ck::tensor_operation::element_wise::Scale;
|
||||
using AddScale = ck::tensor_operation::element_wise::BinaryWithUnaryCombinedOp<Add, Scale, Scale>;
|
||||
|
||||
using A0DataType = F16;
|
||||
using A1DataType = F32;
|
||||
using AsDataType = ck::Tuple<A0DataType, A1DataType>;
|
||||
using B0DataType = F16;
|
||||
using BsDataType = ck::Tuple<B0DataType>;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using D0DataType = F16;
|
||||
using DsDataType = ck::Tuple<D0DataType>;
|
||||
using EDataType = F16;
|
||||
|
||||
using A0Layout = Row;
|
||||
using A1Layout = Row;
|
||||
using AsLayout = ck::Tuple<A0Layout, A1Layout>;
|
||||
using B0Layout = Col;
|
||||
using BsLayout = ck::Tuple<B0Layout>;
|
||||
using D0Layout = Row;
|
||||
using DsLayout = ck::Tuple<D0Layout>;
|
||||
using ELayout = Row;
|
||||
|
||||
using AElementOp = AddScale;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = Add;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK
|
||||
// clang-format off
|
||||
///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 128, 32, 128, 32, 8, 8, 16, 16, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>;
|
||||
// clang-format on
|
||||
|
||||
struct ProblemSize final
|
||||
{
|
||||
std::vector<ck::index_t> Ms;
|
||||
std::vector<ck::index_t> Ns;
|
||||
std::vector<ck::index_t> Ks;
|
||||
|
||||
std::vector<ck::index_t> stride_As;
|
||||
std::vector<ck::index_t> stride_Bs;
|
||||
std::vector<ck::index_t> stride_Cs;
|
||||
|
||||
ck::index_t group_count;
|
||||
};
|
||||
|
||||
struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
int k_batch = 1;
|
||||
};
|
||||
|
||||
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
auto group_count = problem_size.group_count;
|
||||
|
||||
// GEMM shape
|
||||
std::vector<ck::tensor_operation::device::GemmMultiABDDesc> gemm_descs;
|
||||
|
||||
gemm_descs.reserve(group_count);
|
||||
|
||||
int sum_of_m = 0;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{});
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<Tensor<A0DataType>> a0_tensors;
|
||||
std::vector<Tensor<A1DataType>> a1_tensors;
|
||||
std::vector<Tensor<B0DataType>> b_tensors;
|
||||
std::vector<Tensor<D0DataType>> d0_tensors;
|
||||
std::vector<Tensor<EDataType>> e_host_tensors;
|
||||
std::vector<Tensor<EDataType>> e_device_tensors;
|
||||
|
||||
a0_tensors.reserve(group_count);
|
||||
a1_tensors.reserve(group_count);
|
||||
b_tensors.reserve(group_count);
|
||||
d0_tensors.reserve(group_count);
|
||||
e_host_tensors.reserve(group_count);
|
||||
e_device_tensors.reserve(group_count);
|
||||
|
||||
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
|
||||
|
||||
std::vector<DeviceMemPtr> a0_tensors_device, a1_tensors_device, b_tensors_device,
|
||||
d0_tensors_device, c_tensors_device;
|
||||
|
||||
a0_tensors_device.reserve(group_count);
|
||||
a1_tensors_device.reserve(group_count);
|
||||
b_tensors_device.reserve(group_count);
|
||||
d0_tensors_device.reserve(group_count);
|
||||
c_tensors_device.reserve(group_count);
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
sum_of_m += problem_size.Ms[i];
|
||||
a0_tensors.push_back(Tensor<A0DataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], A0Layout{})));
|
||||
a1_tensors.push_back(Tensor<A1DataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], A1Layout{})));
|
||||
b_tensors.push_back(Tensor<B0DataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], B0Layout{})));
|
||||
d0_tensors.push_back(Tensor<D0DataType>(
|
||||
f_host_tensor_descriptor(problem_size.Ms[i], problem_size.Ns[i], 0, ELayout{})));
|
||||
e_host_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
|
||||
e_device_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
|
||||
std::cout << "gemm[" << i << "] a_m_k: " << a0_tensors[i].mDesc
|
||||
<< " b_k_n: " << b_tensors[i].mDesc << " d_m_n: " << d0_tensors[i].mDesc
|
||||
<< " c_m_n: " << e_device_tensors[i].mDesc << std::endl;
|
||||
|
||||
flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i];
|
||||
num_btype += sizeof(A0DataType) * a0_tensors[i].mDesc.GetElementSize() +
|
||||
sizeof(A1DataType) * a1_tensors[i].mDesc.GetElementSize() +
|
||||
sizeof(B0DataType) * b_tensors[i].mDesc.GetElementSize() +
|
||||
sizeof(D0DataType) * d0_tensors[i].mDesc.GetElementSize() +
|
||||
sizeof(EDataType) * e_device_tensors[i].mDesc.GetElementSize();
|
||||
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_tensors[i].GenerateTensorValue(GeneratorTensor_2<A0DataType>{-5, 5});
|
||||
a1_tensors[i].GenerateTensorValue(GeneratorTensor_2<A1DataType>{-5, 5});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
|
||||
break;
|
||||
case 2:
|
||||
a0_tensors[i].GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
a1_tensors[i].GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
break;
|
||||
default:
|
||||
a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<A0DataType, 0>{});
|
||||
a1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<A1DataType, 0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
|
||||
}
|
||||
|
||||
d0_tensors[i].GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
|
||||
}
|
||||
|
||||
constexpr ck::index_t NumATensor = 2;
|
||||
constexpr ck::index_t NumBTensor = 1;
|
||||
constexpr ck::index_t NumDTensor = 1;
|
||||
|
||||
using GroupedGemmKernelArgument = ck::tensor_operation::device::
|
||||
GroupedGemmMultiABDKernelArgument<NumATensor, NumBTensor, NumDTensor>;
|
||||
|
||||
std::vector<GroupedGemmKernelArgument> grouped_gemm_kernel_args_;
|
||||
grouped_gemm_kernel_args_.reserve(group_count);
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
a0_tensors_device.emplace_back(std::make_unique<DeviceMem>(
|
||||
sizeof(A0DataType) * problem_size.Ms[i] * problem_size.Ks[i]));
|
||||
|
||||
a1_tensors_device.emplace_back(std::make_unique<DeviceMem>(
|
||||
sizeof(A1DataType) * problem_size.Ms[i] * problem_size.Ks[i]));
|
||||
|
||||
b_tensors_device.emplace_back(std::make_unique<DeviceMem>(
|
||||
sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i]));
|
||||
|
||||
d0_tensors_device.emplace_back(
|
||||
std::make_unique<DeviceMem>(sizeof(D0DataType) * problem_size.Ns[i]));
|
||||
|
||||
c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
|
||||
sizeof(EDataType) * problem_size.Ms[i] * problem_size.Ns[i]));
|
||||
|
||||
a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data());
|
||||
a1_tensors_device[i]->ToDevice(a1_tensors[i].mData.data());
|
||||
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
|
||||
d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data());
|
||||
c_tensors_device[i]->SetZero();
|
||||
|
||||
gemm_descs.push_back({sum_of_m,
|
||||
problem_size.Ns[i],
|
||||
problem_size.Ks[i],
|
||||
{1, 1},
|
||||
{problem_size.stride_Bs[i]},
|
||||
{0},
|
||||
1});
|
||||
|
||||
grouped_gemm_kernel_args_.push_back(
|
||||
{std::array<const void*, NumATensor>{a0_tensors_device[i]->GetDeviceBuffer(),
|
||||
a1_tensors_device[i]->GetDeviceBuffer()},
|
||||
std::array<const void*, NumBTensor>{b_tensors_device[i]->GetDeviceBuffer()},
|
||||
std::array<const void*, NumDTensor>{d0_tensors_device[i]->GetDeviceBuffer()},
|
||||
c_tensors_device[i]->GetDeviceBuffer(),
|
||||
problem_size.Ms[i],
|
||||
problem_size.Ns[i],
|
||||
problem_size.Ks[i],
|
||||
std::array<ck::index_t, NumATensor>{problem_size.stride_As[i],
|
||||
problem_size.stride_As[i]},
|
||||
std::array<ck::index_t, NumBTensor>{problem_size.stride_Bs[i]},
|
||||
std::array<ck::index_t, NumDTensor>{0},
|
||||
problem_size.stride_Cs[i]});
|
||||
}
|
||||
|
||||
constexpr float scale = 1.f;
|
||||
auto a_element_op = AElementOp{Add{}, Scale{scale}, Scale{scale}};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
|
||||
std::vector<std::array<const void*, NumATensor>> p_As = {};
|
||||
std::vector<std::array<const void*, NumBTensor>> p_Bs = {};
|
||||
std::vector<std::array<const void*, NumDTensor>> p_Ds = {};
|
||||
std::vector<void*> p_Cs = {};
|
||||
|
||||
// do GEMM
|
||||
auto argument = gemm.MakeArgument(p_As, p_Bs, p_Ds, p_Cs, gemm_descs);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument));
|
||||
gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer());
|
||||
|
||||
DeviceMem gemm_kernel_args_dev(gemm.GetDeviceKernelArgSize(&argument));
|
||||
hip_check_error(hipMemcpy(gemm_kernel_args_dev.GetDeviceBuffer(),
|
||||
grouped_gemm_kernel_args_.data(),
|
||||
gemm.GetDeviceKernelArgSize(&argument),
|
||||
hipMemcpyHostToDevice));
|
||||
|
||||
gemm.SetDeviceKernelArgs(argument, gemm_kernel_args_dev.GetDeviceBuffer());
|
||||
gemm.SetKBatch(argument, config.k_batch);
|
||||
|
||||
gemm.SetElementwiseOps(argument, a_element_op, b_element_op, cde_element_op);
|
||||
|
||||
invoker.Run(&argument, StreamConfig{nullptr, false});
|
||||
|
||||
if(config.time_kernel)
|
||||
{
|
||||
float ave_time = invoker.Run(&argument, StreamConfig{nullptr, config.time_kernel});
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << gemm.GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
if(config.do_verification)
|
||||
{
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataType,
|
||||
B0DataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
BElementOp,
|
||||
PassThrough>;
|
||||
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
{
|
||||
for(int m = 0; m < problem_size.Ms[i]; ++m)
|
||||
{
|
||||
for(int k = 0; k < problem_size.Ks[i]; ++k)
|
||||
{
|
||||
a_element_op(a0_tensors[i](m, k), a0_tensors[i](m, k), a1_tensors[i](m, k));
|
||||
}
|
||||
}
|
||||
|
||||
c_tensors_device[i]->FromDevice(e_device_tensors[i].mData.data(),
|
||||
e_device_tensors[i].mDesc.GetElementSize() *
|
||||
sizeof(EDataType));
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(a0_tensors[i],
|
||||
b_tensors[i],
|
||||
e_host_tensors[i],
|
||||
PassThrough{},
|
||||
b_element_op,
|
||||
PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
for(int m = 0; m < problem_size.Ms[i]; ++m)
|
||||
{
|
||||
for(int n = 0; n < problem_size.Ns[i]; ++n)
|
||||
{
|
||||
cde_element_op(
|
||||
e_host_tensors[i](m, n), e_host_tensors[i](m, n), d0_tensors[i](m, n));
|
||||
}
|
||||
}
|
||||
|
||||
pass &= ck::utils::check_err(e_device_tensors[i], e_host_tensors[i]);
|
||||
}
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ProblemSize problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
problem_size.group_count = 16;
|
||||
|
||||
for(int i = 0; i < problem_size.group_count; i++)
|
||||
{
|
||||
problem_size.Ms.push_back(32 + rand() % 32);
|
||||
problem_size.Ns.push_back(64);
|
||||
problem_size.Ks.push_back(64);
|
||||
|
||||
problem_size.stride_As.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
|
||||
}
|
||||
|
||||
if(argc == 5)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
config.k_batch = std::stoi(argv[4]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4: k_batch (>0)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
return !run_grouped_gemm(problem_size, config);
|
||||
}
|
||||
@@ -20,8 +20,6 @@
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
|
||||
using ::ck::DeviceMem;
|
||||
using ::ck::hip_check_error;
|
||||
using ::ck::HostTensorDescriptor;
|
||||
@@ -222,8 +220,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
a0_tensors_device.emplace_back(std::make_unique<DeviceMem>(
|
||||
sizeof(A0DataType) * problem_size.Ms[i] * problem_size.Ks[i]));
|
||||
a0_tensors_device.emplace_back(
|
||||
std::make_unique<DeviceMem>(sizeof(A0DataType) * sum_of_m * problem_size.Ks[i]));
|
||||
|
||||
b0_tensors_device.emplace_back(std::make_unique<DeviceMem>(
|
||||
sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i]));
|
||||
@@ -234,12 +232,21 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
d0_tensors_device.emplace_back(
|
||||
std::make_unique<DeviceMem>(sizeof(D0DataType) * problem_size.Ns[i]));
|
||||
|
||||
c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
|
||||
sizeof(EDataType) * problem_size.Ms[i] * problem_size.Ns[i]));
|
||||
c_tensors_device.emplace_back(
|
||||
std::make_unique<DeviceMem>(sizeof(EDataType) * sum_of_m * problem_size.Ns[i]));
|
||||
|
||||
a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data(),
|
||||
a0_tensors[i].mDesc.GetElementSpaceSize() *
|
||||
sizeof(A0DataType));
|
||||
|
||||
b0_tensors_device[i]->ToDevice(b0_tensors[i].mData.data(),
|
||||
b0_tensors[i].mDesc.GetElementSpaceSize() *
|
||||
sizeof(B0DataType));
|
||||
|
||||
b1_tensors_device[i]->ToDevice(b1_tensors[i].mData.data(),
|
||||
b1_tensors[i].mDesc.GetElementSpaceSize() *
|
||||
sizeof(B1DataType));
|
||||
|
||||
a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data());
|
||||
b0_tensors_device[i]->ToDevice(b0_tensors[i].mData.data());
|
||||
b1_tensors_device[i]->ToDevice(b1_tensors[i].mData.data());
|
||||
d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data());
|
||||
c_tensors_device[i]->SetZero();
|
||||
|
||||
@@ -391,7 +398,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
printf("arg4: k_batch (>0)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
@@ -20,8 +20,6 @@
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
|
||||
using ::ck::DeviceMem;
|
||||
using ::ck::hip_check_error;
|
||||
using ::ck::HostTensorDescriptor;
|
||||
@@ -49,9 +47,9 @@ using B0DataType = F16;
|
||||
using BsDataType = ck::Tuple<B0DataType>;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using D0DataType = F16;
|
||||
using D0DataType = F32;
|
||||
using DsDataType = ck::Tuple<D0DataType>;
|
||||
using EDataType = F16;
|
||||
using EDataType = F32;
|
||||
|
||||
using A0Layout = Row;
|
||||
using A1Layout = Row;
|
||||
@@ -212,11 +210,11 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
a0_tensors_device.emplace_back(std::make_unique<DeviceMem>(
|
||||
sizeof(A0DataType) * problem_size.Ms[i] * problem_size.Ks[i]));
|
||||
a0_tensors_device.emplace_back(
|
||||
std::make_unique<DeviceMem>(sizeof(A0DataType) * sum_of_m * problem_size.Ks[i]));
|
||||
|
||||
a1_tensors_device.emplace_back(std::make_unique<DeviceMem>(
|
||||
sizeof(A1DataType) * problem_size.Ms[i] * problem_size.Ks[i]));
|
||||
a1_tensors_device.emplace_back(
|
||||
std::make_unique<DeviceMem>(sizeof(A1DataType) * sum_of_m * problem_size.Ks[i]));
|
||||
|
||||
b_tensors_device.emplace_back(std::make_unique<DeviceMem>(
|
||||
sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i]));
|
||||
@@ -224,12 +222,19 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
d0_tensors_device.emplace_back(
|
||||
std::make_unique<DeviceMem>(sizeof(D0DataType) * problem_size.Ns[i]));
|
||||
|
||||
c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
|
||||
sizeof(EDataType) * problem_size.Ms[i] * problem_size.Ns[i]));
|
||||
c_tensors_device.emplace_back(
|
||||
std::make_unique<DeviceMem>(sizeof(EDataType) * sum_of_m * problem_size.Ns[i]));
|
||||
|
||||
a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data());
|
||||
a1_tensors_device[i]->ToDevice(a1_tensors[i].mData.data());
|
||||
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
|
||||
a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data(),
|
||||
a0_tensors[i].mDesc.GetElementSpaceSize() *
|
||||
sizeof(A0DataType));
|
||||
|
||||
a1_tensors_device[i]->ToDevice(a1_tensors[i].mData.data(),
|
||||
a1_tensors[i].mDesc.GetElementSpaceSize() *
|
||||
sizeof(A1DataType));
|
||||
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data(),
|
||||
b_tensors[i].mDesc.GetElementSpaceSize() *
|
||||
sizeof(B0DataType));
|
||||
d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data());
|
||||
c_tensors_device[i]->SetZero();
|
||||
|
||||
@@ -389,7 +394,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
printf("arg4: k_batch (>0)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
@@ -1,899 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
#include "ck/utility/env.hpp"
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename GemmDesc,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
typename AsLayout,
|
||||
typename BsLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename Block2ETileMap,
|
||||
typename GroupedGemmBlock2ETileMap,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t MinimumOccupancy = 1,
|
||||
TailNumber TailNum = TailNumber::Full>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
kernel_grouped_gemm_wmma_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
|
||||
const index_t group_count,
|
||||
const index_t grid_size_grp,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
#if defined(__gfx11__) || defined(__gfx12__)
|
||||
__shared__ char p_shared[GridwiseGemm::template GetSharedMemoryNumberOfByte<
|
||||
typename GridwiseGemm::EpilogueCShuffle>()];
|
||||
|
||||
const index_t KBatch = 1;
|
||||
|
||||
const index_t block_id = get_block_1d_id();
|
||||
|
||||
const auto gemm_desc_ptr =
|
||||
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
|
||||
|
||||
const index_t group_id = block_id / grid_size_grp;
|
||||
|
||||
if(group_id >= group_count)
|
||||
return;
|
||||
|
||||
auto karg = gemm_desc_ptr[group_id];
|
||||
|
||||
if(karg.M == 0 || karg.N == 0 || karg.K == 0)
|
||||
return;
|
||||
|
||||
#if defined(__gfx11__)
|
||||
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
|
||||
if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
|
||||
(std::is_same_v<typename GridwiseGemm::EDataType_, ck::half_t> ||
|
||||
std::is_same_v<typename GridwiseGemm::EDataType_, ck::bhalf_t>)))
|
||||
#endif
|
||||
{
|
||||
|
||||
typename GridwiseGemm::Problem problem(karg.M,
|
||||
karg.N,
|
||||
karg.K,
|
||||
karg.StrideAs,
|
||||
karg.StrideBs,
|
||||
karg.StrideDs,
|
||||
karg.StrideE,
|
||||
KBatch);
|
||||
|
||||
const auto e_grid_desc_m_n = GridwiseGemm::template MakeDEGridDescriptor_M_N<ELayout>(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideE);
|
||||
|
||||
const index_t BlockStart = group_id * grid_size_grp;
|
||||
|
||||
const auto local_b2e_tile_map = Block2ETileMap{e_grid_desc_m_n, KBatch};
|
||||
|
||||
const auto local_grid_size = local_b2e_tile_map.CalculateGridSize(e_grid_desc_m_n);
|
||||
|
||||
constexpr auto NumATensor = GridwiseGemm::AsGridPointer::Size();
|
||||
constexpr auto NumBTensor = GridwiseGemm::BsGridPointer::Size();
|
||||
constexpr auto NumDTensor = GridwiseGemm::DsGridPointer::Size();
|
||||
|
||||
typename GridwiseGemm::AsGridPointer p_as_grid_;
|
||||
typename GridwiseGemm::BsGridPointer p_bs_grid_;
|
||||
typename GridwiseGemm::DsGridPointer p_ds_grid_;
|
||||
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
using ADataType = remove_cvref_t<decltype(p_as_grid_(i))>;
|
||||
p_as_grid_(i) = static_cast<ADataType>(karg.p_as_grid[i]);
|
||||
});
|
||||
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
using BDataType = remove_cvref_t<decltype(p_bs_grid_(i))>;
|
||||
p_bs_grid_(i) = static_cast<BDataType>(karg.p_bs_grid[i]);
|
||||
});
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<decltype(p_ds_grid_(i))>;
|
||||
p_ds_grid_(i) = static_cast<DDataType>(karg.p_ds_grid[i]);
|
||||
});
|
||||
|
||||
index_t id_off = 0;
|
||||
index_t id_local = get_block_1d_id() - BlockStart;
|
||||
|
||||
while(id_local < local_grid_size)
|
||||
{
|
||||
const auto block_2_etile_map =
|
||||
GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off);
|
||||
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
decltype(block_2_etile_map),
|
||||
decltype(epilogue_args),
|
||||
1,
|
||||
2>(
|
||||
p_as_grid_,
|
||||
p_bs_grid_,
|
||||
p_ds_grid_,
|
||||
static_cast<typename GridwiseGemm::EDataType_*>(karg.p_e_grid),
|
||||
p_shared,
|
||||
problem,
|
||||
block_2_etile_map,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
epilogue_args);
|
||||
|
||||
id_off += grid_size_grp;
|
||||
id_local += grid_size_grp;
|
||||
}
|
||||
}
|
||||
#else
|
||||
ignore = gemm_descs_const;
|
||||
ignore = group_count;
|
||||
ignore = grid_size_grp;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = cde_element_op;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename AsLayout,
|
||||
typename BsLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t AK1,
|
||||
ck::index_t BK1,
|
||||
ck::index_t MPerWmma,
|
||||
ck::index_t NPerWmma,
|
||||
ck::index_t MRepeat,
|
||||
ck::index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
typename ComputeTypeA = EDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
bool PermuteA = false,
|
||||
bool PermuteB = false>
|
||||
struct DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK
|
||||
: public DeviceGroupedGemmMultiABDFixedNK<AsLayout,
|
||||
BsLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK;
|
||||
|
||||
static constexpr index_t NumATensor = AsDataType::Size();
|
||||
static constexpr index_t NumBTensor = BsDataType::Size();
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
// Note: Pass multiple layout but then using only the first one
|
||||
// This is to replicate xdl functionality but it should be extended
|
||||
using ALayout = remove_cvref_t<tuple_element_t<0, AsLayout>>;
|
||||
using BLayout = remove_cvref_t<tuple_element_t<0, BsLayout>>;
|
||||
|
||||
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3<
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
AsDataType,
|
||||
BsDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename uniform_sequence_gen<NumDTensor + 1,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock>::type,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
false,
|
||||
false>;
|
||||
|
||||
// TODO: Block to tile mappings could potentially moved out to avoid code duplications between
|
||||
// different device implementations.
|
||||
|
||||
template <typename UnderlyingBlockToCTileMap>
|
||||
struct OffsettedBlockToCTileMapMLoops
|
||||
{
|
||||
using underlying_type = UnderlyingBlockToCTileMap;
|
||||
|
||||
__host__ __device__ OffsettedBlockToCTileMapMLoops(
|
||||
UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off = 0)
|
||||
{
|
||||
block_to_ctile_map_ = block_to_ctile_map;
|
||||
block_start_ = block_start;
|
||||
id_off_ = id_off;
|
||||
}
|
||||
|
||||
template <typename TopIdx>
|
||||
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
|
||||
{
|
||||
auto idx_bot = block_to_ctile_map_.CalculateBottomIndex(
|
||||
make_multi_index(idx_top[Number<0>{}] - block_start_ + id_off_));
|
||||
|
||||
return make_tuple(idx_bot[Number<0>{}], idx_bot[Number<1>{}], idx_bot[Number<2>{}]);
|
||||
}
|
||||
|
||||
template <typename CTileIdx, typename CTileDim>
|
||||
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
|
||||
const CTileDim& c_tile_dim) const
|
||||
{
|
||||
return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
{
|
||||
return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
{
|
||||
return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
UnderlyingBlockToCTileMap block_to_ctile_map_;
|
||||
index_t block_start_;
|
||||
index_t id_off_;
|
||||
};
|
||||
|
||||
template <index_t MPerBlock_, index_t NPerBlock_>
|
||||
struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops
|
||||
{
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops() = default;
|
||||
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(
|
||||
const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default;
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(
|
||||
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default;
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&
|
||||
operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default;
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&
|
||||
operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default;
|
||||
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M,
|
||||
index_t N,
|
||||
index_t KBatch,
|
||||
index_t M01 = 8)
|
||||
: M_(M), N_(N), KBatch_(KBatch), M01_(M01)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(
|
||||
const CGridDesc_M_N& c_grid_desc_m_n, index_t KBatch, index_t M01 = 8)
|
||||
: BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(
|
||||
c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), KBatch, M01)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
|
||||
{
|
||||
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
|
||||
|
||||
return M0 * N0 * KBatch_;
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ __device__ constexpr index_t
|
||||
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
{
|
||||
return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename TopIdx>
|
||||
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
|
||||
{
|
||||
auto block_1d_id = idx_top[I0];
|
||||
|
||||
const auto M0 = math::integer_divide_ceil(M_, MPerBlock_);
|
||||
const auto N0 = math::integer_divide_ceil(N_, NPerBlock_);
|
||||
|
||||
block_1d_id = block_1d_id % (M0 * N0 * KBatch_); // hide groups
|
||||
|
||||
const index_t idx_ksplit = block_1d_id / (M0 * N0);
|
||||
block_1d_id = block_1d_id % (M0 * N0);
|
||||
|
||||
index_t idx_N0 = block_1d_id % N0;
|
||||
index_t idx_M0 = block_1d_id / N0;
|
||||
|
||||
const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
|
||||
|
||||
index_t idx_M00 = idx_M0 / M01_;
|
||||
index_t idx_M01 = idx_M0 % M01_;
|
||||
index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
|
||||
|
||||
return make_tuple(idx_ksplit,
|
||||
idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
|
||||
idx_N0_M01_local / M01_adapt);
|
||||
}
|
||||
|
||||
template <typename CTileIdx, typename CTileDim>
|
||||
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
|
||||
const CTileDim& /* c_tile_dim */) const
|
||||
{
|
||||
return true; // always valid provided that user gets grid size from CalculateGridSize()
|
||||
}
|
||||
|
||||
private:
|
||||
index_t M_;
|
||||
index_t N_;
|
||||
index_t KBatch_;
|
||||
index_t M01_;
|
||||
};
|
||||
|
||||
using Block2ETileMap = BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops<MPerBlock, NPerBlock>;
|
||||
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops<Block2ETileMap>;
|
||||
|
||||
static constexpr index_t DefaultKBatch = 1; // implementation only supports KBatch == 1
|
||||
using KernelArgument = typename GridwiseGemm::Argument;
|
||||
|
||||
using GemmTransKernelArg =
|
||||
GroupedGemmMultiABDKernelArgument<NumATensor, NumBTensor, NumDTensor>;
|
||||
|
||||
static constexpr bool CalculateHasMainKBlockLoop(const GemmTransKernelArg& karg,
|
||||
index_t k_batch)
|
||||
{
|
||||
index_t k_grain = k_batch * KPerBlock;
|
||||
index_t K_split = (karg.K + k_grain - 1) / k_batch;
|
||||
return GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
}
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
|
||||
Argument(std::vector<std::array<const void*, NumATensor>>& p_As,
|
||||
std::vector<std::array<const void*, NumBTensor>>& p_Bs,
|
||||
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
|
||||
std::vector<void*>& p_Es,
|
||||
std::vector<GemmMultiABDDesc>& gemm_descs,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation c_element_op)
|
||||
: Argument(p_As,
|
||||
p_Bs,
|
||||
p_Ds,
|
||||
p_Es,
|
||||
gemm_descs,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
DefaultKBatch)
|
||||
{
|
||||
// TODO: use occupancy api to calculate appropriate batch size.
|
||||
}
|
||||
|
||||
// Client is expected to manually copy the kernel arguments to the device therefore there is
|
||||
// no point in setting tensor device pointers for the argument structure.
|
||||
Argument(std::vector<std::array<const void*, NumATensor>>&,
|
||||
std::vector<std::array<const void*, NumBTensor>>&,
|
||||
std::vector<std::array<const void*, NumDTensor>>&,
|
||||
std::vector<void*>&,
|
||||
std::vector<GemmMultiABDDesc>& gemm_descs,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation c_element_op,
|
||||
index_t kbatch)
|
||||
: group_count_{ck::type_convert<ck::index_t>(gemm_descs.size())},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op},
|
||||
grouped_gemm_kernel_args_dev{nullptr},
|
||||
gemm_kernel_host_args_{nullptr},
|
||||
grid_size_{0},
|
||||
k_batch_{kbatch}
|
||||
{
|
||||
gemm_desc_kernel_arg_.reserve(group_count_);
|
||||
|
||||
index_t group_id = 0;
|
||||
|
||||
sum_of_m = gemm_descs[0].M_;
|
||||
const index_t AverM = math::integer_divide_ceil(sum_of_m, group_count_);
|
||||
const index_t fixed_N = gemm_descs[0].N_;
|
||||
const index_t fixed_K = gemm_descs[0].K_;
|
||||
|
||||
for(std::size_t g = 0; g < gemm_descs.size(); g++)
|
||||
{
|
||||
const index_t M = gemm_descs[g].M_;
|
||||
const index_t N = gemm_descs[g].N_;
|
||||
const index_t K = gemm_descs[g].K_;
|
||||
|
||||
if(M != sum_of_m || N != fixed_N || K != fixed_K)
|
||||
{
|
||||
throw std::runtime_error("wrong! M/N/K is not identical");
|
||||
}
|
||||
|
||||
a_mtx_mraw_kraw_.emplace_back(sum_of_m, K);
|
||||
b_mtx_nraw_kraw_.emplace_back(N, K);
|
||||
|
||||
// pointer
|
||||
std::array<const void*, NumATensor> p_as_grid;
|
||||
std::array<const void*, NumBTensor> p_bs_grid;
|
||||
std::array<const void*, NumDTensor> p_ds_grid;
|
||||
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) { p_as_grid[i] = nullptr; });
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) { p_bs_grid[i] = nullptr; });
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) { p_ds_grid[i] = nullptr; });
|
||||
|
||||
std::array<index_t, NumATensor> StrideAs;
|
||||
std::array<index_t, NumBTensor> StrideBs;
|
||||
std::array<index_t, NumDTensor> StrideDs;
|
||||
|
||||
const index_t StrideE = gemm_descs[g].stride_C_;
|
||||
|
||||
if(gemm_descs[g].stride_As_.size() != NumATensor)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! gemm_descs[i].stride_As_.size() does not match NumATensor");
|
||||
}
|
||||
|
||||
static_for<0, NumATensor, 1>{}(
|
||||
[&](auto j) { StrideAs[j] = gemm_descs[g].stride_As_[j]; });
|
||||
|
||||
if(gemm_descs[g].stride_Bs_.size() != NumBTensor)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! gemm_descs[i].stride_Bs_.size() does not match NumBTensor");
|
||||
}
|
||||
|
||||
static_for<0, NumBTensor, 1>{}(
|
||||
[&](auto j) { StrideBs[j] = gemm_descs[g].stride_Bs_[j]; });
|
||||
|
||||
if(gemm_descs[g].stride_Ds_.size() != NumDTensor)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! gemm_descs[i].stride_Ds_.size() does not match NumDTensor");
|
||||
}
|
||||
|
||||
static_for<0, NumDTensor, 1>{}(
|
||||
[&](auto j) { StrideDs[j] = gemm_descs[g].stride_Ds_[j]; });
|
||||
|
||||
const auto e_grid_desc_m_n =
|
||||
GridwiseGemm::template MakeDEGridDescriptor_M_N<ELayout>(
|
||||
AverM, AverM, N, N, StrideE);
|
||||
|
||||
// block-to-e-tile map
|
||||
const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_};
|
||||
|
||||
grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n);
|
||||
|
||||
if(group_id * grid_size_grp_ != grid_size_)
|
||||
{
|
||||
throw std::runtime_error("wrong! grid_size_grp_ is not identical!");
|
||||
}
|
||||
|
||||
const index_t block_start = grid_size_;
|
||||
|
||||
grid_size_ += grid_size_grp_;
|
||||
|
||||
if(!local_b2c_tile_map.CheckValidity(e_grid_desc_m_n))
|
||||
{
|
||||
throw std::runtime_error("wrong! block_2_etile_map validation failed");
|
||||
}
|
||||
|
||||
auto grouped_block_2_ctile_map =
|
||||
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
|
||||
|
||||
auto karg = GemmTransKernelArg({p_as_grid,
|
||||
p_bs_grid,
|
||||
p_ds_grid,
|
||||
nullptr,
|
||||
AverM,
|
||||
N,
|
||||
K,
|
||||
StrideAs,
|
||||
StrideBs,
|
||||
StrideDs,
|
||||
StrideE});
|
||||
|
||||
gemm_desc_kernel_arg_.emplace_back(std::move(karg));
|
||||
|
||||
group_id++;
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateKBatch(index_t) {}
|
||||
|
||||
index_t group_count_;
|
||||
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation c_element_op_;
|
||||
|
||||
std::vector<GemmTransKernelArg> gemm_desc_kernel_arg_;
|
||||
std::vector<Tuple<index_t, index_t>> a_mtx_mraw_kraw_;
|
||||
std::vector<Tuple<index_t, index_t>> b_mtx_nraw_kraw_;
|
||||
|
||||
const void* grouped_gemm_kernel_args_dev;
|
||||
void* gemm_kernel_host_args_;
|
||||
index_t grid_size_;
|
||||
index_t grid_size_grp_;
|
||||
index_t sum_of_m;
|
||||
|
||||
index_t k_batch_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(arg.grouped_gemm_kernel_args_dev == nullptr)
|
||||
{
|
||||
throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullptr");
|
||||
}
|
||||
|
||||
if(arg.k_batch_ != 1)
|
||||
{
|
||||
throw std::runtime_error("Split K functionality is not supported for wmma multi "
|
||||
"abd fixed nk implementation.");
|
||||
}
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
auto launch_kernel = [&](auto e_global_memory_operation_) {
|
||||
const auto kernel = kernel_grouped_gemm_wmma_fixed_nk<GridwiseGemm,
|
||||
GemmTransKernelArg,
|
||||
true, // has_main_k_block_loop
|
||||
e_global_memory_operation_,
|
||||
AsLayout,
|
||||
BsLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
Block2ETileMap,
|
||||
GroupedGemmBlock2ETileMap,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
GemmSpec>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(arg.grid_size_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev),
|
||||
arg.gemm_desc_kernel_arg_.size(),
|
||||
arg.grid_size_grp_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_);
|
||||
};
|
||||
|
||||
constexpr auto Set = InMemoryDataOperationEnum::Set;
|
||||
ave_time = launch_kernel(integral_constant<InMemoryDataOperationEnum, Set>{});
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return RunImp(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ck::type_convert<ck::index_t>(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
bool supported = true;
|
||||
|
||||
// If we use padding we do not support vector loads for dimensions not divisible by
|
||||
// vector load size.
|
||||
if constexpr(GemmSpec != GemmSpecialization::Default)
|
||||
{
|
||||
// [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} layout,
|
||||
// thus we have to adapt it to the {M,K} or {N,K} layout.
|
||||
const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0;
|
||||
const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0;
|
||||
|
||||
for(index_t i = 0; i < arg.group_count_; ++i)
|
||||
{
|
||||
const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number<a_raw_vector_dim>{});
|
||||
const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number<b_raw_vector_dim>{});
|
||||
|
||||
supported = supported & (a_vector_dim % ABlockTransferSrcScalarPerVector == 0);
|
||||
supported = supported & (b_vector_dim % BBlockTransferSrcScalarPerVector == 0);
|
||||
}
|
||||
}
|
||||
|
||||
for(index_t i = 0; i < arg.group_count_; i++)
|
||||
{
|
||||
if(CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i], arg.k_batch_) != true)
|
||||
{
|
||||
supported = false;
|
||||
}
|
||||
}
|
||||
|
||||
return supported;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(std::vector<std::array<const void*, NumATensor>>& p_As,
|
||||
std::vector<std::array<const void*, NumBTensor>>& p_Bs,
|
||||
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
|
||||
std::vector<void*>& p_Es,
|
||||
std::vector<GemmMultiABDDesc> gemm_descs,
|
||||
AElementwiseOperation a_element_op = AElementwiseOperation{},
|
||||
BElementwiseOperation b_element_op = BElementwiseOperation{},
|
||||
CDEElementwiseOperation c_element_op = CDEElementwiseOperation{})
|
||||
{
|
||||
return Argument{
|
||||
p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(std::vector<std::array<const void*, NumATensor>>& p_As,
|
||||
std::vector<std::array<const void*, NumBTensor>>& p_Bs,
|
||||
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
|
||||
std::vector<void*>& p_Es,
|
||||
std::vector<GemmMultiABDDesc>& gemm_descs,
|
||||
AElementwiseOperation a_element_op = AElementwiseOperation{},
|
||||
BElementwiseOperation b_element_op = BElementwiseOperation{},
|
||||
CDEElementwiseOperation c_element_op = CDEElementwiseOperation{}) override
|
||||
{
|
||||
return std::make_unique<Argument>(
|
||||
p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGroupedGemm_Wmma_Fixed_Nk"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< AK1 << ", "
|
||||
<< BK1 << ", "
|
||||
<< MPerWmma << ", "
|
||||
<< NPerWmma << ", "
|
||||
<< ABlockTransferSrcScalarPerVector << ", "
|
||||
<< BBlockTransferSrcScalarPerVector << ", "
|
||||
<< CShuffleMRepeatPerShuffle << ", "
|
||||
<< CShuffleNRepeatPerShuffle << ", "
|
||||
<< getGemmSpecializationString(GemmSpec)
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
|
||||
static void SetElementwiseOps(Argument& arg,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation c_element_op)
|
||||
{
|
||||
arg.a_element_op_ = a_element_op;
|
||||
arg.b_element_op_ = b_element_op;
|
||||
arg.c_element_op_ = c_element_op;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
void SetElementwiseOps(BaseArgument* p_arg,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation c_element_op) const override
|
||||
{
|
||||
|
||||
SetElementwiseOps(
|
||||
*dynamic_cast<Argument*>(p_arg), a_element_op, b_element_op, c_element_op);
|
||||
}
|
||||
|
||||
static void SetDeviceKernelArgs(Argument& arg, const void* kernel_args)
|
||||
{
|
||||
arg.grouped_gemm_kernel_args_dev = kernel_args;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const override
|
||||
{
|
||||
return SetDeviceKernelArgs(*dynamic_cast<Argument*>(p_arg), kernel_args);
|
||||
}
|
||||
|
||||
size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
|
||||
{
|
||||
auto arg = *dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
return arg.group_count_ *
|
||||
sizeof(GroupedGemmMultiABDKernelArgument<NumATensor, NumBTensor, NumDTensor>);
|
||||
}
|
||||
|
||||
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
|
||||
{
|
||||
auto p_arg_ = dynamic_cast<const Argument*>(p_arg);
|
||||
if(p_arg_)
|
||||
{
|
||||
return p_arg_->gemm_desc_kernel_arg_.size() * sizeof(GemmTransKernelArg);
|
||||
}
|
||||
else
|
||||
throw std::runtime_error(
|
||||
"The argument pointer is not an object of "
|
||||
"DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK::Argument structure!");
|
||||
}
|
||||
|
||||
void SetWorkSpacePointer(BaseArgument* p_arg,
|
||||
void* p_workspace,
|
||||
const StreamConfig& stream_config = StreamConfig{}) const override
|
||||
{
|
||||
auto p_arg_ = dynamic_cast<Argument*>(p_arg);
|
||||
p_arg_->p_workspace_ = p_workspace;
|
||||
|
||||
hip_check_error(
|
||||
hipMemsetAsync(p_workspace, 0, GetWorkSpaceSize(p_arg), stream_config.stream_id_));
|
||||
}
|
||||
|
||||
static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); }
|
||||
|
||||
// polymorphic
|
||||
void SetKBatch(BaseArgument* p_arg, index_t k_batch) const override
|
||||
{
|
||||
return SetKBatch(*dynamic_cast<Argument*>(p_arg), k_batch);
|
||||
}
|
||||
|
||||
void SetHostKernelArgsPointer(BaseArgument* p_arg, void* p_host_kernel_args) const
|
||||
{
|
||||
Argument* pArg_ = dynamic_cast<Argument*>(p_arg);
|
||||
if(!pArg_)
|
||||
{
|
||||
throw std::runtime_error("Failed to cast argument pointer!");
|
||||
}
|
||||
|
||||
pArg_->gemm_kernel_host_args_ = p_host_kernel_args;
|
||||
std::copy(pArg_->gemm_desc_kernel_arg_.begin(),
|
||||
pArg_->gemm_desc_kernel_arg_.end(),
|
||||
static_cast<GemmTransKernelArg*>(pArg_->gemm_kernel_host_args_));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -605,7 +605,7 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
|
||||
|
||||
if(arg.grouped_gemm_kernel_args_dev == nullptr)
|
||||
{
|
||||
throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullptr");
|
||||
throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullpr");
|
||||
}
|
||||
|
||||
float ave_time = 0;
|
||||
@@ -688,11 +688,6 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_xdl_wmma_supported<ComputeType, ComputeType, MPerXDL, NPerXDL>())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Split-K autodeduction is not supported
|
||||
if(arg.k_batch_ < 1)
|
||||
{
|
||||
@@ -725,26 +720,6 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
|
||||
}
|
||||
}
|
||||
|
||||
for(index_t i = 0; i < arg.group_count_; i++)
|
||||
{
|
||||
if(get_warp_size() == 64)
|
||||
{
|
||||
if(GridwiseGemm64::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K_) !=
|
||||
true)
|
||||
{
|
||||
supported = false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(GridwiseGemm32::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K_) !=
|
||||
true)
|
||||
{
|
||||
supported = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return supported;
|
||||
}
|
||||
|
||||
|
||||
@@ -696,7 +696,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
|
||||
|
||||
if(arg.grouped_gemm_kernel_args_dev == nullptr)
|
||||
{
|
||||
throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullptr");
|
||||
throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullpr");
|
||||
}
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
@@ -333,7 +333,6 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
using typename Base::DsGridPointer;
|
||||
using AsDataType_ = AsDataType;
|
||||
using BsDataType_ = BsDataType;
|
||||
using EDataType_ = EDataType;
|
||||
|
||||
struct Problem
|
||||
{
|
||||
|
||||
@@ -48,15 +48,6 @@ __host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>&
|
||||
ty);
|
||||
}
|
||||
|
||||
template <typename... X, typename... Y>
|
||||
auto concat_tuple_of_reference(ck::Tuple<X&...>& tx, ck::Tuple<Y&...>& ty)
|
||||
{
|
||||
return ck::unpack2(
|
||||
[&](auto&&... zs) { return ck::Tuple<decltype(zs)...>{ck::forward<decltype(zs)>(zs)...}; },
|
||||
tx,
|
||||
ty);
|
||||
}
|
||||
|
||||
template <typename... X, typename... Y>
|
||||
__host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx, const Tuple<Y...>& ty)
|
||||
{
|
||||
|
||||
@@ -1,194 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/utility/functional4.hpp"
|
||||
#include "ck/utility/tuple_helper.hpp"
|
||||
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace host {
|
||||
|
||||
template <typename AsTensorTuple,
|
||||
typename BsTensorTuple,
|
||||
typename DsTensorTuple,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
typename AComputeType,
|
||||
typename BComputeType>
|
||||
struct ReferenceGemmMultiABD : public device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
struct Argument : public device::BaseArgument
|
||||
{
|
||||
Argument(const AsTensorTuple& as_m_k,
|
||||
const BsTensorTuple& bs_k_n,
|
||||
const DsTensorTuple& ds_m_n,
|
||||
Tensor<EDataType>& e_m_n,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
: as_m_k_{as_m_k},
|
||||
bs_k_n_{bs_k_n},
|
||||
ds_m_n_{ds_m_n},
|
||||
e_m_n_{e_m_n},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op}
|
||||
{
|
||||
}
|
||||
|
||||
const AsTensorTuple& as_m_k_;
|
||||
const BsTensorTuple& bs_k_n_;
|
||||
const DsTensorTuple& ds_m_n_;
|
||||
Tensor<EDataType>& e_m_n_;
|
||||
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public device::BaseInvoker
|
||||
{
|
||||
using Argument = ReferenceGemmMultiABD::Argument;
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
static constexpr index_t NumATensor = AsTensorTuple::Size();
|
||||
static constexpr index_t NumBTensor = BsTensorTuple::Size();
|
||||
static constexpr index_t NumDTensor = DsTensorTuple::Size();
|
||||
|
||||
const int M = arg.as_m_k_[Number<0>{}].mDesc.GetLengths()[0];
|
||||
const int K = arg.as_m_k_[Number<0>{}].mDesc.GetLengths()[1];
|
||||
const int N = arg.bs_k_n_[Number<0>{}].mDesc.GetLengths()[1];
|
||||
|
||||
Tensor<AComputeType> a_m_k({M, K});
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
// result
|
||||
auto data_refs1 = ck::tie(a_m_k(m, k));
|
||||
// inputs
|
||||
auto data_refs2 = generate_tie(
|
||||
[&](auto i) -> auto& { return arg.as_m_k_[Number<i>{}](m, k); },
|
||||
Number<NumATensor>{});
|
||||
auto data_refs = concat_tuple_of_reference(data_refs1, data_refs2);
|
||||
unpack(arg.a_element_op_, data_refs);
|
||||
}
|
||||
}
|
||||
|
||||
Tensor<BComputeType> b_k_n({K, N});
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
// result
|
||||
auto data_refs1 = ck::tie(b_k_n(k, n));
|
||||
// inputs
|
||||
auto data_refs2 = generate_tie(
|
||||
[&](auto i) -> auto& { return arg.bs_k_n_[Number<i>{}](k, n); },
|
||||
Number<NumBTensor>{});
|
||||
auto data_refs = concat_tuple_of_reference(data_refs1, data_refs2);
|
||||
unpack(arg.b_element_op_, data_refs);
|
||||
}
|
||||
}
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
Tensor<AccDataType> c_m_n({M, N});
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<AComputeType,
|
||||
BComputeType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
// compulsory
|
||||
auto data_refs1 = ck::tie(arg.e_m_n_(m, n), c_m_n(m, n));
|
||||
// optional (if multiple Ds)
|
||||
auto data_refs2 = generate_tie(
|
||||
[&](auto i) -> auto& { return arg.ds_m_n_[Number<i>{}](m, n); },
|
||||
Number<NumDTensor>{});
|
||||
auto data_refs = concat_tuple_of_reference(data_refs1, data_refs2);
|
||||
unpack(arg.cde_element_op_, data_refs);
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const device::BaseArgument* p_arg,
|
||||
const StreamConfig& /* stream_config */ = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
|
||||
|
||||
static auto MakeArgument(const AsTensorTuple& as_m_k,
|
||||
const BsTensorTuple& bs_k_n,
|
||||
const DsTensorTuple& ds_m_n,
|
||||
Tensor<EDataType>& e_m_n,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
return Argument{as_m_k, bs_k_n, ds_m_n, e_m_n, a_element_op, b_element_op, cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "ReferenceGemmMultiABD"
|
||||
<< std::endl;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace host
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -10,6 +10,7 @@
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -20,7 +21,6 @@ using Multiply = ck::tensor_operation::element_wise::Multiply;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
// RRR
|
||||
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
|
||||
@@ -179,167 +179,6 @@ void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instan
|
||||
PassThrough,
|
||||
Multiply,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
AddFastGelu>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
Add>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
FastGelu>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
// RCR
|
||||
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
|
||||
ck::Tuple<Col, Col>,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
AddFastGelu>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
|
||||
ck::Tuple<Col, Col>,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
Add>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
|
||||
ck::Tuple<Col, Col>,
|
||||
ck::Tuple<>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
FastGelu>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
|
||||
ck::Tuple<Col, Col>,
|
||||
ck::Tuple<>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
// CRR
|
||||
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Col>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
AddFastGelu>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Col>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
Add>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Col>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
FastGelu>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Col>,
|
||||
ck::Tuple<Row, Row>,
|
||||
ck::Tuple<>,
|
||||
Row,
|
||||
ck::Tuple<BF16>,
|
||||
ck::Tuple<I8, BF16>,
|
||||
ck::Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
PassThrough>>>& instances);
|
||||
#endif // CK_USE
|
||||
|
||||
// GEMM + Add + Gelu
|
||||
template <typename AsLayout,
|
||||
@@ -379,7 +218,6 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
|
||||
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
|
||||
is_same_v<DsDataType, ck::Tuple<BF16>> && is_same_v<EDataType, BF16>)
|
||||
@@ -408,38 +246,6 @@ struct DeviceOperationInstanceFactory<
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
|
||||
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
|
||||
is_same_v<DsDataType, ck::Tuple<BF16>> && is_same_v<EDataType, BF16>)
|
||||
{
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Col>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Col, Col>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_USE_WMMA
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
@@ -483,7 +289,6 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
|
||||
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
|
||||
is_same_v<DsDataType, ck::Tuple<BF16>> && is_same_v<EDataType, BF16>)
|
||||
@@ -512,38 +317,6 @@ struct DeviceOperationInstanceFactory<
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
|
||||
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
|
||||
is_same_v<DsDataType, ck::Tuple<BF16>> && is_same_v<EDataType, BF16>)
|
||||
{
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Col>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Col, Col>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_USE_WMMA
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
@@ -587,7 +360,6 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
|
||||
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
|
||||
is_same_v<DsDataType, ck::Tuple<>> && is_same_v<EDataType, BF16>)
|
||||
@@ -616,38 +388,6 @@ struct DeviceOperationInstanceFactory<
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
|
||||
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
|
||||
is_same_v<DsDataType, ck::Tuple<>> && is_same_v<EDataType, BF16>)
|
||||
{
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Col>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Col, Col>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_USE_WMMA
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
@@ -691,7 +431,6 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
|
||||
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
|
||||
is_same_v<DsDataType, ck::Tuple<>> && is_same_v<EDataType, BF16>)
|
||||
@@ -720,38 +459,6 @@ struct DeviceOperationInstanceFactory<
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
|
||||
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
|
||||
is_same_v<DsDataType, ck::Tuple<>> && is_same_v<EDataType, BF16>)
|
||||
{
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Col>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
|
||||
is_same_v<BsLayout, ck::Tuple<Col, Col>> &&
|
||||
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_USE_WMMA
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
|
||||
@@ -1,17 +1,13 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# ONLY XDL_AND_WMMA_KERNELS
|
||||
# ONLY XDL_KERNELS
|
||||
set(GROUPED_GEMM_FIXED_NK_MULTI_ABD_INSTANCES)
|
||||
|
||||
list(APPEND GROUPED_GEMM_FIXED_NK_MULTI_ABD_INSTANCES
|
||||
device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp
|
||||
device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp
|
||||
device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp
|
||||
|
||||
device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp
|
||||
device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp
|
||||
device_grouped_gemm_wmma_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp
|
||||
)
|
||||
|
||||
add_instance_library(device_grouped_gemm_fixed_nk_multi_abd_instance ${GROUPED_GEMM_FIXED_NK_MULTI_ABD_INSTANCES})
|
||||
|
||||
@@ -1,144 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <index_t... Is>
|
||||
using S = Sequence<Is...>;
|
||||
|
||||
using BF16 = bhalf_t;
|
||||
using I8 = int8_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = tensor_layout::gemm::RowMajor;
|
||||
using Col = tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using Multiply = element_wise::Multiply;
|
||||
using PassThrough = element_wise::PassThrough;
|
||||
using AddFastGelu = element_wise::AddFastGelu;
|
||||
using Add = element_wise::Add;
|
||||
using FastGelu = element_wise::FastGelu;
|
||||
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
template <typename DsLayout,
|
||||
typename DsDataType,
|
||||
typename CDEElementOp,
|
||||
GemmSpecialization GemmSpec>
|
||||
using device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#######################################| AsLayout| BsLayout| DsLayout| ELayout| AsData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//#######################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector|
|
||||
//#######################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | _NWaveNPerXdl|
|
||||
//#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple<Col>, Tuple<Row, Row>, DsLayout, Row, Tuple<BF16>, Tuple<I8, BF16>, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple<Col>, Tuple<Row, Row>, DsLayout, Row, Tuple<BF16>, Tuple<I8, BF16>, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple<Col>, Tuple<Row, Row>, DsLayout, Row, Tuple<BF16>, Tuple<I8, BF16>, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<Tuple<Col>,
|
||||
Tuple<Row, Row>,
|
||||
Tuple<Row>,
|
||||
Row,
|
||||
Tuple<BF16>,
|
||||
Tuple<I8, BF16>,
|
||||
Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
AddFastGelu>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances<
|
||||
Tuple<Row>,
|
||||
Tuple<BF16>,
|
||||
AddFastGelu,
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<Tuple<Col>,
|
||||
Tuple<Row, Row>,
|
||||
Tuple<Row>,
|
||||
Row,
|
||||
Tuple<BF16>,
|
||||
Tuple<I8, BF16>,
|
||||
Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
Add>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances<
|
||||
Tuple<Row>,
|
||||
Tuple<BF16>,
|
||||
Add,
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<Tuple<Col>,
|
||||
Tuple<Row, Row>,
|
||||
Tuple<>,
|
||||
Row,
|
||||
Tuple<BF16>,
|
||||
Tuple<I8, BF16>,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances<
|
||||
Tuple<>,
|
||||
Tuple<>,
|
||||
PassThrough,
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<Tuple<Col>,
|
||||
Tuple<Row, Row>,
|
||||
Tuple<>,
|
||||
Row,
|
||||
Tuple<BF16>,
|
||||
Tuple<I8, BF16>,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
FastGelu>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances<
|
||||
Tuple<>,
|
||||
Tuple<>,
|
||||
FastGelu,
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,144 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <index_t... Is>
|
||||
using S = Sequence<Is...>;
|
||||
|
||||
using BF16 = bhalf_t;
|
||||
using I8 = int8_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = tensor_layout::gemm::RowMajor;
|
||||
using Col = tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using Multiply = element_wise::Multiply;
|
||||
using PassThrough = element_wise::PassThrough;
|
||||
using AddFastGelu = element_wise::AddFastGelu;
|
||||
using Add = element_wise::Add;
|
||||
using FastGelu = element_wise::FastGelu;
|
||||
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
template <typename DsLayout,
|
||||
typename DsDataType,
|
||||
typename CDEElementOp,
|
||||
GemmSpecialization GemmSpec>
|
||||
using device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#######################################| AsLayout| BsLayout| DsLayout| ELayout| AsData| BsData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//#######################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector|
|
||||
//#######################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | |
|
||||
//#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple<Row>, Tuple<Row, Row>, DsLayout, Row, Tuple<BF16>, Tuple<I8, BF16>, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple<Row>, Tuple<Row, Row>, DsLayout, Row, Tuple<BF16>, Tuple<I8, BF16>, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple<Row>, Tuple<Row, Row>, DsLayout, Row, Tuple<BF16>, Tuple<I8, BF16>, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<Tuple<Row>,
|
||||
Tuple<Row, Row>,
|
||||
Tuple<Row>,
|
||||
Row,
|
||||
Tuple<BF16>,
|
||||
Tuple<I8, BF16>,
|
||||
Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
AddFastGelu>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances<
|
||||
Tuple<Row>,
|
||||
Tuple<BF16>,
|
||||
AddFastGelu,
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<Tuple<Row>,
|
||||
Tuple<Row, Row>,
|
||||
Tuple<Row>,
|
||||
Row,
|
||||
Tuple<BF16>,
|
||||
Tuple<I8, BF16>,
|
||||
Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
Add>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances<
|
||||
Tuple<Row>,
|
||||
Tuple<BF16>,
|
||||
Add,
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<Tuple<Row>,
|
||||
Tuple<Row, Row>,
|
||||
Tuple<>,
|
||||
Row,
|
||||
Tuple<BF16>,
|
||||
Tuple<I8, BF16>,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances<
|
||||
Tuple<>,
|
||||
Tuple<>,
|
||||
PassThrough,
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<Tuple<Row>,
|
||||
Tuple<Row, Row>,
|
||||
Tuple<>,
|
||||
Row,
|
||||
Tuple<BF16>,
|
||||
Tuple<I8, BF16>,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
FastGelu>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances<
|
||||
Tuple<>,
|
||||
Tuple<>,
|
||||
FastGelu,
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,144 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <index_t... Is>
|
||||
using S = Sequence<Is...>;
|
||||
|
||||
using BF16 = bhalf_t;
|
||||
using I8 = int8_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = tensor_layout::gemm::RowMajor;
|
||||
using Col = tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using Multiply = element_wise::Multiply;
|
||||
using PassThrough = element_wise::PassThrough;
|
||||
using AddFastGelu = element_wise::AddFastGelu;
|
||||
using Add = element_wise::Add;
|
||||
using FastGelu = element_wise::FastGelu;
|
||||
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
template <typename DsLayout,
|
||||
typename DsDataType,
|
||||
typename CDEElementOp,
|
||||
GemmSpecialization GemmSpec>
|
||||
using device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//######################################| AsLayout| BsLayout| DsLayout| ELayout| AsData| BsData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector|
|
||||
//######################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | |
|
||||
//######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple<Row>, Tuple<Col, Col>, DsLayout, Row, Tuple<BF16>, Tuple<I8, BF16>, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>,
|
||||
DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple<Row>, Tuple<Col, Col>, DsLayout, Row, Tuple<BF16>, Tuple<I8, BF16>, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8>,
|
||||
DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK< Tuple<Row>, Tuple<Col, Col>, DsLayout, Row, Tuple<BF16>, Tuple<I8, BF16>, F32, BF16, DsDataType, BF16, PassThrough, Multiply, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<Tuple<Row>,
|
||||
Tuple<Col, Col>,
|
||||
Tuple<Row>,
|
||||
Row,
|
||||
Tuple<BF16>,
|
||||
Tuple<I8, BF16>,
|
||||
Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
AddFastGelu>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances<
|
||||
Tuple<Row>,
|
||||
Tuple<BF16>,
|
||||
AddFastGelu,
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<Tuple<Row>,
|
||||
Tuple<Col, Col>,
|
||||
Tuple<Row>,
|
||||
Row,
|
||||
Tuple<BF16>,
|
||||
Tuple<I8, BF16>,
|
||||
Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
Add>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances<
|
||||
Tuple<Row>,
|
||||
Tuple<BF16>,
|
||||
Add,
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<Tuple<Row>,
|
||||
Tuple<Col, Col>,
|
||||
Tuple<>,
|
||||
Row,
|
||||
Tuple<BF16>,
|
||||
Tuple<I8, BF16>,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances<
|
||||
Tuple<>,
|
||||
Tuple<>,
|
||||
PassThrough,
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
void add_device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<Tuple<Row>,
|
||||
Tuple<Col, Col>,
|
||||
Tuple<>,
|
||||
Row,
|
||||
Tuple<BF16>,
|
||||
Tuple<I8, BF16>,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
FastGelu>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_wmma_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances<
|
||||
Tuple<>,
|
||||
Tuple<>,
|
||||
FastGelu,
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -61,8 +61,6 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecial
|
||||
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
// NOTE: After adding unit tests for DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK it tuned out that
|
||||
// portion of the instances are failing. As a workaround these have been commented out.
|
||||
template <typename DsLayout,
|
||||
typename DsDataType,
|
||||
typename CDEElementOp,
|
||||
@@ -74,14 +72,14 @@ using device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances
|
||||
//######################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
// DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
// DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
// DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
// DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
@@ -61,8 +61,6 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecial
|
||||
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
// NOTE: After adding unit tests for DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK it tuned out that
|
||||
// portion of the instances are failing. As a workaround these have been commented out.
|
||||
template <typename DsLayout,
|
||||
typename DsDataType,
|
||||
typename CDEElementOp,
|
||||
@@ -74,14 +72,14 @@ using device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances
|
||||
//######################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
// DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
// DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
// DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
// DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
@@ -61,8 +61,6 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecial
|
||||
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
// NOTE: After adding unit tests for DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK it tuned out that
|
||||
// portion of the instances are failing. As a workaround these have been commented out.
|
||||
template <typename DsLayout,
|
||||
typename DsDataType,
|
||||
typename CDEElementOp,
|
||||
|
||||
@@ -17,11 +17,22 @@
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm_multi_abd.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
// this function is also defined in CK but because of the way we use it in
|
||||
// profile_gemm_multi_impl, it requires the arguments to not be const
|
||||
template <typename... X, typename... Y>
|
||||
auto concat_tuple_of_refs(ck::Tuple<X&...>& tx, ck::Tuple<Y&...>& ty)
|
||||
{
|
||||
return ck::unpack2(
|
||||
[&](auto&&... zs) { return ck::Tuple<decltype(zs)...>{ck::forward<decltype(zs)>(zs)...}; },
|
||||
tx,
|
||||
ty);
|
||||
}
|
||||
|
||||
template <typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename AccDataType,
|
||||
@@ -169,35 +180,80 @@ bool profile_gemm_multi_abd_impl(int do_verification,
|
||||
// run reference
|
||||
if(do_verification)
|
||||
{
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
Tensor<AccDataType> c_m_n({M, N});
|
||||
|
||||
using AComputeType =
|
||||
typename std::conditional<(NumATensor > 1),
|
||||
EDataType,
|
||||
remove_cvref_t<tuple_element_t<0, AsDataType>>>::type;
|
||||
|
||||
Tensor<AComputeType> a_m_k({M, K});
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
// result
|
||||
auto data_refs1 = ck::tie(a_m_k(m, k));
|
||||
// inputs
|
||||
auto data_refs2 =
|
||||
generate_tie([&](auto i) -> auto& { return as_m_k(Number<i>{})(m, k); },
|
||||
Number<NumATensor>{});
|
||||
auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2);
|
||||
unpack(a_element_op, data_refs);
|
||||
}
|
||||
}
|
||||
|
||||
using BComputeType =
|
||||
typename std::conditional<(NumBTensor > 1),
|
||||
EDataType,
|
||||
remove_cvref_t<tuple_element_t<0, BsDataType>>>::type;
|
||||
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceGemmMultiABD<decltype(as_m_k),
|
||||
decltype(bs_k_n),
|
||||
decltype(ds_m_n),
|
||||
EDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp,
|
||||
AComputeType,
|
||||
BComputeType>;
|
||||
Tensor<BComputeType> b_k_n({K, N});
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
// result
|
||||
auto data_refs1 = ck::tie(b_k_n(k, n));
|
||||
// inputs
|
||||
auto data_refs2 =
|
||||
generate_tie([&](auto i) -> auto& { return bs_k_n(Number<i>{})(k, n); },
|
||||
Number<NumBTensor>{});
|
||||
auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2);
|
||||
unpack(b_element_op, data_refs);
|
||||
}
|
||||
}
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<AComputeType,
|
||||
BComputeType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
as_m_k, bs_k_n, ds_m_n, e_m_n_host_result, a_element_op, b_element_op, cde_element_op);
|
||||
auto ref_argument =
|
||||
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
// compulsory
|
||||
auto data_refs1 = ck::tie(e_m_n_host_result(m, n), c_m_n(m, n));
|
||||
// optional (if multiple Ds)
|
||||
auto data_refs2 =
|
||||
generate_tie([&](auto i) -> auto& { return ds_m_n(Number<i>{})(m, n); },
|
||||
Number<NumDTensor>{});
|
||||
auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2);
|
||||
unpack(cde_element_op, data_refs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::array<DeviceMem*, NumATensor> as_device_buf;
|
||||
|
||||
@@ -1,534 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iomanip>
|
||||
#include <array>
|
||||
#include <vector>
|
||||
#include <numeric>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/env.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm_multi_abd.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/convolution_parameter.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
template <typename T>
|
||||
auto reserveVector(std::size_t size)
|
||||
{
|
||||
std::vector<T> vec;
|
||||
vec.reserve(size);
|
||||
return vec;
|
||||
}
|
||||
|
||||
template <typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename AsLayout,
|
||||
typename BsLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename AElementOp = ck::tensor_operation::element_wise::PassThrough,
|
||||
typename BElementOp = ck::tensor_operation::element_wise::Multiply,
|
||||
typename CDEElementOp = ck::tensor_operation::element_wise::PassThrough>
|
||||
bool profile_grouped_gemm_multi_abd_fixed_nk_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
const std::vector<int>& Ms,
|
||||
const std::vector<int>& Ns,
|
||||
const std::vector<int>& Ks,
|
||||
const std::vector<int>& StrideAs,
|
||||
const std::vector<int>& StrideBs,
|
||||
const std::vector<int>& StrideDs,
|
||||
const std::vector<int>& StrideE,
|
||||
const std::vector<int>& kbatch_list = {1},
|
||||
int n_warmup = 1,
|
||||
int n_iter = 10)
|
||||
{
|
||||
bool pass = true;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
const std::size_t group_count = Ms.size();
|
||||
const int sum_of_m = std::accumulate(Ms.begin(), Ms.end(), 0);
|
||||
|
||||
static constexpr index_t NumATensor = AsDataType::Size();
|
||||
static constexpr index_t NumBTensor = BsDataType::Size();
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
if(group_count != Ns.size() || group_count != Ks.size() || group_count != StrideAs.size() ||
|
||||
group_count != StrideBs.size() || (NumDTensor > 0 && group_count != StrideDs.size()))
|
||||
{
|
||||
throw std::runtime_error("wrong! inconsistent M/N/Ks, StrideAs/Bs/Ds/E size\n");
|
||||
}
|
||||
|
||||
auto generateInputTupleA = [&](std::size_t g) {
|
||||
if constexpr(NumATensor == 0)
|
||||
{
|
||||
static_assert("Gemm problem should have at least 1 A tensor.");
|
||||
}
|
||||
else
|
||||
{
|
||||
using ALayout = remove_cvref_t<tuple_element_t<Number<0>{}, AsLayout>>;
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using ADataType = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
|
||||
return Tensor<ADataType>(
|
||||
f_host_tensor_descriptor(Ms[g], Ks[g], StrideAs[g], ALayout{}));
|
||||
},
|
||||
Number<NumATensor>{});
|
||||
}
|
||||
};
|
||||
auto generateInputTupleB = [&](std::size_t g) {
|
||||
if constexpr(NumBTensor == 0)
|
||||
{
|
||||
static_assert("Gemm problem should have at least 1 B tensor.");
|
||||
}
|
||||
else
|
||||
{
|
||||
using BLayout = remove_cvref_t<tuple_element_t<Number<0>{}, BsLayout>>;
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using BDataType = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
|
||||
return Tensor<BDataType>(
|
||||
f_host_tensor_descriptor(Ks[g], Ns[g], StrideBs[g], BLayout{}));
|
||||
},
|
||||
Number<NumBTensor>{});
|
||||
}
|
||||
};
|
||||
auto generateInputTupleD = [&](std::size_t g) {
|
||||
if constexpr(NumDTensor == 0)
|
||||
{
|
||||
return ck::Tuple<>();
|
||||
}
|
||||
else
|
||||
{
|
||||
using DLayout = remove_cvref_t<tuple_element_t<Number<0>{}, DsLayout>>;
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
return Tensor<DDataType>(
|
||||
f_host_tensor_descriptor(Ms[g], Ns[g], StrideDs[g], DLayout{}));
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
};
|
||||
|
||||
using AsTensorTuple = decltype(generateInputTupleA(0));
|
||||
using BsTensorTuple = decltype(generateInputTupleB(0));
|
||||
using DsTensorTuple = decltype(generateInputTupleD(0));
|
||||
|
||||
auto g_as_m_k = reserveVector<AsTensorTuple>(group_count);
|
||||
auto g_bs_k_n = reserveVector<BsTensorTuple>(group_count);
|
||||
auto g_ds_m_n = reserveVector<DsTensorTuple>(group_count);
|
||||
auto g_e_m_n_host_results = reserveVector<Tensor<EDataType>>(group_count);
|
||||
auto g_e_m_n_device_results = reserveVector<Tensor<EDataType>>(group_count);
|
||||
|
||||
for(std::size_t g = 0; g < group_count; g++)
|
||||
{
|
||||
auto& as_m_k = g_as_m_k.emplace_back(generateInputTupleA(g));
|
||||
auto& bs_k_n = g_bs_k_n.emplace_back(generateInputTupleB(g));
|
||||
auto& ds_m_n = g_ds_m_n.emplace_back(generateInputTupleD(g));
|
||||
|
||||
g_e_m_n_host_results.push_back(
|
||||
Tensor<EDataType>(f_host_tensor_descriptor(Ms[g], Ns[g], StrideE[g], ELayout{})));
|
||||
g_e_m_n_device_results.push_back(
|
||||
Tensor<EDataType>(f_host_tensor_descriptor(Ms[g], Ns[g], StrideE[g], ELayout{})));
|
||||
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "group: " << g << std::endl;
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
std::cout << "a" << i.value << "_m_k: " << as_m_k(i).mDesc << std::endl;
|
||||
});
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
std::cout << "b" << i.value << "_k_n: " << bs_k_n(i).mDesc << std::endl;
|
||||
});
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
std::cout << "d" << i.value << "_m_n: " << ds_m_n(i).mDesc << std::endl;
|
||||
});
|
||||
std::cout << "e_m_n: " << g_e_m_n_device_results[g].mDesc << std::endl;
|
||||
}
|
||||
|
||||
std::size_t num_thread = 1;
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
using ADataType = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
|
||||
as_m_k(i).GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}, num_thread);
|
||||
});
|
||||
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
using BDataType = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
|
||||
bs_k_n(i).GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread);
|
||||
});
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
ds_m_n(i).GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5}, num_thread);
|
||||
});
|
||||
|
||||
break;
|
||||
default:
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
using ADataType = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
|
||||
as_m_k(i).GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}, num_thread);
|
||||
});
|
||||
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
using BDataType = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
|
||||
bs_k_n(i).GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
|
||||
});
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
ds_m_n(i).GenerateTensorValue(GeneratorTensor_3<DDataType>{0.0, 1.0}, num_thread);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const auto a_element_op = AElementOp{};
|
||||
const auto b_element_op = BElementOp{};
|
||||
const auto cde_element_op = CDEElementOp{};
|
||||
|
||||
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
|
||||
std::vector<std::array<DeviceMemPtr, NumATensor>> g_as_device_buf(group_count);
|
||||
std::vector<std::array<DeviceMemPtr, NumBTensor>> g_bs_device_buf(group_count);
|
||||
std::vector<std::array<DeviceMemPtr, NumDTensor>> g_ds_device_buf(group_count);
|
||||
std::vector<DeviceMemPtr> g_e_device_buf(group_count);
|
||||
|
||||
std::vector<std::array<const void*, NumATensor>> g_as_device_view(group_count);
|
||||
std::vector<std::array<const void*, NumBTensor>> g_bs_device_view(group_count);
|
||||
std::vector<std::array<const void*, NumDTensor>> g_ds_device_view(group_count);
|
||||
std::vector<void*> g_e_device_view(group_count);
|
||||
|
||||
auto g_gemm_descs = reserveVector<tensor_operation::device::GemmMultiABDDesc>(group_count);
|
||||
|
||||
auto grouped_gemm_kernel_args_host =
|
||||
reserveVector<tensor_operation::device::
|
||||
GroupedGemmMultiABDKernelArgument<NumATensor, NumBTensor, NumDTensor>>(
|
||||
group_count);
|
||||
|
||||
for(std::size_t g = 0; g < group_count; g++)
|
||||
{
|
||||
std::array<ck::index_t, NumATensor> as_stride;
|
||||
std::array<ck::index_t, NumBTensor> bs_stride;
|
||||
std::array<ck::index_t, NumDTensor> ds_stride;
|
||||
|
||||
auto& as_m_k = g_as_m_k[g];
|
||||
auto& as_device_buf = g_as_device_buf[g];
|
||||
auto& as_device_view = g_as_device_view[g];
|
||||
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
using ADataType = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
|
||||
as_device_buf[i] = std::make_unique<DeviceMem>(sizeof(ADataType) * Ms[g] * Ks[g]);
|
||||
as_device_buf[i]->ToDevice(as_m_k[i].mData.data());
|
||||
as_device_view[i] = as_device_buf[i]->GetDeviceBuffer();
|
||||
as_stride[i] = StrideAs[g];
|
||||
});
|
||||
|
||||
auto& bs_k_n = g_bs_k_n[g];
|
||||
auto& bs_device_buf = g_bs_device_buf[g];
|
||||
auto& bs_device_view = g_bs_device_view[g];
|
||||
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
using BDataType = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
|
||||
bs_device_buf[i] = std::make_unique<DeviceMem>(sizeof(BDataType) * Ks[g] * Ns[g]);
|
||||
bs_device_buf[i]->ToDevice(bs_k_n[i].mData.data());
|
||||
bs_device_view[i] = bs_device_buf[i]->GetDeviceBuffer();
|
||||
bs_stride[i] = StrideBs[g];
|
||||
});
|
||||
|
||||
auto& ds_m_n = g_ds_m_n[g];
|
||||
auto& ds_device_buf = g_ds_device_buf[g];
|
||||
auto& ds_device_view = g_ds_device_view[g];
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
ds_device_buf[i] = std::make_unique<DeviceMem>(sizeof(DDataType) * Ms[g] * Ns[g]);
|
||||
ds_device_buf[i]->ToDevice(ds_m_n[i].mData.data());
|
||||
ds_device_view[i] = ds_device_buf[i]->GetDeviceBuffer();
|
||||
ds_stride[i] = StrideDs[g];
|
||||
});
|
||||
|
||||
g_e_device_buf[g] = std::make_unique<DeviceMem>(sizeof(EDataType) * Ms[g] * Ns[g]);
|
||||
g_e_device_view[g] = g_e_device_buf[g]->GetDeviceBuffer();
|
||||
|
||||
g_gemm_descs.push_back(tensor_operation::device::GemmMultiABDDesc{
|
||||
sum_of_m,
|
||||
Ns[g],
|
||||
Ks[g],
|
||||
std::vector<ck::index_t>(as_stride.begin(), as_stride.end()),
|
||||
std::vector<ck::index_t>(bs_stride.begin(), bs_stride.end()),
|
||||
std::vector<ck::index_t>(ds_stride.begin(), ds_stride.end()),
|
||||
StrideE[g]});
|
||||
|
||||
tensor_operation::device::
|
||||
GroupedGemmMultiABDKernelArgument<NumATensor, NumBTensor, NumDTensor>
|
||||
kernelArg{as_device_view,
|
||||
bs_device_view,
|
||||
ds_device_view,
|
||||
g_e_device_view[g],
|
||||
Ms[g],
|
||||
Ns[g],
|
||||
Ks[g],
|
||||
as_stride,
|
||||
bs_stride,
|
||||
ds_stride,
|
||||
StrideE[g]};
|
||||
|
||||
grouped_gemm_kernel_args_host.push_back(std::move(kernelArg));
|
||||
}
|
||||
|
||||
using DeviceOp = tensor_operation::device::DeviceGroupedGemmMultiABDFixedNK<AsLayout,
|
||||
BsLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
const auto op_ptrs = tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
if(op_ptrs.size() <= 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! no device GEMM instance found");
|
||||
}
|
||||
|
||||
std::string best_gemm_name;
|
||||
float best_ave_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
float best_kbatch = 0;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
using AComputeType =
|
||||
typename std::conditional<(NumATensor > 1),
|
||||
EDataType,
|
||||
remove_cvref_t<tuple_element_t<0, AsDataType>>>::type;
|
||||
|
||||
using BComputeType =
|
||||
typename std::conditional<(NumBTensor > 1),
|
||||
EDataType,
|
||||
remove_cvref_t<tuple_element_t<0, BsDataType>>>::type;
|
||||
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceGemmMultiABD<AsTensorTuple,
|
||||
BsTensorTuple,
|
||||
DsTensorTuple,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp,
|
||||
AComputeType,
|
||||
BComputeType>;
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
for(std::size_t i = 0; i < group_count; i++)
|
||||
{
|
||||
auto ref_argument = ref_gemm.MakeArgument(g_as_m_k[i],
|
||||
g_bs_k_n[i],
|
||||
g_ds_m_n[i],
|
||||
g_e_m_n_host_results[i],
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
}
|
||||
}
|
||||
|
||||
// profile device GEMM instances
|
||||
for(auto& gemm_ptr : op_ptrs)
|
||||
{
|
||||
auto argument_ptr = gemm_ptr->MakeArgumentPointer(
|
||||
g_as_device_view, g_bs_device_view, g_ds_device_view, g_e_device_view, g_gemm_descs);
|
||||
|
||||
if(!gemm_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Gemm incompatible with runtime set parameters. Skipping..."
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
DeviceMem gemm_workspace_dev(gemm_ptr->GetWorkSpaceSize(argument_ptr.get()));
|
||||
gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_workspace_dev.GetDeviceBuffer());
|
||||
|
||||
DeviceMem grouped_gemm_kernel_args_dev(
|
||||
gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get()));
|
||||
hipGetErrorString(hipMemcpy(grouped_gemm_kernel_args_dev.GetDeviceBuffer(),
|
||||
grouped_gemm_kernel_args_host.data(),
|
||||
gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get()),
|
||||
hipMemcpyHostToDevice));
|
||||
|
||||
gemm_ptr->SetDeviceKernelArgs(argument_ptr.get(),
|
||||
grouped_gemm_kernel_args_dev.GetDeviceBuffer());
|
||||
gemm_ptr->SetElementwiseOps(argument_ptr.get(), a_element_op, b_element_op, cde_element_op);
|
||||
|
||||
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
|
||||
|
||||
std::string gemm_name = gemm_ptr->GetTypeString();
|
||||
|
||||
for(const auto kbatch_curr : kbatch_list)
|
||||
{
|
||||
gemm_ptr->SetKBatch(argument_ptr.get(), kbatch_curr);
|
||||
|
||||
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
for(std::size_t g = 0; g < group_count; g++)
|
||||
{
|
||||
g_e_device_buf[g]->SetZero();
|
||||
}
|
||||
|
||||
float ave_time = invoker_ptr->Run(
|
||||
argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter});
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
bool instance_pass = true;
|
||||
for(std::size_t g = 0; g < group_count; g++)
|
||||
{
|
||||
g_e_device_buf[g]->FromDevice(
|
||||
g_e_m_n_device_results[g].mData.data(),
|
||||
g_e_m_n_device_results[g].mDesc.GetElementSize() * sizeof(EDataType));
|
||||
|
||||
instance_pass =
|
||||
instance_pass && ck::utils::check_err(g_e_m_n_device_results[g],
|
||||
g_e_m_n_host_results[g]);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "a[" << g << "]: ", g_as_m_k[g](i).mData, ",")
|
||||
<< std::endl;
|
||||
});
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "b[" << g << "]: ", g_bs_k_n[g](i).mData, ",")
|
||||
<< std::endl;
|
||||
});
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "d[" << g << "]: ", g_ds_m_n[g](i).mData, ",")
|
||||
<< std::endl;
|
||||
});
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "e_device: ", g_e_m_n_device_results[g].mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "e_host : ", g_e_m_n_host_results[g].mData, ",")
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Instance: " << gemm_name << " verification "
|
||||
<< (instance_pass ? "SUCCEED" : "FAILED") << std::endl;
|
||||
|
||||
pass = pass && instance_pass;
|
||||
}
|
||||
|
||||
if(time_kernel)
|
||||
{
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
for(std::size_t g = 0; g < group_count; g++)
|
||||
{
|
||||
flop += std::size_t(2) * Ms[g] * Ns[g] * Ks[g];
|
||||
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
using ADataType = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
|
||||
num_btype += sizeof(ADataType) * Ms[g] * Ks[g];
|
||||
});
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
using BDataType = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
|
||||
num_btype += sizeof(BDataType) * Ks[g] * Ns[g];
|
||||
});
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
num_btype += sizeof(DDataType) * Ms[g] * Ns[g];
|
||||
});
|
||||
}
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops
|
||||
<< " TFlops, " << gb_per_sec << " GB/s, " << gemm_name << ", KBatch "
|
||||
<< kbatch_curr << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
best_gemm_name = gemm_name;
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
best_kbatch = kbatch_curr;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(time_kernel)
|
||||
{
|
||||
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
|
||||
<< best_gb_per_sec << " GB/s, " << best_gemm_name << ", KBatch = " << best_kbatch
|
||||
<< std::endl;
|
||||
}
|
||||
return pass;
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace ck
|
||||
@@ -18,12 +18,6 @@ if (CK_USE_XDL OR CK_USE_WMMA)
|
||||
target_link_libraries(test_grouped_gemm_fastgelu PRIVATE utility device_grouped_gemm_fastgelu_instance)
|
||||
add_dependencies(test_grouped_gemm test_grouped_gemm_fastgelu)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_grouped_gemm_multi_abd_fixed_nk test_grouped_gemm_multi_abd_fixed_nk.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_grouped_gemm_multi_abd_fixed_nk PRIVATE utility device_grouped_gemm_fixed_nk_multi_abd_instance)
|
||||
add_dependencies(test_grouped_gemm test_grouped_gemm_multi_abd_fixed_nk)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp)
|
||||
|
||||
@@ -1,256 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/type.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
static ck::index_t param_mask = 0xffffff;
|
||||
static ck::index_t instance_index = -1;
|
||||
|
||||
using FP32 = float;
|
||||
using FP16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using I8 = int8_t;
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
|
||||
using Add = ck::tensor_operation::element_wise::Add;
|
||||
using Multiply = ck::tensor_operation::element_wise::Multiply;
|
||||
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
std::tuple<ck::Tuple<BF16>, ck::Tuple<I8, BF16>, ck::Tuple<BF16>, BF16, ck::Tuple<Row>, ck::Tuple<Col, Col>, ck::Tuple<Row>, Row, AddFastGelu>,
|
||||
std::tuple<ck::Tuple<BF16>, ck::Tuple<I8, BF16>, ck::Tuple<BF16>, BF16, ck::Tuple<Col>, ck::Tuple<Row, Row>, ck::Tuple<Row>, Row, AddFastGelu>,
|
||||
std::tuple<ck::Tuple<BF16>, ck::Tuple<I8, BF16>, ck::Tuple<BF16>, BF16, ck::Tuple<Row>, ck::Tuple<Row, Row>, ck::Tuple<Row>, Row, AddFastGelu>,
|
||||
std::tuple<ck::Tuple<BF16>, ck::Tuple<I8, BF16>, ck::Tuple<BF16>, BF16, ck::Tuple<Row>, ck::Tuple<Col, Col>, ck::Tuple<Row>, Row, Add>,
|
||||
std::tuple<ck::Tuple<BF16>, ck::Tuple<I8, BF16>, ck::Tuple<BF16>, BF16, ck::Tuple<Col>, ck::Tuple<Row, Row>, ck::Tuple<Row>, Row, Add>,
|
||||
std::tuple<ck::Tuple<BF16>, ck::Tuple<I8, BF16>, ck::Tuple<BF16>, BF16, ck::Tuple<Row>, ck::Tuple<Row, Row>, ck::Tuple<Row>, Row, Add>,
|
||||
std::tuple<ck::Tuple<BF16>, ck::Tuple<I8, BF16>, ck::Tuple<>, BF16, ck::Tuple<Row>, ck::Tuple<Col, Col>, ck::Tuple<>, Row, PassThrough>,
|
||||
std::tuple<ck::Tuple<BF16>, ck::Tuple<I8, BF16>, ck::Tuple<>, BF16, ck::Tuple<Col>, ck::Tuple<Row, Row>, ck::Tuple<>, Row, PassThrough>,
|
||||
std::tuple<ck::Tuple<BF16>, ck::Tuple<I8, BF16>, ck::Tuple<>, BF16, ck::Tuple<Row>, ck::Tuple<Row, Row>, ck::Tuple<>, Row, PassThrough>,
|
||||
std::tuple<ck::Tuple<BF16>, ck::Tuple<I8, BF16>, ck::Tuple<>, BF16, ck::Tuple<Row>, ck::Tuple<Col, Col>, ck::Tuple<>, Row, FastGelu>,
|
||||
std::tuple<ck::Tuple<BF16>, ck::Tuple<I8, BF16>, ck::Tuple<>, BF16, ck::Tuple<Col>, ck::Tuple<Row, Row>, ck::Tuple<>, Row, FastGelu>,
|
||||
std::tuple<ck::Tuple<BF16>, ck::Tuple<I8, BF16>, ck::Tuple<>, BF16, ck::Tuple<Row>, ck::Tuple<Row, Row>, ck::Tuple<>, Row, FastGelu>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedGemmMultiABDFixedNK : public testing::Test
|
||||
{
|
||||
protected:
|
||||
using AsDataType = std::tuple_element_t<0, Tuple>;
|
||||
using BsDataType = std::tuple_element_t<1, Tuple>;
|
||||
using DsDataType = std::tuple_element_t<2, Tuple>;
|
||||
using EDataType = std::tuple_element_t<3, Tuple>;
|
||||
using AccDataType = float;
|
||||
using AsLayout = std::tuple_element_t<4, Tuple>;
|
||||
using BsLayout = std::tuple_element_t<5, Tuple>;
|
||||
using DsLayout = std::tuple_element_t<6, Tuple>;
|
||||
using ELayout = std::tuple_element_t<7, Tuple>;
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = Multiply;
|
||||
using CDEElementOp = std::tuple_element_t<8, Tuple>;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
public:
|
||||
static constexpr bool verify_ = true;
|
||||
static constexpr int init_method_ = 1; // integer value initialization
|
||||
static constexpr bool log_ = false;
|
||||
static constexpr bool bench_ = false; // measure kernel performance
|
||||
static constexpr int n_warmup_ = 0;
|
||||
static constexpr int n_iter_ = 1;
|
||||
|
||||
std::vector<int> k_batches_ = {1};
|
||||
|
||||
private:
|
||||
template <typename Layout>
|
||||
void SetStrides(std::vector<int>& strides,
|
||||
const std::vector<int>& rows,
|
||||
const std::vector<int>& cols) const
|
||||
{
|
||||
if(std::is_same_v<Layout, Row>)
|
||||
{
|
||||
for(const auto c : cols)
|
||||
{
|
||||
strides.emplace_back(c);
|
||||
}
|
||||
}
|
||||
else if(std::is_same_v<Layout, Col>)
|
||||
{
|
||||
for(const auto r : rows)
|
||||
{
|
||||
strides.emplace_back(r);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Layouts>
|
||||
void SetTupleStrides(std::vector<int>& strides,
|
||||
const std::vector<int>& rows,
|
||||
const std::vector<int>& cols) const
|
||||
{
|
||||
if constexpr(Layouts::Size() > 0)
|
||||
{
|
||||
// As of now multi ABD implementation supports only tensors with matching layouts.
|
||||
using Layout = ck::remove_cvref_t<ck::tuple_element_t<ck::Number<0>{}, Layouts>>;
|
||||
SetStrides<Layout>(strides, rows, cols);
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
void Run(const std::vector<int>& Ms,
|
||||
const std::vector<int>& Ns,
|
||||
const std::vector<int>& Ks,
|
||||
const std::vector<int>& StrideAs = {},
|
||||
const std::vector<int>& StrideBs = {},
|
||||
const std::vector<int>& StrideDs = {},
|
||||
const std::vector<int>& StrideE = {})
|
||||
{
|
||||
std::vector<int> stride_as = StrideAs;
|
||||
std::vector<int> stride_bs = StrideBs;
|
||||
std::vector<int> stride_ds = StrideDs;
|
||||
std::vector<int> stride_e = StrideE;
|
||||
|
||||
if(stride_as.empty())
|
||||
{
|
||||
SetTupleStrides<AsLayout>(stride_as, Ms, Ks);
|
||||
}
|
||||
if(stride_bs.empty())
|
||||
{
|
||||
SetTupleStrides<BsLayout>(stride_bs, Ks, Ns);
|
||||
}
|
||||
if(stride_ds.empty())
|
||||
{
|
||||
SetTupleStrides<DsLayout>(stride_ds, Ms, Ns);
|
||||
}
|
||||
if(stride_e.empty())
|
||||
{
|
||||
SetStrides<ELayout>(stride_e, Ms, Ns);
|
||||
}
|
||||
|
||||
RunSingle(Ms, Ns, Ks, stride_as, stride_bs, stride_ds, stride_e);
|
||||
}
|
||||
|
||||
void RunSingle(const std::vector<int>& Ms,
|
||||
const std::vector<int>& Ns,
|
||||
const std::vector<int>& Ks,
|
||||
const std::vector<int>& StrideAs,
|
||||
const std::vector<int>& StrideBs,
|
||||
const std::vector<int>& StrideDs,
|
||||
const std::vector<int>& StrideE)
|
||||
{
|
||||
bool pass =
|
||||
ck::profiler::profile_grouped_gemm_multi_abd_fixed_nk_impl<AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
AsLayout,
|
||||
BsLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>(verify_,
|
||||
init_method_,
|
||||
log_,
|
||||
bench_,
|
||||
Ms,
|
||||
Ns,
|
||||
Ks,
|
||||
StrideAs,
|
||||
StrideBs,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
k_batches_,
|
||||
n_warmup_,
|
||||
n_iter_);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(TestGroupedGemmMultiABDFixedNK, KernelTypes);
|
||||
|
||||
TYPED_TEST(TestGroupedGemmMultiABDFixedNK, TinyCases)
|
||||
{
|
||||
const std::vector<int> Ms{3, 4};
|
||||
constexpr int N = 8;
|
||||
constexpr int K = 64;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
|
||||
this->Run(Ms, Ns, Ks);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedGemmMultiABDFixedNK, SmallCases)
|
||||
{
|
||||
const std::vector<int> Ms{3, 5, 16, 7, 8};
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 544;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
|
||||
this->Run(Ms, Ns, Ks);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedGemmMultiABDFixedNK, MidCases)
|
||||
{
|
||||
const std::vector<int> Ms{167, 183, 177, 153, 139, 204};
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 544;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
|
||||
this->Run(Ms, Ns, Ks);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedGemmMultiABDFixedNK, Regular)
|
||||
{
|
||||
const std::vector<int> Ms{64, 128, 256};
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 320;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
|
||||
this->Run(Ms, Ns, Ks);
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
if(argc == 1)
|
||||
{
|
||||
// Run with default arguments.
|
||||
}
|
||||
else if(argc == 3)
|
||||
{
|
||||
param_mask = strtol(argv[1], nullptr, 0);
|
||||
instance_index = atoi(argv[2]);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Usage of " << argv[0] << std::endl;
|
||||
std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl;
|
||||
}
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
Reference in New Issue
Block a user