mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Grouped Gemm with Fixed K and N with SplitK (#818)
* move all arguments into device
* add b2c_tile_map
* add examples
* add SetDeviceKernelArgs
* dedicated fixed_nk solution
* init client api
* add grouped_gemm_bias example
* add a instance
* add instances
* formatting
* fixed cmake
* Update EnableCompilerWarnings.cmake
* Update cmake-ck-dev.sh
* clean; fixed comments
* fixed comment
* add instances for fp32 output
* add instances for fp32 output
* add fp32 out client example
* fixed CI
* init commit for kbatch
* add splitk gridwise
* format
* fixed
* clean deviceop
* clean code
* finish splitk
* fixed instances
* change m_loops to tile_loops
* add setkbatch
* clean code
* add splitK+bias
* add instances
* opt mk_nk instances
* clean examples
* fixed CI
* remove zero
* finished non-zero
* clean
* clean code
* optimized global_barrier
* fixed ci
* fixed CI
* removed AddBias
* format
* fixed CI
* fixed CI
* move 20_grouped_gemm to 21_grouped_gemm
---------
Co-authored-by: Jing Zhang <jizha@amd.com>
[ROCm/composable_kernel commit: f5ec04f091]
This commit is contained in:
2
client_example/21_grouped_gemm_bias/CMakeLists.txt
Normal file
2
client_example/21_grouped_gemm_bias/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
add_executable(client_grouped_gemm_fixed_nk_bias_fp16 grouped_gemm_fixed_nk_bias_fp16.cpp)
|
||||
target_link_libraries(client_grouped_gemm_fixed_nk_bias_fp16 PRIVATE composable_kernel::device_operations)
|
||||
@@ -0,0 +1,244 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <random>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_bias.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Add = ck::tensor_operation::element_wise::Add;
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using D0DataType = F32;
|
||||
using DsDataType = ck::Tuple<D0DataType>;
|
||||
using EDataType = F32;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Row;
|
||||
using D0Layout = Row;
|
||||
using DsLayout = ck::Tuple<D0Layout>;
|
||||
using ELayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = Add;
|
||||
|
||||
struct SimpleDeviceMem
|
||||
{
|
||||
SimpleDeviceMem() = delete;
|
||||
|
||||
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
|
||||
{
|
||||
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
|
||||
}
|
||||
|
||||
void* GetDeviceBuffer() { return p_mem_; }
|
||||
|
||||
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
|
||||
|
||||
void* p_mem_;
|
||||
};
|
||||
|
||||
int main()
|
||||
{
|
||||
std::vector<int> Ms, Ns, Ks, StrideAs, StrideBs, StrideEs;
|
||||
|
||||
int sum_of_m = 0;
|
||||
|
||||
Ms = {167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148};
|
||||
|
||||
int group_count = Ms.size();
|
||||
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
Ns.push_back(768);
|
||||
Ks.push_back(4608);
|
||||
|
||||
StrideAs.push_back(std::is_same<Row, ALayout>::value ? Ks[i] : Ms[i]);
|
||||
StrideBs.push_back(std::is_same<Row, BLayout>::value ? Ns[i] : Ks[i]);
|
||||
StrideEs.push_back(std::is_same<Row, ELayout>::value ? Ns[i] : Ms[i]);
|
||||
|
||||
sum_of_m += Ms[i];
|
||||
}
|
||||
|
||||
auto f_matrix_space_size =
|
||||
[](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
|
||||
using Layout = decltype(layout);
|
||||
|
||||
if constexpr(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return (nRow - 1) * stride + nCol;
|
||||
}
|
||||
else
|
||||
{
|
||||
return (nCol - 1) * stride + nRow;
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<SimpleDeviceMem> a_dev_bufs, b_dev_bufs, d0_dev_bufs, e_dev_bufs;
|
||||
|
||||
a_dev_bufs.reserve(group_count);
|
||||
b_dev_bufs.reserve(group_count);
|
||||
d0_dev_bufs.reserve(group_count);
|
||||
e_dev_bufs.reserve(group_count);
|
||||
|
||||
std::vector<void*> p_e;
|
||||
|
||||
p_e.reserve(group_count);
|
||||
|
||||
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
|
||||
|
||||
gemm_descs.reserve(group_count);
|
||||
|
||||
std::vector<ck::tensor_operation::device::GroupedGemmKernelArgument<1>>
|
||||
grouped_gemm_kernel_args_;
|
||||
grouped_gemm_kernel_args_.reserve(group_count);
|
||||
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
a_dev_bufs.emplace_back(sizeof(ADataType) *
|
||||
f_matrix_space_size(Ms[i], Ks[i], StrideAs[i], ALayout{}));
|
||||
b_dev_bufs.emplace_back(sizeof(BDataType) *
|
||||
f_matrix_space_size(Ks[i], Ns[i], StrideBs[i], BLayout{}));
|
||||
d0_dev_bufs.emplace_back(sizeof(D0DataType) *
|
||||
f_matrix_space_size(Ms[i], Ns[i], 0, D0Layout{}));
|
||||
e_dev_bufs.emplace_back(sizeof(EDataType) *
|
||||
f_matrix_space_size(Ms[i], Ns[i], StrideEs[i], ELayout{}));
|
||||
|
||||
gemm_descs.push_back({sum_of_m, Ns[i], Ks[i], 1, StrideBs[i], 1, {0}});
|
||||
|
||||
p_e.push_back(e_dev_bufs[i].GetDeviceBuffer());
|
||||
|
||||
grouped_gemm_kernel_args_.push_back(
|
||||
{a_dev_bufs[i].GetDeviceBuffer(),
|
||||
b_dev_bufs[i].GetDeviceBuffer(),
|
||||
std::array<const void*, 1>{d0_dev_bufs[i].GetDeviceBuffer()},
|
||||
e_dev_bufs[i].GetDeviceBuffer(),
|
||||
Ms[i],
|
||||
Ns[i],
|
||||
Ks[i],
|
||||
StrideAs[i],
|
||||
StrideBs[i],
|
||||
std::array<ck::index_t, 1>{0},
|
||||
StrideEs[i]});
|
||||
}
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmFixedNK<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
const auto a_element_op = AElementOp{};
|
||||
const auto b_element_op = BElementOp{};
|
||||
const auto cde_element_op = CDEElementOp{};
|
||||
|
||||
std::string best_op_name;
|
||||
bool found = false;
|
||||
int best_op_id = -1;
|
||||
float best_ave_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
|
||||
// profile device operation instances
|
||||
std::cout << "Run all instances and do timing" << std::endl;
|
||||
|
||||
std::vector<const void*> p_a = {}, p_b = {};
|
||||
std::vector<std::array<const void*, 1>> p_ds = {};
|
||||
|
||||
for(int i = 0; i < op_ptrs.size(); ++i)
|
||||
{
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(
|
||||
p_a, p_b, p_ds, p_e, gemm_descs, a_element_op, b_element_op, cde_element_op);
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
SimpleDeviceMem grouped_gemm_kernel_args_dev(
|
||||
op_ptr->GetDeviceKernelArgSize(argument_ptr.get()));
|
||||
|
||||
SimpleDeviceMem grouped_gemm_workspace_dev(op_ptr->GetWorkSpaceSize(argument_ptr.get()));
|
||||
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
hipGetErrorString(hipMemcpy(grouped_gemm_kernel_args_dev.GetDeviceBuffer(),
|
||||
grouped_gemm_kernel_args_.data(),
|
||||
op_ptr->GetDeviceKernelArgSize(argument_ptr.get()),
|
||||
hipMemcpyHostToDevice));
|
||||
|
||||
op_ptr->SetWorkSpacePointer(argument_ptr.get(),
|
||||
grouped_gemm_workspace_dev.GetDeviceBuffer());
|
||||
|
||||
op_ptr->SetDeviceKernelArgs(argument_ptr.get(),
|
||||
grouped_gemm_kernel_args_dev.GetDeviceBuffer());
|
||||
|
||||
op_ptr->SetKBatch(argument_ptr.get(), 2);
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
for(std::size_t j = 0; j < gemm_descs.size(); ++j)
|
||||
{
|
||||
flop += std::size_t(2) * Ms[j] * Ns[j] * Ks[j];
|
||||
|
||||
num_btype += sizeof(ADataType) * Ms[j] * Ks[j] + sizeof(BDataType) * Ks[j] * Ns[j] +
|
||||
sizeof(EDataType) * Ms[j] * Ns[j];
|
||||
}
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
found = true;
|
||||
best_op_id = i;
|
||||
best_op_name = op_name;
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << op_name << " does not support this problem" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
|
||||
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
add_custom_target(example_grouped_gemm_xdl)
|
||||
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp)
|
||||
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32)
|
||||
@@ -7,10 +8,14 @@ if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp)
|
||||
add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp)
|
||||
add_example_executable(example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp)
|
||||
add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp16 grouped_gemm_xdl_fixed_nk_fp16.cpp)
|
||||
add_example_executable(example_grouped_gemm_xdl_fixed_nk_bias_fp16 grouped_gemm_xdl_fixed_nk_bias_fp16.cpp)
|
||||
add_dependencies(example_grouped_gemm_xdl
|
||||
example_grouped_gemm_xdl_fp16
|
||||
example_grouped_gemm_multiple_d_dl_fp16
|
||||
example_grouped_gemm_xdl_splitk_fp16)
|
||||
example_grouped_gemm_xdl_splitk_fp16
|
||||
example_grouped_gemm_xdl_fixed_nk_fp16
|
||||
example_grouped_gemm_xdl_fixed_nk_bias_fp16)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_grouped_gemm_xdl_bfp16 grouped_gemm_xdl_bfp16.cpp)
|
||||
|
||||
353
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp
Normal file
353
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp
Normal file
@@ -0,0 +1,353 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Add = ck::tensor_operation::element_wise::Add;
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using D0DataType = F32;
|
||||
using DsDataType = ck::Tuple<D0DataType>;
|
||||
using EDataType = F32;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Row;
|
||||
using D0Layout = Row;
|
||||
using DsLayout = ck::Tuple<D0Layout>;
|
||||
using ELayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
|
||||
using CDEElementOp = Add;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MPadding;
|
||||
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_Fixed_NK
|
||||
// clang-format off
|
||||
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 128, 16, 128, 32, 8, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>;
|
||||
// clang-format on
|
||||
|
||||
struct ProblemSize final
|
||||
{
|
||||
std::vector<ck::index_t> Ms;
|
||||
std::vector<ck::index_t> Ns;
|
||||
std::vector<ck::index_t> Ks;
|
||||
|
||||
std::vector<ck::index_t> stride_As;
|
||||
std::vector<ck::index_t> stride_Bs;
|
||||
std::vector<ck::index_t> stride_Cs;
|
||||
|
||||
ck::index_t group_count;
|
||||
};
|
||||
|
||||
struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
int k_batch = 1;
|
||||
};
|
||||
|
||||
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
auto group_count = problem_size.group_count;
|
||||
|
||||
// GEMM shape
|
||||
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
|
||||
|
||||
gemm_descs.reserve(group_count);
|
||||
|
||||
int sum_of_m = 0;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<Tensor<ADataType>> a_tensors;
|
||||
std::vector<Tensor<BDataType>> b_tensors;
|
||||
std::vector<Tensor<D0DataType>> d0_tensors;
|
||||
std::vector<Tensor<EDataType>> c_host_tensors;
|
||||
std::vector<Tensor<EDataType>> c_device_tensors;
|
||||
|
||||
a_tensors.reserve(group_count);
|
||||
b_tensors.reserve(group_count);
|
||||
d0_tensors.reserve(group_count);
|
||||
c_host_tensors.reserve(group_count);
|
||||
c_device_tensors.reserve(group_count);
|
||||
|
||||
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
|
||||
|
||||
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, d0_tensors_device,
|
||||
c_tensors_device;
|
||||
|
||||
a_tensors_device.reserve(group_count);
|
||||
b_tensors_device.reserve(group_count);
|
||||
d0_tensors_device.reserve(group_count);
|
||||
c_tensors_device.reserve(group_count);
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
sum_of_m += problem_size.Ms[i];
|
||||
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{})));
|
||||
b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{})));
|
||||
d0_tensors.push_back(Tensor<D0DataType>(
|
||||
f_host_tensor_descriptor(problem_size.Ms[i], problem_size.Ns[i], 0, ELayout{})));
|
||||
c_host_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
|
||||
c_device_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
|
||||
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
|
||||
<< " b_k_n: " << b_tensors[i].mDesc << " d_m_n: " << d0_tensors[i].mDesc
|
||||
<< " c_m_n: " << c_device_tensors[i].mDesc << std::endl;
|
||||
|
||||
flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i];
|
||||
num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() +
|
||||
sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() +
|
||||
sizeof(D0DataType) * d0_tensors[i].mDesc.GetElementSize() +
|
||||
sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize();
|
||||
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
break;
|
||||
case 2:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
break;
|
||||
default:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
}
|
||||
|
||||
d0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
}
|
||||
|
||||
using GroupedGemmKernelArgument = ck::tensor_operation::device::GroupedGemmKernelArgument<1>;
|
||||
|
||||
std::vector<GroupedGemmKernelArgument> grouped_gemm_kernel_args_;
|
||||
grouped_gemm_kernel_args_.reserve(group_count);
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
a_tensors_device.emplace_back(
|
||||
std::make_unique<DeviceMem>(sizeof(ADataType) * sum_of_m * problem_size.Ks[i]));
|
||||
|
||||
b_tensors_device.emplace_back(std::make_unique<DeviceMem>(
|
||||
sizeof(BDataType) * problem_size.Ns[i] * problem_size.Ks[i]));
|
||||
|
||||
d0_tensors_device.emplace_back(
|
||||
std::make_unique<DeviceMem>(sizeof(D0DataType) * problem_size.Ns[i]));
|
||||
|
||||
c_tensors_device.emplace_back(
|
||||
std::make_unique<DeviceMem>(sizeof(EDataType) * sum_of_m * problem_size.Ns[i]));
|
||||
|
||||
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data(),
|
||||
a_tensors[i].mDesc.GetElementSpaceSize() * sizeof(ADataType));
|
||||
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data(),
|
||||
b_tensors[i].mDesc.GetElementSpaceSize() * sizeof(BDataType));
|
||||
d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data());
|
||||
c_tensors_device[i]->SetZero();
|
||||
|
||||
gemm_descs.push_back({sum_of_m,
|
||||
problem_size.Ns[i],
|
||||
problem_size.Ks[i],
|
||||
1,
|
||||
problem_size.stride_Bs[i],
|
||||
1,
|
||||
{0}});
|
||||
|
||||
grouped_gemm_kernel_args_.push_back(
|
||||
{a_tensors_device[i]->GetDeviceBuffer(),
|
||||
b_tensors_device[i]->GetDeviceBuffer(),
|
||||
std::array<const void*, 1>{d0_tensors_device[i]->GetDeviceBuffer()},
|
||||
c_tensors_device[i]->GetDeviceBuffer(),
|
||||
problem_size.Ms[i],
|
||||
problem_size.Ns[i],
|
||||
problem_size.Ks[i],
|
||||
problem_size.stride_As[i],
|
||||
problem_size.stride_Bs[i],
|
||||
std::array<ck::index_t, 1>{0},
|
||||
problem_size.stride_Cs[i]});
|
||||
}
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
|
||||
std::vector<const void*> p_As = {};
|
||||
std::vector<const void*> p_Bs = {};
|
||||
std::vector<std::array<const void*, 1>> p_Ds = {};
|
||||
std::vector<void*> p_Cs = {};
|
||||
|
||||
// do GEMM
|
||||
auto argument = gemm.MakeArgument(
|
||||
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument));
|
||||
gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer());
|
||||
|
||||
DeviceMem gemm_kernel_args_dev(gemm.GetDeviceKernelArgSize(&argument));
|
||||
hip_check_error(hipMemcpy(gemm_kernel_args_dev.GetDeviceBuffer(),
|
||||
grouped_gemm_kernel_args_.data(),
|
||||
gemm.GetDeviceKernelArgSize(&argument),
|
||||
hipMemcpyHostToDevice));
|
||||
|
||||
gemm.SetDeviceKernelArgs(argument, gemm_kernel_args_dev.GetDeviceBuffer());
|
||||
gemm.SetKBatch(argument, config.k_batch);
|
||||
|
||||
invoker.Run(argument, StreamConfig{nullptr, false});
|
||||
|
||||
if(config.time_kernel)
|
||||
{
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << gemm.GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
if(config.do_verification)
|
||||
{
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
PassThrough>;
|
||||
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
{
|
||||
c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(),
|
||||
c_device_tensors[i].mDesc.GetElementSize() *
|
||||
sizeof(EDataType));
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(a_tensors[i],
|
||||
b_tensors[i],
|
||||
c_host_tensors[i],
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
for(int m = 0; m < problem_size.Ms[i]; ++m)
|
||||
{
|
||||
for(int n = 0; n < problem_size.Ns[i]; ++n)
|
||||
{
|
||||
cde_element_op(
|
||||
c_host_tensors[i](m, n), c_host_tensors[i](m, n), d0_tensors[i](m, n));
|
||||
}
|
||||
}
|
||||
|
||||
pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]);
|
||||
}
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ProblemSize problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
problem_size.group_count = 16;
|
||||
|
||||
problem_size.Ms = {0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0};
|
||||
|
||||
for(int i = 0; i < problem_size.group_count; i++)
|
||||
{
|
||||
problem_size.Ns.push_back(768);
|
||||
problem_size.Ks.push_back(4608);
|
||||
|
||||
problem_size.stride_As.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Bs.push_back(problem_size.Ns[i]);
|
||||
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
|
||||
}
|
||||
|
||||
if(argc == 5)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
config.k_batch = std::stoi(argv[4]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
printf("arg4: k_batch (>0)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
return !run_grouped_gemm(problem_size, config);
|
||||
}
|
||||
329
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp
Normal file
329
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp
Normal file
@@ -0,0 +1,329 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using EDataType = F32;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding;
|
||||
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_Fixed_NK
|
||||
// clang-format off
|
||||
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
|
||||
// clang-format on
|
||||
|
||||
struct ProblemSize final
|
||||
{
|
||||
std::vector<ck::index_t> Ms;
|
||||
std::vector<ck::index_t> Ns;
|
||||
std::vector<ck::index_t> Ks;
|
||||
|
||||
std::vector<ck::index_t> stride_As;
|
||||
std::vector<ck::index_t> stride_Bs;
|
||||
std::vector<ck::index_t> stride_Cs;
|
||||
|
||||
ck::index_t group_count;
|
||||
};
|
||||
|
||||
struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
int k_batch = 1;
|
||||
bool time_kernel = false;
|
||||
};
|
||||
|
||||
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
auto group_count = problem_size.group_count;
|
||||
|
||||
// GEMM shape
|
||||
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
|
||||
std::vector<void*> p_Cs;
|
||||
|
||||
gemm_descs.reserve(group_count);
|
||||
|
||||
int sum_of_m = 0;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<Tensor<ADataType>> a_tensors;
|
||||
std::vector<Tensor<BDataType>> b_tensors;
|
||||
std::vector<Tensor<EDataType>> c_host_tensors;
|
||||
std::vector<Tensor<EDataType>> c_device_tensors;
|
||||
|
||||
a_tensors.reserve(group_count);
|
||||
b_tensors.reserve(group_count);
|
||||
c_host_tensors.reserve(group_count);
|
||||
c_device_tensors.reserve(group_count);
|
||||
|
||||
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
|
||||
|
||||
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, c_tensors_device;
|
||||
|
||||
a_tensors_device.reserve(group_count);
|
||||
b_tensors_device.reserve(group_count);
|
||||
c_tensors_device.reserve(group_count);
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
sum_of_m += problem_size.Ms[i];
|
||||
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{})));
|
||||
b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{})));
|
||||
c_host_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
|
||||
c_device_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
|
||||
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
|
||||
<< " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc
|
||||
<< std::endl;
|
||||
|
||||
flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i];
|
||||
num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() +
|
||||
sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() +
|
||||
sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize();
|
||||
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
break;
|
||||
case 2:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
break;
|
||||
default:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
}
|
||||
}
|
||||
|
||||
using GroupedGemmKernelArgument = ck::tensor_operation::device::GroupedGemmKernelArgument<>;
|
||||
|
||||
std::vector<GroupedGemmKernelArgument> grouped_gemm_kernel_args_;
|
||||
grouped_gemm_kernel_args_.reserve(group_count);
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
a_tensors_device.emplace_back(
|
||||
std::make_unique<DeviceMem>(sizeof(ADataType) * sum_of_m * problem_size.Ks[i]));
|
||||
|
||||
b_tensors_device.emplace_back(std::make_unique<DeviceMem>(
|
||||
sizeof(BDataType) * problem_size.Ns[i] * problem_size.Ks[i]));
|
||||
|
||||
c_tensors_device.emplace_back(
|
||||
std::make_unique<DeviceMem>(sizeof(EDataType) * sum_of_m * problem_size.Ns[i]));
|
||||
|
||||
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data(),
|
||||
a_tensors[i].mDesc.GetElementSpaceSize() * sizeof(ADataType));
|
||||
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data(),
|
||||
b_tensors[i].mDesc.GetElementSpaceSize() * sizeof(BDataType));
|
||||
c_tensors_device[i]->SetZero();
|
||||
|
||||
p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer());
|
||||
|
||||
gemm_descs.push_back({sum_of_m,
|
||||
problem_size.Ns[i],
|
||||
problem_size.Ks[i],
|
||||
1,
|
||||
problem_size.stride_Bs[i],
|
||||
1,
|
||||
{}});
|
||||
|
||||
grouped_gemm_kernel_args_.push_back({a_tensors_device[i]->GetDeviceBuffer(),
|
||||
b_tensors_device[i]->GetDeviceBuffer(),
|
||||
{},
|
||||
c_tensors_device[i]->GetDeviceBuffer(),
|
||||
problem_size.Ms[i],
|
||||
problem_size.Ns[i],
|
||||
problem_size.Ks[i],
|
||||
problem_size.stride_As[i],
|
||||
problem_size.stride_Bs[i],
|
||||
{},
|
||||
problem_size.stride_Cs[i]});
|
||||
}
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CDEElementOp{};
|
||||
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
|
||||
std::vector<const void*> p_As = {};
|
||||
std::vector<const void*> p_Bs = {};
|
||||
std::vector<std::array<const void*, 0>> p_Ds = {};
|
||||
|
||||
// do GEMM
|
||||
auto argument = gemm.MakeArgument(
|
||||
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, c_element_op);
|
||||
|
||||
DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument));
|
||||
DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument));
|
||||
|
||||
gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer());
|
||||
|
||||
hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(),
|
||||
grouped_gemm_kernel_args_.data(),
|
||||
gemm.GetDeviceKernelArgSize(&argument),
|
||||
hipMemcpyHostToDevice));
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer());
|
||||
gemm.SetKBatch(argument, config.k_batch);
|
||||
|
||||
invoker.Run(argument, StreamConfig{nullptr, false});
|
||||
|
||||
if(config.time_kernel)
|
||||
{
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << gemm.GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
if(config.do_verification)
|
||||
{
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
{
|
||||
c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(),
|
||||
c_device_tensors[i].mDesc.GetElementSize() *
|
||||
sizeof(EDataType));
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(a_tensors[i],
|
||||
b_tensors[i],
|
||||
c_host_tensors[i],
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]);
|
||||
}
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ProblemSize problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
problem_size.group_count = 16;
|
||||
|
||||
problem_size.Ms = {
|
||||
167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148};
|
||||
|
||||
for(int i = 0; i < problem_size.group_count; i++)
|
||||
{
|
||||
problem_size.Ns.push_back(768);
|
||||
problem_size.Ks.push_back(4608);
|
||||
|
||||
problem_size.stride_As.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
|
||||
}
|
||||
|
||||
if(argc == 5)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
config.k_batch = std::stoi(argv[4]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
printf("arg4: k_batch (> 0)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
return !run_grouped_gemm(problem_size, config);
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <array>
|
||||
|
||||
#include "device_grouped_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <index_t NumDTensor = 0>
|
||||
struct GroupedGemmKernelArgument
|
||||
{
|
||||
const void* p_a_grid;
|
||||
const void* p_b_grid;
|
||||
std::array<const void*, NumDTensor> p_ds_grid;
|
||||
void* p_e_grid;
|
||||
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
|
||||
index_t StrideA;
|
||||
index_t StrideB;
|
||||
std::array<index_t, NumDTensor> StrideDs;
|
||||
index_t StrideE;
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceGroupedGemmFixedNK : DeviceGroupedGemm<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
virtual void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const = 0;
|
||||
virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0;
|
||||
virtual void SetKBatch(BaseArgument* p_arg, index_t k_batch) const = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,836 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename GemmDesc,
|
||||
GemmSpecialization GemmSpec,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename DsDataType,
|
||||
typename Block2ETileMap,
|
||||
typename GroupedGemmBlock2ETileMap,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_grouped_gemm_xdl_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
|
||||
uint32_t* barrier_count,
|
||||
const index_t barrier_size_grp,
|
||||
const index_t group_count,
|
||||
const index_t grid_size_grp,
|
||||
const index_t KBatch,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation c_element_op)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
const index_t block_id = get_block_1d_id();
|
||||
|
||||
const auto gemm_desc_ptr =
|
||||
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
|
||||
|
||||
const index_t group_id = block_id / grid_size_grp;
|
||||
|
||||
if(group_id >= group_count)
|
||||
return;
|
||||
|
||||
const index_t M = gemm_desc_ptr[group_id].M;
|
||||
const index_t N = gemm_desc_ptr[group_id].N;
|
||||
const index_t K = gemm_desc_ptr[group_id].K;
|
||||
|
||||
if(M * N * K == 0)
|
||||
return;
|
||||
|
||||
const auto StrideA = gemm_desc_ptr[group_id].StrideA;
|
||||
const auto StrideB = gemm_desc_ptr[group_id].StrideB;
|
||||
const auto StrideDs = gemm_desc_ptr[group_id].StrideDs;
|
||||
const auto StrideE = gemm_desc_ptr[group_id].StrideE;
|
||||
|
||||
const auto e_grid_desc_m_n =
|
||||
GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
|
||||
|
||||
const index_t BlockStart = group_id * grid_size_grp;
|
||||
|
||||
const auto local_b2e_tile_map = Block2ETileMap{e_grid_desc_m_n, KBatch};
|
||||
|
||||
const auto local_grid_size = local_b2e_tile_map.CalculateGridSize(e_grid_desc_m_n);
|
||||
|
||||
constexpr auto NumDTensor = DsDataType::Size();
|
||||
|
||||
using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer());
|
||||
|
||||
DsGridPointer p_ds_grid_;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
// D pointer
|
||||
p_ds_grid_(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]);
|
||||
});
|
||||
|
||||
index_t id_off = 0;
|
||||
index_t id_local = get_block_1d_id() - BlockStart;
|
||||
|
||||
const index_t mn_blocks = local_grid_size / KBatch;
|
||||
|
||||
while(id_local < local_grid_size)
|
||||
{
|
||||
const auto block_2_etile_map =
|
||||
GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off);
|
||||
|
||||
auto barrier_count_finished =
|
||||
barrier_count + group_id * barrier_size_grp + id_local % mn_blocks;
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop,
|
||||
EGlobalMemoryDataOperation,
|
||||
GemmSpec,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout>(gemm_desc_ptr[group_id].p_a_grid,
|
||||
gemm_desc_ptr[group_id].p_b_grid,
|
||||
p_ds_grid_,
|
||||
gemm_desc_ptr[group_id].p_e_grid,
|
||||
p_shared,
|
||||
barrier_count_finished,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
KBatch,
|
||||
block_2_etile_map);
|
||||
|
||||
id_off += grid_size_grp;
|
||||
id_local += grid_size_grp;
|
||||
}
|
||||
#else
|
||||
ignore = gemm_descs_const;
|
||||
ignore = barrier_count;
|
||||
ignore = barrier_size_grp;
|
||||
ignore = group_count;
|
||||
ignore = grid_size_grp;
|
||||
ignore = KBatch;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = c_element_op;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::index_t NumPrefetch,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t AK1,
|
||||
ck::index_t BK1,
|
||||
ck::index_t MPerXDL,
|
||||
ck::index_t NPerXDL,
|
||||
ck::index_t MXdlPerWave,
|
||||
ck::index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedGemm_Xdl_Fixed_NK;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_xdl_splitk_cshuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
NumPrefetch, // NumGemmKPrefetchStage
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched>;
|
||||
|
||||
template <typename UnderlyingBlockToCTileMap>
|
||||
struct OffsettedBlockToCTileMapMLoops
|
||||
{
|
||||
using underlying_type = UnderlyingBlockToCTileMap;
|
||||
|
||||
__host__ __device__ OffsettedBlockToCTileMapMLoops(
|
||||
UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off = 0)
|
||||
{
|
||||
block_to_ctile_map_ = block_to_ctile_map;
|
||||
block_start_ = block_start;
|
||||
id_off_ = id_off;
|
||||
}
|
||||
|
||||
template <typename TopIdx>
|
||||
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
|
||||
{
|
||||
auto idx_bot = block_to_ctile_map_.CalculateBottomIndex(
|
||||
make_multi_index(idx_top[Number<0>{}] - block_start_ + id_off_));
|
||||
|
||||
return make_tuple(idx_bot[Number<0>{}], idx_bot[Number<1>{}], idx_bot[Number<2>{}]);
|
||||
}
|
||||
|
||||
template <typename CTileIdx, typename CTileDim>
|
||||
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
|
||||
const CTileDim& c_tile_dim) const
|
||||
{
|
||||
return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
{
|
||||
return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
{
|
||||
return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
UnderlyingBlockToCTileMap block_to_ctile_map_;
|
||||
index_t block_start_;
|
||||
index_t id_off_;
|
||||
};
|
||||
|
||||
template <index_t MPerBlock_, index_t NPerBlock_>
|
||||
struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops() = default;
|
||||
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(
|
||||
const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default;
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(
|
||||
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default;
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&
|
||||
operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default;
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&
|
||||
operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default;
|
||||
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M,
|
||||
index_t N,
|
||||
index_t KBatch,
|
||||
index_t M01 = 8)
|
||||
: M_(M), N_(N), KBatch_(KBatch), M01_(M01)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(
|
||||
const CGridDesc_M_N& c_grid_desc_m_n, index_t KBatch, index_t M01 = 8)
|
||||
: BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(
|
||||
c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), KBatch, M01)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
|
||||
{
|
||||
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
|
||||
|
||||
return M0 * N0 * KBatch_;
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ __device__ constexpr index_t
|
||||
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
|
||||
{
|
||||
return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename TopIdx>
|
||||
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
|
||||
{
|
||||
auto block_1d_id = idx_top[I0];
|
||||
|
||||
const auto M0 = math::integer_divide_ceil(M_, MPerBlock_);
|
||||
const auto N0 = math::integer_divide_ceil(N_, NPerBlock_);
|
||||
|
||||
block_1d_id = block_1d_id % (M0 * N0 * KBatch_); // hide groups
|
||||
|
||||
const index_t idx_ksplit = block_1d_id / (M0 * N0);
|
||||
block_1d_id = block_1d_id % (M0 * N0);
|
||||
|
||||
index_t idx_N0 = block_1d_id % N0;
|
||||
index_t idx_M0 = block_1d_id / N0;
|
||||
|
||||
const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
|
||||
|
||||
index_t idx_M00 = idx_M0 / M01_;
|
||||
index_t idx_M01 = idx_M0 % M01_;
|
||||
index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
|
||||
|
||||
return make_tuple(idx_ksplit,
|
||||
idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
|
||||
idx_N0_M01_local / M01_adapt);
|
||||
}
|
||||
|
||||
template <typename CTileIdx, typename CTileDim>
|
||||
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
|
||||
const CTileDim& /* c_tile_dim */) const
|
||||
{
|
||||
return true; // always valid provided that user gets grid size from CalculateGridSize()
|
||||
}
|
||||
|
||||
private:
|
||||
index_t M_;
|
||||
index_t N_;
|
||||
index_t KBatch_;
|
||||
index_t M01_;
|
||||
};
|
||||
|
||||
using Block2ETileMap = BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops<MPerBlock, NPerBlock>;
|
||||
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops<Block2ETileMap>;
|
||||
|
||||
struct GemmBiasTransKernelArg
|
||||
{
|
||||
// pointers
|
||||
const void* a_ptr_;
|
||||
const void* b_ptr_;
|
||||
std::array<const void*, NumDTensor> ds_ptr_;
|
||||
void* e_ptr_;
|
||||
|
||||
index_t M_, N_, K_;
|
||||
index_t StrideA_, StrideB_;
|
||||
std::array<index_t, NumDTensor> StrideDs_;
|
||||
index_t StrideE_;
|
||||
};
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
|
||||
void UpdateKBatch(index_t k_batch)
|
||||
{
|
||||
k_batch_ = k_batch;
|
||||
|
||||
if(k_batch_ < 1)
|
||||
{
|
||||
|
||||
throw std::runtime_error("wrong! k_batch must be > 0");
|
||||
}
|
||||
|
||||
const index_t AverM = math::integer_divide_ceil(sum_of_m, group_count_);
|
||||
|
||||
const index_t StrideE = gemm_desc_kernel_arg_[0].StrideE_;
|
||||
const index_t N = gemm_desc_kernel_arg_[0].N_;
|
||||
|
||||
const auto e_grid_desc_m_n =
|
||||
GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(
|
||||
AverM, N, StrideE);
|
||||
|
||||
const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_};
|
||||
|
||||
grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n);
|
||||
|
||||
grid_size_ = grid_size_grp_ * group_count_;
|
||||
}
|
||||
|
||||
Argument(std::vector<const void*>&,
|
||||
std::vector<const void*>&,
|
||||
std::vector<std::array<const void*, NumDTensor>>&,
|
||||
std::vector<void*>&,
|
||||
std::vector<GemmDesc>& gemm_descs,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation c_element_op)
|
||||
: a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op}
|
||||
{
|
||||
grid_size_ = 0;
|
||||
|
||||
k_batch_ = 1;
|
||||
|
||||
grouped_gemm_kernel_args_dev = nullptr;
|
||||
|
||||
group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size());
|
||||
|
||||
gemm_desc_kernel_arg_.reserve(group_count_);
|
||||
|
||||
index_t group_id = 0;
|
||||
|
||||
sum_of_m = gemm_descs[0].M_;
|
||||
const index_t AverM = math::integer_divide_ceil(sum_of_m, group_count_);
|
||||
const index_t N = gemm_descs[0].N_;
|
||||
const index_t K = gemm_descs[0].K_;
|
||||
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
{
|
||||
if(sum_of_m != gemm_descs[i].M_ || N != gemm_descs[i].N_ || K != gemm_descs[i].K_)
|
||||
{
|
||||
throw std::runtime_error("wrong! M/N/K is not identical");
|
||||
}
|
||||
|
||||
a_mtx_mraw_kraw_.emplace_back(sum_of_m, K);
|
||||
b_mtx_nraw_kraw_.emplace_back(N, K);
|
||||
|
||||
const index_t StrideA = gemm_descs[i].stride_A_;
|
||||
const index_t StrideB = gemm_descs[i].stride_B_;
|
||||
const index_t StrideE = gemm_descs[i].stride_C_;
|
||||
|
||||
// pointer
|
||||
std::array<const void*, NumDTensor> p_ds_grid;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto j) { p_ds_grid[j] = nullptr; });
|
||||
|
||||
std::array<index_t, NumDTensor> StrideDs;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto j) {
|
||||
// using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
|
||||
|
||||
if(gemm_descs[i].stride_Ds_.size() != NumDTensor)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! gemm_descs[i].stride_Ds_.size() does not match NumDTensor");
|
||||
}
|
||||
|
||||
StrideDs[j] = gemm_descs[i].stride_Ds_[j];
|
||||
});
|
||||
|
||||
const auto e_grid_desc_m_n =
|
||||
GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(
|
||||
AverM, N, StrideE);
|
||||
|
||||
// block-to-e-tile map
|
||||
const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_};
|
||||
|
||||
grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n);
|
||||
|
||||
if(group_id * grid_size_grp_ != grid_size_)
|
||||
{
|
||||
throw std::runtime_error("wrong! grid_size_grp_ is not identical!");
|
||||
}
|
||||
|
||||
grid_size_ += grid_size_grp_;
|
||||
|
||||
// check block-to-E-tile
|
||||
if(!local_b2c_tile_map.CheckValidity(e_grid_desc_m_n))
|
||||
{
|
||||
throw std::runtime_error("wrong! block_2_etile_map validation failed");
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::
|
||||
template CheckValidity<ALayout, BLayout, DsLayout, ELayout, GemmSpec>(
|
||||
AverM, N, K, StrideA, StrideB, StrideDs, StrideE, 1))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
|
||||
}
|
||||
|
||||
gemm_desc_kernel_arg_.push_back(GemmBiasTransKernelArg{
|
||||
nullptr,
|
||||
nullptr,
|
||||
p_ds_grid,
|
||||
nullptr,
|
||||
AverM,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
});
|
||||
|
||||
group_id++;
|
||||
}
|
||||
|
||||
const auto e_grid_desc_sum_m_n =
|
||||
GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(
|
||||
sum_of_m, gemm_desc_kernel_arg_[0].N_, gemm_desc_kernel_arg_[0].StrideE_);
|
||||
|
||||
const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_sum_m_n, 1};
|
||||
|
||||
barrier_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_sum_m_n);
|
||||
}
|
||||
|
||||
// private:
|
||||
index_t group_count_;
|
||||
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation c_element_op_;
|
||||
|
||||
std::vector<GemmBiasTransKernelArg> gemm_desc_kernel_arg_;
|
||||
std::vector<Tuple<index_t, index_t>> a_mtx_mraw_kraw_;
|
||||
std::vector<Tuple<index_t, index_t>> b_mtx_nraw_kraw_;
|
||||
|
||||
const void* grouped_gemm_kernel_args_dev;
|
||||
|
||||
index_t grid_size_;
|
||||
index_t grid_size_grp_;
|
||||
index_t barrier_size_grp_;
|
||||
index_t sum_of_m;
|
||||
|
||||
index_t k_batch_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
bool has_main_k_block_loop = true;
|
||||
|
||||
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
|
||||
{
|
||||
const auto KPad =
|
||||
GridwiseGemm::CalculateKPadded(arg.gemm_desc_kernel_arg_[i].K_, arg.k_batch_);
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(KPad) != has_main_k_block_loop)
|
||||
{
|
||||
throw std::runtime_error("wrong! not all gemm has_main_k_block_loop");
|
||||
}
|
||||
}
|
||||
|
||||
if(arg.grouped_gemm_kernel_args_dev == nullptr)
|
||||
{
|
||||
throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullpr");
|
||||
}
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop_, auto e_global_memory_operation_) {
|
||||
const auto kernel =
|
||||
kernel_grouped_gemm_xdl_fixed_nk<GridwiseGemm,
|
||||
GroupedGemmKernelArgument<NumDTensor>,
|
||||
GemmSpec,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
DsDataType,
|
||||
Block2ETileMap,
|
||||
GroupedGemmBlock2ETileMap,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
e_global_memory_operation_,
|
||||
has_main_k_block_loop_>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(arg.grid_size_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev),
|
||||
reinterpret_cast<uint32_t*>(arg.p_workspace_),
|
||||
arg.barrier_size_grp_,
|
||||
arg.gemm_desc_kernel_arg_.size(),
|
||||
arg.grid_size_grp_,
|
||||
arg.k_batch_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_);
|
||||
};
|
||||
|
||||
constexpr auto AtomicAdd = InMemoryDataOperationEnum::AtomicAdd;
|
||||
constexpr auto Set = InMemoryDataOperationEnum::Set;
|
||||
|
||||
if(arg.k_batch_ > 1)
|
||||
{
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
ave_time =
|
||||
launch_kernel(integral_constant<bool, true>{},
|
||||
integral_constant<InMemoryDataOperationEnum, AtomicAdd>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time =
|
||||
launch_kernel(integral_constant<bool, false>{},
|
||||
integral_constant<InMemoryDataOperationEnum, AtomicAdd>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
ave_time = launch_kernel(integral_constant<bool, true>{},
|
||||
integral_constant<InMemoryDataOperationEnum, Set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = launch_kernel(integral_constant<bool, false>{},
|
||||
integral_constant<InMemoryDataOperationEnum, Set>{});
|
||||
}
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::type_convert<ck::index_t>(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
bool supported = true;
|
||||
|
||||
// If we use padding we do not support vector loads for dimensions not divisible by vector
|
||||
// load size.
|
||||
if constexpr(GemmSpec != GemmSpecialization::Default)
|
||||
{
|
||||
// [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} layout,
|
||||
// thus we have to adapt it to the {M,K} or {N,K} layout.
|
||||
const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0;
|
||||
const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0;
|
||||
|
||||
for(index_t i = 0; i < arg.group_count_; ++i)
|
||||
{
|
||||
const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number<a_raw_vector_dim>{});
|
||||
const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number<b_raw_vector_dim>{});
|
||||
|
||||
supported = supported & (a_vector_dim % ABlockTransferSrcScalarPerVector == 0);
|
||||
supported = supported & (b_vector_dim % BBlockTransferSrcScalarPerVector == 0);
|
||||
}
|
||||
}
|
||||
|
||||
return supported;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(std::vector<const void*>& p_As,
|
||||
std::vector<const void*>& p_Bs,
|
||||
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
|
||||
std::vector<void*>& p_Es,
|
||||
std::vector<GemmDesc> gemm_descs,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation c_element_op)
|
||||
{
|
||||
return Argument{
|
||||
p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(std::vector<const void*>& p_As,
|
||||
std::vector<const void*>& p_Bs,
|
||||
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
|
||||
std::vector<void*>& p_Es,
|
||||
std::vector<GemmDesc>& gemm_descs,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation c_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(
|
||||
p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGroupedGemm_Xdl_Fixed_NK"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< AK1 << ", "
|
||||
<< BK1 << ", "
|
||||
<< MPerXDL << ", "
|
||||
<< NPerXDL << ", "
|
||||
<< MXdlPerWave << ", "
|
||||
<< NXdlPerWave << ", "
|
||||
<< ABlockTransferSrcScalarPerVector << ", "
|
||||
<< BBlockTransferSrcScalarPerVector << ", "
|
||||
<< CShuffleMXdlPerWavePerShuffle << ", "
|
||||
<< CShuffleNXdlPerWavePerShuffle << ", "
|
||||
<< getGemmSpecializationString(GemmSpec)
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
|
||||
static void SetDeviceKernelArgs(Argument& arg, const void* kernel_args)
|
||||
{
|
||||
arg.grouped_gemm_kernel_args_dev = kernel_args;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const override
|
||||
{
|
||||
return SetDeviceKernelArgs(*dynamic_cast<Argument*>(p_arg), kernel_args);
|
||||
}
|
||||
|
||||
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
|
||||
{
|
||||
auto arg = *dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
return arg.group_count_ * arg.barrier_size_grp_ * sizeof(uint32_t);
|
||||
}
|
||||
|
||||
size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
|
||||
{
|
||||
auto arg = *dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
return arg.group_count_ * sizeof(GroupedGemmKernelArgument<NumDTensor>);
|
||||
}
|
||||
|
||||
void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const override
|
||||
{
|
||||
auto p_arg_ = dynamic_cast<Argument*>(p_arg);
|
||||
p_arg_->p_workspace_ = p_workspace;
|
||||
|
||||
hip_check_error(hipMemset(p_workspace, 0, GetWorkSpaceSize(p_arg)));
|
||||
}
|
||||
|
||||
static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); }
|
||||
|
||||
// polymorphic
|
||||
void SetKBatch(BaseArgument* p_arg, index_t k_batch) const override
|
||||
{
|
||||
return SetKBatch(*dynamic_cast<Argument*>(p_arg), k_batch);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -36,6 +36,13 @@ struct Add
|
||||
y = x0 + type_convert<half_t>(x1);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t>(half_t& y, const float& x0, const float& x1) const
|
||||
{
|
||||
y = type_convert<half_t>(x0 + x1);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
|
||||
|
||||
@@ -587,7 +587,8 @@ struct OffsettedBlockToCTileMap
|
||||
{
|
||||
using underlying_type = UnderlyingBlockToCTileMap;
|
||||
|
||||
OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start)
|
||||
__host__ __device__ OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map,
|
||||
index_t block_start)
|
||||
{
|
||||
block_to_ctile_map_ = block_to_ctile_map;
|
||||
block_start_ = block_start;
|
||||
|
||||
@@ -15,6 +15,9 @@
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// GEMM:
|
||||
@@ -74,6 +77,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
@@ -330,6 +335,94 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
|
||||
using DsGridPointer = decltype(MakeDsGridPointer());
|
||||
|
||||
template <typename ALayout, GemmSpecialization GemmSpec>
|
||||
__host__ __device__ static auto
|
||||
MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
{
|
||||
constexpr auto matrix_padder =
|
||||
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
|
||||
MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
const auto a_grid_desc_mraw_kraw = [&]() {
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(StrideA, I1));
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(I1, StrideA));
|
||||
}
|
||||
}();
|
||||
|
||||
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
|
||||
}
|
||||
|
||||
template <typename BLayout, GemmSpecialization GemmSpec>
|
||||
__host__ __device__ static auto
|
||||
MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
{
|
||||
constexpr auto matrix_padder =
|
||||
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
|
||||
MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
const auto b_grid_desc_nraw_kraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(I1, StrideB));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(StrideB, I1));
|
||||
}
|
||||
}();
|
||||
|
||||
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
|
||||
}
|
||||
|
||||
template <typename ELayout, GemmSpecialization GemmSpec>
|
||||
__host__ __device__ static auto
|
||||
MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
|
||||
{
|
||||
constexpr auto matrix_padder =
|
||||
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
|
||||
MPerBlock, NPerBlock, KPerBlock};
|
||||
const auto e_grid_desc_mraw_nraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
|
||||
make_tuple(StrideE, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
|
||||
make_tuple(I1, StrideE));
|
||||
}
|
||||
}();
|
||||
|
||||
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
|
||||
}
|
||||
|
||||
template <typename DsLayout, GemmSpecialization GemmSpec>
|
||||
__host__ __device__ static auto
|
||||
MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
|
||||
const std::array<index_t, NumDTensor>& NRaws,
|
||||
const std::array<index_t, NumDTensor>& DsStride)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
return MakeEGridDescriptor_M_N<DLayout, GemmSpec>(MRaws[i], NRaws[i], DsStride[i]);
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
__device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; }
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
@@ -758,6 +851,85 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
GemmSpecialization GemmSpec,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename Block2ETileMap>
|
||||
__device__ static void Run(const void* __restrict__ p_a_grid_,
|
||||
const void* __restrict__ p_b_grid_,
|
||||
DsGridPointer p_ds_grid,
|
||||
void* __restrict__ p_e_grid_,
|
||||
void* __restrict__ p_shared,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op,
|
||||
const index_t M,
|
||||
const index_t N,
|
||||
const index_t K,
|
||||
const index_t StrideA,
|
||||
const index_t StrideB,
|
||||
const std::array<index_t, NumDTensor> StrideDs,
|
||||
const index_t StrideE,
|
||||
const Block2ETileMap& block_2_etile_map)
|
||||
{
|
||||
const auto p_a_grid = reinterpret_cast<const ADataType*>(p_a_grid_);
|
||||
const auto p_b_grid = reinterpret_cast<const BDataType*>(p_b_grid_);
|
||||
const auto p_e_grid = reinterpret_cast<EDataType*>(p_e_grid_);
|
||||
|
||||
// tensor descriptors for problem definiton
|
||||
const auto a_grid_desc_m_k = MakeAGridDescriptor_M_K<ALayout, GemmSpec>(M, K, StrideA);
|
||||
const auto b_grid_desc_n_k = MakeBGridDescriptor_N_K<BLayout, GemmSpec>(K, N, StrideB);
|
||||
|
||||
using DsGridDesc_M_N =
|
||||
remove_cvref_t<decltype(MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>({}, {}, {}))>;
|
||||
|
||||
DsGridDesc_M_N ds_grid_desc_m_n;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto j) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
|
||||
|
||||
ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N<DLayout, GemmSpec>(M, N, StrideDs[j]);
|
||||
});
|
||||
|
||||
const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
|
||||
|
||||
// tensor descriptors for block/thread-wise copy
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 = MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
|
||||
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
DsGridDesc_M_N{}))>;
|
||||
|
||||
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto j) {
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock(j) =
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]);
|
||||
});
|
||||
|
||||
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
|
||||
|
||||
Run<HasMainKBlockLoop>(p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_etile_map);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -101,6 +101,7 @@ using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
|
||||
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
|
||||
using Gelu = ck::tensor_operation::element_wise::Gelu;
|
||||
using Swish = ck::tensor_operation::element_wise::Swish;
|
||||
using Add = ck::tensor_operation::element_wise::Add;
|
||||
|
||||
template <typename Activation>
|
||||
using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Activation>;
|
||||
|
||||
@@ -0,0 +1,146 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// fp16_output
|
||||
void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F32_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Add>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
|
||||
Col,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F32_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Add>>>& instances);
|
||||
|
||||
// fp32_output
|
||||
void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F32_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Add>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
|
||||
Col,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F32_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Add>>>& instances);
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename EDataType>
|
||||
struct DeviceOperationInstanceFactory<
|
||||
ck::tensor_operation::device::DeviceGroupedGemmFixedNK<ALayout,
|
||||
BLayout,
|
||||
Row_Tuple,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
F32_Tuple,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Add>>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedGemmFixedNK<ALayout,
|
||||
BLayout,
|
||||
Row_Tuple,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
F32_Tuple,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Add>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
// fp16_output
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<EDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
// fp32_output
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<EDataType, float>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -26,7 +26,9 @@ struct DeviceMem
|
||||
void* GetDeviceBuffer() const;
|
||||
std::size_t GetBufferSize() const;
|
||||
void ToDevice(const void* p) const;
|
||||
void ToDevice(const void* p, const std::size_t cpySize) const;
|
||||
void FromDevice(void* p) const;
|
||||
void FromDevice(void* p, const std::size_t cpySize) const;
|
||||
void SetZero() const;
|
||||
template <typename T>
|
||||
void SetValue(T x) const;
|
||||
|
||||
@@ -83,7 +83,6 @@ using device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_tile_instanc
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, PipelineVersion::v2>,
|
||||
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, PipelineVersion::v2>
|
||||
// clang-format on
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances(
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
add_instance_library(device_grouped_gemm_bias_instance
|
||||
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instance.cpp
|
||||
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
|
||||
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instance.cpp
|
||||
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instance.cpp
|
||||
)
|
||||
@@ -0,0 +1,79 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using D0DataType = F32;
|
||||
using DsDataType = ck::Tuple<D0DataType>;
|
||||
|
||||
using D0Layout = Row;
|
||||
using DsLayout = ck::Tuple<D0Layout>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Add = ck::tensor_operation::element_wise::Add;
|
||||
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
using device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_irregular_tile_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 16,16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
|
||||
Row,
|
||||
DsLayout,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
DsDataType,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Add>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_irregular_tile_instances{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,82 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using D0DataType = F32;
|
||||
using DsDataType = ck::Tuple<D0DataType>;
|
||||
|
||||
using D0Layout = Row;
|
||||
using DsLayout = ck::Tuple<D0Layout>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Add = ck::tensor_operation::element_wise::Add;
|
||||
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
using device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_irregular_tile_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 8, 32, 32, 2, 4, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 64, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 128, 128, 64, 8, 8, 32, 32, 4, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 128, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 64, 128, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 32, 256, 64, 8, 8, 32, 32, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
|
||||
Col,
|
||||
DsLayout,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
DsDataType,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Add>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_irregular_tile_instances{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,83 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using D0DataType = F32;
|
||||
using DsDataType = ck::Tuple<D0DataType>;
|
||||
|
||||
using D0Layout = Row;
|
||||
using DsLayout = ck::Tuple<D0Layout>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Add = ck::tensor_operation::element_wise::Add;
|
||||
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
using device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_irregular_tile_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 16,16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 32, 128, 32, 8, 8, 32, 32, 1, 1, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 32, 256, 32, 8, 8, 32, 32, 1, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 16, 128, 32, 8, 8, 16, 16, 1, 4, S< 1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
|
||||
Row,
|
||||
DsLayout,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
DsDataType,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Add>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_irregular_tile_instances{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,83 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using D0DataType = F32;
|
||||
using DsDataType = ck::Tuple<D0DataType>;
|
||||
|
||||
using D0Layout = Row;
|
||||
using DsLayout = ck::Tuple<D0Layout>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Add = ck::tensor_operation::element_wise::Add;
|
||||
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
using device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_irregular_tile_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 8, 32, 32, 2, 4, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 128, 64, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 32, 128, 64, 8, 8, 32, 32, 1, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 128, 128, 64, 8, 8, 32, 32, 4, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 128, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 64, 128, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 128, 32, 256, 64, 8, 8, 32, 32, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F32, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
|
||||
Col,
|
||||
DsLayout,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
DsDataType,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Add>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_irregular_tile_instances{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -37,6 +37,11 @@ void DeviceMem::ToDevice(const void* p) const
|
||||
}
|
||||
}
|
||||
|
||||
void DeviceMem::ToDevice(const void* p, const std::size_t cpySize) const
|
||||
{
|
||||
hip_check_error(hipMemcpy(mpDeviceBuf, const_cast<void*>(p), cpySize, hipMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
void DeviceMem::FromDevice(void* p) const
|
||||
{
|
||||
if(mpDeviceBuf)
|
||||
@@ -49,6 +54,11 @@ void DeviceMem::FromDevice(void* p) const
|
||||
}
|
||||
}
|
||||
|
||||
void DeviceMem::FromDevice(void* p, const std::size_t cpySize) const
|
||||
{
|
||||
hip_check_error(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost));
|
||||
}
|
||||
|
||||
void DeviceMem::SetZero() const
|
||||
{
|
||||
if(mpDeviceBuf)
|
||||
|
||||
@@ -16,4 +16,3 @@ cmake
|
||||
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
|
||||
-D USE_BITINT_EXTENSION_INT4=OFF \
|
||||
${MY_PROJECT_SOURCE}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user