mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
Using named functors instead of lambdas ## Motivation Currently, in block-level GEMM pipelines, there is significant code repetition for prefetching and tail handling, where lambda functions create a unique instantiations at each call. This includes repeated static_for instantiations and large loops such as MRepeat. Each repetition results in additional instantiations, which increases compilation time and binary bloat. ## Technical Details Refactor repeated code blocks into named functors so the compiler can reuse already instantiated code instead of generating multiple copies. Scope of changes: 1. WMMAOPS pipeline internals: projects/composablekernel/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp, projects/composablekernel/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp, projects/composablekernel/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp 2. XDLOPS and preshuffle pipeline variants across projects/composablekernel/include/ck/tensor_operation/gpu/block (v1/v2/v3/v4/v5, scale, dequant, gufusion, moe, mx, blockscale, skip-b-lds, dpp, xdlops) Shared functor file: projects/composablekernel/include/ck/utility/vector_load_functor.hpp ## Test Plan Note that the provided compilation traces by -ftime-trace do not report unnamed lambda instantiations, so a clear baseline for instantiation counts cannot be established. As a result, the impact of this change will be evaluated based on runtime performance rather than direct instantiation-count comparisons. ## Test Result The effects of this were timed by the compilation of a single HIP object through an example (grouped_gemm_wmma_splitk_fp16.cpp). The average user time and speedup of this using the average of 100 compilations is: - Mean compile time before the changes: 37.734 s - Mean compile time after: 32.087 s - Speedup: 17.6% Ran a full CK compilation on Alola with the following results: | Metric | Before (min) | After (min) | Absolute Reduction (min) | % Reduction | | ------ | ------------ | ----------- | ------------------------ |
54 lines
1.8 KiB
C++
54 lines
1.8 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#include <gtest/gtest.h>
|
|
|
|
#include "ck/utility/index_expression.hpp"
|
|
|
|
using namespace ck;
|
|
|
|
/**
|
|
* Test basic evaluation of literal values and index variables
|
|
* - Number<7> should evaluate to the literal constant 7 regardless of index value
|
|
* - Ik (index variable) should evaluate to the provided index value
|
|
*/
|
|
TEST(IndexExpression, EvalLiteralAndIk)
|
|
{
|
|
EXPECT_EQ((index_expression::eval_v<Number<7>, 3>), 7);
|
|
EXPECT_EQ((index_expression::eval_v<Number<7>, 5>), 7);
|
|
|
|
EXPECT_EQ((index_expression::eval_v<index_expression::Ik, 3>), 3);
|
|
EXPECT_EQ((index_expression::eval_v<index_expression::Ik, 7>), 7);
|
|
}
|
|
|
|
/**
|
|
* Test arithmetic operations with index expressions
|
|
*/
|
|
TEST(IndexExpression, EvalAddMultDivMod)
|
|
{
|
|
|
|
using ExprAdd = index_expression::Add<index_expression::Ik, Number<5>>;
|
|
using ExprMult = index_expression::Mult<ExprAdd, Number<2>>;
|
|
using ExprDiv = index_expression::Div<ExprMult, Number<4>>;
|
|
using ExprMod = index_expression::Mod<ExprMult, Number<3>>;
|
|
|
|
EXPECT_EQ((index_expression::eval_v<ExprAdd, 3>), 8);
|
|
EXPECT_EQ((index_expression::eval_v<ExprMult, 3>), 16);
|
|
EXPECT_EQ((index_expression::eval_v<ExprDiv, 3>), 4);
|
|
EXPECT_EQ((index_expression::eval_v<ExprMod, 3>), 1);
|
|
}
|
|
|
|
/**
|
|
* Test nested compound expressions to verify proper precedence and composition
|
|
*/
|
|
TEST(IndexExpression, EvalNestedExpression)
|
|
{
|
|
// Build nested expression: (ik + (2 * 5)) / 2
|
|
using InnerMult = index_expression::Mult<Number<2>, Number<5>>;
|
|
using InnerAdd = index_expression::Add<index_expression::Ik, InnerMult>;
|
|
using Expr = index_expression::Div<InnerAdd, Number<2>>;
|
|
|
|
// With ik=6: ((6 + (2*5)) / 2) = 8
|
|
EXPECT_EQ((index_expression::eval_v<Expr, 6>), 8);
|
|
}
|