mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Add GetGenericInstance() interface for DeviceOperationInstanceFactory class of DeviceSoftmax
This commit is contained in:
@@ -29,6 +29,20 @@ void add_device_operation_instances(std::vector<std::unique_ptr<BaseOp>>& op_ins
|
||||
});
|
||||
}
|
||||
|
||||
template <typename BaseOp, typename NewOpInstances>
|
||||
void get_first_device_operation_instance(std::unique_ptr<BaseOp>& op_instance,
|
||||
const NewOpInstances& new_op_instances)
|
||||
{
|
||||
const auto first_op_instance = std::get<0>(new_op_instances);
|
||||
|
||||
using FirstOpInstance = remove_cvref_t<decltype(first_op_instance)>;
|
||||
|
||||
static_assert(std::is_base_of_v<BaseOp, FirstOpInstance>,
|
||||
"wrong! FirstOpInstance should be derived from BaseOp");
|
||||
|
||||
op_instance = std::make_unique<FirstOpInstance>(first_op_instance);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -37,6 +37,86 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceSoftma
|
||||
Rank,
|
||||
NumReduceDim>;
|
||||
|
||||
static auto GetGenericInstance()
|
||||
{
|
||||
std::unique_ptr<DeviceOp> op_ptr;
|
||||
|
||||
if constexpr(std::is_same_v<InDataType, F16> && std::is_same_v<AccDataType, F32> &&
|
||||
std::is_same_v<OutDataType, F16>)
|
||||
{
|
||||
if constexpr(Rank == 3)
|
||||
{
|
||||
if constexpr(NumReduceDim == 1)
|
||||
get_device_softmax_f16_f16_rank3_reduce1_generic_instance(op_ptr);
|
||||
else if constexpr(NumReduceDim == 2)
|
||||
get_device_softmax_f16_f16_rank3_reduce2_generic_instance(op_ptr);
|
||||
else if constexpr(NumReduceDim == 3)
|
||||
get_device_softmax_f16_f16_rank3_reduce3_generic_instance(op_ptr);
|
||||
}
|
||||
else if constexpr(Rank == 4)
|
||||
{
|
||||
if constexpr(NumReduceDim == 1)
|
||||
get_device_softmax_f16_f16_rank4_reduce1_generic_instance(op_ptr);
|
||||
else if constexpr(NumReduceDim == 2)
|
||||
get_device_softmax_f16_f16_rank4_reduce2_generic_instance(op_ptr);
|
||||
else if constexpr(NumReduceDim == 3)
|
||||
get_device_softmax_f16_f16_rank4_reduce3_generic_instance(op_ptr);
|
||||
else if constexpr(NumReduceDim == 4)
|
||||
get_device_softmax_f16_f16_rank4_reduce4_generic_instance(op_ptr);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same_v<InDataType, F32> && std::is_same_v<AccDataType, F32> &&
|
||||
std::is_same_v<OutDataType, F32>)
|
||||
{
|
||||
if constexpr(Rank == 3)
|
||||
{
|
||||
if constexpr(NumReduceDim == 1)
|
||||
get_device_softmax_f32_f32_rank3_reduce1_generic_instance(op_ptr);
|
||||
else if constexpr(NumReduceDim == 2)
|
||||
get_device_softmax_f32_f32_rank3_reduce2_generic_instance(op_ptr);
|
||||
else if constexpr(NumReduceDim == 3)
|
||||
get_device_softmax_f32_f32_rank3_reduce3_generic_instance(op_ptr);
|
||||
}
|
||||
else if constexpr(Rank == 4)
|
||||
{
|
||||
if constexpr(NumReduceDim == 1)
|
||||
get_device_softmax_f32_f32_rank4_reduce1_generic_instance(op_ptr);
|
||||
else if constexpr(NumReduceDim == 2)
|
||||
get_device_softmax_f32_f32_rank4_reduce2_generic_instance(op_ptr);
|
||||
else if constexpr(NumReduceDim == 3)
|
||||
get_device_softmax_f32_f32_rank4_reduce3_generic_instance(op_ptr);
|
||||
else if constexpr(NumReduceDim == 4)
|
||||
get_device_softmax_f32_f32_rank4_reduce4_generic_instance(op_ptr);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same_v<InDataType, I8> && std::is_same_v<AccDataType, F32> &&
|
||||
std::is_same_v<OutDataType, I8>)
|
||||
{
|
||||
if constexpr(Rank == 3)
|
||||
{
|
||||
if constexpr(NumReduceDim == 1)
|
||||
get_device_softmax_i8_i8_rank3_reduce1_generic_instance(op_ptr);
|
||||
else if constexpr(NumReduceDim == 2)
|
||||
get_device_softmax_i8_i8_rank3_reduce2_generic_instance(op_ptr);
|
||||
else if constexpr(NumReduceDim == 3)
|
||||
get_device_softmax_i8_i8_rank3_reduce3_generic_instance(op_ptr);
|
||||
}
|
||||
else if constexpr(Rank == 4)
|
||||
{
|
||||
if constexpr(NumReduceDim == 1)
|
||||
get_device_softmax_i8_i8_rank4_reduce1_generic_instance(op_ptr);
|
||||
else if constexpr(NumReduceDim == 2)
|
||||
get_device_softmax_i8_i8_rank4_reduce2_generic_instance(op_ptr);
|
||||
else if constexpr(NumReduceDim == 3)
|
||||
get_device_softmax_i8_i8_rank4_reduce3_generic_instance(op_ptr);
|
||||
else if constexpr(NumReduceDim == 4)
|
||||
get_device_softmax_i8_i8_rank4_reduce4_generic_instance(op_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
return op_ptr;
|
||||
};
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
@@ -16,6 +16,9 @@ namespace instance {
|
||||
void add_device_softmax_f16_f16_rank3_reduce1_instances(
|
||||
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3, 1>>& instances);
|
||||
|
||||
void get_device_softmax_f16_f16_rank3_reduce1_generic_instance(
|
||||
DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3, 1>& instance);
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -16,6 +16,9 @@ namespace instance {
|
||||
void add_device_softmax_f16_f16_rank3_reduce2_instances(
|
||||
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3, 2>>& instances);
|
||||
|
||||
void get_device_softmax_f16_f16_rank3_reduce2_generic_instance(
|
||||
DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3, 2>& instance);
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -16,6 +16,9 @@ namespace instance {
|
||||
void add_device_softmax_f16_f16_rank3_reduce3_instances(
|
||||
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3, 3>>& instances);
|
||||
|
||||
void get_device_softmax_f16_f16_rank3_reduce3_instance(
|
||||
DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3, 3>& instance);
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -16,6 +16,9 @@ namespace instance {
|
||||
void add_device_softmax_f16_f16_rank4_reduce1_instances(
|
||||
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4, 1>>& instances);
|
||||
|
||||
void get_device_softmax_f16_f16_rank4_reduce1_generic_instance(
|
||||
DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4, 1>& instance);
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -16,6 +16,9 @@ namespace instance {
|
||||
void add_device_softmax_f16_f16_rank4_reduce2_instances(
|
||||
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4, 2>>& instances);
|
||||
|
||||
void get_device_softmax_f16_f16_rank4_reduce2_generic_instance(
|
||||
DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4, 2>& instance);
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -16,6 +16,9 @@ namespace instance {
|
||||
void add_device_softmax_f16_f16_rank4_reduce3_instances(
|
||||
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4, 3>>& instances);
|
||||
|
||||
void get_device_softmax_f16_f16_rank4_reduce3_generic_instance(
|
||||
DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4, 3>& instance);
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -16,6 +16,9 @@ namespace instance {
|
||||
void add_device_softmax_f16_f16_rank4_reduce4_instances(
|
||||
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4, 4>>& instances);
|
||||
|
||||
void get_device_softmax_f16_f16_rank4_reduce4_generic_instance(
|
||||
DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4, 4>& instance);
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -16,6 +16,9 @@ namespace instance {
|
||||
void add_device_softmax_f32_f32_rank3_reduce1_instances(
|
||||
std::vector<DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 3, 1>>& instances);
|
||||
|
||||
void get_device_softmax_f32_f32_rank3_reduce1_generic_instance(
|
||||
DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 3, 1>& instance);
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -16,6 +16,9 @@ namespace instance {
|
||||
void add_device_softmax_f32_f32_rank3_reduce2_instances(
|
||||
std::vector<DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 3, 2>>& instances);
|
||||
|
||||
void get_device_softmax_f32_f32_rank3_reduce2_generic_instance(
|
||||
DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 3, 2>& instance);
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -16,6 +16,9 @@ namespace instance {
|
||||
void add_device_softmax_f32_f32_rank3_reduce3_instances(
|
||||
std::vector<DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 3, 3>>& instances);
|
||||
|
||||
void get_device_softmax_f32_f32_rank3_reduce3_generic_instance(
|
||||
DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 3, 3>& instance);
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -16,6 +16,9 @@ namespace instance {
|
||||
void add_device_softmax_f32_f32_rank4_reduce1_instances(
|
||||
std::vector<DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 4, 1>>& instances);
|
||||
|
||||
void get_device_softmax_f32_f32_rank4_reduce1_generic_instance(
|
||||
DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 4, 1>& instance);
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -16,6 +16,9 @@ namespace instance {
|
||||
void add_device_softmax_f32_f32_rank4_reduce2_instances(
|
||||
std::vector<DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 4, 2>>& instances);
|
||||
|
||||
void get_device_softmax_f32_f32_rank4_reduce2_generic_instance(
|
||||
DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 4, 2>& instance);
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -16,6 +16,9 @@ namespace instance {
|
||||
void add_device_softmax_f32_f32_rank4_reduce3_instances(
|
||||
std::vector<DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 4, 3>>& instances);
|
||||
|
||||
void get_device_softmax_f32_f32_rank4_reduce3_generic_instance(
|
||||
DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 4, 3>& instance);
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -16,6 +16,9 @@ namespace instance {
|
||||
void add_device_softmax_f32_f32_rank4_reduce4_instances(
|
||||
std::vector<DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 4, 4>>& instances);
|
||||
|
||||
void get_device_softmax_f32_f32_rank4_reduce4_generic_instance(
|
||||
DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 4, 4>& instance);
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -16,6 +16,9 @@ namespace instance {
|
||||
void add_device_softmax_i8_i8_rank3_reduce1_instances(
|
||||
std::vector<DeviceSoftmaxPtr<I8, F32, I8, PassThrough, PassThrough, 3, 1>>& instances);
|
||||
|
||||
void get_device_softmax_i8_i8_rank3_reduce1_generic_instance(
|
||||
DeviceSoftmaxPtr<I8, F32, I8, PassThrough, PassThrough, 3, 1>& instance);
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -16,6 +16,9 @@ namespace instance {
|
||||
void add_device_softmax_i8_i8_rank3_reduce2_instances(
|
||||
std::vector<DeviceSoftmaxPtr<I8, F32, I8, PassThrough, PassThrough, 3, 2>>& instances);
|
||||
|
||||
void get_device_softmax_i8_i8_rank3_reduce2_generic_instance(
|
||||
DeviceSoftmaxPtr<I8, F32, I8, PassThrough, PassThrough, 3, 2>& instance);
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -16,6 +16,9 @@ namespace instance {
|
||||
void add_device_softmax_i8_i8_rank3_reduce3_instances(
|
||||
std::vector<DeviceSoftmaxPtr<I8, F32, I8, PassThrough, PassThrough, 3, 3>>& instances);
|
||||
|
||||
void get_device_softmax_i8_i8_rank3_reduce3_generic_instance(
|
||||
DeviceSoftmaxPtr<I8, F32, I8, PassThrough, PassThrough, 3, 3>& instance);
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -16,6 +16,9 @@ namespace instance {
|
||||
void add_device_softmax_i8_i8_rank4_reduce1_instances(
|
||||
std::vector<DeviceSoftmaxPtr<I8, F32, I8, PassThrough, PassThrough, 4, 1>>& instances);
|
||||
|
||||
void get_device_softmax_i8_i8_rank4_reduce1_generic_instance(
|
||||
DeviceSoftmaxPtr<I8, F32, I8, PassThrough, PassThrough, 4, 1>& instance);
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -16,6 +16,9 @@ namespace instance {
|
||||
void add_device_softmax_i8_i8_rank4_reduce2_instances(
|
||||
std::vector<DeviceSoftmaxPtr<I8, F32, I8, PassThrough, PassThrough, 4, 2>>& instances);
|
||||
|
||||
void get_device_softmax_i8_i8_rank4_reduce2_generic_instance(
|
||||
DeviceSoftmaxPtr<I8, F32, I8, PassThrough, PassThrough, 4, 2>& instance);
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -16,6 +16,9 @@ namespace instance {
|
||||
void add_device_softmax_i8_i8_rank4_reduce3_instances(
|
||||
std::vector<DeviceSoftmaxPtr<I8, F32, I8, PassThrough, PassThrough, 4, 3>>& instances);
|
||||
|
||||
void get_device_softmax_i8_i8_rank4_reduce3_generic_instance(
|
||||
DeviceSoftmaxPtr<I8, F32, I8, PassThrough, PassThrough, 4, 3>& instance);
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -16,6 +16,9 @@ namespace instance {
|
||||
void add_device_softmax_i8_i8_rank4_reduce4_instances(
|
||||
std::vector<DeviceSoftmaxPtr<I8, F32, I8, PassThrough, PassThrough, 4, 4>>& instances);
|
||||
|
||||
void get_device_softmax_i8_i8_rank4_reduce4_generic_instance(
|
||||
DeviceSoftmaxPtr<I8, F32, I8, PassThrough, PassThrough, 4, 4>& instance);
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -19,6 +19,12 @@ void add_device_softmax_f16_f16_rank3_reduce1_instances(
|
||||
add_device_operation_instances(instances, device_softmax_f16_f16_instances<3, 1>{});
|
||||
}
|
||||
|
||||
void get_device_softmax_f16_f16_rank3_reduce1_generic_instance(
|
||||
DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3, 1>& instance)
|
||||
{
|
||||
get_first_device_operation_instance(instance, device_softmax_f16_f16_instances<3, 1>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -19,6 +19,12 @@ void add_device_softmax_f16_f16_rank3_reduce2_instances(
|
||||
add_device_operation_instances(instances, device_softmax_f16_f16_instances<3, 2>{});
|
||||
}
|
||||
|
||||
void get_device_softmax_f16_f16_rank3_reduce2_generic_instance(
|
||||
DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3, 2>& instance)
|
||||
{
|
||||
get_first_device_operation_instance(instance, device_softmax_f16_f16_instances<3, 2>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -19,6 +19,12 @@ void add_device_softmax_f16_f16_rank3_reduce3_instances(
|
||||
add_device_operation_instances(instances, device_softmax_f16_f16_instances<3, 3>{});
|
||||
}
|
||||
|
||||
void get_device_softmax_f16_f16_rank3_reduce3_generic_instance(
|
||||
DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3, 3>& instance)
|
||||
{
|
||||
get_first_device_operation_instance(instance, device_softmax_f16_f16_instances<3, 3>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -19,6 +19,12 @@ void add_device_softmax_f16_f16_rank4_reduce1_instances(
|
||||
add_device_operation_instances(instances, device_softmax_f16_f16_instances<4, 1>{});
|
||||
}
|
||||
|
||||
void get_device_softmax_f16_f16_rank4_reduce1_generic_instance(
|
||||
DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4, 1>& instance)
|
||||
{
|
||||
get_first_device_operation_instance(instance, device_softmax_f16_f16_instances<4, 1>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -19,6 +19,12 @@ void add_device_softmax_f16_f16_rank4_reduce2_instances(
|
||||
add_device_operation_instances(instances, device_softmax_f16_f16_instances<4, 2>{});
|
||||
}
|
||||
|
||||
void get_device_softmax_f16_f16_rank4_reduce2_generic_instance(
|
||||
DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4, 2>& instance)
|
||||
{
|
||||
get_first_device_operation_instance(instance, device_softmax_f16_f16_instances<4, 2>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -19,6 +19,12 @@ void add_device_softmax_f16_f16_rank4_reduce3_instances(
|
||||
add_device_operation_instances(instances, device_softmax_f16_f16_instances<4, 3>{});
|
||||
}
|
||||
|
||||
void get_device_softmax_f16_f16_rank4_reduce3_generic_instance(
|
||||
DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4, 3>& instance)
|
||||
{
|
||||
get_first_device_operation_instance(instance, device_softmax_f16_f16_instances<4, 3>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -19,6 +19,12 @@ void add_device_softmax_f16_f16_rank4_reduce4_instances(
|
||||
add_device_operation_instances(instances, device_softmax_f16_f16_instances<4, 4>{});
|
||||
}
|
||||
|
||||
void get_device_softmax_f16_f16_rank4_reduce4_generic_instance(
|
||||
DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4, 4>& instance)
|
||||
{
|
||||
get_first_device_operation_instance(instance, device_softmax_f16_f16_instances<4, 4>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -19,6 +19,12 @@ void add_device_softmax_f32_f32_rank3_reduce1_instances(
|
||||
add_device_operation_instances(instances, device_softmax_f32_f32_instances<3, 1>{});
|
||||
}
|
||||
|
||||
void get_device_softmax_f32_f32_rank3_reduce1_generic_instance(
|
||||
DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 3, 1>& instance)
|
||||
{
|
||||
get_first_device_operation_instance(instance, device_softmax_f32_f32_instances<3, 1>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -19,6 +19,12 @@ void add_device_softmax_f32_f32_rank3_reduce2_instances(
|
||||
add_device_operation_instances(instances, device_softmax_f32_f32_instances<3, 2>{});
|
||||
}
|
||||
|
||||
void get_device_softmax_f32_f32_rank3_reduce2_generic_instance(
|
||||
DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 3, 2>& instance)
|
||||
{
|
||||
get_first_device_operation_instance(instance, device_softmax_f32_f32_instances<3, 2>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -19,6 +19,12 @@ void add_device_softmax_f32_f32_rank3_reduce3_instances(
|
||||
add_device_operation_instances(instances, device_softmax_f32_f32_instances<3, 3>{});
|
||||
}
|
||||
|
||||
void get_device_softmax_f32_f32_rank3_reduce3_generic_instance(
|
||||
DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 3, 3>& instance)
|
||||
{
|
||||
get_first_device_operation_instance(instance, device_softmax_f32_f32_instances<3, 3>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -19,6 +19,12 @@ void add_device_softmax_f32_f32_rank4_reduce1_instances(
|
||||
add_device_operation_instances(instances, device_softmax_f32_f32_instances<4, 1>{});
|
||||
}
|
||||
|
||||
void get_device_softmax_f32_f32_rank4_reduce1_generic_instance(
|
||||
DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 4, 1>& instance)
|
||||
{
|
||||
get_first_device_operation_instance(instance, device_softmax_f32_f32_instances<4, 1>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -19,6 +19,12 @@ void add_device_softmax_f32_f32_rank4_reduce2_instances(
|
||||
add_device_operation_instances(instances, device_softmax_f32_f32_instances<4, 2>{});
|
||||
}
|
||||
|
||||
void get_device_softmax_f32_f32_rank4_reduce2_generic_instance(
|
||||
DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 4, 2>& instance)
|
||||
{
|
||||
get_first_device_operation_instance(instance, device_softmax_f32_f32_instances<4, 2>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -19,6 +19,12 @@ void add_device_softmax_f32_f32_rank4_reduce3_instances(
|
||||
add_device_operation_instances(instances, device_softmax_f32_f32_instances<4, 3>{});
|
||||
}
|
||||
|
||||
void get_device_softmax_f32_f32_rank4_reduce3_generic_instance(
|
||||
DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 4, 3>& instance)
|
||||
{
|
||||
get_first_device_operation_instance(instance, device_softmax_f32_f32_instances<4, 3>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -19,6 +19,12 @@ void add_device_softmax_f32_f32_rank4_reduce4_instances(
|
||||
add_device_operation_instances(instances, device_softmax_f32_f32_instances<4, 4>{});
|
||||
}
|
||||
|
||||
void get_device_softmax_f32_f32_rank4_reduce4_generic_instance(
|
||||
DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 4, 4>& instance)
|
||||
{
|
||||
get_first_device_operation_instance(instance, device_softmax_f32_f32_instances<4, 4>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -19,6 +19,12 @@ void add_device_softmax_i8_i8_rank3_reduce1_instances(
|
||||
add_device_operation_instances(instances, device_softmax_i8_i8_instances<3, 1>{});
|
||||
}
|
||||
|
||||
void get_device_softmax_i8_i8_rank3_reduce1_generic_instance(
|
||||
DeviceSoftmaxPtr<I8, F32, I8, PassThrough, PassThrough, 3, 1>& instance)
|
||||
{
|
||||
get_first_device_operation_instance(instance, device_softmax_i8_i8_instances<3, 1>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -19,6 +19,12 @@ void add_device_softmax_i8_i8_rank3_reduce2_instances(
|
||||
add_device_operation_instances(instances, device_softmax_i8_i8_instances<3, 2>{});
|
||||
}
|
||||
|
||||
void get_device_softmax_i8_i8_rank3_reduce2_generic_instance(
|
||||
DeviceSoftmaxPtr<I8, F32, I8, PassThrough, PassThrough, 3, 2>& instance)
|
||||
{
|
||||
get_first_device_operation_instance(instance, device_softmax_i8_i8_instances<3, 2>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -19,6 +19,12 @@ void add_device_softmax_i8_i8_rank3_reduce3_instances(
|
||||
add_device_operation_instances(instances, device_softmax_i8_i8_instances<3, 3>{});
|
||||
}
|
||||
|
||||
void get_device_softmax_i8_i8_rank3_reduce3_generic_instance(
|
||||
DeviceSoftmaxPtr<I8, F32, I8, PassThrough, PassThrough, 3, 3>& instance)
|
||||
{
|
||||
get_first_device_operation_instance(instance, device_softmax_i8_i8_instances<3, 3>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -19,6 +19,12 @@ void add_device_softmax_i8_i8_rank4_reduce1_instances(
|
||||
add_device_operation_instances(instances, device_softmax_i8_i8_instances<4, 1>{});
|
||||
}
|
||||
|
||||
void get_device_softmax_i8_i8_rank4_reduce1_generic_instance(
|
||||
DeviceSoftmaxPtr<I8, F32, I8, PassThrough, PassThrough, 4, 1>& instance)
|
||||
{
|
||||
get_first_device_operation_instance(instance, device_softmax_i8_i8_instances<4, 1>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -19,6 +19,12 @@ void add_device_softmax_i8_i8_rank4_reduce2_instances(
|
||||
add_device_operation_instances(instances, device_softmax_i8_i8_instances<4, 2>{});
|
||||
}
|
||||
|
||||
void get_device_softmax_i8_i8_rank4_reduce2_generic_instance(
|
||||
DeviceSoftmaxPtr<I8, F32, I8, PassThrough, PassThrough, 4, 2>& instance)
|
||||
{
|
||||
get_first_device_operation_instance(instance, device_softmax_i8_i8_instances<4, 2>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -19,6 +19,12 @@ void add_device_softmax_i8_i8_rank4_reduce3_instances(
|
||||
add_device_operation_instances(instances, device_softmax_i8_i8_instances<4, 3>{});
|
||||
}
|
||||
|
||||
void get_device_softmax_i8_i8_rank4_reduce3_generic_instance(
|
||||
DeviceSoftmaxPtr<I8, F32, I8, PassThrough, PassThrough, 4, 3>& instance)
|
||||
{
|
||||
get_first_device_operation_instance(instance, device_softmax_i8_i8_instances<4, 3>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -19,6 +19,12 @@ void add_device_softmax_i8_i8_rank4_reduce4_instances(
|
||||
add_device_operation_instances(instances, device_softmax_i8_i8_instances<4, 4>{});
|
||||
}
|
||||
|
||||
void add_device_softmax_i8_i8_rank4_reduce4_instances(
|
||||
DeviceSoftmaxPtr<I8, F32, I8, PassThrough, PassThrough, 4, 4>& instance)
|
||||
{
|
||||
get_first_device_operation_instance(instance, device_softmax_i8_i8_instances<4, 4>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
Reference in New Issue
Block a user