mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 20:40:07 +00:00
Softmax client example (#396)
* Update Softmax device operation interface.
* Update ckProfiler.
* Update Softmax UT.
* Update example.
* Client example.
* Clang format
Co-authored-by: Adam Osewski <aosewski@amd.com>
[ROCm/composable_kernel commit: 3da5c19e62]
This commit is contained in:
@@ -6,25 +6,36 @@
|
||||
#include <iomanip>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/convolution_parameter.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_softmax.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_softmax_f16_f16_rank3_instances(std::vector<DeviceNormalizationPtr>&);
|
||||
void add_device_softmax_f16_f16_rank4_instances(std::vector<DeviceNormalizationPtr>&);
|
||||
namespace {
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
} // namespace
|
||||
|
||||
void add_device_softmax_f32_f32_rank3_instances(std::vector<DeviceNormalizationPtr>&);
|
||||
void add_device_softmax_f32_f32_rank4_instances(std::vector<DeviceNormalizationPtr>&);
|
||||
void add_device_softmax_f16_f16_rank3_instances(
|
||||
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 3>>&);
|
||||
void add_device_softmax_f16_f16_rank4_instances(
|
||||
std::vector<DeviceSoftmaxPtr<F16, F32, F16, PassThrough, PassThrough, 4>>&);
|
||||
|
||||
void add_device_softmax_f32_f32_rank3_instances(
|
||||
std::vector<DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 3>>&);
|
||||
void add_device_softmax_f32_f32_rank4_instances(
|
||||
std::vector<DeviceSoftmaxPtr<F32, F32, F32, PassThrough, PassThrough, 4>>&);
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
@@ -57,7 +68,7 @@ template <> std::string type_to_string<int8_t>() { return "int8"; }
|
||||
template <> std::string type_to_string<int32_t>() { return "int32"; }
|
||||
// clang-format on
|
||||
|
||||
template <typename InDataType, typename AccDataType, typename OutDataType>
|
||||
template <typename InDataType, typename AccDataType, typename OutDataType, index_t Rank>
|
||||
void profile_normalization_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
@@ -69,6 +80,11 @@ void profile_normalization_impl(int do_verification,
|
||||
AccDataType beta,
|
||||
NormType norm_type)
|
||||
{
|
||||
if(Rank != in_length.size())
|
||||
{
|
||||
throw std::runtime_error("Input tensor rank is different from template argument Rank!");
|
||||
}
|
||||
|
||||
Tensor<InDataType> in = in_strides.empty() ? Tensor<InDataType>(in_length)
|
||||
: Tensor<InDataType>(in_length, in_strides);
|
||||
Tensor<OutDataType> out(in.mDesc);
|
||||
@@ -99,30 +115,31 @@ void profile_normalization_impl(int do_verification,
|
||||
std::vector<index_t> i_in_lengths(in.mDesc.GetLengths().begin(), in.mDesc.GetLengths().end());
|
||||
std::vector<index_t> i_in_strides(in.mDesc.GetStrides().begin(), in.mDesc.GetStrides().end());
|
||||
|
||||
// add device normalization instances
|
||||
std::vector<tensor_operation::device::DeviceNormalizationPtr> instances;
|
||||
// add device softmax instances
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using DeviceOpPtr = tensor_operation::device::
|
||||
DeviceSoftmaxPtr<InDataType, AccDataType, OutDataType, PassThrough, PassThrough, Rank>;
|
||||
std::vector<DeviceOpPtr> instances;
|
||||
|
||||
if(norm_type == NormType::SOFTMAX)
|
||||
{
|
||||
if constexpr(is_same<InDataType, half_t>::value && is_same<OutDataType, half_t>::value &&
|
||||
is_same<AccDataType, float>::value)
|
||||
{
|
||||
if(in_length.size() == 3)
|
||||
if constexpr(Rank == 3)
|
||||
tensor_operation::device::instance::add_device_softmax_f16_f16_rank3_instances(
|
||||
instances);
|
||||
|
||||
if(in_length.size() == 4)
|
||||
else if constexpr(Rank == 4)
|
||||
tensor_operation::device::instance::add_device_softmax_f16_f16_rank4_instances(
|
||||
instances);
|
||||
}
|
||||
else if constexpr(is_same<InDataType, float>::value && is_same<OutDataType, float>::value &&
|
||||
is_same<AccDataType, float>::value)
|
||||
{
|
||||
if(in_length.size() == 3)
|
||||
if constexpr(Rank == 3)
|
||||
tensor_operation::device::instance::add_device_softmax_f32_f32_rank3_instances(
|
||||
instances);
|
||||
|
||||
if(in_length.size() == 4)
|
||||
else if constexpr(Rank == 4)
|
||||
tensor_operation::device::instance::add_device_softmax_f32_f32_rank4_instances(
|
||||
instances);
|
||||
}
|
||||
@@ -137,6 +154,8 @@ void profile_normalization_impl(int do_verification,
|
||||
float best_avg_time = std::numeric_limits<float>::max();
|
||||
float best_gb_per_sec = 0;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
for(auto& inst_ptr : instances)
|
||||
{
|
||||
// Is this user's responsibility to check if problem mismatches kernel instance (ie. rank 3
|
||||
@@ -153,7 +172,9 @@ void profile_normalization_impl(int do_verification,
|
||||
&alpha,
|
||||
&beta,
|
||||
in_dev.GetDeviceBuffer(),
|
||||
out_dev.GetDeviceBuffer());
|
||||
out_dev.GetDeviceBuffer(),
|
||||
PassThrough{},
|
||||
PassThrough{});
|
||||
|
||||
if(!inst_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user