mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
add more datatype to gemm+gemm and conv+conv example (#397)
* refactor
* refactor
* adding int4/int8/fp16/bf16 for conv+conv and gemm+gemm
* adding int4/int8/fp16/bf16 for conv+conv and gemm+gemm
* clean
[ROCm/composable_kernel commit: 204ef976ca]
This commit is contained in:
@@ -43,30 +43,28 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
}
|
||||
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<
|
||||
#ifdef BUILD_INT4_EXAMPLE
|
||||
KernelCDataType
|
||||
#else
|
||||
CDataType
|
||||
#endif
|
||||
>
|
||||
c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
|
||||
|
||||
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
#ifdef BUILD_INT4_EXAMPLE
|
||||
DeviceMem a_m_k_device_buf(sizeof(KernelADataType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_k_n_device_buf(sizeof(KernelBDataType) * b_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_m_n_device_buf(sizeof(KernelCDataType) *
|
||||
c_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
const Tensor<KernelADataType> a_m_k_converted(a_m_k);
|
||||
const Tensor<KernelBDataType> b_k_n_converted(b_k_n);
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k_converted.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n_converted.mData.data());
|
||||
#else
|
||||
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
|
||||
#endif
|
||||
@@ -80,13 +78,13 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
auto argument = gemm.MakeArgument(
|
||||
#ifdef BUILD_INT4_EXAMPLE
|
||||
reinterpret_cast<KernelADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
reinterpret_cast<KernelBDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
reinterpret_cast<KernelCDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<KernelADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<KernelBDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<KernelCDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
#else
|
||||
reinterpret_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
reinterpret_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
reinterpret_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
#endif
|
||||
M,
|
||||
N,
|
||||
@@ -128,13 +126,17 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
#ifdef BUILD_INT4_EXAMPLE
|
||||
const Tensor<CDataType> c_m_n_device_result_converted(c_m_n_device_result);
|
||||
Tensor<CDataType> c_m_n_device_result_converted(c_m_n_host_result.mDesc);
|
||||
|
||||
c_m_n_device_buf.FromDevice(c_m_n_device_result_converted.mData.data());
|
||||
|
||||
c_m_n_device_result = c_m_n_device_result_converted.CopyAsType<CDataType>();
|
||||
|
||||
return ck::utils::check_err(c_m_n_device_result_converted.mData, c_m_n_host_result.mData);
|
||||
#else
|
||||
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -34,16 +34,16 @@ template <ck::index_t NDimSpatial,
|
||||
typename WeiElementOp,
|
||||
typename OutElementOp,
|
||||
typename DeviceConvNDFwdInstance>
|
||||
int run_grouped_conv_fwd(bool do_verification,
|
||||
int init_method,
|
||||
bool time_kernel,
|
||||
const ck::utils::conv::ConvParam& conv_param,
|
||||
const HostTensorDescriptor& in_g_n_c_wis_desc,
|
||||
const HostTensorDescriptor& wei_g_k_c_xs_desc,
|
||||
const HostTensorDescriptor& out_g_n_k_wos_desc,
|
||||
const InElementOp& in_element_op,
|
||||
const WeiElementOp& wei_element_op,
|
||||
const OutElementOp& out_element_op)
|
||||
bool run_grouped_conv_fwd(bool do_verification,
|
||||
int init_method,
|
||||
bool time_kernel,
|
||||
const ck::utils::conv::ConvParam& conv_param,
|
||||
const HostTensorDescriptor& in_g_n_c_wis_desc,
|
||||
const HostTensorDescriptor& wei_g_k_c_xs_desc,
|
||||
const HostTensorDescriptor& out_g_n_k_wos_desc,
|
||||
const InElementOp& in_element_op,
|
||||
const WeiElementOp& wei_element_op,
|
||||
const OutElementOp& out_element_op)
|
||||
{
|
||||
Tensor<InDataType> in(in_g_n_c_wis_desc);
|
||||
Tensor<WeiDataType> wei(wei_g_k_c_xs_desc);
|
||||
@@ -164,10 +164,8 @@ int run_grouped_conv_fwd(bool do_verification,
|
||||
out_device_buf.FromDevice(out_device.mData.data());
|
||||
|
||||
return ck::utils::check_err(
|
||||
out_device.mData, out_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f)
|
||||
? 0
|
||||
: 1;
|
||||
out_device.mData, out_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f);
|
||||
}
|
||||
|
||||
return 0;
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -74,154 +74,6 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
S<1, 32, 1, 8>,
|
||||
8>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
namespace ctc = ck::tensor_layout::convolution;
|
||||
#include "run_convnd_fwd_example.inc"
|
||||
|
||||
print_helper_msg();
|
||||
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
ck::utils::conv::ConvParam conv_param{
|
||||
2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}};
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
const ck::index_t num_dim_spatial = std::stoi(argv[4]);
|
||||
|
||||
conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv);
|
||||
}
|
||||
|
||||
const auto in_element_op = InElementOp{};
|
||||
const auto wei_element_op = WeiElementOp{};
|
||||
const auto out_element_op = OutElementOp{};
|
||||
|
||||
if(conv_param.num_dim_spatial_ == 1)
|
||||
{
|
||||
using InLayout = ctc::GNWC;
|
||||
using WeiLayout = ctc::GKXC;
|
||||
using OutLayout = ctc::GNWK;
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto wei_g_k_c_xs_desc =
|
||||
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto out_g_n_k_wos_desc =
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
|
||||
conv_param);
|
||||
|
||||
return run_grouped_conv_fwd<
|
||||
1,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
DeviceGroupedConvNDFwdInstance<1, InLayout, WeiLayout, OutLayout>>(do_verification,
|
||||
init_method,
|
||||
time_kernel,
|
||||
conv_param,
|
||||
in_g_n_c_wis_desc,
|
||||
wei_g_k_c_xs_desc,
|
||||
out_g_n_k_wos_desc,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
}
|
||||
else if(conv_param.num_dim_spatial_ == 2)
|
||||
{
|
||||
using InLayout = ctc::GNHWC;
|
||||
using WeiLayout = ctc::GKYXC;
|
||||
using OutLayout = ctc::GNHWK;
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto wei_g_k_c_xs_desc =
|
||||
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto out_g_n_k_wos_desc =
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
|
||||
conv_param);
|
||||
|
||||
return run_grouped_conv_fwd<
|
||||
2,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
DeviceGroupedConvNDFwdInstance<2, InLayout, WeiLayout, OutLayout>>(do_verification,
|
||||
init_method,
|
||||
time_kernel,
|
||||
conv_param,
|
||||
in_g_n_c_wis_desc,
|
||||
wei_g_k_c_xs_desc,
|
||||
out_g_n_k_wos_desc,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
}
|
||||
else if(conv_param.num_dim_spatial_ == 3)
|
||||
{
|
||||
using InLayout = ctc::GNDHWC;
|
||||
using WeiLayout = ctc::GKZYXC;
|
||||
using OutLayout = ctc::GNDHWK;
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto wei_g_k_c_xs_desc =
|
||||
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto out_g_n_k_wos_desc =
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
|
||||
conv_param);
|
||||
|
||||
return run_grouped_conv_fwd<
|
||||
3,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
DeviceGroupedConvNDFwdInstance<3, InLayout, WeiLayout, OutLayout>>(do_verification,
|
||||
init_method,
|
||||
time_kernel,
|
||||
conv_param,
|
||||
in_g_n_c_wis_desc,
|
||||
wei_g_k_c_xs_desc,
|
||||
out_g_n_k_wos_desc,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; }
|
||||
|
||||
@@ -74,95 +74,6 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
S<1, 32, 1, 8>,
|
||||
8>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
print_helper_msg();
|
||||
#include "run_convnd_fwd_example.inc"
|
||||
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
ck::utils::conv::ConvParam conv_param{
|
||||
2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}};
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
const ck::index_t num_dim_spatial = std::stoi(argv[4]);
|
||||
|
||||
conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv);
|
||||
}
|
||||
|
||||
const auto in_element_op = InElementOp{};
|
||||
const auto wei_element_op = WeiElementOp{};
|
||||
const auto out_element_op = OutElementOp{};
|
||||
|
||||
const auto run = [&](auto ndim_spatial, auto in_layout, auto wei_layout, auto out_layout) {
|
||||
constexpr ck::index_t ndim_spatial_value = ndim_spatial.value;
|
||||
|
||||
using InLayout = decltype(in_layout);
|
||||
using WeiLayout = decltype(wei_layout);
|
||||
using OutLayout = decltype(out_layout);
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto wei_g_k_c_xs_desc =
|
||||
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto out_g_n_k_wos_desc =
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
|
||||
conv_param);
|
||||
|
||||
return run_grouped_conv_fwd<
|
||||
ndim_spatial_value,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
DeviceGroupedConvNDFwdInstance<ndim_spatial_value, InLayout, WeiLayout, OutLayout>>(
|
||||
do_verification,
|
||||
init_method,
|
||||
time_kernel,
|
||||
conv_param,
|
||||
in_g_n_c_wis_desc,
|
||||
wei_g_k_c_xs_desc,
|
||||
out_g_n_k_wos_desc,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
};
|
||||
|
||||
namespace ctc = ck::tensor_layout::convolution;
|
||||
|
||||
if(conv_param.num_dim_spatial_ == 1)
|
||||
{
|
||||
run(ck::Number<1>{}, ctc::GNWC{}, ctc::GKXC{}, ctc::GNWK{});
|
||||
}
|
||||
else if(conv_param.num_dim_spatial_ == 2)
|
||||
{
|
||||
run(ck::Number<2>{}, ctc::GNHWC{}, ctc::GKYXC{}, ctc::GNHWK{});
|
||||
}
|
||||
else if(conv_param.num_dim_spatial_ == 3)
|
||||
{
|
||||
run(ck::Number<3>{}, ctc::GNDHWC{}, ctc::GKZYXC{}, ctc::GNDHWK{});
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; }
|
||||
|
||||
@@ -74,154 +74,6 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
S<1, 16, 1, 16>,
|
||||
4>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
namespace ctc = ck::tensor_layout::convolution;
|
||||
#include "run_convnd_fwd_example.inc"
|
||||
|
||||
print_helper_msg();
|
||||
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
ck::utils::conv::ConvParam conv_param{
|
||||
2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}};
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
const ck::index_t num_dim_spatial = std::stoi(argv[4]);
|
||||
|
||||
conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv);
|
||||
}
|
||||
|
||||
const auto in_element_op = InElementOp{};
|
||||
const auto wei_element_op = WeiElementOp{};
|
||||
const auto out_element_op = OutElementOp{};
|
||||
|
||||
if(conv_param.num_dim_spatial_ == 1)
|
||||
{
|
||||
using InLayout = ctc::GNWC;
|
||||
using WeiLayout = ctc::GKXC;
|
||||
using OutLayout = ctc::GNWK;
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto wei_g_k_c_xs_desc =
|
||||
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto out_g_n_k_wos_desc =
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
|
||||
conv_param);
|
||||
|
||||
return run_grouped_conv_fwd<
|
||||
1,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
DeviceGroupedConvNDFwdInstance<1, InLayout, WeiLayout, OutLayout>>(do_verification,
|
||||
init_method,
|
||||
time_kernel,
|
||||
conv_param,
|
||||
in_g_n_c_wis_desc,
|
||||
wei_g_k_c_xs_desc,
|
||||
out_g_n_k_wos_desc,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
}
|
||||
else if(conv_param.num_dim_spatial_ == 2)
|
||||
{
|
||||
using InLayout = ctc::GNHWC;
|
||||
using WeiLayout = ctc::GKYXC;
|
||||
using OutLayout = ctc::GNHWK;
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto wei_g_k_c_xs_desc =
|
||||
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto out_g_n_k_wos_desc =
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
|
||||
conv_param);
|
||||
|
||||
return run_grouped_conv_fwd<
|
||||
2,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
DeviceGroupedConvNDFwdInstance<2, InLayout, WeiLayout, OutLayout>>(do_verification,
|
||||
init_method,
|
||||
time_kernel,
|
||||
conv_param,
|
||||
in_g_n_c_wis_desc,
|
||||
wei_g_k_c_xs_desc,
|
||||
out_g_n_k_wos_desc,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
}
|
||||
else if(conv_param.num_dim_spatial_ == 3)
|
||||
{
|
||||
using InLayout = ctc::GNDHWC;
|
||||
using WeiLayout = ctc::GKZYXC;
|
||||
using OutLayout = ctc::GNDHWK;
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto wei_g_k_c_xs_desc =
|
||||
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto out_g_n_k_wos_desc =
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
|
||||
conv_param);
|
||||
|
||||
return run_grouped_conv_fwd<
|
||||
3,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
DeviceGroupedConvNDFwdInstance<3, InLayout, WeiLayout, OutLayout>>(do_verification,
|
||||
init_method,
|
||||
time_kernel,
|
||||
conv_param,
|
||||
in_g_n_c_wis_desc,
|
||||
wei_g_k_c_xs_desc,
|
||||
out_g_n_k_wos_desc,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; }
|
||||
|
||||
@@ -74,154 +74,6 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
S<1, 16, 1, 16>,
|
||||
1>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
namespace ctc = ck::tensor_layout::convolution;
|
||||
#include "run_convnd_fwd_example.inc"
|
||||
|
||||
print_helper_msg();
|
||||
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
ck::utils::conv::ConvParam conv_param{
|
||||
2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}};
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
const ck::index_t num_dim_spatial = std::stoi(argv[4]);
|
||||
|
||||
conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv);
|
||||
}
|
||||
|
||||
const auto in_element_op = InElementOp{};
|
||||
const auto wei_element_op = WeiElementOp{};
|
||||
const auto out_element_op = OutElementOp{};
|
||||
|
||||
if(conv_param.num_dim_spatial_ == 1)
|
||||
{
|
||||
using InLayout = ctc::GNWC;
|
||||
using WeiLayout = ctc::GKXC;
|
||||
using OutLayout = ctc::GNWK;
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto wei_g_k_c_xs_desc =
|
||||
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto out_g_n_k_wos_desc =
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
|
||||
conv_param);
|
||||
|
||||
return run_grouped_conv_fwd<
|
||||
1,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
DeviceGroupedConvNDFwdInstance<1, InLayout, WeiLayout, OutLayout>>(do_verification,
|
||||
init_method,
|
||||
time_kernel,
|
||||
conv_param,
|
||||
in_g_n_c_wis_desc,
|
||||
wei_g_k_c_xs_desc,
|
||||
out_g_n_k_wos_desc,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
}
|
||||
else if(conv_param.num_dim_spatial_ == 2)
|
||||
{
|
||||
using InLayout = ctc::GNHWC;
|
||||
using WeiLayout = ctc::GKYXC;
|
||||
using OutLayout = ctc::GNHWK;
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto wei_g_k_c_xs_desc =
|
||||
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto out_g_n_k_wos_desc =
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
|
||||
conv_param);
|
||||
|
||||
return run_grouped_conv_fwd<
|
||||
2,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
DeviceGroupedConvNDFwdInstance<2, InLayout, WeiLayout, OutLayout>>(do_verification,
|
||||
init_method,
|
||||
time_kernel,
|
||||
conv_param,
|
||||
in_g_n_c_wis_desc,
|
||||
wei_g_k_c_xs_desc,
|
||||
out_g_n_k_wos_desc,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
}
|
||||
else if(conv_param.num_dim_spatial_ == 3)
|
||||
{
|
||||
using InLayout = ctc::GNDHWC;
|
||||
using WeiLayout = ctc::GKZYXC;
|
||||
using OutLayout = ctc::GNDHWK;
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto wei_g_k_c_xs_desc =
|
||||
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto out_g_n_k_wos_desc =
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
|
||||
conv_param);
|
||||
|
||||
return run_grouped_conv_fwd<
|
||||
3,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
DeviceGroupedConvNDFwdInstance<3, InLayout, WeiLayout, OutLayout>>(do_verification,
|
||||
init_method,
|
||||
time_kernel,
|
||||
conv_param,
|
||||
in_g_n_c_wis_desc,
|
||||
wei_g_k_c_xs_desc,
|
||||
out_g_n_k_wos_desc,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; }
|
||||
|
||||
@@ -74,154 +74,6 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
S<1, 64, 1, 4>,
|
||||
16>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
namespace ctc = ck::tensor_layout::convolution;
|
||||
#include "run_convnd_fwd_example.inc"
|
||||
|
||||
print_helper_msg();
|
||||
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
ck::utils::conv::ConvParam conv_param{
|
||||
2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}};
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
const ck::index_t num_dim_spatial = std::stoi(argv[4]);
|
||||
|
||||
conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv);
|
||||
}
|
||||
|
||||
const auto in_element_op = InElementOp{};
|
||||
const auto wei_element_op = WeiElementOp{};
|
||||
const auto out_element_op = OutElementOp{};
|
||||
|
||||
if(conv_param.num_dim_spatial_ == 1)
|
||||
{
|
||||
using InLayout = ctc::GNWC;
|
||||
using WeiLayout = ctc::GKXC;
|
||||
using OutLayout = ctc::GNWK;
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto wei_g_k_c_xs_desc =
|
||||
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto out_g_n_k_wos_desc =
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
|
||||
conv_param);
|
||||
|
||||
return run_grouped_conv_fwd<
|
||||
1,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
DeviceGroupedConvNDFwdInstance<1, InLayout, WeiLayout, OutLayout>>(do_verification,
|
||||
init_method,
|
||||
time_kernel,
|
||||
conv_param,
|
||||
in_g_n_c_wis_desc,
|
||||
wei_g_k_c_xs_desc,
|
||||
out_g_n_k_wos_desc,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
}
|
||||
else if(conv_param.num_dim_spatial_ == 2)
|
||||
{
|
||||
using InLayout = ctc::GNHWC;
|
||||
using WeiLayout = ctc::GKYXC;
|
||||
using OutLayout = ctc::GNHWK;
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto wei_g_k_c_xs_desc =
|
||||
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto out_g_n_k_wos_desc =
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
|
||||
conv_param);
|
||||
|
||||
return run_grouped_conv_fwd<
|
||||
2,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
DeviceGroupedConvNDFwdInstance<2, InLayout, WeiLayout, OutLayout>>(do_verification,
|
||||
init_method,
|
||||
time_kernel,
|
||||
conv_param,
|
||||
in_g_n_c_wis_desc,
|
||||
wei_g_k_c_xs_desc,
|
||||
out_g_n_k_wos_desc,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
}
|
||||
else if(conv_param.num_dim_spatial_ == 3)
|
||||
{
|
||||
using InLayout = ctc::GNDHWC;
|
||||
using WeiLayout = ctc::GKZYXC;
|
||||
using OutLayout = ctc::GNDHWK;
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto wei_g_k_c_xs_desc =
|
||||
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto out_g_n_k_wos_desc =
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
|
||||
conv_param);
|
||||
|
||||
return run_grouped_conv_fwd<
|
||||
3,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
DeviceGroupedConvNDFwdInstance<3, InLayout, WeiLayout, OutLayout>>(do_verification,
|
||||
init_method,
|
||||
time_kernel,
|
||||
conv_param,
|
||||
in_g_n_c_wis_desc,
|
||||
wei_g_k_c_xs_desc,
|
||||
out_g_n_k_wos_desc,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; }
|
||||
|
||||
97
example/09_convnd_fwd/run_convnd_fwd_example.inc
Normal file
97
example/09_convnd_fwd/run_convnd_fwd_example.inc
Normal file
@@ -0,0 +1,97 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
bool run_convnd_fwd_example(int argc, char* argv[])
|
||||
{
|
||||
print_helper_msg();
|
||||
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
ck::utils::conv::ConvParam conv_param{
|
||||
2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}};
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
const ck::index_t num_dim_spatial = std::stoi(argv[4]);
|
||||
|
||||
conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv);
|
||||
}
|
||||
|
||||
const auto in_element_op = InElementOp{};
|
||||
const auto wei_element_op = WeiElementOp{};
|
||||
const auto out_element_op = OutElementOp{};
|
||||
|
||||
const auto run = [&](auto ndim_spatial, auto in_layout, auto wei_layout, auto out_layout) {
|
||||
constexpr ck::index_t ndim_spatial_value = ndim_spatial.value;
|
||||
|
||||
using InLayout = decltype(in_layout);
|
||||
using WeiLayout = decltype(wei_layout);
|
||||
using OutLayout = decltype(out_layout);
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto wei_g_k_c_xs_desc =
|
||||
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto out_g_n_k_wos_desc =
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
|
||||
conv_param);
|
||||
|
||||
return run_grouped_conv_fwd<
|
||||
ndim_spatial_value,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
DeviceGroupedConvNDFwdInstance<ndim_spatial_value, InLayout, WeiLayout, OutLayout>>(
|
||||
do_verification,
|
||||
init_method,
|
||||
time_kernel,
|
||||
conv_param,
|
||||
in_g_n_c_wis_desc,
|
||||
wei_g_k_c_xs_desc,
|
||||
out_g_n_k_wos_desc,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
};
|
||||
|
||||
namespace ctc = ck::tensor_layout::convolution;
|
||||
|
||||
if(conv_param.num_dim_spatial_ == 1)
|
||||
{
|
||||
return run(ck::Number<1>{}, ctc::GNWC{}, ctc::GKXC{}, ctc::GNWK{});
|
||||
}
|
||||
else if(conv_param.num_dim_spatial_ == 2)
|
||||
{
|
||||
return run(ck::Number<2>{}, ctc::GNHWC{}, ctc::GKYXC{}, ctc::GNHWK{});
|
||||
}
|
||||
else if(conv_param.num_dim_spatial_ == 3)
|
||||
{
|
||||
return run(ck::Number<3>{}, ctc::GNDHWC{}, ctc::GKZYXC{}, ctc::GNDHWK{});
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
@@ -1 +1,8 @@
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp)
|
||||
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
|
||||
135
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_bf16.cpp
Normal file
135
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_bf16.cpp
Normal file
@@ -0,0 +1,135 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
/*
|
||||
Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
|
||||
|------------|
|
||||
Gemm0
|
||||
|---------------------|
|
||||
Gemm1
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = BF16;
|
||||
using B0DataType = BF16;
|
||||
using B1DataType = BF16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using CDataType = BF16;
|
||||
|
||||
using ALayout = Row;
|
||||
using B0Layout = Col;
|
||||
using B1Layout = Row;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using Acc0ElementOp = PassThrough;
|
||||
using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle<
|
||||
ALayout,
|
||||
B0Layout,
|
||||
B1Layout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
AElementOp,
|
||||
B0ElementOp,
|
||||
Acc0ElementOp,
|
||||
B1ElementOp,
|
||||
CElementOp,
|
||||
GemmDefault,
|
||||
1,
|
||||
256,
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
32, // KPerBlock
|
||||
128, // Gemm1NPerBlock
|
||||
32, // Gemm1KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
2, // B1K1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
1, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
4, // Gemm1NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<4, 64, 1>, // BBlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<8, 32, 1>, // B1BlockTransfer
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
4,
|
||||
2,
|
||||
false,
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
|
||||
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
|
||||
B0DataType,
|
||||
ADataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
B0ElementOp,
|
||||
CElementOp>;
|
||||
|
||||
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
B1ElementOp,
|
||||
CElementOp>;
|
||||
|
||||
#include "run_batched_gemm_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return run_batched_gemm_gemm_example(argc, argv) ? 0 : 1; }
|
||||
@@ -121,6 +121,7 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
|
||||
AElementOp,
|
||||
B0ElementOp,
|
||||
CElementOp>;
|
||||
|
||||
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
@@ -129,244 +130,6 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
|
||||
B1ElementOp,
|
||||
CElementOp>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
#include "run_batched_gemm_gemm_example.inc"
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 1024;
|
||||
ck::index_t N = 1024;
|
||||
ck::index_t K = 64;
|
||||
ck::index_t O = 128;
|
||||
ck::index_t BatchCount = 4;
|
||||
ck::index_t StrideA = -1;
|
||||
ck::index_t StrideB0 = -1;
|
||||
ck::index_t StrideB1 = -1;
|
||||
ck::index_t StrideC = -1;
|
||||
ck::index_t BatchStrideA = -1;
|
||||
ck::index_t BatchStrideB0 = -1;
|
||||
ck::index_t BatchStrideB1 = -1;
|
||||
ck::index_t BatchStrideC = -1;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 9)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
O = std::stoi(argv[7]);
|
||||
|
||||
BatchCount = std::stoi(argv[8]);
|
||||
}
|
||||
else if(argc == 17)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
O = std::stoi(argv[7]);
|
||||
|
||||
BatchCount = std::stoi(argv[8]);
|
||||
|
||||
StrideA = std::stoi(argv[9]);
|
||||
StrideB0 = std::stoi(argv[10]);
|
||||
StrideB1 = std::stoi(argv[11]);
|
||||
StrideC = std::stoi(argv[12]);
|
||||
|
||||
BatchStrideA = std::stoi(argv[13]);
|
||||
BatchStrideB0 = std::stoi(argv[14]);
|
||||
BatchStrideB1 = std::stoi(argv[15]);
|
||||
BatchStrideC = std::stoi(argv[16]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 17: M, N, K, O, Batch, StrideA, StrideB0, StrideB1, StrideC, BatchStrideA, "
|
||||
"BatchStrideB0, BatchStrideB1, BatchStrideC\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
|
||||
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
|
||||
const int DefaultStrideB1 = ck::is_same_v<B1Layout, Row> ? O : N;
|
||||
const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? O : M;
|
||||
|
||||
StrideA = (StrideA < 0) ? DefaultStrideA : StrideA;
|
||||
StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0;
|
||||
StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1;
|
||||
StrideC = (StrideC < 0) ? DefaultStrideC : StrideC;
|
||||
|
||||
const int DefaultBatchStrideA = (ck::is_same_v<ALayout, Col> ? K : M) * StrideA;
|
||||
const int DefaultBatchStrideB0 = (ck::is_same_v<B0Layout, Col> ? N : K) * StrideB0;
|
||||
const int DefaultBatchStrideB1 = (ck::is_same_v<B1Layout, Col> ? O : N) * StrideB1;
|
||||
const int DefaultBatchStrideC = (ck::is_same_v<CLayout, Col> ? O : M) * StrideC;
|
||||
|
||||
BatchStrideA = BatchStrideA < 0 ? DefaultBatchStrideA : BatchStrideA;
|
||||
BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0;
|
||||
BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1;
|
||||
BatchStrideC = BatchStrideC < 0 ? DefaultBatchStrideC : BatchStrideC;
|
||||
|
||||
auto f_host_tensor_descriptor = [](std::size_t batch_count,
|
||||
std::size_t row,
|
||||
std::size_t col,
|
||||
std::size_t stride,
|
||||
std::size_t batch_stride,
|
||||
auto layout) {
|
||||
if(std::is_same<decltype(layout), Row>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
|
||||
std::vector<std::size_t>({batch_stride, stride, 1}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
|
||||
std::vector<std::size_t>({batch_stride, 1, stride}));
|
||||
}
|
||||
};
|
||||
|
||||
// C_m_o = A_m_k * B0_k_n * B1_n_o
|
||||
Tensor<ADataType> a_g_m_k(
|
||||
f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{}));
|
||||
Tensor<B0DataType> b0_g_k_n(
|
||||
f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{}));
|
||||
Tensor<B1DataType> b1_g_n_o(
|
||||
f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{}));
|
||||
Tensor<CDataType> c_g_m_o_host_result(
|
||||
f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{}));
|
||||
Tensor<CDataType> c_g_m_o_device_result(
|
||||
f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{}));
|
||||
|
||||
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
|
||||
std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl;
|
||||
std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl;
|
||||
std::cout << "c_g_m_o: " << c_g_m_o_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-5, 5});
|
||||
break;
|
||||
case 2:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
|
||||
break;
|
||||
default:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
}
|
||||
|
||||
DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_g_m_o_device_buf(sizeof(CDataType) *
|
||||
c_g_m_o_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data());
|
||||
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
|
||||
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b0_element_op = B0ElementOp{};
|
||||
auto acc0_element_op = Acc0ElementOp{};
|
||||
auto b1_element_op = B1ElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
auto argument =
|
||||
gemm.MakeArgument(static_cast<ADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_g_m_o_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
BatchCount,
|
||||
StrideA,
|
||||
StrideB0,
|
||||
StrideB1,
|
||||
StrideC,
|
||||
BatchStrideA,
|
||||
BatchStrideB0,
|
||||
BatchStrideB1,
|
||||
BatchStrideC,
|
||||
a_element_op,
|
||||
b0_element_op,
|
||||
acc0_element_op,
|
||||
b1_element_op,
|
||||
c_element_op);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
|
||||
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
|
||||
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) *
|
||||
BatchCount;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< gemm.GetTypeString() << std::endl;
|
||||
|
||||
c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result.mData.data());
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
// Output of Gemm0 is input A of Gemm1
|
||||
Tensor<ADataType> a1_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
|
||||
|
||||
auto ref_gemm0 = ReferenceGemm0Instance{};
|
||||
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
|
||||
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
|
||||
a_g_m_k, b0_g_k_n, a1_g_m_n, a_element_op, b0_element_op, PassThrough{});
|
||||
|
||||
ref_gemm0_invoker.Run(ref_gemm0_argument);
|
||||
|
||||
auto ref_gemm1 = ReferenceGemm1Instance{};
|
||||
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
|
||||
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
|
||||
a1_g_m_n, b1_g_n_o, c_g_m_o_host_result, PassThrough{}, b1_element_op, c_element_op);
|
||||
|
||||
ref_gemm1_invoker.Run(ref_gemm1_argument);
|
||||
|
||||
return ck::utils::check_err(c_g_m_o_device_result.mData, c_g_m_o_host_result.mData) ? 0 : 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
int main(int argc, char* argv[]) { return run_batched_gemm_gemm_example(argc, argv) ? 0 : 1; }
|
||||
|
||||
134
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp32.cpp
Normal file
134
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp32.cpp
Normal file
@@ -0,0 +1,134 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
/*
|
||||
Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
|
||||
|------------|
|
||||
Gemm0
|
||||
|---------------------|
|
||||
Gemm1
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = F32;
|
||||
using B0DataType = F32;
|
||||
using B1DataType = F32;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using CDataType = F32;
|
||||
|
||||
using ALayout = Row;
|
||||
using B0Layout = Col;
|
||||
using B1Layout = Row;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using Acc0ElementOp = PassThrough;
|
||||
using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle<
|
||||
ALayout,
|
||||
B0Layout,
|
||||
B1Layout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
AElementOp,
|
||||
B0ElementOp,
|
||||
Acc0ElementOp,
|
||||
B1ElementOp,
|
||||
CElementOp,
|
||||
GemmDefault,
|
||||
1,
|
||||
256,
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
16, // KPerBlock
|
||||
128, // Gemm1NPerBlock
|
||||
16, // Gemm1KPerBlock
|
||||
4, // AK1
|
||||
4, // BK1
|
||||
1, // B1K1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
1, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
4, // Gemm1NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
4,
|
||||
4,
|
||||
true,
|
||||
S<4, 64, 1>, // BBlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
4,
|
||||
4,
|
||||
true,
|
||||
S<8, 32, 1>, // B1BlockTransfer
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
false,
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 16, 1, 16>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
4>; // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
|
||||
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
|
||||
B0DataType,
|
||||
ADataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
B0ElementOp,
|
||||
CElementOp>;
|
||||
|
||||
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
B1ElementOp,
|
||||
CElementOp>;
|
||||
|
||||
#include "run_batched_gemm_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return run_batched_gemm_gemm_example(argc, argv) ? 0 : 1; }
|
||||
145
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp
Normal file
145
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp
Normal file
@@ -0,0 +1,145 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
/*
|
||||
Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
|
||||
|------------|
|
||||
Gemm0
|
||||
|---------------------|
|
||||
Gemm1
|
||||
*/
|
||||
|
||||
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
#error Should compile this file with ck::int4_t support
|
||||
#endif
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = ck::int4_t;
|
||||
using B0DataType = ck::int4_t;
|
||||
using B1DataType = ck::int4_t;
|
||||
using KernelADataType = int8_t;
|
||||
using KernelB0DataType = int8_t;
|
||||
using KernelB1DataType = int8_t;
|
||||
using AccDataType = int32_t;
|
||||
using CShuffleDataType = int32_t;
|
||||
using CDataType = ck::int4_t;
|
||||
using KernelCDataType = int8_t;
|
||||
|
||||
using ALayout = Row;
|
||||
using B0Layout = Col;
|
||||
using B1Layout = Row;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using Acc0ElementOp = PassThrough;
|
||||
using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle<
|
||||
ALayout,
|
||||
B0Layout,
|
||||
B1Layout,
|
||||
CLayout,
|
||||
KernelADataType,
|
||||
KernelB0DataType,
|
||||
KernelB1DataType,
|
||||
KernelCDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
AElementOp,
|
||||
B0ElementOp,
|
||||
Acc0ElementOp,
|
||||
B1ElementOp,
|
||||
CElementOp,
|
||||
GemmDefault,
|
||||
1,
|
||||
256,
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
64, // KPerBlock
|
||||
128, // Gemm1NPerBlock
|
||||
64, // Gemm1KPerBlock
|
||||
16, // AK1
|
||||
16, // BK1
|
||||
4, // B1K1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
1, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
4, // Gemm1NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
16,
|
||||
16,
|
||||
true,
|
||||
S<4, 64, 1>, // BBlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
16,
|
||||
16,
|
||||
true,
|
||||
S<8, 32, 1>, // B1BlockTransfer
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
4,
|
||||
4,
|
||||
false,
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
|
||||
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
|
||||
B0DataType,
|
||||
ADataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
B0ElementOp,
|
||||
CElementOp>;
|
||||
|
||||
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
B1ElementOp,
|
||||
CElementOp>;
|
||||
|
||||
#define BUILD_INT4_EXAMPLE
|
||||
#include "run_batched_gemm_gemm_example.inc"
|
||||
|
||||
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
|
||||
static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
|
||||
#endif
|
||||
|
||||
int main(int argc, char* argv[]) { return run_batched_gemm_gemm_example(argc, argv) ? 0 : 1; }
|
||||
132
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int8.cpp
Normal file
132
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int8.cpp
Normal file
@@ -0,0 +1,132 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
/*
|
||||
Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
|
||||
|------------|
|
||||
Gemm0
|
||||
|---------------------|
|
||||
Gemm1
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = int8_t;
|
||||
using B0DataType = int8_t;
|
||||
using B1DataType = int8_t;
|
||||
using AccDataType = int32_t;
|
||||
using CShuffleDataType = int32_t;
|
||||
using CDataType = int8_t;
|
||||
|
||||
using ALayout = Row;
|
||||
using B0Layout = Col;
|
||||
using B1Layout = Row;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using Acc0ElementOp = PassThrough;
|
||||
using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle<
|
||||
ALayout,
|
||||
B0Layout,
|
||||
B1Layout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
AElementOp,
|
||||
B0ElementOp,
|
||||
Acc0ElementOp,
|
||||
B1ElementOp,
|
||||
CElementOp,
|
||||
GemmDefault,
|
||||
1,
|
||||
256,
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
64, // KPerBlock
|
||||
128, // Gemm1NPerBlock
|
||||
64, // Gemm1KPerBlock
|
||||
16, // AK1
|
||||
16, // BK1
|
||||
4, // B1K1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
1, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
4, // Gemm1NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
16,
|
||||
16,
|
||||
true,
|
||||
S<4, 64, 1>, // BBlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
16,
|
||||
16,
|
||||
true,
|
||||
S<8, 32, 1>, // B1BlockTransfer
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
4,
|
||||
4,
|
||||
false,
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
|
||||
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
|
||||
B0DataType,
|
||||
ADataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
B0ElementOp,
|
||||
CElementOp>;
|
||||
|
||||
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
B1ElementOp,
|
||||
CElementOp>;
|
||||
|
||||
#include "run_batched_gemm_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return run_batched_gemm_gemm_example(argc, argv) ? 0 : 1; }
|
||||
277
example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc
Normal file
277
example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc
Normal file
@@ -0,0 +1,277 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
bool run_batched_gemm_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 1024;
|
||||
ck::index_t N = 1024;
|
||||
ck::index_t K = 64;
|
||||
ck::index_t O = 128;
|
||||
ck::index_t BatchCount = 4;
|
||||
ck::index_t StrideA = -1;
|
||||
ck::index_t StrideB0 = -1;
|
||||
ck::index_t StrideB1 = -1;
|
||||
ck::index_t StrideC = -1;
|
||||
ck::index_t BatchStrideA = -1;
|
||||
ck::index_t BatchStrideB0 = -1;
|
||||
ck::index_t BatchStrideB1 = -1;
|
||||
ck::index_t BatchStrideC = -1;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 9)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
O = std::stoi(argv[7]);
|
||||
|
||||
BatchCount = std::stoi(argv[8]);
|
||||
}
|
||||
else if(argc == 17)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
O = std::stoi(argv[7]);
|
||||
|
||||
BatchCount = std::stoi(argv[8]);
|
||||
|
||||
StrideA = std::stoi(argv[9]);
|
||||
StrideB0 = std::stoi(argv[10]);
|
||||
StrideB1 = std::stoi(argv[11]);
|
||||
StrideC = std::stoi(argv[12]);
|
||||
|
||||
BatchStrideA = std::stoi(argv[13]);
|
||||
BatchStrideB0 = std::stoi(argv[14]);
|
||||
BatchStrideB1 = std::stoi(argv[15]);
|
||||
BatchStrideC = std::stoi(argv[16]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 17: M, N, K, O, Batch, StrideA, StrideB0, StrideB1, StrideC, BatchStrideA, "
|
||||
"BatchStrideB0, BatchStrideB1, BatchStrideC\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
|
||||
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
|
||||
const int DefaultStrideB1 = ck::is_same_v<B1Layout, Row> ? O : N;
|
||||
const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? O : M;
|
||||
|
||||
StrideA = (StrideA < 0) ? DefaultStrideA : StrideA;
|
||||
StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0;
|
||||
StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1;
|
||||
StrideC = (StrideC < 0) ? DefaultStrideC : StrideC;
|
||||
|
||||
const int DefaultBatchStrideA = (ck::is_same_v<ALayout, Col> ? K : M) * StrideA;
|
||||
const int DefaultBatchStrideB0 = (ck::is_same_v<B0Layout, Col> ? N : K) * StrideB0;
|
||||
const int DefaultBatchStrideB1 = (ck::is_same_v<B1Layout, Col> ? O : N) * StrideB1;
|
||||
const int DefaultBatchStrideC = (ck::is_same_v<CLayout, Col> ? O : M) * StrideC;
|
||||
|
||||
BatchStrideA = BatchStrideA < 0 ? DefaultBatchStrideA : BatchStrideA;
|
||||
BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0;
|
||||
BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1;
|
||||
BatchStrideC = BatchStrideC < 0 ? DefaultBatchStrideC : BatchStrideC;
|
||||
|
||||
auto f_host_tensor_descriptor = [](std::size_t batch_count,
|
||||
std::size_t row,
|
||||
std::size_t col,
|
||||
std::size_t stride,
|
||||
std::size_t batch_stride,
|
||||
auto layout) {
|
||||
if(std::is_same<decltype(layout), Row>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
|
||||
std::vector<std::size_t>({batch_stride, stride, 1}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
|
||||
std::vector<std::size_t>({batch_stride, 1, stride}));
|
||||
}
|
||||
};
|
||||
|
||||
// C_m_o = A_m_k * B0_k_n * B1_n_o
|
||||
Tensor<ADataType> a_g_m_k(
|
||||
f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{}));
|
||||
Tensor<B0DataType> b0_g_k_n(
|
||||
f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{}));
|
||||
Tensor<B1DataType> b1_g_n_o(
|
||||
f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{}));
|
||||
Tensor<CDataType> c_g_m_o_host_result(
|
||||
f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{}));
|
||||
Tensor<CDataType> c_g_m_o_device_result(
|
||||
f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{}));
|
||||
|
||||
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
|
||||
std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl;
|
||||
std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl;
|
||||
std::cout << "c_g_m_o: " << c_g_m_o_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-5, 5});
|
||||
break;
|
||||
case 2:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
|
||||
break;
|
||||
default:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
}
|
||||
|
||||
#ifdef BUILD_INT4_EXAMPLE
|
||||
DeviceMem a_g_m_k_device_buf(sizeof(KernelADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b0_g_k_n_device_buf(sizeof(KernelB0DataType) * b0_g_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b1_g_n_o_device_buf(sizeof(KernelB1DataType) * b1_g_n_o.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_g_m_o_device_buf(sizeof(KernelCDataType) *
|
||||
c_g_m_o_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
const Tensor<KernelADataType> a_g_m_k_converted(a_g_m_k);
|
||||
const Tensor<KernelB0DataType> b0_g_k_n_converted(b0_g_k_n);
|
||||
const Tensor<KernelB1DataType> b1_g_n_o_converted(b1_g_n_o);
|
||||
|
||||
a_g_m_k_device_buf.ToDevice(a_g_m_k_converted.mData.data());
|
||||
b0_g_k_n_device_buf.ToDevice(b0_g_k_n_converted.mData.data());
|
||||
b1_g_n_o_device_buf.ToDevice(b1_g_n_o_converted.mData.data());
|
||||
#else
|
||||
DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_g_m_o_device_buf(sizeof(CDataType) *
|
||||
c_g_m_o_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data());
|
||||
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
|
||||
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data());
|
||||
#endif
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b0_element_op = B0ElementOp{};
|
||||
auto acc0_element_op = Acc0ElementOp{};
|
||||
auto b1_element_op = B1ElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
auto argument = gemm.MakeArgument(
|
||||
#ifdef BUILD_INT4_EXAMPLE
|
||||
static_cast<KernelADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<KernelB0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<KernelB1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
|
||||
static_cast<KernelCDataType*>(c_g_m_o_device_buf.GetDeviceBuffer()),
|
||||
#else
|
||||
static_cast<ADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_g_m_o_device_buf.GetDeviceBuffer()),
|
||||
#endif
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
BatchCount,
|
||||
StrideA,
|
||||
StrideB0,
|
||||
StrideB1,
|
||||
StrideC,
|
||||
BatchStrideA,
|
||||
BatchStrideB0,
|
||||
BatchStrideB1,
|
||||
BatchStrideC,
|
||||
a_element_op,
|
||||
b0_element_op,
|
||||
acc0_element_op,
|
||||
b1_element_op,
|
||||
c_element_op);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
|
||||
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
|
||||
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) *
|
||||
BatchCount;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< gemm.GetTypeString() << std::endl;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
// Output of Gemm0 is input A of Gemm1
|
||||
Tensor<ADataType> a1_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
|
||||
|
||||
auto ref_gemm0 = ReferenceGemm0Instance{};
|
||||
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
|
||||
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
|
||||
a_g_m_k, b0_g_k_n, a1_g_m_n, a_element_op, b0_element_op, PassThrough{});
|
||||
|
||||
ref_gemm0_invoker.Run(ref_gemm0_argument);
|
||||
|
||||
auto ref_gemm1 = ReferenceGemm1Instance{};
|
||||
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
|
||||
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
|
||||
a1_g_m_n, b1_g_n_o, c_g_m_o_host_result, PassThrough{}, b1_element_op, c_element_op);
|
||||
|
||||
ref_gemm1_invoker.Run(ref_gemm1_argument);
|
||||
|
||||
#ifdef BUILD_INT4_EXAMPLE
|
||||
Tensor<KernelCDataType> c_g_m_o_device_result_converted(c_g_m_o_host_result.mDesc);
|
||||
|
||||
c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result_converted.mData.data());
|
||||
|
||||
c_g_m_o_device_result = c_g_m_o_device_result_converted.CopyAsType<CDataType>();
|
||||
#else
|
||||
c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result.mData.data());
|
||||
#endif
|
||||
|
||||
return ck::utils::check_err(c_g_m_o_device_result.mData, c_g_m_o_host_result.mData);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
@@ -1 +1,8 @@
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp)
|
||||
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/convolution_parameter.hpp"
|
||||
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
|
||||
|
||||
using In0DataType = ck::bhalf_t;
|
||||
using Wei0DataType = ck::bhalf_t;
|
||||
using Acc0DataType = float;
|
||||
using Wei1DataType = ck::bhalf_t;
|
||||
using Acc1DataType = float;
|
||||
using C1ShuffleDataType = float;
|
||||
using Out1DataType = ck::bhalf_t;
|
||||
|
||||
// This is used for reference code
|
||||
using Out0DataType = ck::bhalf_t;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using In0ElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Wei0ElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Wei1ElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Out0ElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Out1ElementOp = ck::tensor_operation::element_wise::UnaryConvert;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
using DeviceBatchedGemmGemmInstance =
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle<
|
||||
Row, // ALayout
|
||||
Col, // B0Layout
|
||||
Col, // B1Layout
|
||||
Row, // CLayout
|
||||
In0DataType, // ADataType,
|
||||
Wei0DataType, // B0DataType,
|
||||
Wei1DataType, // B1DataType,
|
||||
Out1DataType, // CDataType,
|
||||
Acc0DataType, // AccDataType,
|
||||
C1ShuffleDataType, // CShuffleDataType,
|
||||
In0ElementOp, // AElementOp,
|
||||
Wei0ElementOp, // B0ElementOp,
|
||||
Out0ElementOp, // Acc0ElementOp,
|
||||
Wei1ElementOp, // B1ElementOp,
|
||||
Out1ElementOp, // CElementOp,
|
||||
GemmDefault,
|
||||
1,
|
||||
256,
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
32, // KPerBlock
|
||||
128, // Gemm1NPerBlock
|
||||
32, // Gemm1KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
4, // B1K1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
1, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
4, // Gemm1NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<4, 64, 1>, // BBlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<4, 64, 1>, // B1BlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
4,
|
||||
4,
|
||||
true,
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
|
||||
#include "run_grouped_conv_conv_fwd_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return run_grouped_conv_conv_fwd_example(argc, argv) ? 0 : 1; }
|
||||
@@ -1,11 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "grouped_conv_conv_fwd_common.hpp"
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/convolution_parameter.hpp"
|
||||
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
|
||||
|
||||
using In0DataType = ck::half_t;
|
||||
using Wei0DataType = ck::half_t;
|
||||
@@ -15,6 +27,9 @@ using Acc1DataType = float;
|
||||
using C1ShuffleDataType = float;
|
||||
using Out1DataType = ck::half_t;
|
||||
|
||||
// This is used for reference code
|
||||
using Out0DataType = ck::half_t;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
@@ -88,117 +103,6 @@ using DeviceBatchedGemmGemmInstance =
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
#include "run_grouped_conv_conv_fwd_example.inc"
|
||||
|
||||
ck::utils::conv::ConvParam conv0_param{
|
||||
2, 1, 128, 512, 128, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}};
|
||||
|
||||
ck::utils::conv::ConvParam conv1_param{
|
||||
2, 1, 128, 128, 512, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}};
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
const auto in0_element_op = In0ElementOp{};
|
||||
const auto wei0_element_op = Wei0ElementOp{};
|
||||
const auto wei1_element_op = Wei1ElementOp{};
|
||||
const auto out0_element_op = Out0ElementOp{};
|
||||
const auto out1_element_op = Out1ElementOp{};
|
||||
|
||||
const auto run = [&](auto ndim_spatial,
|
||||
auto in0_layout,
|
||||
auto wei0_layout,
|
||||
auto wei1_layout,
|
||||
auto out1_layout) {
|
||||
constexpr ck::index_t ndim_spatial_value = ndim_spatial.value;
|
||||
|
||||
using In0Layout = decltype(in0_layout);
|
||||
using Wei0Layout = decltype(wei0_layout);
|
||||
using Wei1Layout = decltype(wei1_layout);
|
||||
using Out1Layout = decltype(out1_layout);
|
||||
|
||||
const auto in0_g_n_c_wis_desc =
|
||||
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<In0Layout>(
|
||||
conv0_param);
|
||||
|
||||
const auto wei0_g_k_c_xs_desc =
|
||||
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<Wei0Layout>(
|
||||
conv0_param);
|
||||
|
||||
// out0 doesn't physical exist, any layout for host verification is OK
|
||||
const auto out0_g_n_k_wos_desc =
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<Out1Layout>(
|
||||
conv0_param);
|
||||
|
||||
const auto wei1_g_k_c_xs_desc =
|
||||
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<Wei1Layout>(
|
||||
conv1_param);
|
||||
|
||||
const auto out1_g_n_k_wos_desc =
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<Out1Layout>(
|
||||
conv1_param);
|
||||
|
||||
return run_grouped_conv_conv_fwd<ndim_spatial_value,
|
||||
In0DataType,
|
||||
Wei0DataType,
|
||||
Acc0DataType,
|
||||
Wei1DataType,
|
||||
Out1DataType,
|
||||
In0ElementOp,
|
||||
Wei0ElementOp,
|
||||
Out0ElementOp,
|
||||
Wei1ElementOp,
|
||||
Out1ElementOp,
|
||||
DeviceBatchedGemmGemmInstance>(do_verification,
|
||||
init_method,
|
||||
time_kernel,
|
||||
conv0_param,
|
||||
conv1_param,
|
||||
in0_g_n_c_wis_desc,
|
||||
wei0_g_k_c_xs_desc,
|
||||
out0_g_n_k_wos_desc,
|
||||
wei1_g_k_c_xs_desc,
|
||||
out1_g_n_k_wos_desc,
|
||||
in0_element_op,
|
||||
wei0_element_op,
|
||||
wei1_element_op,
|
||||
out0_element_op,
|
||||
out1_element_op);
|
||||
};
|
||||
|
||||
namespace ctc = ck::tensor_layout::convolution;
|
||||
|
||||
if(conv0_param.num_dim_spatial_ == 1)
|
||||
{
|
||||
run(ck::Number<1>{}, ctc::GNWC{}, ctc::GKXC{}, ctc::GKXC{}, ctc::GNWK{});
|
||||
}
|
||||
else if(conv0_param.num_dim_spatial_ == 2)
|
||||
{
|
||||
run(ck::Number<2>{}, ctc::GNHWC{}, ctc::GKYXC{}, ctc::GKYXC{}, ctc::GNHWK{});
|
||||
}
|
||||
else if(conv0_param.num_dim_spatial_ == 3)
|
||||
{
|
||||
run(ck::Number<3>{}, ctc::GNDHWC{}, ctc::GKZYXC{}, ctc::GKZYXC{}, ctc::GNDHWK{});
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
int main(int argc, char* argv[]) { return run_grouped_conv_conv_fwd_example(argc, argv) ? 0 : 1; }
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/convolution_parameter.hpp"
|
||||
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
|
||||
|
||||
using In0DataType = float;
|
||||
using Wei0DataType = float;
|
||||
using Acc0DataType = float;
|
||||
using Wei1DataType = float;
|
||||
using Acc1DataType = float;
|
||||
using C1ShuffleDataType = float;
|
||||
using Out1DataType = float;
|
||||
|
||||
// This is used for reference code
|
||||
using Out0DataType = float;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using In0ElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Wei0ElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Wei1ElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Out0ElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Out1ElementOp = ck::tensor_operation::element_wise::UnaryConvert;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
using DeviceBatchedGemmGemmInstance =
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle<
|
||||
Row, // ALayout
|
||||
Col, // B0Layout
|
||||
Col, // B1Layout
|
||||
Row, // CLayout
|
||||
In0DataType, // ADataType,
|
||||
Wei0DataType, // B0DataType,
|
||||
Wei1DataType, // B1DataType,
|
||||
Out1DataType, // CDataType,
|
||||
Acc0DataType, // AccDataType,
|
||||
C1ShuffleDataType, // CShuffleDataType,
|
||||
In0ElementOp, // AElementOp,
|
||||
Wei0ElementOp, // B0ElementOp,
|
||||
Out0ElementOp, // Acc0ElementOp,
|
||||
Wei1ElementOp, // B1ElementOp,
|
||||
Out1ElementOp, // CElementOp,
|
||||
GemmDefault,
|
||||
1,
|
||||
256,
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
16, // KPerBlock
|
||||
128, // Gemm1NPerBlock
|
||||
16, // Gemm1KPerBlock
|
||||
4, // AK1
|
||||
4, // BK1
|
||||
2, // B1K1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
1, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
4, // Gemm1NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
4,
|
||||
4,
|
||||
true,
|
||||
S<4, 64, 1>, // BBlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
4,
|
||||
4,
|
||||
true,
|
||||
S<4, 64, 1>, // B1BlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
true,
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 16, 1, 16>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
4>; // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
|
||||
#include "run_grouped_conv_conv_fwd_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return run_grouped_conv_conv_fwd_example(argc, argv) ? 0 : 1; }
|
||||
@@ -0,0 +1,121 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
#error Should compile this file with ck::int4_t support
|
||||
#endif
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/convolution_parameter.hpp"
|
||||
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
|
||||
|
||||
using In0DataType = ck::int4_t;
|
||||
using Wei0DataType = ck::int4_t;
|
||||
using KernelIn0DataType = int8_t;
|
||||
using KernelWei0DataType = int8_t;
|
||||
using Acc0DataType = int32_t;
|
||||
using Wei1DataType = ck::int4_t;
|
||||
using KernelWei1DataType = int8_t;
|
||||
using Acc1DataType = int32_t;
|
||||
using C1ShuffleDataType = int32_t;
|
||||
using Out1DataType = ck::int4_t;
|
||||
using KernelOut1DataType = int8_t;
|
||||
|
||||
// This is used for reference code
|
||||
using Out0DataType = ck::int4_t;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using In0ElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Wei0ElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Wei1ElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Out0ElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Out1ElementOp = ck::tensor_operation::element_wise::UnaryConvert;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
using DeviceBatchedGemmGemmInstance =
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle<
|
||||
Row, // ALayout
|
||||
Col, // B0Layout
|
||||
Col, // B1Layout
|
||||
Row, // CLayout
|
||||
KernelIn0DataType, // ADataType,
|
||||
KernelWei0DataType, // B0DataType,
|
||||
KernelWei1DataType, // B1DataType,
|
||||
KernelOut1DataType, // CDataType,
|
||||
Acc0DataType, // AccDataType,
|
||||
C1ShuffleDataType, // CShuffleDataType,
|
||||
In0ElementOp, // AElementOp,
|
||||
Wei0ElementOp, // B0ElementOp,
|
||||
Out0ElementOp, // Acc0ElementOp,
|
||||
Wei1ElementOp, // B1ElementOp,
|
||||
Out1ElementOp, // CElementOp,
|
||||
GemmDefault,
|
||||
1,
|
||||
256,
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
64, // KPerBlock
|
||||
128, // Gemm1NPerBlock
|
||||
64, // Gemm1KPerBlock
|
||||
16, // AK1
|
||||
16, // BK1
|
||||
4, // B1K1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
1, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
4, // Gemm1NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
16,
|
||||
16,
|
||||
true,
|
||||
S<4, 64, 1>, // BBlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
16,
|
||||
16,
|
||||
true,
|
||||
S<4, 64, 1>, // B1BlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
4,
|
||||
4,
|
||||
true,
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
|
||||
#define BUILD_INT4_EXAMPLE
|
||||
#include "run_grouped_conv_conv_fwd_example.inc"
|
||||
|
||||
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
|
||||
static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
|
||||
#endif
|
||||
|
||||
int main(int argc, char* argv[]) { return run_grouped_conv_conv_fwd_example(argc, argv) ? 0 : 1; }
|
||||
@@ -0,0 +1,108 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/convolution_parameter.hpp"
|
||||
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
|
||||
|
||||
using In0DataType = int8_t;
|
||||
using Wei0DataType = int8_t;
|
||||
using Acc0DataType = int32_t;
|
||||
using Wei1DataType = int8_t;
|
||||
using Acc1DataType = int32_t;
|
||||
using C1ShuffleDataType = int32_t;
|
||||
using Out1DataType = int8_t;
|
||||
|
||||
// This is used for reference code
|
||||
using Out0DataType = int8_t;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using In0ElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Wei0ElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Wei1ElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Out0ElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Out1ElementOp = ck::tensor_operation::element_wise::UnaryConvert;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
using DeviceBatchedGemmGemmInstance =
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle<
|
||||
Row, // ALayout
|
||||
Col, // B0Layout
|
||||
Col, // B1Layout
|
||||
Row, // CLayout
|
||||
In0DataType, // ADataType,
|
||||
Wei0DataType, // B0DataType,
|
||||
Wei1DataType, // B1DataType,
|
||||
Out1DataType, // CDataType,
|
||||
Acc0DataType, // AccDataType,
|
||||
C1ShuffleDataType, // CShuffleDataType,
|
||||
In0ElementOp, // AElementOp,
|
||||
Wei0ElementOp, // B0ElementOp,
|
||||
Out0ElementOp, // Acc0ElementOp,
|
||||
Wei1ElementOp, // B1ElementOp,
|
||||
Out1ElementOp, // CElementOp,
|
||||
GemmDefault,
|
||||
1,
|
||||
256,
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
64, // KPerBlock
|
||||
128, // Gemm1NPerBlock
|
||||
64, // Gemm1KPerBlock
|
||||
16, // AK1
|
||||
16, // BK1
|
||||
4, // B1K1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
1, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
4, // Gemm1NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
16,
|
||||
16,
|
||||
true,
|
||||
S<4, 64, 1>, // BBlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
16,
|
||||
16,
|
||||
true,
|
||||
S<4, 64, 1>, // B1BlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
4,
|
||||
4,
|
||||
true,
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
|
||||
#include "run_grouped_conv_conv_fwd_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return run_grouped_conv_conv_fwd_example(argc, argv) ? 0 : 1; }
|
||||
@@ -1,27 +1,12 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/convolution_parameter.hpp"
|
||||
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
|
||||
#pragma once
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename In0DataType,
|
||||
typename Wei0DataType,
|
||||
typename Acc0DataType,
|
||||
typename Out0DataType,
|
||||
typename Wei1DataType,
|
||||
typename Out1DataType,
|
||||
typename In0ElementOp,
|
||||
@@ -30,21 +15,21 @@ template <ck::index_t NDimSpatial,
|
||||
typename Wei1ElementOp,
|
||||
typename Out1ElementOp,
|
||||
typename DeviceOpInstance>
|
||||
int run_grouped_conv_conv_fwd(bool do_verification,
|
||||
int init_method,
|
||||
bool time_kernel,
|
||||
const ck::utils::conv::ConvParam& conv0_param,
|
||||
const ck::utils::conv::ConvParam& conv1_param,
|
||||
const HostTensorDescriptor& in0_g_n_c_wis_desc,
|
||||
const HostTensorDescriptor& wei0_g_k_c_xs_desc,
|
||||
const HostTensorDescriptor& out0_g_n_k_wos_desc,
|
||||
const HostTensorDescriptor& wei1_g_k_c_xs_desc,
|
||||
const HostTensorDescriptor& out1_g_n_k_wos_desc,
|
||||
const In0ElementOp& in0_element_op,
|
||||
const Wei0ElementOp& wei0_element_op,
|
||||
const Wei1ElementOp& wei1_element_op,
|
||||
const Out0ElementOp& out0_element_op,
|
||||
const Out1ElementOp& out1_element_op)
|
||||
bool run_grouped_conv_conv_fwd(bool do_verification,
|
||||
int init_method,
|
||||
bool time_kernel,
|
||||
const ck::utils::conv::ConvParam& conv0_param,
|
||||
const ck::utils::conv::ConvParam& conv1_param,
|
||||
const HostTensorDescriptor& in0_g_n_c_wis_desc,
|
||||
const HostTensorDescriptor& wei0_g_k_c_xs_desc,
|
||||
const HostTensorDescriptor& out0_g_n_k_wos_desc,
|
||||
const HostTensorDescriptor& wei1_g_k_c_xs_desc,
|
||||
const HostTensorDescriptor& out1_g_n_k_wos_desc,
|
||||
const In0ElementOp& in0_element_op,
|
||||
const Wei0ElementOp& wei0_element_op,
|
||||
const Wei1ElementOp& wei1_element_op,
|
||||
const Out0ElementOp& out0_element_op,
|
||||
const Out1ElementOp& out1_element_op)
|
||||
{
|
||||
Tensor<In0DataType> in0(in0_g_n_c_wis_desc);
|
||||
Tensor<Wei0DataType> wei0(wei0_g_k_c_xs_desc);
|
||||
@@ -71,6 +56,20 @@ int run_grouped_conv_conv_fwd(bool do_verification,
|
||||
wei1.GenerateTensorValue(GeneratorTensor_3<Wei1DataType>{-0.5, 0.5});
|
||||
}
|
||||
|
||||
#ifdef BUILD_INT4_EXAMPLE
|
||||
DeviceMem in0_device_buf(sizeof(KernelIn0DataType) * in0.mDesc.GetElementSpaceSize());
|
||||
DeviceMem wei0_device_buf(sizeof(KernelWei0DataType) * wei0.mDesc.GetElementSpaceSize());
|
||||
DeviceMem wei1_device_buf(sizeof(KernelWei1DataType) * wei1.mDesc.GetElementSpaceSize());
|
||||
DeviceMem out1_device_buf(sizeof(KernelOut1DataType) * out1_device.mDesc.GetElementSpaceSize());
|
||||
|
||||
const Tensor<KernelIn0DataType> in0_converted(in0);
|
||||
const Tensor<KernelWei0DataType> wei0_converted(wei0);
|
||||
const Tensor<KernelWei1DataType> wei1_converted(wei1);
|
||||
|
||||
in0_device_buf.ToDevice(in0_converted.mData.data());
|
||||
wei0_device_buf.ToDevice(wei0_converted.mData.data());
|
||||
wei1_device_buf.ToDevice(wei1_converted.mData.data());
|
||||
#else
|
||||
DeviceMem in0_device_buf(sizeof(In0DataType) * in0.mDesc.GetElementSpaceSize());
|
||||
DeviceMem wei0_device_buf(sizeof(Wei0DataType) * wei0.mDesc.GetElementSpaceSize());
|
||||
DeviceMem wei1_device_buf(sizeof(Wei1DataType) * wei1.mDesc.GetElementSpaceSize());
|
||||
@@ -79,6 +78,7 @@ int run_grouped_conv_conv_fwd(bool do_verification,
|
||||
in0_device_buf.ToDevice(in0.mData.data());
|
||||
wei0_device_buf.ToDevice(wei0.mData.data());
|
||||
wei1_device_buf.ToDevice(wei1.mData.data());
|
||||
#endif
|
||||
|
||||
std::array<ck::index_t, NDimSpatial + 3> a0_g_n_c_wis_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> a0_g_n_c_wis_strides{};
|
||||
@@ -116,7 +116,6 @@ int run_grouped_conv_conv_fwd(bool do_verification,
|
||||
copy(conv1_param.input_left_pads_, input1_left_pads);
|
||||
copy(conv1_param.input_right_pads_, input1_right_pads);
|
||||
|
||||
#if 1
|
||||
// do Conv using GEMM, only works for 1x1 conv for now
|
||||
const ck::index_t gemm_batch = a0_g_n_c_wis_lengths[0];
|
||||
|
||||
@@ -150,29 +149,36 @@ int run_grouped_conv_conv_fwd(bool do_verification,
|
||||
|
||||
auto device_op = DeviceOpInstance{};
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument =
|
||||
device_op.MakeArgument(static_cast<In0DataType*>(in0_device_buf.GetDeviceBuffer()),
|
||||
static_cast<Wei0DataType*>(wei0_device_buf.GetDeviceBuffer()),
|
||||
static_cast<Wei1DataType*>(wei1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<Out1DataType*>(out1_device_buf.GetDeviceBuffer()),
|
||||
gemm0_m_length,
|
||||
gemm0_n_length,
|
||||
gemm0_k_length,
|
||||
gemm1_n_length,
|
||||
gemm_batch,
|
||||
a0_stride,
|
||||
b0_stride,
|
||||
b1_stride,
|
||||
e1_stride,
|
||||
a0_batch_stride,
|
||||
b0_batch_stride,
|
||||
b1_batch_stride,
|
||||
e1_batch_stride,
|
||||
in0_element_op,
|
||||
wei0_element_op,
|
||||
out0_element_op,
|
||||
wei1_element_op,
|
||||
out1_element_op);
|
||||
auto argument = device_op.MakeArgument(
|
||||
#ifdef BUILD_INT4_EXAMPLE
|
||||
static_cast<KernelIn0DataType*>(in0_device_buf.GetDeviceBuffer()),
|
||||
static_cast<KernelWei0DataType*>(wei0_device_buf.GetDeviceBuffer()),
|
||||
static_cast<KernelWei1DataType*>(wei1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<KernelOut1DataType*>(out1_device_buf.GetDeviceBuffer()),
|
||||
#else
|
||||
static_cast<In0DataType*>(in0_device_buf.GetDeviceBuffer()),
|
||||
static_cast<Wei0DataType*>(wei0_device_buf.GetDeviceBuffer()),
|
||||
static_cast<Wei1DataType*>(wei1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<Out1DataType*>(out1_device_buf.GetDeviceBuffer()),
|
||||
#endif
|
||||
gemm0_m_length,
|
||||
gemm0_n_length,
|
||||
gemm0_k_length,
|
||||
gemm1_n_length,
|
||||
gemm_batch,
|
||||
a0_stride,
|
||||
b0_stride,
|
||||
b1_stride,
|
||||
e1_stride,
|
||||
a0_batch_stride,
|
||||
b0_batch_stride,
|
||||
b1_batch_stride,
|
||||
e1_batch_stride,
|
||||
in0_element_op,
|
||||
wei0_element_op,
|
||||
out0_element_op,
|
||||
wei1_element_op,
|
||||
out1_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
@@ -193,24 +199,23 @@ int run_grouped_conv_conv_fwd(bool do_verification,
|
||||
float gb_per_sec = num_btype / 1.E6 / avg_time;
|
||||
std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< device_op.GetTypeString() << std::endl;
|
||||
#endif
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
Tensor<Acc0DataType> out0_host(out0_g_n_k_wos_desc);
|
||||
Tensor<Out0DataType> out0_host(out0_g_n_k_wos_desc);
|
||||
|
||||
auto ref_conv0 = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
|
||||
In0DataType,
|
||||
Wei0DataType,
|
||||
Acc0DataType,
|
||||
Out0DataType,
|
||||
In0ElementOp,
|
||||
Wei0ElementOp,
|
||||
Out0ElementOp>();
|
||||
|
||||
auto ref_conv1 = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
|
||||
Acc0DataType,
|
||||
Out0DataType,
|
||||
Wei1DataType,
|
||||
Out1DataType,
|
||||
PassThrough,
|
||||
@@ -245,13 +250,134 @@ int run_grouped_conv_conv_fwd(bool do_verification,
|
||||
ref_conv0_invoker.Run(ref_conv0_argument);
|
||||
ref_conv1_invoker.Run(ref_conv1_argument);
|
||||
|
||||
#ifdef BUILD_INT4_EXAMPLE
|
||||
Tensor<KernelOut1DataType> out1_device_converted(out1_host.mDesc);
|
||||
|
||||
out1_device_buf.FromDevice(out1_device_converted.mData.data());
|
||||
|
||||
out1_device = out1_device_converted.CopyAsType<Out1DataType>();
|
||||
#else
|
||||
out1_device_buf.FromDevice(out1_device.mData.data());
|
||||
#endif
|
||||
|
||||
return ck::utils::check_err(
|
||||
out1_device.mData, out1_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f)
|
||||
? 0
|
||||
: 1;
|
||||
out1_device.mData, out1_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f);
|
||||
}
|
||||
|
||||
return 0;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool run_grouped_conv_conv_fwd_example(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
ck::utils::conv::ConvParam conv0_param{
|
||||
2, 1, 128, 512, 128, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}};
|
||||
|
||||
ck::utils::conv::ConvParam conv1_param{
|
||||
2, 1, 128, 128, 512, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}};
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
const auto in0_element_op = In0ElementOp{};
|
||||
const auto wei0_element_op = Wei0ElementOp{};
|
||||
const auto wei1_element_op = Wei1ElementOp{};
|
||||
const auto out0_element_op = Out0ElementOp{};
|
||||
const auto out1_element_op = Out1ElementOp{};
|
||||
|
||||
const auto run = [&](auto ndim_spatial,
|
||||
auto in0_layout,
|
||||
auto wei0_layout,
|
||||
auto wei1_layout,
|
||||
auto out1_layout) {
|
||||
constexpr ck::index_t ndim_spatial_value = ndim_spatial.value;
|
||||
|
||||
using In0Layout = decltype(in0_layout);
|
||||
using Wei0Layout = decltype(wei0_layout);
|
||||
using Wei1Layout = decltype(wei1_layout);
|
||||
using Out1Layout = decltype(out1_layout);
|
||||
|
||||
const auto in0_g_n_c_wis_desc =
|
||||
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<In0Layout>(
|
||||
conv0_param);
|
||||
|
||||
const auto wei0_g_k_c_xs_desc =
|
||||
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<Wei0Layout>(
|
||||
conv0_param);
|
||||
|
||||
// out0 doesn't physical exist, any layout for host verification is OK
|
||||
const auto out0_g_n_k_wos_desc =
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<Out1Layout>(
|
||||
conv0_param);
|
||||
|
||||
const auto wei1_g_k_c_xs_desc =
|
||||
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<Wei1Layout>(
|
||||
conv1_param);
|
||||
|
||||
const auto out1_g_n_k_wos_desc =
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<Out1Layout>(
|
||||
conv1_param);
|
||||
|
||||
return run_grouped_conv_conv_fwd<ndim_spatial_value,
|
||||
In0DataType,
|
||||
Wei0DataType,
|
||||
Out0DataType,
|
||||
Wei1DataType,
|
||||
Out1DataType,
|
||||
In0ElementOp,
|
||||
Wei0ElementOp,
|
||||
Out0ElementOp,
|
||||
Wei1ElementOp,
|
||||
Out1ElementOp,
|
||||
DeviceBatchedGemmGemmInstance>(do_verification,
|
||||
init_method,
|
||||
time_kernel,
|
||||
conv0_param,
|
||||
conv1_param,
|
||||
in0_g_n_c_wis_desc,
|
||||
wei0_g_k_c_xs_desc,
|
||||
out0_g_n_k_wos_desc,
|
||||
wei1_g_k_c_xs_desc,
|
||||
out1_g_n_k_wos_desc,
|
||||
in0_element_op,
|
||||
wei0_element_op,
|
||||
wei1_element_op,
|
||||
out0_element_op,
|
||||
out1_element_op);
|
||||
};
|
||||
|
||||
namespace ctc = ck::tensor_layout::convolution;
|
||||
|
||||
if(conv0_param.num_dim_spatial_ == 1)
|
||||
{
|
||||
return run(ck::Number<1>{}, ctc::GNWC{}, ctc::GKXC{}, ctc::GKXC{}, ctc::GNWK{});
|
||||
}
|
||||
else if(conv0_param.num_dim_spatial_ == 2)
|
||||
{
|
||||
return run(ck::Number<2>{}, ctc::GNHWC{}, ctc::GKYXC{}, ctc::GKYXC{}, ctc::GNHWK{});
|
||||
}
|
||||
else if(conv0_param.num_dim_spatial_ == 3)
|
||||
{
|
||||
return run(ck::Number<3>{}, ctc::GNDHWC{}, ctc::GKZYXC{}, ctc::GKZYXC{}, ctc::GNDHWK{});
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
Reference in New Issue
Block a user