Files
composable_kernel/include/ck/utility/index_expression.hpp
Márton Bidlek 0d18f4fc05 [rocm-libraries] ROCm/rocm-libraries#4798 (commit 0acaf5f)
Using named functors instead of lambdas

## Motivation

Currently, in block-level GEMM pipelines, there is significant code
repetition for prefetching and tail handling, where lambda functions
create a unique instantiations at each call. This includes repeated
static_for instantiations and large loops such as MRepeat. Each
repetition results in additional instantiations, which increases
compilation time and binary bloat.

## Technical Details

Refactor repeated code blocks into named functors so the compiler can
reuse already instantiated code instead of generating multiple copies.

Scope of changes:

1. WMMAOPS pipeline internals:
projects/composablekernel/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp,
projects/composablekernel/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp,
projects/composablekernel/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp
2. XDLOPS and preshuffle pipeline variants across
projects/composablekernel/include/ck/tensor_operation/gpu/block
(v1/v2/v3/v4/v5, scale, dequant, gufusion, moe, mx, blockscale,
skip-b-lds, dpp, xdlops)

Shared functor file:
projects/composablekernel/include/ck/utility/vector_load_functor.hpp

## Test Plan

Note that the provided compilation traces by -ftime-trace do not report
unnamed lambda instantiations, so a clear baseline for instantiation
counts cannot be established. As a result, the impact of this change
will be evaluated based on runtime performance rather than direct
instantiation-count comparisons.

## Test Result

The effects of this were timed by the compilation of a single HIP object
through an example (grouped_gemm_wmma_splitk_fp16.cpp). The average user
time and speedup of this using the average of 100 compilations is:
- Mean compile time before the changes: 37.734 s
- Mean compile time after:  32.087 s
- Speedup: 17.6%

Ran a full CK compilation on Alola with the following results:

| Metric | Before (min) | After (min) | Absolute Reduction (min) | %
Reduction |
| ------ | ------------ | ----------- | ------------------------ |
2026-06-08 17:11:53 +00:00

167 lines
4.5 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/utility/common_header.hpp"
/**
* @file index_expression.hpp
* @brief Compile-time index expression evaluation utilities
* @details Provides types and operations for building and evaluating index expressions at compile
* time. Supports arithmetic operations (Add, Mult, Div, Mod) on compile-time constants and index
* variables. Expressions are built using template types and evaluated via the eval() function
* family.
*/
namespace ck::index_expression {
/**
* @brief Represents an index variable in a compile-time expression
* @details Used as a placeholder for index values that are determined at evaluation time.
* When an expression containing Ik is evaluated with eval<ik>(), the Ik type evaluates to the
* provided ik value.
*/
struct Ik
{
};
/**
* @brief Binary addition operation for compile-time index expressions
* @tparam L Left operand type (must be evaluable to index_t)
* @tparam R Right operand type (must be evaluable to index_t)
*/
template <typename L, typename R>
struct Add
{
};
/**
* @brief Binary multiplication operation for compile-time index expressions
* @tparam L Left operand type
* @tparam R Right operand type
*/
template <typename L, typename R>
struct Mult
{
};
/**
* @brief Binary division operation for compile-time index expressions
* @tparam L Left operand type
* @tparam R Right operand type
*/
template <typename L, typename R>
struct Div
{
};
/**
* @brief Binary modulo operation for compile-time index expressions
* @tparam L Left operand type
* @tparam R Right operand type
* @note Both operands must evaluate to integer types
*/
template <typename L, typename R>
struct Mod
{
};
template <index_t ik, index_t v>
/**
* @brief Evaluates a literal Number value
* @tparam ik The index variable value (unused for literal expressions)
* @tparam v The literal constant value
* @return The constant value v
*/
constexpr index_t eval(Number<v>)
{
return v;
}
/**
* @brief Evaluates an index variable to its provided value
* @tparam ik The value to substitute for the index variable
* @return The index value ik
*/
template <index_t ik>
constexpr index_t eval(Ik)
{
return ik;
}
/**
* @brief Evaluates an addition expression
* @tparam ik The index variable value
* @tparam L Type of left operand
* @tparam R Type of right operand
* @return Sum of eval(L) + eval(R)
*/
template <index_t ik, typename L, typename R>
constexpr index_t eval(Add<L, R>)
{
return eval<ik>(L{}) + eval<ik>(R{});
}
/**
* @brief Evaluates a multiplication expression
* @tparam ik The index variable value
* @tparam L Type of left operand
* @tparam R Type of right operand
* @return Product of eval(L) * eval(R)
*/
template <index_t ik, typename L, typename R>
constexpr index_t eval(Mult<L, R>)
{
return eval<ik>(L{}) * eval<ik>(R{});
}
/**
* @brief Evaluates a division expression
* @tparam ik The index variable value
* @tparam L Type of left operand
* @tparam R Type of right operand (must evaluate to non-zero)
* @return Quotient of eval(L) / eval(R)
*/
template <index_t ik, typename L, typename R>
constexpr index_t eval(Div<L, R>)
{
constexpr index_t d = eval<ik>(R{});
static_assert(d != 0,
"ck::index_expression::Div: division by zero in compile-time index expression");
return eval<ik>(L{}) / d;
}
/**
* @brief Evaluates a modulo expression
* @tparam ik The index variable value
* @tparam L Type of left operand
* @tparam R Type of right operand
* @return Remainder of eval(L) % eval(R)
* @details Both operands must evaluate to integer index types
*/
template <index_t ik, typename L, typename R>
constexpr index_t eval(Mod<L, R>)
{
using lhs_t = decltype(eval<ik>(L{}));
using rhs_t = decltype(eval<ik>(R{}));
static_assert(std::is_integral_v<lhs_t> && std::is_integral_v<rhs_t>,
"ck::index_expression::Mod: modulo requires integer operands");
return eval<ik>(L{}) % eval<ik>(R{});
}
/**
* @brief Compile-time evaluation helper for index expressions
* @tparam Expr The expression to evaluate
* @tparam ik The index variable value
* @details Helper variable template that evaluates the given expression at compile time
* for the provided index value.
* @example
* using Expr = Add<Ik, Number<5>>;
* static_assert(eval_v<Expr, 3> == 8); // Evaluates to 3 + 5 = 8
*/
template <typename Expr, index_t ik>
inline constexpr index_t eval_v = eval<ik>(Expr{});
} // namespace ck::index_expression