Padded Generic Kernel Instance (#730)

* Add NumReduceDim template parameter to DeviceSoftmax and Softmax client API to simplify instances collecting

* Move the generic kernel instance to be the first of the instance list for elementwise op of normalization

* Add GetGenericInstance() interface for DeviceOperationInstanceFactory class of DeviceSoftmax

* Add testing of GetGenericInstance() in client_example of Softmax

* Revert "Add testing of GetGenericInstance() in client_example of Softmax"

This reverts commit f629cd9a93.

* Revert "Add GetGenericInstance() interface for DeviceOperationInstanceFactory class of DeviceSoftmax"

This reverts commit a9f0d000eb.

* Support generic kernel instance to be the first instance returned by GetInstances() for GroupNorm

* Move generic kernel instance to separate tuple for elementwise op of normalization

* Remove un-used files for softmax instance

* Store generic kernel instance to separate tuple for softmax

* Add IsSupported checking for generic instance to client example of softmax

* Replace the get_device_normalize_from_mean_meansquare_instances() by the DeviceOperationInstanceFactory class for elementwise-normalization

* clang-format fix

* Remove int8 from softmax instances

---------

Co-authored-by: zjing14 <zhangjing14@gmail.com>
This commit is contained in:
Qianfeng
2023-06-17 12:43:11 +08:00
committed by GitHub
parent d140bdc9fa
commit 0d9118226b
76 changed files with 552 additions and 790 deletions

View File

@@ -172,18 +172,19 @@ int main()
BLayout,
CLayout>();
const auto normalize_ptrs =
ck::tensor_operation::device::instance::get_device_normalize_from_mean_meansquare_instances<
CDataType,
ReduceDataType,
ReduceDataType,
GammaDataType,
BetaDataType,
LayerNormOutDataType>();
std::cout << "found " << gemm_reduce_ptrs.size()
<< " gemm_reduceMean_reduceSquareMean instances" << std::endl;
using NormalizeDeviceOp = ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<CDataType, ReduceDataType, ReduceDataType, GammaDataType, BetaDataType>,
ck::Tuple<LayerNormOutDataType>,
ck::tensor_operation::element_wise::Normalize,
2>;
const auto normalize_ptrs =
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
NormalizeDeviceOp>::GetInstances();
std::cout << "found " << normalize_ptrs.size() << " normalize instances" << std::endl;
auto f_matrix_space_size =

View File

@@ -53,12 +53,35 @@ int main(int argc, char* argv[])
SimpleDeviceMem in(sizeof(InDataType) * num_elements);
SimpleDeviceMem out(sizeof(OutDataType) * num_elements);
using DeviceOp = ck::tensor_operation::device::
DeviceSoftmax<InDataType, AccDataType, OutDataType, PassThrough, PassThrough, Rank>;
using DeviceOp = ck::tensor_operation::device::DeviceSoftmax<InDataType,
AccDataType,
OutDataType,
PassThrough,
PassThrough,
Rank,
NumReduceDim>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
auto& generic_op_ptr = op_ptrs[0];
auto generic_argument_ptr = generic_op_ptr->MakeArgumentPointer(in_lengths,
in_strides,
reduce_dims,
alpha,
beta,
in.GetDeviceBuffer(),
out.GetDeviceBuffer(),
PassThrough{},
PassThrough{});
if(!generic_op_ptr->IsSupportedArgument(generic_argument_ptr.get()))
{
throw std::runtime_error(
"The generic kernel instance should be able to support any input shapes");
};
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name;
@@ -74,11 +97,6 @@ int main(int argc, char* argv[])
{
auto& op_ptr = op_ptrs[i];
if(op_ptr->GetRank() != Rank || op_ptr->GetNumReduceDim() != NumReduceDim)
{
continue;
}
auto argument_ptr = op_ptr->MakeArgumentPointer(in_lengths,
in_strides,
reduce_dims,

View File

@@ -72,6 +72,30 @@ int main(int argc, char* argv[])
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
const auto& generic_op_ptr = op_ptrs[0];
auto generic_argument_ptr =
generic_op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths
xy_strides, // xStrides
gamma_beta_strides, // gammaStrides
gamma_beta_strides, // betaStrides
xy_strides, // yStrides
{1, 2, 4}, // reduceDims
1e-6,
x_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(),
y_device_buf.GetDeviceBuffer(),
nullptr,
nullptr,
Swish{});
if(!generic_op_ptr->IsSupportedArgument(generic_argument_ptr.get()))
{
throw std::runtime_error(
"The generic kernel instance should be able to support any input shapes");
};
std::string best_op_name;
bool found = false;
int best_op_id = -1;