// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #pragma once #include #include "ck_tile/core.hpp" #include "ck_tile/host.hpp" #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/pk_int4.hpp" //[TODO] This can be moved to commons // DataTypeTraits for all supported types template struct DataTypeTraits; template <> struct DataTypeTraits { static constexpr const char* name = "fp32"; }; template <> struct DataTypeTraits { static constexpr const char* name = "fp64"; }; template <> struct DataTypeTraits { static constexpr const char* name = "fp16"; }; template <> struct DataTypeTraits { static constexpr const char* name = "bf16"; }; template <> struct DataTypeTraits { static constexpr const char* name = "fp8"; }; template <> struct DataTypeTraits { static constexpr const char* name = "bf8"; }; template <> struct DataTypeTraits { static constexpr const char* name = "int8"; }; template <> struct DataTypeTraits { static constexpr const char* name = "int32"; }; template <> struct DataTypeTraits { static constexpr const char* name = "pk_int4_t"; }; // Helper function to determine if a layout is row-major template constexpr auto is_row_major(Layout) { return ck_tile::bool_constant>{}; } // Structure to hold kernel traits for dispatcher struct KernelTraits { std::string pipeline; // compv3, compv4, mem std::string scheduler; // intrawave, interwave std::string epilogue; // cshuffle, default bool pad_m; bool pad_n; bool pad_k; bool persistent; // Constructor with defaults KernelTraits() : pipeline("compv3"), scheduler("intrawave"), epilogue("cshuffle"), pad_m(false), pad_n(false), pad_k(false), persistent(false) { } };