Files
composable_kernel/include/ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp
Ville Pietilä 5d6e69194d [CK_TILE, CK_BUILDER] Add two-stage bwd weight kernels to CK Tile profiler (#5237)
## Motivation

PR #4797 added CK Tile bwd weight kernels to the CK Profiler. The
two-stage kernels were not supported in the initial PR. This PR adds the
the missing bwd weight two-stage kernels to the CK Profiler.

## Technical Details

Extended the CK Tile conv builder factory to build also the elementwise
ops required for the two-stage kernels. Extended the CK Builder for CK
Tile instance to accept the two-stage flag as part of the algorithm
configuration.

## Test Plan

Added units tests for CK Builder that verify the two-stage kernel
construction.
  
## Test Result

If CI passes, the added unit tests are passing.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.

---------

Co-authored-by: Ville Pietilä <>
2026-03-12 19:20:15 -06:00

41 lines
1.3 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile {
template <typename BlockWarps, typename BlockTile, typename WarpTile, typename ComputeDataType>
struct ElementWiseShape
{
static constexpr index_t kBlockM = BlockTile::at(number<0>{});
static constexpr index_t kWarpM = WarpTile::at(number<0>{});
static constexpr index_t kVectorM =
min(static_cast<index_t>(16 / sizeof(ComputeDataType)), kWarpM / get_warp_size());
static constexpr index_t kWarpPerBlockM = BlockWarps::at(number<0>{});
static constexpr index_t kThreadPerWarpM = get_warp_size();
static constexpr index_t kRepeatM = kBlockM / (kWarpPerBlockM * kVectorM * kThreadPerWarpM);
static constexpr index_t kBlockSize =
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies<>{}, number<1>{});
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "shape",
kBlockM, kWarpM, kVectorM, kWarpPerBlockM, kThreadPerWarpM, kRepeatM, kBlockSize
);
// clang-format on
}
};
} // namespace ck_tile