mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 05:19:20 +00:00
Improve external interface for GEMM and GEMM+add+add+fastgelu (#311)
* interface for GEMM and GEMM+add+add+fastgelu
* rename namespace
* instance factory
* fix build
* fix build; add GEMM client example
* clean
[ROCm/composable_kernel commit: 0dcb3496cf]
This commit is contained in:
@@ -18,7 +18,7 @@
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_normalization_instance {
|
||||
namespace instance {
|
||||
|
||||
void add_device_softmax_f16_f16_rank3_instances(std::vector<DeviceNormalizationPtr>&);
|
||||
void add_device_softmax_f16_f16_rank4_instances(std::vector<DeviceNormalizationPtr>&);
|
||||
@@ -26,7 +26,7 @@ void add_device_softmax_f16_f16_rank4_instances(std::vector<DeviceNormalizationP
|
||||
void add_device_softmax_f32_f32_rank3_instances(std::vector<DeviceNormalizationPtr>&);
|
||||
void add_device_softmax_f32_f32_rank4_instances(std::vector<DeviceNormalizationPtr>&);
|
||||
|
||||
} // namespace device_normalization_instance
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -109,23 +109,23 @@ void profile_normalization_impl(int do_verification,
|
||||
is_same<AccDataType, float>::value)
|
||||
{
|
||||
if(in_length.size() == 3)
|
||||
tensor_operation::device::device_normalization_instance::
|
||||
add_device_softmax_f16_f16_rank3_instances(instances);
|
||||
tensor_operation::device::instance::add_device_softmax_f16_f16_rank3_instances(
|
||||
instances);
|
||||
|
||||
if(in_length.size() == 4)
|
||||
tensor_operation::device::device_normalization_instance::
|
||||
add_device_softmax_f16_f16_rank4_instances(instances);
|
||||
tensor_operation::device::instance::add_device_softmax_f16_f16_rank4_instances(
|
||||
instances);
|
||||
}
|
||||
else if constexpr(is_same<InDataType, float>::value && is_same<OutDataType, float>::value &&
|
||||
is_same<AccDataType, float>::value)
|
||||
{
|
||||
if(in_length.size() == 3)
|
||||
tensor_operation::device::device_normalization_instance::
|
||||
add_device_softmax_f32_f32_rank3_instances(instances);
|
||||
tensor_operation::device::instance::add_device_softmax_f32_f32_rank3_instances(
|
||||
instances);
|
||||
|
||||
if(in_length.size() == 4)
|
||||
tensor_operation::device::device_normalization_instance::
|
||||
add_device_softmax_f32_f32_rank4_instances(instances);
|
||||
tensor_operation::device::instance::add_device_softmax_f32_f32_rank4_instances(
|
||||
instances);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user