mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 12:11:19 +00:00
* Include variant in elementwise_common.hpp * Disallow bf16_t for both UnarySquare and UnaryConvert in elementwise_example_unary.cpp
28 lines
653 B
C++
28 lines
653 B
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#include <variant>
|
|
#include "ck_tile/core/arch/arch.hpp"
|
|
|
|
auto string_to_datatype(const std::string& datatype)
|
|
{
|
|
using PrecVariant = std::variant<ck_tile::half_t, ck_tile::bf16_t, float>;
|
|
|
|
if(datatype == "fp16")
|
|
{
|
|
return PrecVariant{ck_tile::half_t{}};
|
|
}
|
|
else if(datatype == "bf16")
|
|
{
|
|
return PrecVariant{ck_tile::bf16_t{}};
|
|
}
|
|
else if(datatype == "fp32")
|
|
{
|
|
return PrecVariant{float{}};
|
|
}
|
|
else
|
|
{
|
|
throw std::runtime_error("Unsupported data type: " + datatype);
|
|
}
|
|
};
|