Improve external interface for GEMM and GEMM+add+add+fastgelu (#311)

* interface for GEMM and GEMM+add+add+fastgelu

* rename namespace

* instance factory

* fix build

* fix build; add GEMM client example

* clean

[ROCm/composable_kernel commit: 0dcb3496cf]
This commit is contained in:
Chao Liu
2022-06-30 22:11:00 -05:00
committed by GitHub
parent 7094d7c910
commit 74b6e85eaf
259 changed files with 2915 additions and 2969 deletions

View File

@@ -18,7 +18,7 @@
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_normalization_instance {
namespace instance {
void add_device_softmax_f16_f16_rank3_instances(std::vector<DeviceNormalizationPtr>&);
void add_device_softmax_f16_f16_rank4_instances(std::vector<DeviceNormalizationPtr>&);
@@ -26,7 +26,7 @@ void add_device_softmax_f16_f16_rank4_instances(std::vector<DeviceNormalizationP
void add_device_softmax_f32_f32_rank3_instances(std::vector<DeviceNormalizationPtr>&);
void add_device_softmax_f32_f32_rank4_instances(std::vector<DeviceNormalizationPtr>&);
} // namespace device_normalization_instance
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
@@ -109,23 +109,23 @@ void profile_normalization_impl(int do_verification,
is_same<AccDataType, float>::value)
{
if(in_length.size() == 3)
tensor_operation::device::device_normalization_instance::
add_device_softmax_f16_f16_rank3_instances(instances);
tensor_operation::device::instance::add_device_softmax_f16_f16_rank3_instances(
instances);
if(in_length.size() == 4)
tensor_operation::device::device_normalization_instance::
add_device_softmax_f16_f16_rank4_instances(instances);
tensor_operation::device::instance::add_device_softmax_f16_f16_rank4_instances(
instances);
}
else if constexpr(is_same<InDataType, float>::value && is_same<OutDataType, float>::value &&
is_same<AccDataType, float>::value)
{
if(in_length.size() == 3)
tensor_operation::device::device_normalization_instance::
add_device_softmax_f32_f32_rank3_instances(instances);
tensor_operation::device::instance::add_device_softmax_f32_f32_rank3_instances(
instances);
if(in_length.size() == 4)
tensor_operation::device::device_normalization_instance::
add_device_softmax_f32_f32_rank4_instances(instances);
tensor_operation::device::instance::add_device_softmax_f32_f32_rank4_instances(
instances);
}
}