diff --git a/example/ck_tile/21_elementwise/elementwise_example.cpp b/example/ck_tile/21_elementwise/elementwise_example.cpp index 94d3e70be1..e9fbeafde1 100644 --- a/example/ck_tile/21_elementwise/elementwise_example.cpp +++ b/example/ck_tile/21_elementwise/elementwise_example.cpp @@ -211,7 +211,9 @@ bool run(const ck_tile::ArgParser& arg_parser) int main(int argc, char* argv[]) { - auto [result, arg_parser] = create_args(argc, argv); + bool result = true; + ck_tile::ArgParser arg_parser; + std::tie(result, arg_parser) = create_args(argc, argv); if(!result) return -1; diff --git a/example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp b/example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp index ff7ec1517e..1b101c2e5f 100644 --- a/example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp +++ b/example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp @@ -157,7 +157,9 @@ bool run(const ck_tile::ArgParser& arg_parser) int main(int argc, char* argv[]) { - auto [result, arg_parser] = create_args(argc, argv); + bool result = true; + ck_tile::ArgParser arg_parser; + std::tie(result, arg_parser) = create_args(argc, argv); if(!result) return -1; diff --git a/example/ck_tile/21_elementwise/elementwise_example_transpose.cpp b/example/ck_tile/21_elementwise/elementwise_example_transpose.cpp index 16e9832c07..7cdb5cc0d1 100644 --- a/example/ck_tile/21_elementwise/elementwise_example_transpose.cpp +++ b/example/ck_tile/21_elementwise/elementwise_example_transpose.cpp @@ -156,7 +156,9 @@ bool run(const ck_tile::ArgParser& arg_parser) int main(int argc, char* argv[]) { - auto [result, arg_parser] = create_args(argc, argv); + bool result = true; + ck_tile::ArgParser arg_parser; + std::tie(result, arg_parser) = create_args(argc, argv); if(!result) return -1; diff --git a/example/ck_tile/21_elementwise/elementwise_example_unary.cpp b/example/ck_tile/21_elementwise/elementwise_example_unary.cpp index c5a08d910e..4e19cfd688 100644 --- a/example/ck_tile/21_elementwise/elementwise_example_unary.cpp +++ b/example/ck_tile/21_elementwise/elementwise_example_unary.cpp @@ -193,7 +193,9 @@ auto string_to_op(const std::string& op) int main(int argc, char* argv[]) { - auto [result, arg_parser] = create_args(argc, argv); + bool result = true; + ck_tile::ArgParser arg_parser; + std::tie(result, arg_parser) = create_args(argc, argv); if(!result) return -1; diff --git a/include/ck/utility/amd_ck_fp8.hpp b/include/ck/utility/amd_ck_fp8.hpp index 2c00f4f42f..c5525d5ff8 100644 --- a/include/ck/utility/amd_ck_fp8.hpp +++ b/include/ck/utility/amd_ck_fp8.hpp @@ -34,8 +34,8 @@ namespace ck { struct f8_fnuz_t { - using data_type = unsigned char; - data_type m_data; + using data_type = unsigned char; + data_type m_data = data_type{}; __host__ __device__ explicit constexpr f8_fnuz_t(data_type in_data) : m_data(in_data) {} __host__ __device__ explicit constexpr f8_fnuz_t() = default; __host__ __device__ bool constexpr operator==(f8_fnuz_t other) const @@ -47,8 +47,8 @@ struct f8_fnuz_t struct bf8_fnuz_t { - using data_type = unsigned char; - data_type m_data; + using data_type = unsigned char; + data_type m_data = data_type{}; __host__ __device__ explicit constexpr bf8_fnuz_t(data_type in_data) : m_data(in_data) {} __host__ __device__ explicit constexpr bf8_fnuz_t() = default; __host__ __device__ bool constexpr operator==(bf8_fnuz_t other) const diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 585a5f5b42..e0a39a5aea 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -9,25 +9,9 @@ #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #include +#include namespace ck_tile { - -template -concept HasDataType = requires { typename T::DataType; }; - -template -struct GetDataType -{ - using type = float; -}; - -template - requires HasDataType -struct GetDataType -{ - using type = typename T::DataType; // Use T::ScaleN::DataType -}; - template + template CK_TILE_DEVICE void scale_tile(LdsTile& lds_tile, ScaleM& scale_m_window, ScaleN& scale_n_window) { @@ -334,7 +318,7 @@ struct CShuffleEpilogue constexpr index_t num_access = SFC::get_num_of_access(); if constexpr(iAccess != num_access - 1) { - constexpr auto step = SFC::get_forward_step(iAccess); + constexpr auto step = SFC::get_forward_step(number{}); move_tile_window(scale_m_window, {step.at(number<0>{}), step.at(number<1>{})}); move_tile_window(scale_n_window, {step.at(number<0>{}), step.at(number<1>{})}); @@ -342,10 +326,10 @@ struct CShuffleEpilogue } } - template + template CK_TILE_DEVICE void slice_acc_tile(const OAccTile& o_acc_tile, LdsTile& lds_tile) { - constexpr auto idx_y_start = SFC::get_index(iAccess); + constexpr auto idx_y_start = SFC::get_index(number{}); constexpr auto mIter = number{}) / (MPerIterationShuffle)>{}; constexpr auto nIter = number{}) / (NPerIterationShuffle)>{}; @@ -400,13 +384,13 @@ struct CShuffleEpilogue /** * @brief Move both the output and D tensors windows for the next access. */ - template + template CK_TILE_DEVICE void move_windows(OutDramWindow& out_dram_window, DDramWindows& d_dram_windows) { constexpr index_t num_access = SFC::get_num_of_access(); if constexpr(iAccess != num_access - 1) { - constexpr auto step = SFC::get_forward_step(iAccess); + constexpr auto step = SFC::get_forward_step(number{}); // move the output dram window move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})}); @@ -423,6 +407,18 @@ struct CShuffleEpilogue { }; + template + struct ScaleDataType + { + using DataType = float; + }; + + template + struct ScaleDataType> + { + using DataType = typename T::DataType; + }; + template && std::is_same_v; // Tiles to hold row/col scales when present - using SMType = typename GetDataType>::type; - using SNType = typename GetDataType>::type; + using SMType = typename ScaleDataType::DataType; + using SNType = typename ScaleDataType::DataType; auto sm_tile = make_static_distributed_tensor(dram_tile_distribution); auto sn_tile = make_static_distributed_tensor(dram_tile_distribution); diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index bcd0fd9dac..0c9c816672 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -18,73 +18,64 @@ namespace ck_tile { namespace detail { // Helper templates for safe type extraction -template +template struct get_aq_layout_or { using type = Default; }; template - requires requires { typename T::AQLayout; } -struct get_aq_layout_or +struct get_aq_layout_or> { using type = typename T::AQLayout; }; -template +template struct get_bq_layout_or { using type = Default; }; template - requires requires { typename T::BQLayout; } -struct get_bq_layout_or +struct get_bq_layout_or> { using type = typename T::BQLayout; }; -template +template struct get_aq_data_type_or { using type = Default; }; template - requires requires { typename T::AQDataType; } -struct get_aq_data_type_or +struct get_aq_data_type_or> { using type = typename T::AQDataType; }; -template +template struct get_bq_data_type_or { using type = Default; }; template - requires requires { typename T::BQDataType; } -struct get_bq_data_type_or +struct get_bq_data_type_or> { using type = typename T::BQDataType; }; -template -concept HasStaticPreshuffleQuant = requires { - { T::PreshuffleQuant } -> std::convertible_to; -}; - -template +template struct is_quantpreshuffle_enabled { static constexpr bool value = false; }; -template -struct is_quantpreshuffle_enabled +template +struct is_quantpreshuffle_enabled { - static constexpr auto value = T::PreshuffleQuant; + static constexpr bool value = T::PreshuffleQuant; }; } // namespace detail