mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
Merge remote-tracking branch 'origin/develop' into samremes/ck_tile_mx_gemm
This commit is contained in:
@@ -119,7 +119,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 3840;
|
||||
|
||||
@@ -119,7 +119,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 3840;
|
||||
|
||||
@@ -131,6 +131,9 @@ template <ck::index_t NDimSpatial,
|
||||
typename WeiElementOp,
|
||||
typename OutElementOp,
|
||||
typename DeviceConvNDFwdInstance,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename ComputeDataType = OutDataType>
|
||||
bool run_grouped_conv_fwd(int do_verification,
|
||||
int init_method,
|
||||
@@ -283,31 +286,25 @@ bool run_grouped_conv_fwd(int do_verification,
|
||||
DeviceMem out_device_ref_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize());
|
||||
out_device_ref_buf.SetZero();
|
||||
|
||||
// Extract dimensions using helper function
|
||||
ck::ref::ConvDims dims = ck::utils::conv::extract_conv_dims(conv_param, NDimSpatial);
|
||||
|
||||
// Launch GPU reference kernel
|
||||
constexpr ck::index_t block_size = 256;
|
||||
const ck::long_index_t output_length = dims.N * dims.Do * dims.Ho * dims.Wo * dims.K;
|
||||
const ck::index_t grid_size = (output_length + block_size - 1) / block_size;
|
||||
|
||||
auto gpu_ref_kernel = ck::ref::naive_conv_fwd_ndhwc_kzyxc_ndhwk<InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
ComputeDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp>;
|
||||
|
||||
gpu_ref_kernel<<<dim3(grid_size), dim3(block_size), 0, nullptr>>>(
|
||||
// Call GPU reference with ConvParam directly, using the correct layout types
|
||||
ck::ref::naive_conv_fwd<InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp>(
|
||||
reinterpret_cast<const InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
reinterpret_cast<const WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
|
||||
reinterpret_cast<OutDataType*>(out_device_ref_buf.GetDeviceBuffer()),
|
||||
dims);
|
||||
conv_param);
|
||||
|
||||
HIP_CHECK_ERROR(hipDeviceSynchronize());
|
||||
|
||||
std::cout << "GPU reference kernel completed successfully, copying results..." << std::endl;
|
||||
std::cout << "GPU reference function completed successfully, copying results..."
|
||||
<< std::endl;
|
||||
|
||||
// Copy GPU reference result to host
|
||||
out_device_ref_buf.FromDevice(out_host.mData.data());
|
||||
|
||||
@@ -12,7 +12,7 @@ bool run_convnd_fwd_example(int argc, char* argv[])
|
||||
{
|
||||
print_helper_msg();
|
||||
|
||||
int do_verification = 1; // 0=no, 1=CPU, 2=GPU
|
||||
int do_verification = 2; // 0=no, 1=CPU, 2=GPU
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
@@ -71,6 +71,9 @@ bool run_convnd_fwd_example(int argc, char* argv[])
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
DeviceGroupedConvNDFwdInstance<ndim_spatial_value, InLayout, WeiLayout, OutLayout>,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
ComputeDataType>(do_verification,
|
||||
init_method,
|
||||
time_kernel,
|
||||
|
||||
@@ -31,7 +31,7 @@ class SimpleAppArgs
|
||||
bool do_verification = true;
|
||||
int data_type = 1;
|
||||
int init_method = 2;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
public:
|
||||
void show_usage(const char* cmd)
|
||||
|
||||
@@ -31,7 +31,7 @@ class SimpleAppArgs
|
||||
bool do_verification = true;
|
||||
int data_type = 1;
|
||||
int init_method = 2;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
public:
|
||||
void show_usage(const char* cmd)
|
||||
|
||||
@@ -31,7 +31,7 @@ class SimpleAppArgs
|
||||
bool do_verification = true;
|
||||
int data_type = 1;
|
||||
int init_method = 2;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
public:
|
||||
void show_usage(const char* cmd)
|
||||
|
||||
@@ -53,7 +53,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
do_verification = true;
|
||||
init_method = 1;
|
||||
time_kernel = true;
|
||||
time_kernel = false;
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
|
||||
@@ -44,6 +44,9 @@ add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_spl
|
||||
add_example_executable(example_grouped_gemm_wmma_splitk_bf16 grouped_gemm_wmma_splitk_bf16.cpp)
|
||||
add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_splitk_bf16)
|
||||
|
||||
add_example_executable(example_grouped_gemm_multiple_d_wmma_fp16 grouped_gemm_multiple_d_wmma_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_multiple_d_wmma_fp16)
|
||||
|
||||
list(APPEND gpu_list_tf32 gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
|
||||
@@ -90,7 +90,7 @@ struct ExecutionConfig final
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
int k_batch = 128;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
};
|
||||
|
||||
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include <ck/utility/data_type.hpp>
|
||||
#include <ck/utility/tuple.hpp>
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm_multiple_d.hpp"
|
||||
|
||||
using ::ck::DeviceMem;
|
||||
using ::ck::hip_check_error;
|
||||
using ::ck::HostTensorDescriptor;
|
||||
using ::ck::Tensor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddAdd = ck::tensor_operation::element_wise::AddAdd;
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using DDataType = F16;
|
||||
using DsDataType = ck::Tuple<DDataType, DDataType>;
|
||||
using EDataType = F16;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using DLayout = Row;
|
||||
using DsLayout = ck::Tuple<DLayout, DLayout>;
|
||||
using ELayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = AddAdd;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
static constexpr int NumDs = 2;
|
||||
|
||||
using DeviceGemmInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3
|
||||
// clang-format off
|
||||
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<4, 4, 4>>;
|
||||
// clang-format on
|
||||
|
||||
#include "run_grouped_gemm_multiple_d_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }
|
||||
@@ -71,339 +71,6 @@ using DeviceGemmInstance =
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<4,4,4>>;
|
||||
// clang-format on
|
||||
|
||||
struct ProblemSize final
|
||||
{
|
||||
std::vector<ck::index_t> Ms;
|
||||
std::vector<ck::index_t> Ns;
|
||||
std::vector<ck::index_t> Ks;
|
||||
#include "run_grouped_gemm_multiple_d_example.inc"
|
||||
|
||||
std::vector<ck::index_t> stride_As;
|
||||
std::vector<ck::index_t> stride_Bs;
|
||||
std::vector<std::vector<ck::index_t>> stride_Ds;
|
||||
std::vector<ck::index_t> stride_Cs;
|
||||
|
||||
ck::index_t group_count;
|
||||
};
|
||||
|
||||
struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
};
|
||||
|
||||
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
auto group_count = problem_size.group_count;
|
||||
|
||||
using KernelArguments = ck::tensor_operation::device::GroupedGemmKernelArgument<NumDs>;
|
||||
using GemmDesc = ck::tensor_operation::device::GemmDesc;
|
||||
|
||||
// GEMM shape
|
||||
std::vector<GemmDesc> gemm_descs;
|
||||
std::vector<KernelArguments> ggemm_kargs;
|
||||
std::vector<void*> p_Cs;
|
||||
std::vector<const void*> p_As;
|
||||
std::vector<const void*> p_Bs;
|
||||
std::vector<std::array<const void*, NumDs>> p_Ds = {};
|
||||
|
||||
gemm_descs.reserve(group_count);
|
||||
ggemm_kargs.reserve(group_count);
|
||||
p_As.reserve(group_count);
|
||||
p_Bs.reserve(group_count);
|
||||
p_Ds.reserve(group_count);
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<Tensor<ADataType>> a_tensors;
|
||||
std::vector<Tensor<BDataType>> b_tensors;
|
||||
std::vector<std::array<Tensor<DDataType>, NumDs>> d_tensors;
|
||||
std::vector<Tensor<EDataType>> c_host_tensors;
|
||||
std::vector<Tensor<EDataType>> c_device_result_tensors;
|
||||
|
||||
a_tensors.reserve(group_count);
|
||||
b_tensors.reserve(group_count);
|
||||
d_tensors.reserve(group_count);
|
||||
c_host_tensors.reserve(group_count);
|
||||
c_device_result_tensors.reserve(group_count);
|
||||
|
||||
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
|
||||
|
||||
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, c_tensors_device;
|
||||
std::vector<std::vector<DeviceMemPtr>> d_tensors_device;
|
||||
|
||||
a_tensors_device.reserve(group_count);
|
||||
b_tensors_device.reserve(group_count);
|
||||
c_tensors_device.reserve(group_count);
|
||||
d_tensors_device.resize(group_count); // reserve and update vector size
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{})));
|
||||
b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{})));
|
||||
|
||||
auto d0_tensor = Tensor<DDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{}));
|
||||
auto d1_tensor = Tensor<DDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{}));
|
||||
|
||||
std::array<Tensor<DDataType>, NumDs> d_tens = {d0_tensor, d1_tensor};
|
||||
d_tensors.push_back(d_tens);
|
||||
c_host_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
|
||||
c_device_result_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
|
||||
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
|
||||
<< " b_k_n: " << b_tensors[i].mDesc
|
||||
<< " c_m_n: " << c_device_result_tensors[i].mDesc << std::endl;
|
||||
|
||||
flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i];
|
||||
num_btype += sizeof(ADataType) * a_tensors[i].GetElementSize() +
|
||||
sizeof(BDataType) * b_tensors[i].GetElementSize() +
|
||||
sizeof(DDataType) * d_tensors[i][0].GetElementSize() * NumDs +
|
||||
sizeof(EDataType) * c_device_result_tensors[i].GetElementSize();
|
||||
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
|
||||
}
|
||||
break;
|
||||
case 2:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3<DDataType>{0.0, 1.0});
|
||||
}
|
||||
break;
|
||||
default:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<DDataType, 0>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
a_tensors_device.emplace_back(
|
||||
std::make_unique<DeviceMem>(a_tensors[i].GetElementSpaceSize() * sizeof(ADataType)));
|
||||
b_tensors_device.emplace_back(
|
||||
std::make_unique<DeviceMem>(b_tensors[i].GetElementSpaceSize() * sizeof(BDataType)));
|
||||
c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
|
||||
c_device_result_tensors[i].GetElementSpaceSize() * sizeof(EDataType)));
|
||||
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors_device[i].emplace_back(std::make_unique<DeviceMem>(
|
||||
d_tensors[i][j].GetElementSpaceSize() * sizeof(DDataType)));
|
||||
}
|
||||
|
||||
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
|
||||
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors_device[i][j]->ToDevice(d_tensors[i][j].mData.data());
|
||||
}
|
||||
c_tensors_device[i]->SetZero();
|
||||
|
||||
p_As.push_back(a_tensors_device[i]->GetDeviceBuffer());
|
||||
p_Bs.push_back(b_tensors_device[i]->GetDeviceBuffer());
|
||||
p_Ds.push_back(
|
||||
{d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()});
|
||||
p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer());
|
||||
|
||||
// The device op does not have to know M problem size at lunch time.
|
||||
gemm_descs.push_back({0,
|
||||
problem_size.Ns[i],
|
||||
problem_size.Ks[i],
|
||||
problem_size.stride_As[i],
|
||||
problem_size.stride_Bs[i],
|
||||
problem_size.stride_Cs[i],
|
||||
{problem_size.stride_Cs[i], problem_size.stride_Cs[i]}});
|
||||
ggemm_kargs.push_back(
|
||||
{a_tensors_device[i]->GetDeviceBuffer(),
|
||||
b_tensors_device[i]->GetDeviceBuffer(),
|
||||
{d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()},
|
||||
c_tensors_device[i]->GetDeviceBuffer(),
|
||||
problem_size.Ms[i],
|
||||
problem_size.Ns[i],
|
||||
problem_size.Ks[i],
|
||||
problem_size.stride_As[i],
|
||||
problem_size.stride_Bs[i],
|
||||
{problem_size.stride_Cs[i], problem_size.stride_Cs[i]},
|
||||
problem_size.stride_Cs[i]});
|
||||
}
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
|
||||
// do GEMM
|
||||
auto argument = gemm.MakeArgument(
|
||||
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op);
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument));
|
||||
hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(),
|
||||
ggemm_kargs.data(),
|
||||
gemm.GetDeviceKernelArgSize(&argument),
|
||||
hipMemcpyHostToDevice));
|
||||
gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer());
|
||||
|
||||
invoker.Run(argument, StreamConfig{nullptr, false, 1});
|
||||
|
||||
bool pass = true;
|
||||
if(config.do_verification)
|
||||
{
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceGemmMultipleD<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
{
|
||||
auto karg = ggemm_kargs[i];
|
||||
auto dev_res_tensor =
|
||||
Tensor<float>(f_host_tensor_descriptor(karg.M, karg.N, karg.StrideE, ELayout{}));
|
||||
c_tensors_device[i]->FromDevice(c_device_result_tensors[i].mData.data());
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(a_tensors[i],
|
||||
b_tensors[i],
|
||||
d_tensors[i],
|
||||
c_host_tensors[i],
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
pass &= ck::utils::check_err(c_device_result_tensors[i], c_host_tensors[i]);
|
||||
}
|
||||
|
||||
std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl;
|
||||
}
|
||||
|
||||
if(config.time_kernel)
|
||||
{
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << gemm.GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
std::vector<int> argToIntArray(char* input)
|
||||
{
|
||||
std::vector<int> out;
|
||||
std::istringstream in(input);
|
||||
std::string item;
|
||||
|
||||
while(std::getline(in, item, ','))
|
||||
{
|
||||
out.push_back(std::stoi(item));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ProblemSize problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
if(argc < 10)
|
||||
{
|
||||
std::vector<ck::index_t> Ms{64, 127, 255, 129, 260, 190, 77};
|
||||
problem_size.group_count = Ms.size();
|
||||
|
||||
for(int i = 0; i < problem_size.group_count; i++)
|
||||
{
|
||||
problem_size.Ms.push_back(Ms[i]);
|
||||
problem_size.Ns.push_back(252);
|
||||
problem_size.Ks.push_back(4608);
|
||||
|
||||
problem_size.stride_As.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
|
||||
|
||||
problem_size.stride_Ds.push_back({});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
problem_size.stride_Ds[i].push_back(problem_size.Ns[i]);
|
||||
}
|
||||
}
|
||||
|
||||
std::cout
|
||||
<< "Usage:\n"
|
||||
<< "arg1: verification (0=no, 1=yes)\n"
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
|
||||
<< "arg3: time kernel (0=n0, 1=yes)\n"
|
||||
<< "arg4 to 9: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
|
||||
"64,64 64,64 128,128)\n"
|
||||
<< "... setting default values." << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
|
||||
problem_size.Ms = argToIntArray(argv[4]);
|
||||
problem_size.Ns = argToIntArray(argv[5]);
|
||||
problem_size.Ks = argToIntArray(argv[6]);
|
||||
|
||||
problem_size.stride_As = argToIntArray(argv[7]);
|
||||
problem_size.stride_Bs = argToIntArray(argv[8]);
|
||||
problem_size.stride_Cs = argToIntArray(argv[9]);
|
||||
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
problem_size.stride_Ds.push_back(problem_size.stride_Cs);
|
||||
}
|
||||
|
||||
problem_size.group_count = problem_size.Ms.size();
|
||||
}
|
||||
|
||||
return !run_grouped_gemm(problem_size, config);
|
||||
}
|
||||
int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }
|
||||
|
||||
@@ -58,11 +58,11 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
|
||||
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_CShuffleV3
|
||||
// clang-format off
|
||||
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>;
|
||||
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>;
|
||||
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -57,11 +57,11 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
|
||||
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_CShuffleV3
|
||||
// clang-format off
|
||||
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>;
|
||||
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>;
|
||||
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -323,8 +323,8 @@ bool run_grouped_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
printf("arg4: async hargs (0=n0, 1=yes)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4: async hargs (0=no, 1=yes)\n");
|
||||
printf("arg5: group count (default=16)\n");
|
||||
#if defined(EXAMPLE_USE_SPLITK)
|
||||
printf("arg6: k-batch count (default=1)\n");
|
||||
|
||||
341
example/15_grouped_gemm/run_grouped_gemm_multiple_d_example.inc
Normal file
341
example/15_grouped_gemm/run_grouped_gemm_multiple_d_example.inc
Normal file
@@ -0,0 +1,341 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
struct ProblemSize final
|
||||
{
|
||||
std::vector<ck::index_t> Ms;
|
||||
std::vector<ck::index_t> Ns;
|
||||
std::vector<ck::index_t> Ks;
|
||||
|
||||
std::vector<ck::index_t> stride_As;
|
||||
std::vector<ck::index_t> stride_Bs;
|
||||
std::vector<std::vector<ck::index_t>> stride_Ds;
|
||||
std::vector<ck::index_t> stride_Cs;
|
||||
|
||||
ck::index_t group_count;
|
||||
};
|
||||
|
||||
struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
};
|
||||
|
||||
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
auto group_count = problem_size.group_count;
|
||||
|
||||
using KernelArguments = ck::tensor_operation::device::GroupedGemmKernelArgument<NumDs>;
|
||||
using GemmDesc = ck::tensor_operation::device::GemmDesc;
|
||||
|
||||
// GEMM shape
|
||||
std::vector<GemmDesc> gemm_descs;
|
||||
std::vector<KernelArguments> ggemm_kargs;
|
||||
std::vector<void*> p_Cs;
|
||||
std::vector<const void*> p_As;
|
||||
std::vector<const void*> p_Bs;
|
||||
std::vector<std::array<const void*, NumDs>> p_Ds = {};
|
||||
|
||||
gemm_descs.reserve(group_count);
|
||||
ggemm_kargs.reserve(group_count);
|
||||
p_As.reserve(group_count);
|
||||
p_Bs.reserve(group_count);
|
||||
p_Ds.reserve(group_count);
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<Tensor<ADataType>> a_tensors;
|
||||
std::vector<Tensor<BDataType>> b_tensors;
|
||||
std::vector<std::array<Tensor<DDataType>, NumDs>> d_tensors;
|
||||
std::vector<Tensor<EDataType>> c_host_tensors;
|
||||
std::vector<Tensor<EDataType>> c_device_result_tensors;
|
||||
|
||||
a_tensors.reserve(group_count);
|
||||
b_tensors.reserve(group_count);
|
||||
d_tensors.reserve(group_count);
|
||||
c_host_tensors.reserve(group_count);
|
||||
c_device_result_tensors.reserve(group_count);
|
||||
|
||||
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
|
||||
|
||||
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, c_tensors_device;
|
||||
std::vector<std::vector<DeviceMemPtr>> d_tensors_device;
|
||||
|
||||
a_tensors_device.reserve(group_count);
|
||||
b_tensors_device.reserve(group_count);
|
||||
c_tensors_device.reserve(group_count);
|
||||
d_tensors_device.resize(group_count); // reserve and update vector size
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{})));
|
||||
b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{})));
|
||||
|
||||
auto d0_tensor = Tensor<DDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{}));
|
||||
auto d1_tensor = Tensor<DDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{}));
|
||||
|
||||
std::array<Tensor<DDataType>, NumDs> d_tens = {d0_tensor, d1_tensor};
|
||||
d_tensors.push_back(d_tens);
|
||||
c_host_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
|
||||
c_device_result_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
|
||||
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
|
||||
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
|
||||
<< " b_k_n: " << b_tensors[i].mDesc
|
||||
<< " c_m_n: " << c_device_result_tensors[i].mDesc << std::endl;
|
||||
|
||||
flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i];
|
||||
num_btype += sizeof(ADataType) * a_tensors[i].GetElementSize() +
|
||||
sizeof(BDataType) * b_tensors[i].GetElementSize() +
|
||||
sizeof(DDataType) * d_tensors[i][0].GetElementSize() * NumDs +
|
||||
sizeof(EDataType) * c_device_result_tensors[i].GetElementSize();
|
||||
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
|
||||
}
|
||||
break;
|
||||
case 2:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3<DDataType>{0.0, 1.0});
|
||||
}
|
||||
break;
|
||||
default:
|
||||
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 0>{});
|
||||
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<BDataType, 1>{});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<DDataType, 0>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
a_tensors_device.emplace_back(
|
||||
std::make_unique<DeviceMem>(a_tensors[i].GetElementSpaceSize() * sizeof(ADataType)));
|
||||
b_tensors_device.emplace_back(
|
||||
std::make_unique<DeviceMem>(b_tensors[i].GetElementSpaceSize() * sizeof(BDataType)));
|
||||
c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
|
||||
c_device_result_tensors[i].GetElementSpaceSize() * sizeof(EDataType)));
|
||||
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors_device[i].emplace_back(std::make_unique<DeviceMem>(
|
||||
d_tensors[i][j].GetElementSpaceSize() * sizeof(DDataType)));
|
||||
}
|
||||
|
||||
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
|
||||
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
d_tensors_device[i][j]->ToDevice(d_tensors[i][j].mData.data());
|
||||
}
|
||||
c_tensors_device[i]->SetZero();
|
||||
|
||||
p_As.push_back(a_tensors_device[i]->GetDeviceBuffer());
|
||||
p_Bs.push_back(b_tensors_device[i]->GetDeviceBuffer());
|
||||
p_Ds.push_back(
|
||||
{d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()});
|
||||
p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer());
|
||||
|
||||
// The device op does not have to know M problem size at lunch time.
|
||||
gemm_descs.push_back({0,
|
||||
problem_size.Ns[i],
|
||||
problem_size.Ks[i],
|
||||
problem_size.stride_As[i],
|
||||
problem_size.stride_Bs[i],
|
||||
problem_size.stride_Cs[i],
|
||||
{problem_size.stride_Cs[i], problem_size.stride_Cs[i]}});
|
||||
ggemm_kargs.push_back(
|
||||
{a_tensors_device[i]->GetDeviceBuffer(),
|
||||
b_tensors_device[i]->GetDeviceBuffer(),
|
||||
{d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()},
|
||||
c_tensors_device[i]->GetDeviceBuffer(),
|
||||
problem_size.Ms[i],
|
||||
problem_size.Ns[i],
|
||||
problem_size.Ks[i],
|
||||
problem_size.stride_As[i],
|
||||
problem_size.stride_Bs[i],
|
||||
{problem_size.stride_Cs[i], problem_size.stride_Cs[i]},
|
||||
problem_size.stride_Cs[i]});
|
||||
}
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
|
||||
// do GEMM
|
||||
auto argument = gemm.MakeArgument(
|
||||
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op);
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument));
|
||||
hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(),
|
||||
ggemm_kargs.data(),
|
||||
gemm.GetDeviceKernelArgSize(&argument),
|
||||
hipMemcpyHostToDevice));
|
||||
gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer());
|
||||
|
||||
invoker.Run(argument, StreamConfig{nullptr, false, 1});
|
||||
|
||||
bool pass = true;
|
||||
if(config.do_verification)
|
||||
{
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceGemmMultipleD<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); i++)
|
||||
{
|
||||
auto karg = ggemm_kargs[i];
|
||||
auto dev_res_tensor =
|
||||
Tensor<float>(f_host_tensor_descriptor(karg.M, karg.N, karg.StrideE, ELayout{}));
|
||||
c_tensors_device[i]->FromDevice(c_device_result_tensors[i].mData.data());
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(a_tensors[i],
|
||||
b_tensors[i],
|
||||
d_tensors[i],
|
||||
c_host_tensors[i],
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
pass &= ck::utils::check_err(c_device_result_tensors[i], c_host_tensors[i]);
|
||||
}
|
||||
|
||||
std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl;
|
||||
}
|
||||
|
||||
if(config.time_kernel)
|
||||
{
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << gemm.GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
std::vector<int> argToIntArray(char* input)
|
||||
{
|
||||
std::vector<int> out;
|
||||
std::istringstream in(input);
|
||||
std::string item;
|
||||
|
||||
while(std::getline(in, item, ','))
|
||||
{
|
||||
out.push_back(std::stoi(item));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
bool run_grouped_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
ProblemSize problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
if(argc < 10)
|
||||
{
|
||||
std::vector<ck::index_t> Ms{64, 127, 255, 129, 260, 190, 77};
|
||||
problem_size.group_count = Ms.size();
|
||||
|
||||
for(int i = 0; i < problem_size.group_count; i++)
|
||||
{
|
||||
problem_size.Ms.push_back(Ms[i]);
|
||||
problem_size.Ns.push_back(252);
|
||||
problem_size.Ks.push_back(4608);
|
||||
|
||||
problem_size.stride_As.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
|
||||
|
||||
problem_size.stride_Ds.push_back({});
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
problem_size.stride_Ds[i].push_back(problem_size.Ns[i]);
|
||||
}
|
||||
}
|
||||
|
||||
std::cout
|
||||
<< "Usage:\n"
|
||||
<< "arg1: verification (0=no, 1=yes)\n"
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
|
||||
<< "arg3: time kernel (0=n0, 1=yes)\n"
|
||||
<< "arg4 to 9: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
|
||||
"64,64 64,64 128,128)\n"
|
||||
<< "... setting default values." << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
|
||||
problem_size.Ms = argToIntArray(argv[4]);
|
||||
problem_size.Ns = argToIntArray(argv[5]);
|
||||
problem_size.Ks = argToIntArray(argv[6]);
|
||||
|
||||
problem_size.stride_As = argToIntArray(argv[7]);
|
||||
problem_size.stride_Bs = argToIntArray(argv[8]);
|
||||
problem_size.stride_Cs = argToIntArray(argv[9]);
|
||||
|
||||
for(int j = 0; j < NumDs; ++j)
|
||||
{
|
||||
problem_size.stride_Ds.push_back(problem_size.stride_Cs);
|
||||
}
|
||||
|
||||
problem_size.group_count = problem_size.Ms.size();
|
||||
}
|
||||
|
||||
return run_grouped_gemm(problem_size, config);
|
||||
}
|
||||
@@ -268,7 +268,7 @@ int main()
|
||||
pass &= ck::utils::check_err(r1_m, r1_m_host, "Error: Incorrect results d1", 1e-2, 1e-2);
|
||||
}
|
||||
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
if(time_kernel)
|
||||
{
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
@@ -302,7 +302,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 1024;
|
||||
|
||||
@@ -106,7 +106,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 1024;
|
||||
|
||||
@@ -106,7 +106,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 1024;
|
||||
|
||||
@@ -106,7 +106,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 1024;
|
||||
|
||||
@@ -108,7 +108,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 1024;
|
||||
|
||||
@@ -105,7 +105,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 1024;
|
||||
|
||||
@@ -112,7 +112,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 1024;
|
||||
|
||||
@@ -112,7 +112,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 1024;
|
||||
|
||||
@@ -112,7 +112,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 1024;
|
||||
|
||||
@@ -18,7 +18,8 @@
|
||||
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp"
|
||||
#include "ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp"
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
|
||||
using ::ck::DeviceMem;
|
||||
using ::ck::HostTensorDescriptor;
|
||||
@@ -81,7 +82,10 @@ template <ck::index_t NDimSpatial,
|
||||
typename InElementOp,
|
||||
typename WeiElementOp,
|
||||
typename OutElementOp,
|
||||
typename DeviceConvNdBwdDataInstance>
|
||||
typename DeviceConvNdBwdDataInstance,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
int run_conv_bwd_data(int do_verification,
|
||||
int init_method,
|
||||
bool time_kernel,
|
||||
@@ -225,50 +229,52 @@ int run_conv_bwd_data(int do_verification,
|
||||
}
|
||||
else if(do_verification == 2)
|
||||
{
|
||||
// GPU verification
|
||||
// GPU verification using naive GPU reference
|
||||
std::cout << "Running GPU verification..." << std::endl;
|
||||
|
||||
// Allocate and ZERO GPU memory for reference input
|
||||
DeviceMem in_device_ref_buf(sizeof(InDataType) * in_device.mDesc.GetElementSpaceSize());
|
||||
in_device_ref_buf.SetZero();
|
||||
|
||||
// Extract dimensions using helper function
|
||||
ck::ref::ConvDims dims = ck::utils::conv::extract_conv_dims(conv_param, NDimSpatial);
|
||||
|
||||
constexpr ck::index_t block_size = 256;
|
||||
const ck::long_index_t input_length = dims.N * dims.Di * dims.Hi * dims.Wi * dims.C;
|
||||
const ck::index_t grid_size = (input_length + block_size - 1) / block_size;
|
||||
|
||||
auto gpu_ref_kernel = ck::ref::naive_conv_bwd_data_ndhwc_kzyxc_ndhwk<InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
float,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp>;
|
||||
|
||||
gpu_ref_kernel<<<dim3(grid_size), dim3(block_size), 0, nullptr>>>(
|
||||
// Call GPU reference with ConvParam directly, using the correct layout types
|
||||
ck::ref::naive_conv_bwd_data<InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp>(
|
||||
reinterpret_cast<InDataType*>(in_device_ref_buf.GetDeviceBuffer()),
|
||||
reinterpret_cast<const WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
|
||||
reinterpret_cast<const OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
dims);
|
||||
conv_param,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
|
||||
HIP_CHECK_ERROR(hipDeviceSynchronize());
|
||||
|
||||
std::cout << "GPU reference kernel completed, copying results..." << std::endl;
|
||||
std::cout << "GPU reference function completed successfully, copying results..."
|
||||
<< std::endl;
|
||||
|
||||
// Copy GPU reference result
|
||||
// Copy GPU reference result to host
|
||||
Tensor<InDataType> in_gpu_ref(in_host.mDesc);
|
||||
in_device_ref_buf.FromDevice(in_gpu_ref.mData.data());
|
||||
|
||||
// Copy optimized kernel result
|
||||
// Copy GPU kernel result to host
|
||||
in_device_buf.FromDevice(in_device.mData.data());
|
||||
|
||||
std::cout << "Comparing GPU kernel output vs GPU reference..." << std::endl;
|
||||
|
||||
// Compare: Optimized kernel result vs GPU reference result
|
||||
bool pass = ck::utils::check_err(in_device,
|
||||
in_gpu_ref,
|
||||
"Error: Incorrect results!",
|
||||
get_rtol<InDataType, float>(),
|
||||
get_atol<InDataType, float>());
|
||||
|
||||
std::cout << "GPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
|
||||
return pass ? 0 : 1;
|
||||
|
||||
@@ -92,16 +92,19 @@ int main(int argc, char* argv[])
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
DeviceConvNdBwdDataInstance<1>>(do_verification,
|
||||
init_method,
|
||||
time_kernel,
|
||||
conv_param,
|
||||
in_g_n_c_wis_desc,
|
||||
wei_g_k_c_xs_desc,
|
||||
out_g_n_k_wos_desc,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
DeviceConvNdBwdDataInstance<1>,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(do_verification,
|
||||
init_method,
|
||||
time_kernel,
|
||||
conv_param,
|
||||
in_g_n_c_wis_desc,
|
||||
wei_g_k_c_xs_desc,
|
||||
out_g_n_k_wos_desc,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
}
|
||||
else if(conv_param.num_dim_spatial_ == 2)
|
||||
{
|
||||
@@ -128,16 +131,19 @@ int main(int argc, char* argv[])
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
DeviceConvNdBwdDataInstance<2>>(do_verification,
|
||||
init_method,
|
||||
time_kernel,
|
||||
conv_param,
|
||||
in_g_n_c_wis_desc,
|
||||
wei_g_k_c_xs_desc,
|
||||
out_g_n_k_wos_desc,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
DeviceConvNdBwdDataInstance<2>,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(do_verification,
|
||||
init_method,
|
||||
time_kernel,
|
||||
conv_param,
|
||||
in_g_n_c_wis_desc,
|
||||
wei_g_k_c_xs_desc,
|
||||
out_g_n_k_wos_desc,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
}
|
||||
else if(conv_param.num_dim_spatial_ == 3)
|
||||
{
|
||||
@@ -164,16 +170,19 @@ int main(int argc, char* argv[])
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
DeviceConvNdBwdDataInstance<3>>(do_verification,
|
||||
init_method,
|
||||
time_kernel,
|
||||
conv_param,
|
||||
in_g_n_c_wis_desc,
|
||||
wei_g_k_c_xs_desc,
|
||||
out_g_n_k_wos_desc,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
DeviceConvNdBwdDataInstance<3>,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(do_verification,
|
||||
init_method,
|
||||
time_kernel,
|
||||
conv_param,
|
||||
in_g_n_c_wis_desc,
|
||||
wei_g_k_c_xs_desc,
|
||||
out_g_n_k_wos_desc,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
@@ -119,16 +119,19 @@ int main(int argc, char* argv[])
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
DeviceConvNdBwdDataInstance<1>>(do_verification,
|
||||
init_method,
|
||||
time_kernel,
|
||||
conv_param,
|
||||
in_g_n_c_wis_desc,
|
||||
wei_g_k_c_xs_desc,
|
||||
out_g_n_k_wos_desc,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
DeviceConvNdBwdDataInstance<1>,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(do_verification,
|
||||
init_method,
|
||||
time_kernel,
|
||||
conv_param,
|
||||
in_g_n_c_wis_desc,
|
||||
wei_g_k_c_xs_desc,
|
||||
out_g_n_k_wos_desc,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
}
|
||||
else if(conv_param.num_dim_spatial_ == 2)
|
||||
{
|
||||
@@ -155,16 +158,19 @@ int main(int argc, char* argv[])
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
DeviceConvNdBwdDataInstance<2>>(do_verification,
|
||||
init_method,
|
||||
time_kernel,
|
||||
conv_param,
|
||||
in_g_n_c_wis_desc,
|
||||
wei_g_k_c_xs_desc,
|
||||
out_g_n_k_wos_desc,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
DeviceConvNdBwdDataInstance<2>,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(do_verification,
|
||||
init_method,
|
||||
time_kernel,
|
||||
conv_param,
|
||||
in_g_n_c_wis_desc,
|
||||
wei_g_k_c_xs_desc,
|
||||
out_g_n_k_wos_desc,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
}
|
||||
else if(conv_param.num_dim_spatial_ == 3)
|
||||
{
|
||||
@@ -191,16 +197,19 @@ int main(int argc, char* argv[])
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
DeviceConvNdBwdDataInstance<3>>(do_verification,
|
||||
init_method,
|
||||
time_kernel,
|
||||
conv_param,
|
||||
in_g_n_c_wis_desc,
|
||||
wei_g_k_c_xs_desc,
|
||||
out_g_n_k_wos_desc,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
DeviceConvNdBwdDataInstance<3>,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(do_verification,
|
||||
init_method,
|
||||
time_kernel,
|
||||
conv_param,
|
||||
in_g_n_c_wis_desc,
|
||||
wei_g_k_c_xs_desc,
|
||||
out_g_n_k_wos_desc,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
@@ -11,8 +11,11 @@ add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bw
|
||||
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8)
|
||||
|
||||
add_example_executable(example_grouped_conv_bwd_weight_wmma_fp16 grouped_conv_bwd_weight_wmma_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_wmma_fp16)
|
||||
add_example_executable(example_grouped_conv_bwd_weight_v3_wmma_fp16 grouped_conv_bwd_weight_v3_wmma_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_v3_wmma_fp16)
|
||||
|
||||
add_example_executable(example_grouped_conv_bwd_weight_v3_wmma_bf16 grouped_conv_bwd_weight_v3_wmma_bf16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_v3_wmma_bf16)
|
||||
|
||||
add_example_executable(example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_dl_fp16)
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp"
|
||||
|
||||
using InDataType = BF16;
|
||||
// bf16 kernel use fp32 atomic add to accumulate Weight tensor into global memory
|
||||
using WeiDataType = F32;
|
||||
using OutDataType = BF16;
|
||||
using AccDataType = F32;
|
||||
|
||||
using InElementOp = PassThrough;
|
||||
using WeiElementOp = PassThrough;
|
||||
using OutElementOp = PassThrough;
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using DeviceConvBwdWeightInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffleV3<
|
||||
NDimSpatial,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GNWC,
|
||||
ck::tensor_layout::convolution::NHWGC,
|
||||
ck::tensor_layout::convolution::NDHWGC>>,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GKXC,
|
||||
ck::tensor_layout::convolution::GKYXC,
|
||||
ck::tensor_layout::convolution::GKZYXC>>,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GNWK,
|
||||
ck::tensor_layout::convolution::NHWGK,
|
||||
ck::tensor_layout::convolution::NDHWGK>>,
|
||||
InDataType, // InDataType
|
||||
WeiDataType, // WeiDataType
|
||||
OutDataType, // OutDataType
|
||||
AccDataType, // AccDataType
|
||||
InElementOp, // InElementwiseOperation
|
||||
WeiElementOp, // WeiElementwiseOperation
|
||||
OutElementOp, // OutElementwiseOperation
|
||||
ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
32, // KPerBlock
|
||||
8, // K1
|
||||
16, // MPerWmma
|
||||
16, // NPerWmma
|
||||
4, // MRepeat
|
||||
2, // NRepeat
|
||||
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
1, // ABlockTransferSrcVectorDim
|
||||
1, // ABlockTransferSrcScalarPerVector
|
||||
2, // ABlockTransferDstScalarPerVector_K1
|
||||
true, // ABlockLdsAddExtraM
|
||||
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
|
||||
S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
1, // BBlockTransferSrcVectorDim
|
||||
1, // BBlockTransferSrcScalarPerVector
|
||||
2, // BBlockTransferDstScalarPerVector_K1
|
||||
true, // BBlockLdsAddExtraN
|
||||
1, // CShuffleMRepeatPerShuffle
|
||||
1, // CShuffleNRepeatPerShuffle
|
||||
S<1, 32, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
4>; // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp>;
|
||||
|
||||
#include "run_grouped_conv_bwd_weight_example.inc"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ExecutionConfig config;
|
||||
ck::utils::conv::ConvParam conv_param = DefaultConvParam;
|
||||
|
||||
if(!parse_cmd_args(argc, argv, config, conv_param))
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
switch(conv_param.num_dim_spatial_)
|
||||
{
|
||||
case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param);
|
||||
case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param);
|
||||
case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param);
|
||||
default: break;
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp"
|
||||
|
||||
using InDataType = F16;
|
||||
using WeiDataType = F16;
|
||||
@@ -16,11 +16,20 @@ using OutElementOp = PassThrough;
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using DeviceConvBwdWeightInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle<
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffleV3<
|
||||
NDimSpatial,
|
||||
ck::tensor_layout::convolution::GNDHWC,
|
||||
ck::tensor_layout::convolution::GKZYXC,
|
||||
ck::tensor_layout::convolution::GNDHWK,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GNWC,
|
||||
ck::tensor_layout::convolution::NHWGC,
|
||||
ck::tensor_layout::convolution::NDHWGC>>,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GKXC,
|
||||
ck::tensor_layout::convolution::GKYXC,
|
||||
ck::tensor_layout::convolution::GKZYXC>>,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GNWK,
|
||||
ck::tensor_layout::convolution::NHWGK,
|
||||
ck::tensor_layout::convolution::NDHWGK>>,
|
||||
InDataType, // InDataType
|
||||
WeiDataType, // WeiDataType
|
||||
OutDataType, // OutDataType
|
||||
@@ -32,30 +41,30 @@ using DeviceConvBwdWeightInstance =
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
4, // K0PerBlock
|
||||
32, // KPerBlock
|
||||
8, // K1
|
||||
16, // MPerWMMA
|
||||
16, // NPerWMMA
|
||||
16, // MPerWmma
|
||||
16, // NPerWmma
|
||||
4, // MRepeat
|
||||
2, // NRepeat
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<0, 2, 1>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<0, 2, 1>, // ABlockTransferSrcAccessOrder
|
||||
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
1, // ABlockTransferSrcVectorDim
|
||||
1, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
true, // ABlockLdsExtraM
|
||||
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<0, 2, 1>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<0, 2, 1>, // BBlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferDstScalarPerVector_K1
|
||||
false, // ABlockLdsAddExtraM
|
||||
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
|
||||
S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
1, // BBlockTransferSrcVectorDim
|
||||
1, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
true, // BBlockLdsExtraN
|
||||
4,
|
||||
2,
|
||||
S<1, 32, 1, 8>,
|
||||
1>;
|
||||
2, // BBlockTransferDstScalarPerVector_K1
|
||||
false, // BBlockLdsAddExtraN
|
||||
1, // CShuffleMRepeatPerShuffle
|
||||
1, // CShuffleNRepeatPerShuffle
|
||||
S<1, 32, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
4>; // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
|
||||
@@ -80,6 +89,8 @@ int main(int argc, char* argv[])
|
||||
|
||||
switch(conv_param.num_dim_spatial_)
|
||||
{
|
||||
case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param);
|
||||
case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param);
|
||||
case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param);
|
||||
default: break;
|
||||
}
|
||||
@@ -5,7 +5,7 @@ template <ck::index_t NDimSpatial>
|
||||
bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
|
||||
const ck::utils::conv::ConvParam& conv_param)
|
||||
{
|
||||
// Dl and WMMA ops don't support split_k > 1
|
||||
// Dl ops don't support split_k > 1
|
||||
constexpr ck::index_t split_k = 1;
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
@@ -131,59 +131,71 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
|
||||
|
||||
wei_device_buf.FromDevice(wei_device_result.mData.data());
|
||||
|
||||
return ck::utils::check_err(wei_device_result.mData, wei_host_result.mData);
|
||||
float max_accumulated_value =
|
||||
*std::max_element(wei_host_result.mData.begin(), wei_host_result.mData.end());
|
||||
|
||||
const ck::index_t num_accums = out.GetElementSize() / conv_param.K_;
|
||||
const ck::index_t num_accums_split_k = split_k;
|
||||
double rtol = ck::utils::get_relative_threshold<InDataType, WeiDataType, AccDataType>(
|
||||
num_accums / num_accums_split_k);
|
||||
double atol = ck::utils::get_absolute_threshold<InDataType, WeiDataType, AccDataType>(
|
||||
max_accumulated_value / num_accums_split_k, num_accums / num_accums_split_k);
|
||||
|
||||
return ck::utils::check_err(wei_device_result.mData,
|
||||
wei_host_result.mData,
|
||||
"Error: Incorrect results!",
|
||||
rtol,
|
||||
atol);
|
||||
}
|
||||
else if(config.do_verification == 2)
|
||||
{
|
||||
// GPU verification (only supports G=1, standard convolution)
|
||||
if(conv_param.G_ != 1)
|
||||
{
|
||||
std::cout << "GPU verification only supports G=1 (standard convolution)" << std::endl;
|
||||
std::cout << "Current G=" << conv_param.G_ << " not supported." << std::endl;
|
||||
std::cout << "Use do_verification=1 for CPU verification with grouped convolution."
|
||||
<< std::endl;
|
||||
return true;
|
||||
}
|
||||
|
||||
std::cout << "Running GPU verification (G=1)..." << std::endl;
|
||||
// GPU verification using naive GPU reference
|
||||
std::cout << "Running GPU verification..." << std::endl;
|
||||
|
||||
// Allocate and ZERO GPU memory for reference weights
|
||||
DeviceMem wei_device_ref_buf(sizeof(WeiDataType) *
|
||||
wei_device_result.mDesc.GetElementSpaceSize());
|
||||
wei_device_ref_buf.SetZero();
|
||||
|
||||
// Extract dimensions using helper function (G=1, standard convolution)
|
||||
ck::ref::ConvDims dims = ck::utils::conv::extract_conv_dims(conv_param, NDimSpatial, false);
|
||||
// Call GPU reference function with ConvParam and layout types
|
||||
using InLayout = InputLayout<NDimSpatial>;
|
||||
using WeiLayout = WeightLayout<NDimSpatial>;
|
||||
using OutLayout = OutputLayout<NDimSpatial>;
|
||||
|
||||
constexpr ck::index_t block_size = 256;
|
||||
const ck::long_index_t weight_length = dims.K * dims.Z * dims.Y * dims.X * dims.C;
|
||||
const ck::index_t grid_size = (weight_length + block_size - 1) / block_size;
|
||||
|
||||
auto gpu_ref_kernel = ck::ref::naive_conv_bwd_weight_ndhwc_kzyxc_ndhwk<InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
float,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp>;
|
||||
|
||||
gpu_ref_kernel<<<dim3(grid_size), dim3(block_size), 0, nullptr>>>(
|
||||
ck::ref::naive_conv_bwd_weight<InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp>(
|
||||
reinterpret_cast<const InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
reinterpret_cast<WeiDataType*>(wei_device_ref_buf.GetDeviceBuffer()),
|
||||
reinterpret_cast<const OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
dims);
|
||||
conv_param);
|
||||
|
||||
HIP_CHECK_ERROR(hipDeviceSynchronize());
|
||||
|
||||
std::cout << "GPU reference kernel completed, copying results..." << std::endl;
|
||||
std::cout << "GPU reference function completed successfully, copying results..."
|
||||
<< std::endl;
|
||||
|
||||
// Copy GPU reference result to host
|
||||
wei_device_ref_buf.FromDevice(wei_host_result.mData.data());
|
||||
|
||||
// Copy GPU kernel result to host
|
||||
wei_device_buf.FromDevice(wei_device_result.mData.data());
|
||||
|
||||
std::cout << "Comparing GPU kernel output vs GPU reference..." << std::endl;
|
||||
|
||||
// Compare: Optimized kernel result vs GPU reference result
|
||||
bool pass = ck::utils::check_err(wei_device_result.mData,
|
||||
wei_host_result.mData,
|
||||
"Error: Incorrect results!",
|
||||
get_rtol<WeiDataType, float>(),
|
||||
get_atol<WeiDataType, float>());
|
||||
|
||||
std::cout << "GPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
|
||||
return pass;
|
||||
|
||||
@@ -81,7 +81,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// CGEMM shape
|
||||
ck::index_t M = 1024;
|
||||
|
||||
@@ -65,7 +65,7 @@ class SimpleAppArgs
|
||||
|
||||
bool do_verification = true;
|
||||
int init_method = 2;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
public:
|
||||
void show_usage(const char* cmd)
|
||||
|
||||
@@ -27,7 +27,7 @@ struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
|
||||
@@ -248,7 +248,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
ck::index_t G0 = 1;
|
||||
ck::index_t G1 = 2;
|
||||
|
||||
@@ -92,7 +92,7 @@ struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
};
|
||||
|
||||
#define DefaultConvParam \
|
||||
|
||||
@@ -92,7 +92,7 @@ struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
};
|
||||
|
||||
#define DefaultConvParam \
|
||||
|
||||
@@ -40,7 +40,7 @@ class SimpleAppArgs
|
||||
|
||||
bool do_verification = true;
|
||||
int init_method = 2;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
public:
|
||||
SimpleAppArgs()
|
||||
|
||||
@@ -44,7 +44,7 @@ struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 2;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
};
|
||||
|
||||
template <ck::index_t... Is>
|
||||
|
||||
@@ -56,7 +56,7 @@ template<> struct emb_kernel<ck::half_t, 8192> { using kernel_type = DeviceInsta
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
ck::index_t num_rows = 65536;
|
||||
constexpr auto dims = ck::Sequence<256, 512, 768, 1024, 1536, 2048, 4096, 8192>{};
|
||||
|
||||
@@ -195,7 +195,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 1024;
|
||||
|
||||
@@ -9,8 +9,29 @@ add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_
|
||||
add_example_executable(example_grouped_conv_bwd_data_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_data_xdl_fp16_comp_bf8_fp8.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_xdl_fp16_comp_bf8_fp8)
|
||||
|
||||
add_example_executable(example_grouped_conv_bwd_data_wmma_v3_fp16_comp_bf8_fp8 grouped_conv_bwd_data_wmma_v3_fp16_comp_bf8_fp8.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_wmma_v3_fp16_comp_bf8_fp8)
|
||||
|
||||
add_example_executable(example_grouped_conv_bwd_data_bias_relu_xdl_fp16 grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_xdl_fp16)
|
||||
|
||||
add_example_executable(example_grouped_conv_bwd_data_bias_relu_wmma_v3_fp16 grouped_conv_bwd_data_bias_relu_wmma_v3_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_wmma_v3_fp16)
|
||||
|
||||
add_example_executable(example_grouped_conv_bwd_data_wmma_fp16 grouped_conv_bwd_data_wmma_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_wmma_fp16)
|
||||
|
||||
add_example_executable(example_grouped_conv_bwd_data_wmma_v3_bf16 grouped_conv_bwd_data_wmma_v3_bf16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_wmma_v3_bf16)
|
||||
|
||||
add_example_executable(example_grouped_conv3d_bwd_data_wmma_v3_bf16 grouped_conv3d_bwd_data_wmma_v3_bf16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv3d_bwd_data_wmma_v3_bf16)
|
||||
|
||||
add_example_executable(example_grouped_conv3d_bwd_data_wmma_v3_fp16 grouped_conv3d_bwd_data_wmma_v3_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv3d_bwd_data_wmma_v3_fp16)
|
||||
|
||||
add_example_executable(example_grouped_conv_bwd_data_wmma_v3_fp16 grouped_conv_bwd_data_wmma_v3_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_wmma_v3_fp16)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -37,7 +37,11 @@ static inline constexpr ck::index_t NDimSpatial = 2;
|
||||
static constexpr auto ConvBwdDataDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default;
|
||||
|
||||
static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 =
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0;
|
||||
|
||||
using FP16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using FP32 = float;
|
||||
using FP8 = ck::f8_t;
|
||||
using BF8 = ck::bf8_t;
|
||||
|
||||
116
example/38_grouped_conv_bwd_data_multiple_d/common_conv3d.hpp
Normal file
116
example/38_grouped_conv_bwd_data_multiple_d/common_conv3d.hpp
Normal file
@@ -0,0 +1,116 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <initializer_list>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp"
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
|
||||
#include "ck/library/utility/convolution_parameter.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
|
||||
using ::ck::DeviceMem;
|
||||
using ::ck::hip_check_error;
|
||||
using ::ck::HostTensorDescriptor;
|
||||
using ::ck::Tensor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static inline constexpr ck::index_t NDimSpatial = 3;
|
||||
|
||||
static constexpr auto ConvBwdDataDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default;
|
||||
|
||||
static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 =
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0;
|
||||
|
||||
using FP16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using FP32 = float;
|
||||
using FP8 = ck::f8_t;
|
||||
using BF8 = ck::bf8_t;
|
||||
|
||||
struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
};
|
||||
|
||||
#define DefaultConvParams \
|
||||
ck::utils::conv::ConvParam \
|
||||
{ \
|
||||
NDimSpatial, 32, 4, 192, 192, {3, 3, 3}, {28, 28, 28}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, \
|
||||
{ \
|
||||
1, 1, 1 \
|
||||
} \
|
||||
}
|
||||
|
||||
inline void print_help_msg()
|
||||
{
|
||||
std::cerr << "arg1: verification (0=no, 1=yes)\n"
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
|
||||
<< "arg3: time kernel (0=no, 1=yes)\n"
|
||||
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl;
|
||||
}
|
||||
|
||||
inline bool parse_cmd_args(int argc,
|
||||
char* argv[],
|
||||
ExecutionConfig& config,
|
||||
ck::utils::conv::ConvParam& conv_params)
|
||||
{
|
||||
constexpr int num_execution_config_args =
|
||||
3; // arguments for do_verification, init_method, time_kernel
|
||||
constexpr int num_conv_param_leading_args = 5; // arguments for num_dim_spatial_, G_, N_, K_, C_
|
||||
|
||||
constexpr int threshold_to_catch_partial_args = 1 + num_execution_config_args;
|
||||
constexpr int threshold_to_catch_all_args =
|
||||
threshold_to_catch_partial_args + num_conv_param_leading_args;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default
|
||||
config = ExecutionConfig{};
|
||||
}
|
||||
// catch only ExecutionConfig arguments
|
||||
else if(argc == threshold_to_catch_partial_args)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
// catch both ExecutionConfig & ConvParam arguments
|
||||
else if(threshold_to_catch_all_args < argc && ((argc - threshold_to_catch_all_args) % 3 == 0))
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
|
||||
const ck::index_t num_dim_spatial = std::stoi(argv[4]);
|
||||
conv_params = ck::utils::conv::parse_conv_param(
|
||||
num_dim_spatial, threshold_to_catch_partial_args + 1, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
print_help_msg();
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "common_conv3d.hpp"
|
||||
|
||||
using OutDataType = BF16;
|
||||
using WeiDataType = BF16;
|
||||
using AccDataType = FP32;
|
||||
using CShuffleDataType = BF16;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using InDataType = BF16;
|
||||
|
||||
using InLayout = ck::tensor_layout::convolution::NDHWGC;
|
||||
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
// clang-format off
|
||||
using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3
|
||||
// ######| NDimSpatial| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer|
|
||||
// ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat | NRepeat | _MBlock_MPerBlock| ScalarPerVector|
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< NDimSpatial, OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>;
|
||||
// clang-format on
|
||||
|
||||
#include "run_grouped_conv3d_bwd_data_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return run_grouped_conv_bwd_data_example(argc, argv); }
|
||||
@@ -0,0 +1,30 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "common_conv3d.hpp"
|
||||
using OutDataType = FP16;
|
||||
using WeiDataType = FP16;
|
||||
using AccDataType = FP32;
|
||||
using CShuffleDataType = FP16;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using InDataType = FP16;
|
||||
|
||||
using InLayout = ck::tensor_layout::convolution::NDHWGC;
|
||||
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
// clang-format off
|
||||
using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3
|
||||
// ######| NDimSpatial| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer|
|
||||
// ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat | NRepeat | _MBlock_MPerBlock| ScalarPerVector|
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< NDimSpatial, OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, PassThrough, PassThrough, PassThrough, ConvBwdDataDefault, true, true, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>;
|
||||
// clang-format on
|
||||
|
||||
#include "run_grouped_conv3d_bwd_data_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return run_grouped_conv_bwd_data_example(argc, argv); }
|
||||
@@ -0,0 +1,34 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "common.hpp"
|
||||
|
||||
using OutDataType = FP16;
|
||||
using WeiDataType = FP16;
|
||||
using AccDataType = FP32;
|
||||
using CShuffleDataType = FP16;
|
||||
using BiasDataType = FP16; // bias
|
||||
using InDataType = FP16;
|
||||
|
||||
using OutLayout = ck::tensor_layout::convolution::GNHWK;
|
||||
using WeiLayout = ck::tensor_layout::convolution::GKYXC;
|
||||
using BiasLayout = ck::Tuple<ck::tensor_layout::convolution::G_C>;
|
||||
using InLayout = ck::tensor_layout::convolution::GNHWC;
|
||||
|
||||
using OutElementOp = PassThrough;
|
||||
using WeiElementOp = PassThrough;
|
||||
using InElementOp = ck::tensor_operation::element_wise::AddRelu;
|
||||
|
||||
// clang-format off
|
||||
using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3
|
||||
// ######| NDimSpatial| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
|
||||
// ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< NDimSpatial, OutLayout, WeiLayout, BiasLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, ck::Tuple<BiasDataType>, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>;
|
||||
// clang-format on
|
||||
|
||||
#include "run_grouped_conv_bwd_data_bias_relu_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return run_grouped_conv_bwd_data_bias_relu_example(argc, argv); }
|
||||
@@ -0,0 +1,34 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "common.hpp"
|
||||
|
||||
using OutDataType = BF16;
|
||||
using WeiDataType = BF16;
|
||||
using AccDataType = FP32;
|
||||
using CShuffleDataType = BF16;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using InDataType = BF16;
|
||||
|
||||
using OutLayout = ck::tensor_layout::convolution::GNHWK;
|
||||
using WeiLayout = ck::tensor_layout::convolution::GKYXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using InLayout = ck::tensor_layout::convolution::GNHWC;
|
||||
|
||||
using OutElementOp = PassThrough;
|
||||
using WeiElementOp = PassThrough;
|
||||
using InElementOp = PassThrough;
|
||||
|
||||
// clang-format off
|
||||
using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3
|
||||
// ######| NDimSpatial| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer|
|
||||
// ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat | NRepeat | _MBlock_MPerBlock| ScalarPerVector|
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< NDimSpatial, OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>;
|
||||
// clang-format on
|
||||
|
||||
#include "run_grouped_conv_bwd_data_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return run_grouped_conv_bwd_data_example(argc, argv); }
|
||||
@@ -0,0 +1,35 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "common.hpp"
|
||||
|
||||
using OutDataType = FP16;
|
||||
using WeiDataType = FP16;
|
||||
using AccDataType = FP32;
|
||||
using CShuffleDataType = FP16;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using InDataType = FP16;
|
||||
|
||||
using OutLayout = ck::tensor_layout::convolution::GNHWK;
|
||||
using WeiLayout = ck::tensor_layout::convolution::GKYXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using InLayout = ck::tensor_layout::convolution::GNHWC;
|
||||
|
||||
using OutElementOp = PassThrough;
|
||||
using WeiElementOp = PassThrough;
|
||||
using InElementOp = PassThrough;
|
||||
|
||||
// clang-format off
|
||||
using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3
|
||||
// ######| NDimSpatial| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer|
|
||||
// ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat | NRepeat | _MBlock_MPerBlock| ScalarPerVector|
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< NDimSpatial, OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>>;
|
||||
|
||||
// clang-format on
|
||||
|
||||
#include "run_grouped_conv_bwd_data_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return run_grouped_conv_bwd_data_example(argc, argv); }
|
||||
@@ -0,0 +1,47 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "common.hpp"
|
||||
|
||||
using OutDataType = FP16;
|
||||
using WeiDataType = FP16;
|
||||
using AccDataType = FP32;
|
||||
using CShuffleDataType = FP16;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using InDataType = FP16;
|
||||
using AComputeType = BF8;
|
||||
using BComputeType = FP8;
|
||||
|
||||
using OutLayout = ck::tensor_layout::convolution::GNHWK;
|
||||
using WeiLayout = ck::tensor_layout::convolution::GKYXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using InLayout = ck::tensor_layout::convolution::GNHWC;
|
||||
|
||||
using OutElementOp = PassThrough;
|
||||
using WeiElementOp = PassThrough;
|
||||
using InElementOp = PassThrough;
|
||||
|
||||
static constexpr auto BlkGemmPipeSched = ck::BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto BlkGemmPipelineVer = ck::BlockGemmPipelineVersion::v1;
|
||||
|
||||
// clang-format off
|
||||
using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3
|
||||
// ######| NDimSpatial| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| Loop| ACompute| BCompute|
|
||||
// ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| Scheduler| Type| Type|
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| | | |
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< NDimSpatial, OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>, BlkGemmPipeSched,BlkGemmPipelineVer, AComputeType, BComputeType , false , false>;
|
||||
// clang-format on
|
||||
|
||||
#include "run_grouped_conv_bwd_data_example.inc"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
// temp disable on gfx11
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
return run_grouped_conv_bwd_data_example(argc, argv);
|
||||
}
|
||||
@@ -0,0 +1,192 @@
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using OutElementOp = PassThrough;
|
||||
using WeiElementOp = PassThrough;
|
||||
using InElementOp = PassThrough;
|
||||
|
||||
bool run_conv_bwd_data(const ExecutionConfig& config,
|
||||
const ck::utils::conv::ConvParam& conv_params,
|
||||
const HostTensorDescriptor& out_g_n_k_wos_desc,
|
||||
const HostTensorDescriptor& wei_g_k_c_xs_desc,
|
||||
const HostTensorDescriptor& in_g_n_c_wis_desc,
|
||||
const OutElementOp& out_element_op,
|
||||
const WeiElementOp& wei_element_op,
|
||||
const InElementOp& in_element_op)
|
||||
{
|
||||
|
||||
Tensor<OutDataType> out(out_g_n_k_wos_desc);
|
||||
Tensor<WeiDataType> wei(wei_g_k_c_xs_desc);
|
||||
Tensor<InDataType> in_host(in_g_n_c_wis_desc);
|
||||
Tensor<InDataType> in_device(in_g_n_c_wis_desc);
|
||||
|
||||
std::cout << "out: " << out.mDesc << std::endl;
|
||||
std::cout << "wei: " << wei.mDesc << std::endl;
|
||||
std::cout << "in: " << in_host.mDesc << std::endl;
|
||||
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
out.GenerateTensorValue(GeneratorTensor_2<OutDataType>{-5, 5});
|
||||
wei.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
out.GenerateTensorValue(GeneratorTensor_3<OutDataType>{0.0, 1.0});
|
||||
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
|
||||
}
|
||||
|
||||
DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize());
|
||||
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize());
|
||||
DeviceMem in_device_buf(sizeof(InDataType) * in_device.mDesc.GetElementSpaceSize());
|
||||
|
||||
out_device_buf.ToDevice(out.mData.data());
|
||||
wei_device_buf.ToDevice(wei.mData.data());
|
||||
|
||||
// reset input to zero
|
||||
in_device_buf.SetZero();
|
||||
|
||||
std::array<ck::index_t, NDimSpatial + 3> a_g_n_k_wos_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> a_g_n_k_wos_strides{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_strides{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> e_g_n_c_wis_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> e_g_n_c_wis_strides{};
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads{};
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads{};
|
||||
|
||||
auto copy = [](auto& x, auto& y) { ck::ranges::copy(x, y.begin()); };
|
||||
|
||||
copy(out_g_n_k_wos_desc.GetLengths(), a_g_n_k_wos_lengths);
|
||||
copy(out_g_n_k_wos_desc.GetStrides(), a_g_n_k_wos_strides);
|
||||
copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths);
|
||||
copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides);
|
||||
copy(in_g_n_c_wis_desc.GetLengths(), e_g_n_c_wis_lengths);
|
||||
copy(in_g_n_c_wis_desc.GetStrides(), e_g_n_c_wis_strides);
|
||||
copy(conv_params.conv_filter_strides_, conv_filter_strides);
|
||||
copy(conv_params.conv_filter_dilations_, conv_filter_dilations);
|
||||
copy(conv_params.input_left_pads_, input_left_pads);
|
||||
copy(conv_params.input_right_pads_, input_right_pads);
|
||||
|
||||
static_assert(std::is_default_constructible_v<DeviceConvInstance>);
|
||||
// do conv
|
||||
auto conv = DeviceConvInstance{};
|
||||
auto invoker = conv.MakeInvoker();
|
||||
auto argument = conv.MakeArgument(out_device_buf.GetDeviceBuffer(),
|
||||
wei_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, 0>{},
|
||||
in_device_buf.GetDeviceBuffer(),
|
||||
a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
std::array<std::array<ck::index_t, NDimSpatial + 3>, 0>{},
|
||||
std::array<std::array<ck::index_t, NDimSpatial + 3>, 0>{},
|
||||
e_g_n_c_wis_lengths,
|
||||
e_g_n_c_wis_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
out_element_op,
|
||||
wei_element_op,
|
||||
in_element_op);
|
||||
|
||||
if(!conv.IsSupportedArgument(argument))
|
||||
{
|
||||
std::cerr << "wrong! device_conv with the specified compilation parameters does "
|
||||
"not support this Conv problem"
|
||||
<< std::endl;
|
||||
|
||||
return false;
|
||||
}
|
||||
std::string op_name = conv.GetTypeString();
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
|
||||
|
||||
std::size_t flop = conv_params.GetFlops();
|
||||
std::size_t num_btype = conv_params.GetByte<InDataType, WeiDataType, OutDataType>();
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
if(config.do_verification)
|
||||
{
|
||||
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData<NDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
PassThrough,
|
||||
WeiElementOp,
|
||||
OutElementOp>();
|
||||
|
||||
auto ref_invoker = ref_conv.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_conv.MakeArgument(in_host,
|
||||
wei,
|
||||
out,
|
||||
conv_params.conv_filter_strides_,
|
||||
conv_params.conv_filter_dilations_,
|
||||
conv_params.input_left_pads_,
|
||||
conv_params.input_right_pads_,
|
||||
PassThrough{},
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
in_device_buf.FromDevice(in_device.mData.data());
|
||||
return ck::utils::check_err(in_device.mData, in_host.mData);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int run_grouped_conv_bwd_data_example(int argc, char* argv[])
|
||||
{
|
||||
namespace ctc = ck::tensor_layout::convolution;
|
||||
|
||||
ExecutionConfig config;
|
||||
ck::utils::conv::ConvParam conv_params = DefaultConvParams;
|
||||
|
||||
if(!parse_cmd_args(argc, argv, config, conv_params))
|
||||
{
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
const auto in_element_op = InElementOp{};
|
||||
const auto wei_element_op = WeiElementOp{};
|
||||
const auto out_element_op = OutElementOp{};
|
||||
|
||||
if(conv_params.num_dim_spatial_ != NDimSpatial)
|
||||
{
|
||||
std::cerr << "unsupported # of spatials dimensions" << std::endl;
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
// output image: GNHWK
|
||||
const auto out_g_n_k_wos_desc =
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
|
||||
conv_params);
|
||||
|
||||
// weight: GKYXC
|
||||
const auto wei_g_k_c_xs_desc =
|
||||
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(conv_params);
|
||||
|
||||
// input image: GNHWC
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_params);
|
||||
|
||||
return !run_conv_bwd_data(config,
|
||||
conv_params,
|
||||
out_g_n_k_wos_desc,
|
||||
wei_g_k_c_xs_desc,
|
||||
in_g_n_c_wis_desc,
|
||||
wei_element_op,
|
||||
out_element_op,
|
||||
in_element_op);
|
||||
}
|
||||
@@ -86,7 +86,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
|
||||
@@ -84,7 +84,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
|
||||
@@ -87,7 +87,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
|
||||
@@ -84,7 +84,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
|
||||
@@ -84,7 +84,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
|
||||
@@ -90,7 +90,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
|
||||
@@ -88,7 +88,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
|
||||
@@ -88,7 +88,7 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
|
||||
@@ -12,7 +12,7 @@ int run_groupnorm_fwd_example(int argc, char* argv[])
|
||||
ck::index_t C = 128;
|
||||
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
bool log_kernel = true;
|
||||
|
||||
if(argc == 1)
|
||||
|
||||
@@ -53,7 +53,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
std::vector<std::size_t> nchw = {16, 128, 32, 64};
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
|
||||
@@ -50,7 +50,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
|
||||
@@ -50,7 +50,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
|
||||
@@ -49,7 +49,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
|
||||
@@ -50,7 +50,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
|
||||
@@ -121,7 +121,7 @@ void reference_scale_permute_amax(Tensor<InputDataType>& input,
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
const float scale = 2.f;
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
|
||||
@@ -84,7 +84,7 @@ void host_elementwise2D(HostTensorC& C,
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
ck::index_t M = 48 * 256;
|
||||
ck::index_t N = 1024;
|
||||
|
||||
@@ -11,3 +11,12 @@ add_example_executable(example_conv_fwd_xdl_scaleadd_ab_bf16 conv_fwd_xdl_scalea
|
||||
add_example_dependencies(example_convnd_activ_multi_ab_xdl example_conv_fwd_xdl_scaleadd_ab_bf16)
|
||||
add_example_executable(example_conv_fwd_xdl_scaleadd_ab_int8 conv_fwd_xdl_scaleadd_ab_int8.cpp)
|
||||
add_example_dependencies(example_convnd_activ_multi_ab_xdl example_conv_fwd_xdl_scaleadd_ab_int8)
|
||||
|
||||
add_custom_target(example_convnd_activ_multi_ab_wmma_cshufflev3)
|
||||
# ScaleAdd on A and B
|
||||
add_example_executable(example_conv_fwd_wmma_cshufflev3_scaleadd_ab_fp16 conv_fwd_wmma_cshufflev3_scaleadd_ab_fp16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_multi_ab_wmma_cshufflev3 example_conv_fwd_wmma_cshufflev3_scaleadd_ab_fp16)
|
||||
add_example_executable(example_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16 conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16.cpp)
|
||||
add_example_dependencies(example_convnd_activ_multi_ab_wmma_cshufflev3 example_conv_fwd_wmma_cshufflev3_scaleadd_ab_bf16)
|
||||
add_example_executable(example_conv_fwd_wmma_cshufflev3_scaleadd_ab_int8 conv_fwd_wmma_cshufflev3_scaleadd_ab_int8.cpp)
|
||||
add_example_dependencies(example_convnd_activ_multi_ab_wmma_cshufflev3 example_conv_fwd_wmma_cshufflev3_scaleadd_ab_int8)
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#define EXAMPLE_USE_WMMA
|
||||
#include "convnd_fwd_activ_multi_ab_common.hpp"
|
||||
|
||||
using DataType = ck::bhalf_t;
|
||||
using AccDataType = float;
|
||||
using InDataType = DataType;
|
||||
using WeiDataType = DataType;
|
||||
using OutDataType = DataType;
|
||||
using ADataTypes = ck::Tuple<DataType, DataType>;
|
||||
using BDataTypes = ck::Tuple<DataType, DataType>;
|
||||
|
||||
using InElementOp = ck::tensor_operation::element_wise::ScaleAdd;
|
||||
using WeiElementOp = ck::tensor_operation::element_wise::ScaleAdd;
|
||||
|
||||
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDMultiABFwdInstance<DataType,
|
||||
AccDataType,
|
||||
ADataTypes,
|
||||
BDataTypes,
|
||||
InElementOp,
|
||||
WeiElementOp>;
|
||||
|
||||
#include "../run_convnd_activ_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
|
||||
@@ -0,0 +1,27 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#define EXAMPLE_USE_WMMA
|
||||
#include "convnd_fwd_activ_multi_ab_common.hpp"
|
||||
|
||||
using DataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using InDataType = DataType;
|
||||
using WeiDataType = DataType;
|
||||
using OutDataType = DataType;
|
||||
using ADataTypes = ck::Tuple<DataType, DataType>;
|
||||
using BDataTypes = ck::Tuple<DataType, DataType>;
|
||||
|
||||
using InElementOp = ck::tensor_operation::element_wise::ScaleAdd;
|
||||
using WeiElementOp = ck::tensor_operation::element_wise::ScaleAdd;
|
||||
|
||||
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDMultiABFwdInstance<DataType,
|
||||
AccDataType,
|
||||
ADataTypes,
|
||||
BDataTypes,
|
||||
InElementOp,
|
||||
WeiElementOp>;
|
||||
|
||||
#include "../run_convnd_activ_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
|
||||
@@ -0,0 +1,27 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#define EXAMPLE_USE_WMMA
|
||||
#include "convnd_fwd_activ_multi_ab_common.hpp"
|
||||
|
||||
using DataType = int8_t;
|
||||
using AccDataType = int32_t;
|
||||
using InDataType = DataType;
|
||||
using WeiDataType = DataType;
|
||||
using OutDataType = DataType;
|
||||
using ADataTypes = ck::Tuple<DataType, DataType>;
|
||||
using BDataTypes = ck::Tuple<DataType, DataType>;
|
||||
|
||||
using InElementOp = ck::tensor_operation::element_wise::ScaleAdd;
|
||||
using WeiElementOp = ck::tensor_operation::element_wise::ScaleAdd;
|
||||
|
||||
using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDMultiABFwdInstance<DataType,
|
||||
AccDataType,
|
||||
ADataTypes,
|
||||
BDataTypes,
|
||||
InElementOp,
|
||||
WeiElementOp>;
|
||||
|
||||
#include "../run_convnd_activ_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); }
|
||||
@@ -9,7 +9,11 @@
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#ifdef EXAMPLE_USE_WMMA
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"
|
||||
#else
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
|
||||
#endif
|
||||
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
@@ -41,6 +45,62 @@ static constexpr auto ConvSpec =
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
#ifdef EXAMPLE_USE_WMMA
|
||||
template <typename DataType,
|
||||
typename AccDataType,
|
||||
typename InDataTypes,
|
||||
typename WeiDataTypes,
|
||||
typename InElementOp,
|
||||
typename WeiElementOp>
|
||||
using DeviceGroupedConvNDMultiABFwdInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<
|
||||
NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
ck::Tuple<>,
|
||||
OutLayout,
|
||||
InDataTypes,
|
||||
WeiDataTypes,
|
||||
AccDataType,
|
||||
DataType,
|
||||
ck::Tuple<>,
|
||||
DataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
ConvSpec, // ConvForwardSpecialization
|
||||
GemmSpec, // GemmSpecialization
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
256, // NPerBlock
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
16, // MPerWmma
|
||||
16, // NPerWmma
|
||||
4, // MWmmaPerWave
|
||||
4, // NWmmaPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave,
|
||||
ck::BlockGemmPipelineVersion::v1>;
|
||||
#else
|
||||
template <typename DataType,
|
||||
typename AccDataType,
|
||||
typename InDataTypes,
|
||||
@@ -94,6 +154,7 @@ using DeviceGroupedConvNDMultiABFwdInstance =
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
4>;
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
template <ck::index_t NDimSpatial,
|
||||
@@ -261,6 +322,8 @@ bool run_grouped_conv(bool do_verification,
|
||||
|
||||
out_device_buf.FromDevice(out_device.mData.data());
|
||||
|
||||
printf("Running verification\n");
|
||||
|
||||
return ck::utils::check_err(out_device, out_host, "Error: incorrect results!");
|
||||
}
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ add_example_executable(example_moe_gemm1_xdl_fp8 moe_gemm1_xdl_fp8.cpp)
|
||||
add_example_executable(example_moe_gemm2_xdl_fp8 moe_gemm2_xdl_fp8.cpp)
|
||||
add_example_executable(example_moe_gemm2_xdl_fp8_blockscale moe_gemm2_xdl_fp8_blockscale.cpp)
|
||||
add_example_executable(example_moe_gemm1_xdl_fp8_blockscale moe_gemm1_xdl_fp8_blockscale.cpp)
|
||||
add_example_executable(example_moe_gemm1_xdl_fp8_blockscale_splitk moe_gemm1_xdl_fp8_blockscale_splitk.cpp)
|
||||
|
||||
list(APPEND gpu_list gfx942 gfx950 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1153 gfx1200 gfx1201 gfx11-generic gfx12-generic)
|
||||
set(target 0)
|
||||
|
||||
@@ -205,7 +205,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t N = 4096;
|
||||
|
||||
@@ -171,7 +171,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceM
|
||||
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
CShuffleMXDLPerWave, CShuffleNXDLPerWave, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, true, MulRoutedWeight, int32_t, A0DataType>;
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, true, false, MulRoutedWeight, int32_t, A0DataType>;
|
||||
#else
|
||||
static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
|
||||
Row, Col, DsLayout, ELayout,
|
||||
@@ -185,7 +185,7 @@ static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
4, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, MulRoutedWeight, int32_t, A0DataType>;
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, false, MulRoutedWeight, int32_t, A0DataType>;
|
||||
#endif
|
||||
// clang-format on
|
||||
|
||||
@@ -193,7 +193,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
#if 1
|
||||
// GEMM shape
|
||||
ck::index_t N = 4096;
|
||||
|
||||
@@ -0,0 +1,543 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_moe_gemm1_blockscale_splitk.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
|
||||
#include "ck/utility/blkgemmpipe_scheduler.hpp"
|
||||
|
||||
using ::ck::DeviceMem;
|
||||
using ::ck::HostTensorDescriptor;
|
||||
using ::ck::Tensor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F8 = ck::f8_t;
|
||||
using F32 = float;
|
||||
using I64 = int64_t;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using Bypass = ck::tensor_layout::BypassLayoutVerification;
|
||||
|
||||
using A0DataType = F8;
|
||||
using A1DataType = F32;
|
||||
using B0DataType = F8;
|
||||
using B1DataType = F32;
|
||||
using EDataType = F32;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = EDataType;
|
||||
using D2DataType = F32;
|
||||
using DsDataType = ck::Tuple<D2DataType>;
|
||||
|
||||
using A0Layout = Row;
|
||||
using B0Layout = Col;
|
||||
using ELayout = Row;
|
||||
using D0Layout = Row;
|
||||
using D1Layout = Col;
|
||||
using D2Layout = ELayout;
|
||||
using DsLayout = ck::Tuple<D2Layout>;
|
||||
|
||||
struct MulABScaleExpertWeight
|
||||
{
|
||||
template <typename E, typename C, typename D2>
|
||||
__host__ __device__ constexpr void operator()(E& e, const C& c, const D2& d2) const;
|
||||
// for real kernel use
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<EDataType, EDataType, float>(EDataType& e, const EDataType& c, const float& d2) const
|
||||
{
|
||||
(void)d2;
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
};
|
||||
|
||||
void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl)
|
||||
{
|
||||
int KPack = 16 / sizeof(B0DataType);
|
||||
int NLane = NXdl;
|
||||
int KLane = 64 / NLane;
|
||||
|
||||
int K0 = K / (KLane * KPack);
|
||||
// K -> K0 KLane KPack
|
||||
// N -> N0 NLane
|
||||
// N, K -> N0 K0 KLane NLane KPack
|
||||
int tempk;
|
||||
for(I64 n = 0; n < N; ++n)
|
||||
{
|
||||
for(I64 k = 0; k < K; ++k)
|
||||
{
|
||||
I64 n0 = n / NLane;
|
||||
I64 n1 = n % NLane;
|
||||
|
||||
I64 k0 = k / (KLane * KPack);
|
||||
tempk = k % (KLane * KPack);
|
||||
I64 k1 = tempk / KPack;
|
||||
I64 k2 = tempk % KPack;
|
||||
|
||||
I64 outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
|
||||
k1 * KPack * NLane + n1 * KPack + k2;
|
||||
|
||||
dst[outputIndex] = src[n * static_cast<I64>(K) + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = MulABScaleExpertWeight;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
static constexpr ck::index_t Scale_Block_M = 1;
|
||||
static constexpr ck::index_t Scale_Block_N = 128;
|
||||
static constexpr ck::index_t Scale_Block_K = 128;
|
||||
|
||||
static constexpr ck::index_t Nswizzle = false;
|
||||
static constexpr ck::index_t IsInputGemm = true; // splitk gemm1 goes to gemm2 pipeline.
|
||||
static constexpr ck::index_t IsSplitK = true; // splitk gemm1
|
||||
static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul
|
||||
static constexpr bool MulRoutedWeight = false; // splitk gemm1 does not do routedWeight.
|
||||
|
||||
#if 1
|
||||
static constexpr ck::index_t MPerBlock = 64;
|
||||
static constexpr ck::index_t NPerBlock = 128;
|
||||
static constexpr ck::index_t MNPerXDL = 16;
|
||||
static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * 1);
|
||||
static constexpr ck::index_t NXDLPerWave = NPerBlock / (MNPerXDL * 4);
|
||||
static constexpr ck::index_t CShuffleMXDLPerWave = MXDLPerWave;
|
||||
static constexpr ck::index_t CShuffleNXDLPerWave = NXDLPerWave;
|
||||
static constexpr ck::index_t BLOCKSIZE = 256;
|
||||
|
||||
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
|
||||
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
|
||||
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
|
||||
static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
|
||||
static constexpr ck::index_t D0Vec = 1;
|
||||
static constexpr ck::index_t D1Vec = 1;
|
||||
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale
|
||||
// clang-format off
|
||||
< Row, Col, DsLayout, ELayout,
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
//threadnum, mblock, nblock, kblock
|
||||
BLOCKSIZE, Scale_Block_M, Scale_Block_N, Scale_Block_K,
|
||||
MPerBlock, NPerBlock, KPerBlock,
|
||||
// ak1, bk1
|
||||
AK1, BK1,
|
||||
// mn_perxdl
|
||||
MNPerXDL, MNPerXDL,
|
||||
// mn_xdlperwave
|
||||
MXDLPerWave, NXDLPerWave,
|
||||
// a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
|
||||
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
CShuffleMXDLPerWave, CShuffleNXDLPerWave, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight,
|
||||
int32_t, A0DataType, A0DataType, A0DataType, A0DataType, true>;
|
||||
#else
|
||||
|
||||
static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
|
||||
Row, Col, DsLayout, ELayout,
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
256, Scale_Block_M, Scale_Block_N, Scale_Block_K,
|
||||
MPerBlock, 128, 128,
|
||||
16, 16,
|
||||
16, 16,
|
||||
4, 2,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
4, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight,
|
||||
int32_t, A0DataType, A0DataType, A0DataType, A0DataType, false>;
|
||||
#endif
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
#if 1
|
||||
// GEMM shape
|
||||
ck::index_t N = 1536;
|
||||
ck::index_t K = 4096;
|
||||
// ck::index_t N = 4096;
|
||||
// ck::index_t K = 6144;
|
||||
// ck::index_t N = 128;
|
||||
// ck::index_t K = 512;
|
||||
ck::index_t experts = 16;
|
||||
ck::index_t topk = 8;
|
||||
// ck::index_t sorted_tile_num = 515;
|
||||
// ck::index_t valid_tile_num = 512;
|
||||
// ck::index_t tokens = 208;
|
||||
// ck::index_t sorted_tile_num = 15;
|
||||
// ck::index_t valid_tile_num = 13;
|
||||
// ck::index_t sorted_tile_num = 259;
|
||||
// ck::index_t valid_tile_num = 256;
|
||||
// ck::index_t tokens = 4096;
|
||||
ck::index_t sorted_tile_num = 16;
|
||||
ck::index_t valid_tile_num = 16;
|
||||
ck::index_t tokens = 4;
|
||||
#else
|
||||
// deepseek
|
||||
ck::index_t N = 2048;
|
||||
ck::index_t K = 7168;
|
||||
ck::index_t experts = 256;
|
||||
ck::index_t topk = 8;
|
||||
ck::index_t tokens = 4096;
|
||||
ck::index_t sorted_tile_num = 261;
|
||||
ck::index_t valid_tile_num = 256;
|
||||
#endif
|
||||
ck::index_t KBatch = 1;
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 2)
|
||||
{
|
||||
KBatch = std::stoi(argv[1]);
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
// use default case
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 7)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
N = std::stoi(argv[4]);
|
||||
K = std::stoi(argv[5]);
|
||||
tokens = std::stoi(argv[6]);
|
||||
}
|
||||
else if(argc == 9)
|
||||
{
|
||||
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
N = std::stoi(argv[4]);
|
||||
K = std::stoi(argv[5]);
|
||||
tokens = std::stoi(argv[6]);
|
||||
sorted_tile_num = std::stoi(argv[7]);
|
||||
valid_tile_num = std::stoi(argv[8]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 6: N, K, tokens\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
ck::index_t valid_size = valid_tile_num * MPerBlock;
|
||||
if(tokens * topk > valid_size)
|
||||
{
|
||||
printf("err config, tokens * topk > valid_size\n");
|
||||
exit(-1);
|
||||
}
|
||||
ck::index_t StrideA = K;
|
||||
ck::index_t StrideB = K;
|
||||
ck::index_t StrideE = N * 2;
|
||||
constexpr ck::index_t NumDTensor = DsDataType::Size();
|
||||
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0};
|
||||
ck::index_t Scale_Stride_AM = (K + Scale_Block_K - 1) / Scale_Block_K;
|
||||
ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K;
|
||||
ck::index_t Scale_Stride_B = (N + Scale_Block_N - 1) / Scale_Block_N * 2;
|
||||
|
||||
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
|
||||
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
|
||||
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1 + sorted_tile_num}));
|
||||
max_token_id.mData = {valid_size};
|
||||
// int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3};
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts);
|
||||
}
|
||||
|
||||
int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;
|
||||
int tokenid = 0;
|
||||
|
||||
for(int i = 0; i < sorted_size; i++)
|
||||
{
|
||||
int tile_off = i % MPerBlock;
|
||||
if(tile_off < token_per_tile && tokenid < tokens * topk)
|
||||
{
|
||||
sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
|
||||
tokenid++;
|
||||
}
|
||||
else
|
||||
{
|
||||
sorted_token_ids.mData[i] = tokens;
|
||||
}
|
||||
}
|
||||
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1}));
|
||||
Tensor<A1DataType> a1_t_k(HostTensorDescriptor(
|
||||
{tokens, (K + Scale_Block_K - 1) / Scale_Block_K}, {Scale_Stride_AM, 1}, Row{}));
|
||||
Tensor<B0DataType> b0_e_n_k(
|
||||
HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{}));
|
||||
Tensor<B1DataType> b1_e_n_k(
|
||||
HostTensorDescriptor({experts,
|
||||
(K + Scale_Block_K - 1) / Scale_Block_K,
|
||||
(N + Scale_Block_N - 1) / Scale_Block_N * 2},
|
||||
{(Scale_Stride_B * Scale_Stride_BN), 1, Scale_Stride_BN},
|
||||
Col{}));
|
||||
Tensor<B0DataType> b0_preshuffled(
|
||||
HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{}));
|
||||
Tensor<EDataType> e_t_n_host_result(
|
||||
HostTensorDescriptor({tokens, topk, N * 2}, {topk * N * 2, N * 2, 1}, Row{}));
|
||||
Tensor<EDataType> e_t_n_device_result(
|
||||
HostTensorDescriptor({tokens, topk, N * 2}, {topk * N * 2, N * 2, 1}, Row{}));
|
||||
e_t_n_device_result.SetZero();
|
||||
std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl;
|
||||
std::cout << "a1_t_k: " << a1_t_k.mDesc << std::endl;
|
||||
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
|
||||
std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl;
|
||||
std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl;
|
||||
std::cout << "k_batch:" << KBatch << std::endl;
|
||||
std::cout << "init_method:" << init_method << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-1.0, 1.0});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-1.0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0.0, 1.0});
|
||||
break;
|
||||
case 2:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
break;
|
||||
case 3:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
break;
|
||||
case 4:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
break;
|
||||
case 5:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
break;
|
||||
case 6:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
break;
|
||||
default:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
}
|
||||
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) *
|
||||
sorted_token_ids.mDesc.GetElementSpaceSize());
|
||||
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
|
||||
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize());
|
||||
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem a1_device_buf(sizeof(A1DataType) * a1_t_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_e_n_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
|
||||
expert_ids_dev.ToDevice(expert_ids.mData.data());
|
||||
max_token_id_dev.ToDevice(max_token_id.mData.data());
|
||||
a0_device_buf.ToDevice(a0_t_k.mData.data());
|
||||
a1_device_buf.ToDevice(a1_t_k.mData.data());
|
||||
b1_device_buf.ToDevice(b1_e_n_k.mData.data());
|
||||
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto device_op = DeviceOpInstance{};
|
||||
|
||||
int NPerXdl = device_op.GetPreShuffleParameters();
|
||||
|
||||
preShuffleBuffer(
|
||||
b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * 2 * experts, K, NPerXdl);
|
||||
|
||||
b0_device_buf.ToDevice(b0_preshuffled.mData.data());
|
||||
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument = device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(),
|
||||
expert_ids_dev.GetDeviceBuffer(),
|
||||
max_token_id_dev.GetDeviceBuffer(),
|
||||
a0_device_buf.GetDeviceBuffer(),
|
||||
b0_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, NumDTensor>{nullptr},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
tokens,
|
||||
topk,
|
||||
sorted_size,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
a1_device_buf.GetDeviceBuffer(),
|
||||
b1_device_buf.GetDeviceBuffer(),
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
if(time_kernel)
|
||||
{
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * tokens * topk * N * 2 * K;
|
||||
std::size_t num_btype = sizeof(A0DataType) * valid_tile_num * K +
|
||||
sizeof(B0DataType) * K * N * 2 * experts +
|
||||
sizeof(EDataType) * valid_tile_num * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s.\n"
|
||||
<< device_op.GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
// use atomic, so need to reinit outputs
|
||||
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
|
||||
invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1});
|
||||
|
||||
Tensor<float> a_t_k({tokens, K});
|
||||
Tensor<float> b_e_n_k({experts, K, N * 2});
|
||||
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
|
||||
|
||||
Tensor<float> c_t_k_n({tokens, topk, N * 2}, {topk * N * 2, N * 2, 1}, Row{});
|
||||
|
||||
// handle scale before ref.
|
||||
for(int t = 0; t < tokens; ++t)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
a_t_k(t, k) = ck::type_convert<float>(a0_t_k(t, k)) * a1_t_k(t, k / Scale_Block_K);
|
||||
}
|
||||
}
|
||||
|
||||
for(int e = 0; e < experts; ++e)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
for(int n = 0; n < N * 2; ++n)
|
||||
{
|
||||
b_e_n_k(e, k, n) = ck::type_convert<float>(b0_e_n_k(e, k, n)) *
|
||||
b1_e_n_k(e, k / Scale_Block_K, n / Scale_Block_N);
|
||||
}
|
||||
}
|
||||
}
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceMoeGemm1BlockScaleSplitK<float,
|
||||
float,
|
||||
float,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
auto ref_moe_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_moe_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids,
|
||||
expert_ids,
|
||||
max_token_id,
|
||||
MPerBlock,
|
||||
a_t_k,
|
||||
b_e_n_k,
|
||||
c_t_k_n,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
for(int m = 0; m < valid_size; ++m)
|
||||
{
|
||||
|
||||
const int fuse_t = sorted_token_ids.mData[m];
|
||||
const int t = fuse_t & 0xffffff;
|
||||
const int topk_id = (fuse_t & 0xff000000) >> 24;
|
||||
|
||||
if(t >= tokens)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
for(int n = 0; n < 2 * N; ++n)
|
||||
{
|
||||
e_t_n_host_result(t, topk_id, n) =
|
||||
ck::type_convert<EDataType>(c_t_k_n(t, topk_id, n));
|
||||
}
|
||||
}
|
||||
|
||||
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
|
||||
|
||||
auto status =
|
||||
ck::utils::check_err(
|
||||
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1)
|
||||
? 0
|
||||
: 1;
|
||||
if(status == 0)
|
||||
{
|
||||
printf("Validation Pass.\n");
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -194,7 +194,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
|
||||
@@ -185,7 +185,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
|
||||
@@ -165,7 +165,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
|
||||
2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, int32_t, A0DataType>;
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, false, MulRoutedWeight, int32_t, A0DataType>;
|
||||
|
||||
#else
|
||||
static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
|
||||
@@ -180,7 +180,7 @@ static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, int32_t, A0DataType>;
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, false, MulRoutedWeight, int32_t, A0DataType>;
|
||||
#endif
|
||||
// clang-format on
|
||||
|
||||
@@ -188,7 +188,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// tokens = 1
|
||||
// topk = 1
|
||||
|
||||
@@ -164,7 +164,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
|
||||
@@ -178,7 +178,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
|
||||
@@ -178,7 +178,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
|
||||
@@ -208,7 +208,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
|
||||
@@ -171,7 +171,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
|
||||
@@ -171,7 +171,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
|
||||
@@ -204,7 +204,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
|
||||
@@ -6,6 +6,35 @@ include_directories(BEFORE
|
||||
${PROJECT_SOURCE_DIR}/library/include
|
||||
)
|
||||
|
||||
if(WIN32)
|
||||
# On Windows, HIP uses -nostdlib which prevents C runtime linking
|
||||
# We need legacy_stdio_definitions.lib to provide vfprintf and other legacy C functions
|
||||
# This is mainly needed for the getopt library.
|
||||
set(LEGACY_STDIO_SEARCH_PATHS)
|
||||
|
||||
# Try to use Visual C++ Tools environment variable (if build executes from Visual Studio Developer Command Prompt)
|
||||
if(DEFINED ENV{VCToolsInstallDir})
|
||||
list(APPEND LEGACY_STDIO_SEARCH_PATHS "$ENV{VCToolsInstallDir}/lib/x64")
|
||||
endif()
|
||||
|
||||
# Fallback: Search common Visual Studio installation locations
|
||||
file(GLOB MSVC_LIB_PATHS "C:/Program Files/Microsoft Visual Studio/*/*/VC/Tools/MSVC/*/lib/x64")
|
||||
list(APPEND LEGACY_STDIO_SEARCH_PATHS ${MSVC_LIB_PATHS})
|
||||
|
||||
# Use find_library to locate the library
|
||||
find_library(LEGACY_STDIO_LIB legacy_stdio_definitions
|
||||
PATHS ${LEGACY_STDIO_SEARCH_PATHS}
|
||||
NO_DEFAULT_PATH
|
||||
)
|
||||
|
||||
if(LEGACY_STDIO_LIB)
|
||||
message(STATUS "Found legacy_stdio_definitions.lib: ${LEGACY_STDIO_LIB}")
|
||||
add_link_options("SHELL:-Xlinker \"${LEGACY_STDIO_LIB}\"")
|
||||
else()
|
||||
message(WARNING "Could not find legacy_stdio_definitions.lib - examples may fail to link.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_custom_target(examples)
|
||||
|
||||
|
||||
@@ -216,6 +245,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
|
||||
set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP)
|
||||
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
|
||||
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
|
||||
target_link_libraries(${EXAMPLE_NAME} PRIVATE getopt::getopt)
|
||||
add_dependencies(examples ${EXAMPLE_NAME})
|
||||
set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS})
|
||||
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples)
|
||||
|
||||
@@ -47,7 +47,7 @@ set(FMHA_FWD_CODE_GEN_COMMON_ARGS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--targets ${FMHA_TARGETS_ARG}
|
||||
--api ${FMHA_FWD_APIS}
|
||||
--optdim 32,64,128,256
|
||||
--optdim 32,64,80,128,256
|
||||
# --filter fmha_fwd...
|
||||
)
|
||||
set(FMHA_BWD_CODE_GEN_COMMON_ARGS
|
||||
|
||||
@@ -24,11 +24,31 @@ from codegen.cpp_symbol_map import (
|
||||
)
|
||||
from codegen.utils import update_file
|
||||
|
||||
|
||||
DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8}
|
||||
DTYPE_BITS = {
|
||||
"fp32": 32,
|
||||
"fp16": 16,
|
||||
"bf16": 16,
|
||||
"fp8": 8,
|
||||
"fp8bf16": 8,
|
||||
"fp8fp32": 8,
|
||||
"bf8": 8,
|
||||
}
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}
|
||||
|
||||
SUPPORTED_PAGE_SIZE = [1, 128, 256, 1024]
|
||||
SUPPORTED_KV_MEMORY_LAYOUT = ["vectorized", "linear"]
|
||||
SUPPORTED_KV_LOOKUP_TABLE = ["vllm", "sglang"]
|
||||
KV_MEMORY_LAYOUT_ENUM_MAP = {
|
||||
"vectorized": "ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT",
|
||||
"linear": "ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT",
|
||||
}
|
||||
KV_LOOKUP_TABLE_ENUM_MAP = {
|
||||
"vllm": "ck_tile::BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D",
|
||||
"sglang": "ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D",
|
||||
}
|
||||
|
||||
|
||||
FMHA_BATCH_PREFILL_PIPELINE_MAP = {
|
||||
"qr_async": "ck_tile::BlockFmhaBatchPrefillPipelineQRKSVSAsync",
|
||||
}
|
||||
@@ -52,7 +72,7 @@ using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
|
||||
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
|
||||
{F_vlayout}>;
|
||||
|
||||
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
using fmha_trait_{F_idx} = ck_tile::TileFmhaBatchPrefillTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
@@ -62,13 +82,17 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
{F_lse},
|
||||
{F_dropout},
|
||||
{F_qscale},
|
||||
{F_occupancy}>;
|
||||
{F_occupancy},
|
||||
false,
|
||||
{F_page_size},
|
||||
{F_kv_memory_layout},
|
||||
{F_kv_lookup_table}>;
|
||||
|
||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
|
||||
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
|
||||
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
|
||||
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBatchPrefillPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
|
||||
@@ -85,6 +109,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
|
||||
fmha_variant_{F_idx},
|
||||
fmha_mask_{F_idx},
|
||||
false,
|
||||
{F_page_size},
|
||||
fmha_trait_{F_idx}>;
|
||||
|
||||
using fmha_pipeline_{F_idx} = {F_pipeline}<
|
||||
@@ -98,8 +123,8 @@ using fmha_epilogue_{F_idx} =
|
||||
using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>;
|
||||
using trait_{F_idx} = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@@ -108,7 +133,7 @@ float fmha_batch_prefill_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_b
|
||||
{{
|
||||
using k_ = fmha_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
std::cout << ", {F_kname}" << std::flush;
|
||||
auto [kargs, grids] = fmha_batch_prefill_create_kargs_and_grids<k_>(a);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
@@ -177,8 +202,8 @@ FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>;
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size})) {{
|
||||
using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
|
||||
return fmha_batch_prefill_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
@@ -223,12 +248,15 @@ class FmhaFwdApiTrait:
|
||||
dpad: str
|
||||
dvpad: str
|
||||
constraint: CppConstraint
|
||||
kv_memory_layout: str
|
||||
kv_lookup_table: str
|
||||
page_size: int = 1 # page block size
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return (
|
||||
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
|
||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}"
|
||||
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.kv_memory_layout}-{self.kv_lookup_table}-ps{self.page_size}"
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -315,6 +343,8 @@ class FmhaFwdPipeline:
|
||||
F_dropout: str #
|
||||
F_qscale: str # no/pertensor
|
||||
F_mask: str # value from MASK_MAP
|
||||
F_kv_memory_layout: str #
|
||||
F_kv_lookup_table: str #
|
||||
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
|
||||
|
||||
@property
|
||||
@@ -375,6 +405,8 @@ class FmhaFwdPipeline:
|
||||
n += f"_{self.F_qscale}"
|
||||
else:
|
||||
n += "_nqscale"
|
||||
|
||||
n += "_" + self.F_kv_memory_layout + "_" + self.F_kv_lookup_table
|
||||
return n
|
||||
|
||||
|
||||
@@ -433,6 +465,13 @@ class FmhaFwdApiPool:
|
||||
F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[dtype],
|
||||
F_kv_memory_layout=KV_MEMORY_LAYOUT_ENUM_MAP[
|
||||
trait.kv_memory_layout
|
||||
],
|
||||
F_kv_lookup_table=KV_LOOKUP_TABLE_ENUM_MAP[
|
||||
trait.kv_lookup_table
|
||||
],
|
||||
F_page_size=trait.page_size,
|
||||
)
|
||||
if_j = "if" if j == 0 else "else if"
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
|
||||
@@ -490,10 +529,12 @@ class FmhaFwdKernel:
|
||||
F_tile: FmhaFwdTileSize
|
||||
F_pipeline: FmhaFwdPipeline
|
||||
mask_impl: str
|
||||
F_page_size: int = 1 # page block size
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format(
|
||||
F_kname=self.name,
|
||||
F_idx=self.F_idx,
|
||||
F_hdim=self.F_hdim,
|
||||
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
|
||||
@@ -526,17 +567,24 @@ class FmhaFwdKernel:
|
||||
F_dropout=BOOL_MAP[self.F_pipeline.F_dropout],
|
||||
F_qscale=QSCALE_MAP[self.F_pipeline.F_qscale],
|
||||
F_occupancy=self.F_tile.F_occupancy,
|
||||
F_kv_memory_layout=KV_MEMORY_LAYOUT_ENUM_MAP[
|
||||
self.F_pipeline.F_kv_memory_layout
|
||||
],
|
||||
F_kv_lookup_table=KV_LOOKUP_TABLE_ENUM_MAP[
|
||||
self.F_pipeline.F_kv_lookup_table
|
||||
],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
F_mode=MODE_MAP[self.F_mode],
|
||||
F_pipeline=FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag],
|
||||
F_page_size=self.F_page_size,
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return (
|
||||
f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_"
|
||||
f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_ps{self.F_page_size}_"
|
||||
+ self.F_tile.name
|
||||
+ "_"
|
||||
+ self.F_pipeline.name
|
||||
@@ -570,16 +618,23 @@ class FmhaFwdKernel:
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
|
||||
kv_memory_layout=self.F_pipeline.F_kv_memory_layout,
|
||||
kv_lookup_table=self.F_pipeline.F_kv_lookup_table,
|
||||
page_size=self.F_page_size,
|
||||
)
|
||||
|
||||
|
||||
class KernelComponentFactory:
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
if dtype == "fp16" or dtype == "bf16":
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
return {
|
||||
128 : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
} # fmt: skip
|
||||
elif dtype in ["fp8bf16"]:
|
||||
return {
|
||||
128 : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||
} # fmt: skip
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -589,20 +644,45 @@ class KernelComponentFactory:
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr pipeline, let 't' padding to appear later!!
|
||||
# TODO: how to design this more generic?
|
||||
qscale = "no"
|
||||
pipelines = []
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
for logits, mask, bias, lse, dropout in itertools.product(
|
||||
qscale = "no"
|
||||
for (
|
||||
logits,
|
||||
mask,
|
||||
bias,
|
||||
lse,
|
||||
dropout,
|
||||
kv_memory_layout,
|
||||
kv_lookup_table,
|
||||
) in itertools.product(
|
||||
["t", "f"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
BIAS_MAP.keys(),
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
SUPPORTED_KV_MEMORY_LAYOUT,
|
||||
SUPPORTED_KV_LOOKUP_TABLE,
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip
|
||||
# pipelines.append(FmhaFwdPipeline("qr_async", "col", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip
|
||||
# pipelines.append(FmhaFwdPipeline("qr_async", "col", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip
|
||||
elif dtype in ["fp8bf16"]:
|
||||
# no need lse/dropout kernels
|
||||
for (
|
||||
logits,
|
||||
qscale,
|
||||
mask,
|
||||
bias,
|
||||
kv_memory_layout,
|
||||
kv_lookup_table,
|
||||
) in itertools.product(
|
||||
["t", "f"],
|
||||
["pertensor"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
["no"],
|
||||
SUPPORTED_KV_MEMORY_LAYOUT,
|
||||
SUPPORTED_KV_LOOKUP_TABLE,
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
@@ -612,7 +692,7 @@ class CustomFactory(KernelComponentFactory):
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
|
||||
result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
|
||||
if dtype == "fp16" or dtype == "bf16":
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
if 128 in result.keys():
|
||||
result[128].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("get_num_blocks(128) < num_cus * min_cu_util_rate"))) # fmt: skip
|
||||
return result
|
||||
@@ -654,70 +734,75 @@ def get_fwd_blobs(
|
||||
or pipeline.F_logits == "f"
|
||||
):
|
||||
continue
|
||||
k = FmhaFwdKernel(
|
||||
F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl,
|
||||
)
|
||||
if kernel_filter != "":
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
# 2 - Flash attention integration
|
||||
if receipt in (2, 3):
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_bias in ["no", "alibi"]
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_bias in ["no", "bias"]
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
elif receipt == 100:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_batch_prefill) integration
|
||||
elif receipt == 200:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_batch_prefill C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
# fp32 only
|
||||
if receipt == 800 or receipt == 801:
|
||||
cond = dtype == "fp32"
|
||||
if not cond:
|
||||
# Generate kernels for both page_size=16 and page_size=1024
|
||||
for page_size in SUPPORTED_PAGE_SIZE:
|
||||
if page_size == 1 and pipeline.F_kv_memory_layout != "linear":
|
||||
continue
|
||||
k = FmhaFwdKernel(
|
||||
F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl,
|
||||
F_page_size=page_size,
|
||||
)
|
||||
if kernel_filter != "":
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
# 2 - Flash attention integration
|
||||
if receipt in (2, 3):
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_bias in ["no", "alibi"]
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_bias in ["no", "bias"]
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
elif receipt == 100:
|
||||
cond = dtype in ["fp16", "bf16"]
|
||||
cond &= mode == "batch"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_batch_prefill) integration
|
||||
elif receipt == 200:
|
||||
cond = dtype in ["fp16", "bf16", "fp8bf16"]
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_batch_prefill C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ["fp16", "bf16", "fp8bf16"]
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == "row"
|
||||
cond &= pipeline.F_qscale == "no"
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
# fp32 only
|
||||
if receipt == 800 or receipt == 801:
|
||||
cond = dtype == "fp32"
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
|
||||
@@ -40,7 +40,16 @@ DTYPE_BITS = {
|
||||
"bf8": 8,
|
||||
}
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {32: 32, 48: 48, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256}
|
||||
K0_MAX_SUBMAX_MAP = {
|
||||
32: 32,
|
||||
48: 48,
|
||||
64: 64,
|
||||
80: 96,
|
||||
96: 128,
|
||||
128: 128,
|
||||
192: 192,
|
||||
256: 256,
|
||||
}
|
||||
|
||||
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n
|
||||
@@ -202,11 +211,10 @@ float fmha_fwd(fmha_fwd_traits traits, fmha_fwd_args args, const ck_tile::stream
|
||||
const bool can_dispatch_v3 =
|
||||
(device_name.compare(0, 6, "gfx950") == 0) and
|
||||
(traits.data_type.compare("fp16") == 0 or traits.data_type.compare("bf16") == 0) and
|
||||
traits.is_v_rowmajor and (not traits.has_logits_soft_cap) and
|
||||
(traits.bias_type == bias_enum::no_bias) and (not traits.has_lse) and
|
||||
(not traits.has_dropout) and (traits.qscale_type == quant_scale_enum::no_scale) and
|
||||
(not is_swa) and (args.nhead_q % args.nhead_k == 0) and (args.hdim_q == 128) and
|
||||
(args.hdim_v == 128);
|
||||
traits.is_v_rowmajor and (traits.bias_type == bias_enum::no_bias) and
|
||||
(not traits.has_lse) and (not traits.has_dropout) and
|
||||
(traits.qscale_type == quant_scale_enum::no_scale) and (not is_swa) and
|
||||
(args.nhead_q % args.nhead_k == 0) and (args.hdim_q == 128) and (args.hdim_v == 128);
|
||||
if ({F_is_v3_enabled} and can_dispatch_v3) {{
|
||||
return fmha_fwd_v3(traits, args, config);
|
||||
}} else {{
|
||||
@@ -930,6 +938,7 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
|
||||
( 64, 64) : [FmhaFwdTileSize( 16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
FmhaFwdTileSize( 32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
( 80, 96) : [FmhaFwdTileSize(128, 128, 16, 96, 32, 80, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
( 96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(128, 128) : [FmhaFwdTileSize( 16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
FmhaFwdTileSize( 32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
@@ -1008,14 +1017,18 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
|
||||
elif dtype in cls._DT_FP8BF16 or dtype in cls._DT_FP8FP32:
|
||||
# no need lse/dropout kernels
|
||||
for logits, qscale, mask, bias, sink in itertools.product(
|
||||
["f"],
|
||||
["t", "f"],
|
||||
["no", "pertensor"],
|
||||
get_mask_map(mask_impl).keys(),
|
||||
["no"],
|
||||
["f", "t"],
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
|
||||
if hdim == 64:
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
|
||||
else:
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip
|
||||
elif dtype in ["fp8", "fp8fp16", "bf8"]:
|
||||
# TODO
|
||||
pass
|
||||
@@ -1068,9 +1081,9 @@ class KernelComponentFactoryGfx950(
|
||||
# qr_async_trload_v3 only supports hdim=hdim_v=128 for now
|
||||
if (hdim, hdim_v) == (128, 128):
|
||||
# qr_async_trload_v3 only supports (generic) causal mask
|
||||
for mask in ["no", "causal"]:
|
||||
for logits, mask in itertools.product(["t", "f"], ["no", "causal"]):
|
||||
pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f",
|
||||
F_logits="f", F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip
|
||||
F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip
|
||||
|
||||
return pipelines
|
||||
|
||||
|
||||
@@ -114,7 +114,8 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("kv_eff_lens",
|
||||
"",
|
||||
"Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n"
|
||||
"Comma-separated list of length 'b'. If empty, no override.");
|
||||
"Comma-separated list of length 'b'. If empty, no override.")
|
||||
.insert("init_sink", "0", "value to init the output tensor sink value for validation");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -157,6 +158,7 @@ auto run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::index_t num_splits = arg_parser.get_int("num_splits");
|
||||
std::string init_method = arg_parser.get_str("init");
|
||||
uint32_t seed = arg_parser.get_uint32("seed");
|
||||
int init_sink_value = arg_parser.get_int("init_sink");
|
||||
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
true,
|
||||
@@ -203,6 +205,7 @@ auto run(const ck_tile::ArgParser& arg_parser)
|
||||
init_method,
|
||||
seed,
|
||||
do_validation,
|
||||
init_sink_value,
|
||||
stream_config,
|
||||
json);
|
||||
}
|
||||
|
||||
@@ -621,8 +621,11 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // p_hp_g_m_n high precision
|
||||
ck_tile::HostTensor<AccDataType> p_dropped_hp_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // p_dropped_hp_g_m_n high precision
|
||||
ck_tile::HostTensor<GemmDataType> p_lp_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // p_lp_g_m_n low precision
|
||||
|
||||
// p_lp_g_m_n low precision used for fwd (with rp_undrop)
|
||||
ck_tile::HostTensor<GemmDataType> p_fwd_host_ref({nhead, real_seqlen_q, real_seqlen_k});
|
||||
// p_lp_g_m_n low precision used for bwd (no rp_undrop)
|
||||
ck_tile::HostTensor<GemmDataType> p_lp_host_ref({nhead, real_seqlen_q, real_seqlen_k});
|
||||
|
||||
ck_tile::index_t nr = nhead / nhead_k;
|
||||
|
||||
@@ -762,8 +765,11 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
ck_tile::reference_batched_dropout_randval(
|
||||
randval_host_ref, wb, drop_seed, drop_offset);
|
||||
ck_tile::reference_batched_dropout(
|
||||
p_dropped_hp_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop);
|
||||
p_dropped_hp_host_ref, randval_host_ref, p_undrop_in_uint8_t, 1.f);
|
||||
p_lp_host_ref = p_dropped_hp_host_ref.template CopyAsType<GemmDataType>();
|
||||
p_dropped_hp_host_ref.ForEach(
|
||||
[&](auto& self, const auto& idx) { self(idx) *= rp_undrop; });
|
||||
p_fwd_host_ref = p_dropped_hp_host_ref.template CopyAsType<GemmDataType>();
|
||||
|
||||
ck_tile::HostTensor<RandValOutputDataType> randval_host_result(
|
||||
{nhead, real_seqlen_q, real_seqlen_k});
|
||||
@@ -789,12 +795,13 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
}
|
||||
else
|
||||
{
|
||||
p_lp_host_ref = p_hp_host_ref.template CopyAsType<GemmDataType>();
|
||||
p_lp_host_ref = p_hp_host_ref.template CopyAsType<GemmDataType>();
|
||||
p_fwd_host_ref = p_lp_host_ref;
|
||||
}
|
||||
|
||||
// O = P * V
|
||||
ck_tile::reference_batched_gemm<GemmDataType, VDataType, AccDataType, ODataType>(
|
||||
p_lp_host_ref, v_host_ref, o_host_ref); // o_g_m_o = p_lp_g_m_n@v_g_o_n
|
||||
p_fwd_host_ref, v_host_ref, o_host_ref); // o_g_m_o = p_lp_g_m_n@v_g_o_n
|
||||
|
||||
// clang-format off
|
||||
// permute
|
||||
@@ -900,7 +907,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
if(p_drop > 0)
|
||||
{
|
||||
ck_tile::reference_batched_dropout(
|
||||
dp_hp_host_ref, randval_host_refs[ref_idx], p_undrop_in_uint8_t, rp_undrop);
|
||||
dp_hp_host_ref, randval_host_refs[ref_idx], p_undrop_in_uint8_t, 1.f);
|
||||
}
|
||||
|
||||
// dS_i_j = P_i_j .* (dP_i_j - dO_i dot O_i)
|
||||
@@ -911,7 +918,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
{
|
||||
do_dot_o +=
|
||||
ck_tile::type_convert<AccDataType>(do_host_ref(i0, i1, o)) *
|
||||
ck_tile::type_convert<AccDataType>(o_host_refs[ref_idx](i0, i1, o));
|
||||
ck_tile::type_convert<AccDataType>(o_host_refs[ref_idx](i0, i1, o)) *
|
||||
p_undrop;
|
||||
}
|
||||
ds_hp_host_ref(i0, i1, i2) =
|
||||
ck_tile::type_convert<AccDataType>(p_hp_host_refs[ref_idx](i0, i1, i2) *
|
||||
@@ -935,7 +943,12 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
auto do_t_host_ref = do_host_ref.transpose({0, 2, 1}); // do_g_m_o -> do_g_o_m
|
||||
ck_tile::
|
||||
reference_batched_gemm<GemmDataType, OGradDataType, AccDataType, VGradDataType>(
|
||||
p_t_lp_host_ref, do_t_host_ref, dv_host_ref); // dv_g_n_o = p_lp_g_n_m@do_g_o_m
|
||||
p_t_lp_host_ref,
|
||||
do_t_host_ref,
|
||||
dv_host_ref,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales(rp_undrop)); // dv_g_n_o = p_lp_g_n_m@do_g_o_m
|
||||
|
||||
// dQ = scale * dS@K^T
|
||||
auto k_t_host_ref = k_host_refs[ref_idx].transpose({0, 2, 1}); // k_g_n_k -> k_g_k_n
|
||||
@@ -945,7 +958,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
dq_host_ref,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales(scale)); // dq_g_m_k = ds_g_m_n@k_g_k_n
|
||||
ck_tile::scales(scale * rp_undrop)); // dq_g_m_k = ds_g_m_n@k_g_k_n
|
||||
|
||||
// dK = scale * dS^T@Q^T
|
||||
auto ds_t_lp_host_ref = ds_lp_host_ref.transpose({0, 2, 1}); // ds_g_m_n -> ds_g_n_m
|
||||
@@ -956,7 +969,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
dk_host_ref,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales(scale)); // dk_g_n_k = ds_g_n_m@q_g_k_m
|
||||
ck_tile::scales(scale * rp_undrop)); // dk_g_n_k = ds_g_n_m@q_g_k_m
|
||||
|
||||
ck_tile::HostTensor<QGradDataType> dq_host_result(
|
||||
{nhead, real_seqlen_q, hdim_q}); // dq_g_m_k
|
||||
|
||||
@@ -230,6 +230,7 @@ struct fmha_fwd_args
|
||||
// array [batch + 1]. (Used with padding)
|
||||
const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
|
||||
// array [batch + 1]. (Used with padding)
|
||||
const void* sink_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
@@ -317,6 +318,7 @@ struct fmha_fwd_pagedkv_args
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
const void* seqlen_k_ptr;
|
||||
const void* sink_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
@@ -400,6 +402,7 @@ struct fmha_fwd_splitkv_args
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
const void* seqlen_k_ptr;
|
||||
const void* sink_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
@@ -476,6 +479,7 @@ struct fmha_fwd_appendkv_args
|
||||
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
|
||||
|
||||
const void* cache_batch_idx; // only used if block_table_ptr is nullptr -> batch mode (kvcache)
|
||||
const void* sink_ptr;
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
@@ -500,6 +504,9 @@ struct fmha_batch_prefill_args
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
const void* bias_ptr; // bias or alibi_slope pointer
|
||||
const void* q_descale_ptr;
|
||||
const void* k_descale_ptr;
|
||||
const void* v_descale_ptr;
|
||||
void* rand_val_ptr;
|
||||
void* lse_ptr;
|
||||
void* o_ptr;
|
||||
@@ -516,6 +523,7 @@ struct fmha_batch_prefill_args
|
||||
// 1) +
|
||||
// kargs.kv_last_page_lens[b]
|
||||
const void* seqstart_q_ptr;
|
||||
const void* sink_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
@@ -526,14 +534,25 @@ struct fmha_batch_prefill_args
|
||||
ck_tile::index_t nhead_q;
|
||||
ck_tile::index_t nhead_k;
|
||||
|
||||
// SGLang-style page table
|
||||
int32_t num_total_pages;
|
||||
void* kv_indptr;
|
||||
void* kv_page_indices;
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
void* kv_last_page_lens;
|
||||
ck_tile::index_t page_block_size;
|
||||
#endif
|
||||
// KV cache page table fields (kv_lookup_table selects interpretation):
|
||||
// - SGLANG_PAGE_TABLE_1D:
|
||||
// kv_indptr: prefix-sum [batch+1] into kv_page_indices
|
||||
// kv_page_indices: 1D list of physical page ids, length = num_total_pages
|
||||
// kv_last_page_lens: per-batch last page lengths [batch]
|
||||
// - VLLM_BLOCK_TABLE_2D:
|
||||
// kv_page_indices: block_table [batch, max_blocks_per_seq] (2D)
|
||||
// batch_stride_block_table: row stride for block_table
|
||||
// seqlen_k_ptr: per-batch seqlen_k [batch]
|
||||
int32_t num_total_pages; // total physical pages in KV cache (SGLang/vLLM)
|
||||
ck_tile::index_t page_block_size; // tokens per page (SGLang/vLLM)
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum
|
||||
kv_memory_layout; // KV memory layout (SGLang/vLLM)
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum kv_lookup_table; // lookup table layout selector
|
||||
void* kv_indptr; // SGLang: prefix-sum; vLLM: unused
|
||||
void* kv_page_indices; // SGLang: 1D page list; vLLM: block_table 2D
|
||||
void* kv_last_page_lens; // SGLang: last page lengths; vLLM: unused
|
||||
void* seqlen_k_ptr; // vLLM: per-batch seqlen_k; SGLang: unused
|
||||
ck_tile::index_t batch_stride_block_table; // vLLM: row stride; SGLang: unused
|
||||
|
||||
float scale_s;
|
||||
float scale_p;
|
||||
@@ -624,7 +643,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.s_randval,
|
||||
args.drop_seed_offset,
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.cu_seqlen_k_ptr);
|
||||
args.cu_seqlen_k_ptr,
|
||||
args.sink_ptr);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
@@ -674,7 +694,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.s_randval,
|
||||
args.drop_seed_offset,
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.cu_seqlen_k_ptr);
|
||||
args.cu_seqlen_k_ptr,
|
||||
args.sink_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -728,6 +749,7 @@ auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale_s,
|
||||
args.logits_soft_cap,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
@@ -758,6 +780,7 @@ auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale_s,
|
||||
args.logits_soft_cap,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
@@ -832,7 +855,8 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args)
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
args.mask_type,
|
||||
args.min_seqlen_q);
|
||||
args.min_seqlen_q,
|
||||
args.sink_ptr);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
@@ -877,7 +901,8 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args)
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
args.mask_type);
|
||||
args.mask_type,
|
||||
args.sink_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -944,7 +969,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
args.mask_type);
|
||||
args.mask_type,
|
||||
args.sink_ptr);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
@@ -992,7 +1018,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
args.mask_type);
|
||||
args.mask_type,
|
||||
args.sink_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -1108,6 +1135,22 @@ template <typename FmhaKernel>
|
||||
auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
using PageTableKargs = typename FmhaKernel::PageBlockTableKargs;
|
||||
const PageTableKargs page_table = [&]() {
|
||||
if constexpr(FmhaKernel::kKVLookupTable ==
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D)
|
||||
{
|
||||
return PageTableKargs{reinterpret_cast<const int32_t*>(args.kv_indptr),
|
||||
reinterpret_cast<const int32_t*>(args.kv_page_indices),
|
||||
reinterpret_cast<const int32_t*>(args.kv_last_page_lens)};
|
||||
}
|
||||
else
|
||||
{
|
||||
return PageTableKargs{reinterpret_cast<const int32_t*>(args.kv_page_indices),
|
||||
args.batch_stride_block_table,
|
||||
reinterpret_cast<const int32_t*>(args.seqlen_k_ptr)};
|
||||
}
|
||||
}();
|
||||
auto kargs = [&] {
|
||||
// create group mode kernel arguments
|
||||
if constexpr(FmhaKernel::kIsGroupMode)
|
||||
@@ -1116,6 +1159,9 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.q_descale_ptr,
|
||||
args.k_descale_ptr,
|
||||
args.v_descale_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
@@ -1125,12 +1171,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.num_total_pages,
|
||||
args.kv_indptr,
|
||||
args.kv_page_indices,
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
args.kv_last_page_lens,
|
||||
args.page_block_size,
|
||||
#endif
|
||||
page_table,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.scale_o,
|
||||
@@ -1156,7 +1198,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
args.drop_seed_offset,
|
||||
args.sink_ptr);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
@@ -1164,6 +1207,9 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.q_descale_ptr,
|
||||
args.k_descale_ptr,
|
||||
args.v_descale_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
@@ -1173,12 +1219,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.num_total_pages,
|
||||
args.kv_indptr,
|
||||
args.kv_page_indices,
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
args.kv_last_page_lens,
|
||||
args.page_block_size,
|
||||
#endif
|
||||
page_table,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.scale_o,
|
||||
@@ -1209,7 +1251,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
args.drop_seed_offset,
|
||||
args.sink_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -1270,6 +1313,65 @@ struct fmha_fwd_traits_
|
||||
static constexpr bool kHasSink = kHasSink_;
|
||||
};
|
||||
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
bool kIsGroupMode_,
|
||||
ck_tile::index_t kM0_,
|
||||
ck_tile::index_t kN0_,
|
||||
ck_tile::index_t kK0_,
|
||||
ck_tile::index_t kN1_,
|
||||
ck_tile::index_t kK1_,
|
||||
ck_tile::index_t kK0BlockLength_,
|
||||
bool kIsVLayoutRowMajor_,
|
||||
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
|
||||
bool kHasLogitsSoftCap_,
|
||||
typename FmhaMask_,
|
||||
ck_tile::BlockAttentionBiasEnum BiasEnum_,
|
||||
bool kStoreLse_,
|
||||
bool kHasDropout_,
|
||||
ck_tile::BlockAttentionQuantScaleEnum QScaleEnum_,
|
||||
bool kPadS_,
|
||||
bool kPadSK_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_,
|
||||
bool kUseTrLoad_,
|
||||
bool kSkipMinSeqlenQ_ = false,
|
||||
ck_tile::index_t kPageBlockSize_ = 1,
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout_ =
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT,
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum kKVLookupTable_ =
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D>
|
||||
struct fmha_fwd_batch_prefill_traits_ : public fmha_fwd_traits_<HDim_,
|
||||
DataType_,
|
||||
kIsGroupMode_,
|
||||
kM0_,
|
||||
kN0_,
|
||||
kK0_,
|
||||
kN1_,
|
||||
kK1_,
|
||||
kK0BlockLength_,
|
||||
kIsVLayoutRowMajor_,
|
||||
FmhaPipelineEnum_,
|
||||
kHasLogitsSoftCap_,
|
||||
FmhaMask_,
|
||||
BiasEnum_,
|
||||
kStoreLse_,
|
||||
kHasDropout_,
|
||||
QScaleEnum_,
|
||||
kPadS_,
|
||||
kPadSK_,
|
||||
kPadD_,
|
||||
kPadDv_,
|
||||
kUseTrLoad_,
|
||||
kSkipMinSeqlenQ_,
|
||||
false>
|
||||
{
|
||||
static constexpr auto kKVMemoryLayout = kKVMemoryLayout_;
|
||||
static constexpr auto kKVLookupTable = kKVLookupTable_;
|
||||
static constexpr ck_tile::index_t kPageBlockSize = kPageBlockSize_;
|
||||
static_assert(kIsVLayoutRowMajor_, "Batch prefill only supports row-major V layout");
|
||||
};
|
||||
|
||||
template <typename Traits_, typename Arch = void>
|
||||
float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args);
|
||||
|
||||
@@ -1516,7 +1618,15 @@ float fmha_fwd_appendkv(fmha_fwd_appendkv_traits,
|
||||
fmha_fwd_appendkv_args,
|
||||
const ck_tile::stream_config&);
|
||||
|
||||
using fmha_batch_prefill_traits = fmha_fwd_traits;
|
||||
struct fmha_batch_prefill_traits : public fmha_fwd_traits
|
||||
{
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum kv_memory_layout =
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT;
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum kv_lookup_table =
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D;
|
||||
int page_size = 1;
|
||||
};
|
||||
|
||||
float fmha_batch_prefill(fmha_batch_prefill_traits,
|
||||
fmha_batch_prefill_args,
|
||||
const ck_tile::stream_config&);
|
||||
|
||||
@@ -149,6 +149,28 @@ int override_num_splits_if_necessary(
|
||||
return num_splits;
|
||||
}
|
||||
|
||||
template <typename SMPLComputeDataType>
|
||||
void copy_attention_scores_with_sink(const ck_tile::HostTensor<SMPLComputeDataType>& s_host_ref,
|
||||
const ck_tile::HostTensor<SMPLComputeDataType>& sink_host,
|
||||
ck_tile::HostTensor<SMPLComputeDataType>& s_with_sinks_ref,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t real_seqlen_q,
|
||||
ck_tile::index_t real_seqlen_k)
|
||||
{
|
||||
for(auto i_h = 0; i_h < nhead; i_h++)
|
||||
{
|
||||
for(auto i_r = 0; i_r < real_seqlen_q; i_r++)
|
||||
{
|
||||
for(auto i_c = 0; i_c < real_seqlen_k; i_c++)
|
||||
{
|
||||
s_with_sinks_ref(i_h, i_r, i_c) = s_host_ref(i_h, i_r, i_c);
|
||||
}
|
||||
// Append sink token at the end of each row
|
||||
s_with_sinks_ref(i_h, i_r, real_seqlen_k) = sink_host(i_h);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataTypeConfig>
|
||||
fwd_result fmha_fwd_run(mode_enum mode,
|
||||
ck_tile::index_t batch,
|
||||
@@ -184,6 +206,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
std::string init_method,
|
||||
uint32_t seed,
|
||||
int do_validation,
|
||||
int init_sink_value,
|
||||
const ck_tile::stream_config& stream_config,
|
||||
std::optional<std::string> json = std::nullopt)
|
||||
{
|
||||
@@ -527,6 +550,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
|
||||
ck_tile::HostTensor<QDataType> q_host(
|
||||
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
|
||||
ck_tile::HostTensor<SMPLComputeDataType> sink_host({nhead});
|
||||
ck_tile::HostTensor<KDataType> k_host(
|
||||
0 < page_block_size
|
||||
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q)
|
||||
@@ -609,6 +633,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-3.f, 3.f, next_seed()}(
|
||||
bias_host);
|
||||
}
|
||||
|
||||
else if(init_method == "ni")
|
||||
{
|
||||
ck_tile::FillNormalDistributionIntegerValue<QDataType>{-3.f, 3.f, next_seed()}(q_host);
|
||||
@@ -695,10 +720,17 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
|
||||
iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine);
|
||||
iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0, random_engine);
|
||||
|
||||
if(init_sink_value != 0)
|
||||
{
|
||||
// sink is initialized to a fixed integer value for easy debugging and use 30 to 60 range
|
||||
// for close to rowmax values.
|
||||
ck_tile::FillUniformDistributionIntegerValue<SMPLComputeDataType>{30.f, 60.f, next_seed()}(
|
||||
sink_host);
|
||||
}
|
||||
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem sink_buf(sink_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem vnew_buf(vnew_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
|
||||
@@ -743,6 +775,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
q_buf.ToDevice(q_host.data());
|
||||
k_buf.ToDevice(k_host.data());
|
||||
v_buf.ToDevice(v_host.data());
|
||||
sink_buf.ToDevice(sink_host.data());
|
||||
knew_buf.ToDevice(knew_host.data());
|
||||
vnew_buf.ToDevice(vnew_host.data());
|
||||
bias_buf.ToDevice(bias_host.data());
|
||||
@@ -971,7 +1004,10 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
args.q_ptr = q_buf.GetDeviceBuffer();
|
||||
args.k_ptr = k_buf.GetDeviceBuffer();
|
||||
args.v_ptr = v_buf.GetDeviceBuffer();
|
||||
|
||||
if(init_sink_value != 0)
|
||||
args.sink_ptr = sink_buf.GetDeviceBuffer();
|
||||
else
|
||||
args.sink_ptr = nullptr;
|
||||
args.batch = batch;
|
||||
args.seqlen_q = shape_seqlen_q; // unused in group mode
|
||||
args.hdim_q = hdim_q;
|
||||
@@ -1351,8 +1387,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
|
||||
auto oacc_element_func = [&]() {
|
||||
if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t> && supports_qscale)
|
||||
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
|
||||
ck_tile::scales{scale_o_host});
|
||||
return ck_tile::make_composes(ck_tile::saturates<ck_tile::fp8_t>{},
|
||||
ck_tile::scales{scale_o_host});
|
||||
else if constexpr(supports_qscale)
|
||||
return ck_tile::scales{scale_o_host};
|
||||
else
|
||||
@@ -1675,19 +1711,57 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
mask.type == mask_enum::mask_top_left));
|
||||
}
|
||||
const ck_tile::HostTensor<SaccDataType> masked_s_host_ref = s_host_ref;
|
||||
if(lse)
|
||||
if(init_sink_value != 0)
|
||||
{
|
||||
ck_tile::
|
||||
reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
|
||||
s_host_ref, p_host_ref, p_compute_element_func, lse_host_ref);
|
||||
// Create extended tensor with sink token
|
||||
ck_tile::HostTensor<SMPLComputeDataType> s_with_sinks_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k + 1});
|
||||
|
||||
// Copy original attention scores and append sink values
|
||||
copy_attention_scores_with_sink(
|
||||
s_host_ref, sink_host, s_with_sinks_ref, nhead, real_seqlen_q, real_seqlen_k);
|
||||
|
||||
// Compute softmax on extended tensor
|
||||
ck_tile::HostTensor<PDataType> p_extended(
|
||||
{nhead, real_seqlen_q, real_seqlen_k + 1});
|
||||
|
||||
if(lse)
|
||||
{
|
||||
ck_tile::reference_batched_softmax<SMPLComputeDataType,
|
||||
SMPLComputeDataType,
|
||||
PDataType>(
|
||||
s_with_sinks_ref, p_extended, p_compute_element_func, lse_host_ref);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::reference_batched_softmax<SMPLComputeDataType,
|
||||
SMPLComputeDataType,
|
||||
PDataType>(
|
||||
s_with_sinks_ref, p_extended, p_compute_element_func);
|
||||
}
|
||||
|
||||
// Extract only the original columns (exclude sink token column)
|
||||
p_host_ref.ForEach(
|
||||
[&](auto& self, auto idx) { self(idx) = p_extended(idx[0], idx[1], idx[2]); });
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::
|
||||
reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
|
||||
// No sink tokens - compute softmax directly
|
||||
if(lse)
|
||||
{
|
||||
ck_tile::reference_batched_softmax<SMPLComputeDataType,
|
||||
SMPLComputeDataType,
|
||||
PDataType>(
|
||||
s_host_ref, p_host_ref, p_compute_element_func, lse_host_ref);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::reference_batched_softmax<SMPLComputeDataType,
|
||||
SMPLComputeDataType,
|
||||
PDataType>(
|
||||
s_host_ref, p_host_ref, p_compute_element_func);
|
||||
}
|
||||
}
|
||||
|
||||
if(p_drop > 0)
|
||||
{
|
||||
ck_tile::HostTensor<RandValOutputDataType> randval_host_ref(
|
||||
|
||||
@@ -84,3 +84,10 @@ $EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -l
|
||||
# 1 1 1 1 1 1 1 1 1 1
|
||||
# l=2/r=0(br) l=2/r=0/s=2(br)
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=1
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=0
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1
|
||||
|
||||
$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=1
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user