mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Congma/ck tile/remove cpp 20 code (#2873)
* Remove C++20 code C++20 features should not be used in CK. Remove all C++20 code. * fix c++17 build * format * fix merge issue --------- Co-authored-by: Thomas Ning <Thomas.Ning@amd.com> Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com>
This commit is contained in:
@@ -211,7 +211,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
bool result = true;
|
||||
ck_tile::ArgParser arg_parser;
|
||||
std::tie(result, arg_parser) = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
|
||||
@@ -157,7 +157,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
bool result = true;
|
||||
ck_tile::ArgParser arg_parser;
|
||||
std::tie(result, arg_parser) = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
|
||||
@@ -156,7 +156,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
bool result = true;
|
||||
ck_tile::ArgParser arg_parser;
|
||||
std::tie(result, arg_parser) = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
|
||||
@@ -193,7 +193,9 @@ auto string_to_op(const std::string& op)
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
bool result = true;
|
||||
ck_tile::ArgParser arg_parser;
|
||||
std::tie(result, arg_parser) = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
|
||||
@@ -34,8 +34,8 @@ namespace ck {
|
||||
|
||||
struct f8_fnuz_t
|
||||
{
|
||||
using data_type = unsigned char;
|
||||
data_type m_data;
|
||||
using data_type = unsigned char;
|
||||
data_type m_data = data_type{};
|
||||
__host__ __device__ explicit constexpr f8_fnuz_t(data_type in_data) : m_data(in_data) {}
|
||||
__host__ __device__ explicit constexpr f8_fnuz_t() = default;
|
||||
__host__ __device__ bool constexpr operator==(f8_fnuz_t other) const
|
||||
@@ -47,8 +47,8 @@ struct f8_fnuz_t
|
||||
|
||||
struct bf8_fnuz_t
|
||||
{
|
||||
using data_type = unsigned char;
|
||||
data_type m_data;
|
||||
using data_type = unsigned char;
|
||||
data_type m_data = data_type{};
|
||||
__host__ __device__ explicit constexpr bf8_fnuz_t(data_type in_data) : m_data(in_data) {}
|
||||
__host__ __device__ explicit constexpr bf8_fnuz_t() = default;
|
||||
__host__ __device__ bool constexpr operator==(bf8_fnuz_t other) const
|
||||
|
||||
@@ -9,25 +9,9 @@
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename T>
|
||||
concept HasDataType = requires { typename T::DataType; };
|
||||
|
||||
template <typename T>
|
||||
struct GetDataType
|
||||
{
|
||||
using type = float;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
requires HasDataType<T>
|
||||
struct GetDataType<T>
|
||||
{
|
||||
using type = typename T::DataType; // Use T::ScaleN::DataType
|
||||
};
|
||||
|
||||
template <typename AsDataType_,
|
||||
typename BsDataType_,
|
||||
typename DsDataType_,
|
||||
@@ -300,7 +284,7 @@ struct CShuffleEpilogue
|
||||
return MPerIterationShuffle * NPerIterationShuffle * sizeof(ODataType);
|
||||
}
|
||||
|
||||
template <auto iAccess, typename LdsTile, typename ScaleM, typename ScaleN>
|
||||
template <index_t iAccess, typename LdsTile, typename ScaleM, typename ScaleN>
|
||||
CK_TILE_DEVICE void
|
||||
scale_tile(LdsTile& lds_tile, ScaleM& scale_m_window, ScaleN& scale_n_window)
|
||||
{
|
||||
@@ -334,7 +318,7 @@ struct CShuffleEpilogue
|
||||
constexpr index_t num_access = SFC::get_num_of_access();
|
||||
if constexpr(iAccess != num_access - 1)
|
||||
{
|
||||
constexpr auto step = SFC::get_forward_step(iAccess);
|
||||
constexpr auto step = SFC::get_forward_step(number<iAccess>{});
|
||||
|
||||
move_tile_window(scale_m_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
move_tile_window(scale_n_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
@@ -342,10 +326,10 @@ struct CShuffleEpilogue
|
||||
}
|
||||
}
|
||||
|
||||
template <auto iAccess, typename OAccTile, typename LdsTile>
|
||||
template <index_t iAccess, typename OAccTile, typename LdsTile>
|
||||
CK_TILE_DEVICE void slice_acc_tile(const OAccTile& o_acc_tile, LdsTile& lds_tile)
|
||||
{
|
||||
constexpr auto idx_y_start = SFC::get_index(iAccess);
|
||||
constexpr auto idx_y_start = SFC::get_index(number<iAccess>{});
|
||||
|
||||
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
|
||||
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
|
||||
@@ -400,13 +384,13 @@ struct CShuffleEpilogue
|
||||
/**
|
||||
* @brief Move both the output and D tensors windows for the next access.
|
||||
*/
|
||||
template <auto iAccess, typename OutDramWindow, typename DDramWindows>
|
||||
template <index_t iAccess, typename OutDramWindow, typename DDramWindows>
|
||||
CK_TILE_DEVICE void move_windows(OutDramWindow& out_dram_window, DDramWindows& d_dram_windows)
|
||||
{
|
||||
constexpr index_t num_access = SFC::get_num_of_access();
|
||||
if constexpr(iAccess != num_access - 1)
|
||||
{
|
||||
constexpr auto step = SFC::get_forward_step(iAccess);
|
||||
constexpr auto step = SFC::get_forward_step(number<iAccess>{});
|
||||
|
||||
// move the output dram window
|
||||
move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
@@ -423,6 +407,18 @@ struct CShuffleEpilogue
|
||||
{
|
||||
};
|
||||
|
||||
template <typename, typename = void>
|
||||
struct ScaleDataType
|
||||
{
|
||||
using DataType = float;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ScaleDataType<T, std::void_t<typename T::DataType>>
|
||||
{
|
||||
using DataType = typename T::DataType;
|
||||
};
|
||||
|
||||
template <typename ODramWindow,
|
||||
typename OAccTile,
|
||||
typename DsDramWindows,
|
||||
@@ -475,8 +471,8 @@ struct CShuffleEpilogue
|
||||
std::is_same_v<ScaleM, AccDataType> && std::is_same_v<ScaleN, AccDataType>;
|
||||
|
||||
// Tiles to hold row/col scales when present
|
||||
using SMType = typename GetDataType<remove_cvref_t<ScaleM>>::type;
|
||||
using SNType = typename GetDataType<remove_cvref_t<ScaleN>>::type;
|
||||
using SMType = typename ScaleDataType<ScaleM>::DataType;
|
||||
using SNType = typename ScaleDataType<ScaleN>::DataType;
|
||||
|
||||
auto sm_tile = make_static_distributed_tensor<SMType>(dram_tile_distribution);
|
||||
auto sn_tile = make_static_distributed_tensor<SNType>(dram_tile_distribution);
|
||||
|
||||
@@ -18,73 +18,64 @@ namespace ck_tile {
|
||||
|
||||
namespace detail {
|
||||
// Helper templates for safe type extraction
|
||||
template <typename T, typename Default>
|
||||
template <typename, typename Default, typename = void>
|
||||
struct get_aq_layout_or
|
||||
{
|
||||
using type = Default;
|
||||
};
|
||||
|
||||
template <typename T, typename Default>
|
||||
requires requires { typename T::AQLayout; }
|
||||
struct get_aq_layout_or<T, Default>
|
||||
struct get_aq_layout_or<T, Default, std::void_t<typename T::AQLayout>>
|
||||
{
|
||||
using type = typename T::AQLayout;
|
||||
};
|
||||
|
||||
template <typename T, typename Default>
|
||||
template <typename, typename Default, typename = void>
|
||||
struct get_bq_layout_or
|
||||
{
|
||||
using type = Default;
|
||||
};
|
||||
|
||||
template <typename T, typename Default>
|
||||
requires requires { typename T::BQLayout; }
|
||||
struct get_bq_layout_or<T, Default>
|
||||
struct get_bq_layout_or<T, Default, std::void_t<typename T::BQLayout>>
|
||||
{
|
||||
using type = typename T::BQLayout;
|
||||
};
|
||||
|
||||
template <typename T, typename Default>
|
||||
template <typename, typename Default, typename = void>
|
||||
struct get_aq_data_type_or
|
||||
{
|
||||
using type = Default;
|
||||
};
|
||||
|
||||
template <typename T, typename Default>
|
||||
requires requires { typename T::AQDataType; }
|
||||
struct get_aq_data_type_or<T, Default>
|
||||
struct get_aq_data_type_or<T, Default, std::void_t<typename T::AQDataType>>
|
||||
{
|
||||
using type = typename T::AQDataType;
|
||||
};
|
||||
|
||||
template <typename T, typename Default>
|
||||
template <typename, typename Default, typename = void>
|
||||
struct get_bq_data_type_or
|
||||
{
|
||||
using type = Default;
|
||||
};
|
||||
|
||||
template <typename T, typename Default>
|
||||
requires requires { typename T::BQDataType; }
|
||||
struct get_bq_data_type_or<T, Default>
|
||||
struct get_bq_data_type_or<T, Default, std::void_t<typename T::BQDataType>>
|
||||
{
|
||||
using type = typename T::BQDataType;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept HasStaticPreshuffleQuant = requires {
|
||||
{ T::PreshuffleQuant } -> std::convertible_to<decltype(T::PreshuffleQuant)>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
template <typename, typename = void>
|
||||
struct is_quantpreshuffle_enabled
|
||||
{
|
||||
static constexpr bool value = false;
|
||||
};
|
||||
|
||||
template <HasStaticPreshuffleQuant T>
|
||||
struct is_quantpreshuffle_enabled<T>
|
||||
template <typename T>
|
||||
struct is_quantpreshuffle_enabled<T, decltype(T::PreshuffleQuant)>
|
||||
{
|
||||
static constexpr auto value = T::PreshuffleQuant;
|
||||
static constexpr bool value = T::PreshuffleQuant;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
|
||||
Reference in New Issue
Block a user