mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 05:19:20 +00:00
## Proposed changes TF32 is added in CK on gfx942 and gfx950. This PR is to initiate tf32 in CK_TILE on gfx942 and gfx950. ## Checklist Please put an into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask. - [ ] I have added tests relevant to the introduced functionality, and the unit tests are passing locally - [ ] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more than 30 seconds to run. - [ ] I have added inline documentation which enables the maintainers with understanding the motivation - [ ] I have removed the stale documentation which is no longer relevant after this pull request - [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request - [x] I have run on all changed files - [ ] Any dependent changes have been merged ## Discussion --- 🔁 Imported from [ROCm/composable_kernel#3538](https://github.com/ROCm/composable_kernel/pull/3538) 🧑💻 Originally authored by @yingluAMD --------- Co-authored-by: yingluAMD <Yingmao.Lu@amd.com> Co-authored-by: assistant-librarian[bot] <assistant-librarian[bot]@users.noreply.github.com> Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
107 lines
2.2 KiB
C++
107 lines
2.2 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#pragma once
|
|
|
|
#include <string>
|
|
#include "ck_tile/core.hpp"
|
|
#include "ck_tile/host.hpp"
|
|
#include "ck_tile/core/numeric/integer.hpp"
|
|
#include "ck_tile/core/numeric/pk_int4.hpp"
|
|
|
|
//[TODO] This can be moved to commons
|
|
// DataTypeTraits for all supported types
|
|
template <typename T>
|
|
struct DataTypeTraits;
|
|
|
|
template <>
|
|
struct DataTypeTraits<float>
|
|
{
|
|
static constexpr const char* name = "fp32";
|
|
};
|
|
|
|
template <>
|
|
struct DataTypeTraits<ck_tile::tf32_t>
|
|
{
|
|
static constexpr const char* name = "tf32";
|
|
};
|
|
|
|
template <>
|
|
struct DataTypeTraits<double>
|
|
{
|
|
static constexpr const char* name = "fp64";
|
|
};
|
|
|
|
template <>
|
|
struct DataTypeTraits<ck_tile::half_t>
|
|
{
|
|
static constexpr const char* name = "fp16";
|
|
};
|
|
|
|
template <>
|
|
struct DataTypeTraits<ck_tile::bf16_t>
|
|
{
|
|
static constexpr const char* name = "bf16";
|
|
};
|
|
|
|
template <>
|
|
struct DataTypeTraits<ck_tile::fp8_t>
|
|
{
|
|
static constexpr const char* name = "fp8";
|
|
};
|
|
|
|
template <>
|
|
struct DataTypeTraits<ck_tile::bf8_t>
|
|
{
|
|
static constexpr const char* name = "bf8";
|
|
};
|
|
|
|
template <>
|
|
struct DataTypeTraits<ck_tile::int8_t>
|
|
{
|
|
static constexpr const char* name = "int8";
|
|
};
|
|
|
|
template <>
|
|
struct DataTypeTraits<ck_tile::int32_t>
|
|
{
|
|
static constexpr const char* name = "int32";
|
|
};
|
|
|
|
template <>
|
|
struct DataTypeTraits<ck_tile::pk_int4_t>
|
|
{
|
|
static constexpr const char* name = "pk_int4_t";
|
|
};
|
|
|
|
// Helper function to determine if a layout is row-major
|
|
template <typename Layout>
|
|
constexpr auto is_row_major(Layout)
|
|
{
|
|
return ck_tile::bool_constant<std::is_same_v<Layout, ck_tile::tensor_layout::gemm::RowMajor>>{};
|
|
}
|
|
|
|
// Structure to hold kernel traits for dispatcher
|
|
struct KernelTraits
|
|
{
|
|
std::string pipeline; // compv3, compv4, mem
|
|
std::string scheduler; // intrawave, interwave
|
|
std::string epilogue; // cshuffle, default
|
|
bool pad_m;
|
|
bool pad_n;
|
|
bool pad_k;
|
|
bool persistent;
|
|
|
|
// Constructor with defaults
|
|
KernelTraits()
|
|
: pipeline("compv3"),
|
|
scheduler("intrawave"),
|
|
epilogue("cshuffle"),
|
|
pad_m(false),
|
|
pad_n(false),
|
|
pad_k(false),
|
|
persistent(false)
|
|
{
|
|
}
|
|
};
|