Simplify includes for CK builder reflection (#3357)

We only want to import enums and types into the builder reflection code. But, some of the enums are included in much larger files or even big trees of include files. This leads to unintended mixing of code and very confusing interactions and symbol conflicts. We organize the includes and extract two new enum-only headers to help with decoupling in CK. This refactoring is critical if we want to include reflection in a device-operator "describe" method.

* Remove a few unnecessary includes from headers in builder/reflect/.
* Extract enums scheduler and pipeline to their own headers so they can be used without importing other code.
* Order includes alphabetically for better organization.

The immediate goal is to unblock reflection integration, and this type of cleanup helps the flexibility and robustness of the CK header library.
This commit is contained in:
John Shumway
2025-12-05 07:44:10 -08:00
committed by GitHub
parent 35fc7c9e4f
commit f5b0af2272
11 changed files with 197 additions and 150 deletions

View File

@@ -3,52 +3,12 @@
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/scheduler_enum.hpp"
namespace ck {
enum struct BlockGemmPipelineVersion
{
// For GEMM
v1, // Naive
v2, // Mem
v3, // Comp
v4, // Comp, double lds buffer
v5, // Comp, double global prefetch register buffer
// For GEMM with preshuffled weight
// v1, single lds buffer
// v2, double lds buffer
};
enum struct BlockGemmPipelineScheduler
{
Intrawave,
Interwave,
};
enum struct TailNumber
{
// Single / Double buffer pipeline
Odd,
Even,
// Long prefetch pipeline, up to 8
One,
Two,
Three,
Four,
Five,
Six,
Seven,
// Unroll stages > Prefetch stages, number of loop is multiple of unroll stages
Empty,
// Unroll stages <= Prefetch stages, number of loop is multiple of unroll stages add
// prefetchstages
Full,
};
enum SchedulerGroup : uint32_t
{
SCHED_GROUP_MFMA = 0x008, // Matrix FMA instructions

View File

@@ -3,40 +3,20 @@
#pragma once
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
#include <ostream>
#endif
#include "ck/utility/common_header.hpp"
#include "ck/utility/scheduler_enum.hpp"
namespace ck {
enum struct LoopScheduler
{
Default,
Interwave,
};
/// @brief Helper function to get default loop scheduler
/// @details Returns the default loop scheduler based on compile-time configuration.
constexpr LoopScheduler make_default_loop_scheduler()
{
#if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING
return LoopScheduler::Interwave;
#else
return LoopScheduler::Default;
#endif // if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING
#endif
}
} // namespace ck
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s)
{
switch(s)
{
case ck::LoopScheduler::Default: os << "Default"; break;
case ck::LoopScheduler::Interwave: os << "Interwave"; break;
default: os << "";
}
return os;
}
#endif

View File

@@ -0,0 +1,40 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
#include <ostream>
#endif
namespace ck {
/// @brief Pipeline version enumeration for GEMM kernels
/// @details Defines different pipeline strategies for data movement and computation overlap
/// in GEMM kernels. This is a lightweight header containing only the enum definition,
/// extracted from gridwise_gemm_pipeline_selector.hpp to minimize dependencies.
enum struct PipelineVersion
{
v1, ///< Version 1 pipeline
v2, ///< Version 2 pipeline
// v3 is only used in the Stream-K implementation.
v4, ///< Version 4 pipeline
weight_only, ///< Weight-only specialized pipeline
};
} // namespace ck
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p)
{
switch(p)
{
case ck::PipelineVersion::v1: os << "PipelineVersion::v1"; break;
case ck::PipelineVersion::v2: os << "PipelineVersion::v2"; break;
case ck::PipelineVersion::v4: os << "PipelineVersion::v4"; break;
case ck::PipelineVersion::weight_only: os << "PipelineVersion::weight_only"; break;
default: os << "";
}
return os;
}
#endif

View File

@@ -0,0 +1,83 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
#include <ostream>
#endif
namespace ck {
/// @brief Block GEMM pipeline version enumeration
/// @details Defines different block GEMM pipeline strategies.
/// This is a lightweight header containing only enum definitions,
/// extracted from blkgemmpipe_scheduler.hpp to minimize dependencies.
enum struct BlockGemmPipelineVersion
{
// For GEMM
v1, ///< Naive pipeline
v2, ///< Memory-optimized pipeline
v3, ///< Compute-optimized pipeline
v4, ///< Compute-optimized with double LDS buffer
v5, ///< Compute-optimized with double global prefetch register buffer
// For GEMM with preshuffled weight
// v1, single lds buffer
// v2, double lds buffer
};
/// @brief Block GEMM pipeline scheduler enumeration
/// @details Defines scheduling strategies for block GEMM pipelines.
enum struct BlockGemmPipelineScheduler
{
Intrawave, ///< Schedule within a single wavefront
Interwave, ///< Schedule across multiple wavefronts
};
/// @brief Loop scheduler enumeration
/// @details Defines scheduling strategies for computational loops.
enum struct LoopScheduler
{
Default, ///< Default scheduling strategy
Interwave, ///< Cross-wavefront scheduling
};
/// @brief Tail number enumeration for pipeline buffering
/// @details Defines the number of tail iterations in pipelined loops.
enum struct TailNumber
{
// Single / Double buffer pipeline
Odd, ///< Odd number of iterations
Even, ///< Even number of iterations
// Long prefetch pipeline, up to 8
One, ///< One tail iteration
Two, ///< Two tail iterations
Three, ///< Three tail iterations
Four, ///< Four tail iterations
Five, ///< Five tail iterations
Six, ///< Six tail iterations
Seven, ///< Seven tail iterations
// Unroll stages > Prefetch stages, number of loop is multiple of unroll stages
Empty, ///< No tail iterations
// Unroll stages <= Prefetch stages, number of loop is multiple of unroll stages add
// prefetchstages
Full, ///< Full tail iterations
};
} // namespace ck
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s)
{
switch(s)
{
case ck::LoopScheduler::Default: os << "Default"; break;
case ck::LoopScheduler::Interwave: os << "Interwave"; break;
default: os << "";
}
return os;
}
#endif