mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-03-23 16:47:40 +00:00
[CK_TILE, CK_BUILDER] Add two-stage bwd weight kernels to CK Tile profiler (#5237) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation PR #4797 added CK Tile bwd weight kernels to the CK Profiler. The two-stage kernels were not supported in the initial PR. This PR adds the the missing bwd weight two-stage kernels to the CK Profiler. ## Technical Details Extended the CK Tile conv builder factory to build also the elementwise ops required for the two-stage kernels. Extended the CK Builder for CK Tile instance to accept the two-stage flag as part of the algorithm configuration. ## Test Plan Added units tests for CK Builder that verify the two-stage kernel construction. ## Test Result If CI passes, the added unit tests are passing. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
41 lines
1.3 KiB
C++
41 lines
1.3 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#pragma once
|
|
|
|
#include "ck_tile/core/utility/type_traits.hpp"
|
|
#include "ck_tile/host/concat.hpp"
|
|
|
|
namespace ck_tile {
|
|
|
|
template <typename BlockWarps, typename BlockTile, typename WarpTile, typename ComputeDataType>
|
|
struct ElementWiseShape
|
|
{
|
|
static constexpr index_t kBlockM = BlockTile::at(number<0>{});
|
|
|
|
static constexpr index_t kWarpM = WarpTile::at(number<0>{});
|
|
|
|
static constexpr index_t kVectorM =
|
|
min(static_cast<index_t>(16 / sizeof(ComputeDataType)), kWarpM / get_warp_size());
|
|
|
|
static constexpr index_t kWarpPerBlockM = BlockWarps::at(number<0>{});
|
|
|
|
static constexpr index_t kThreadPerWarpM = get_warp_size();
|
|
|
|
static constexpr index_t kRepeatM = kBlockM / (kWarpPerBlockM * kVectorM * kThreadPerWarpM);
|
|
|
|
static constexpr index_t kBlockSize =
|
|
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies<>{}, number<1>{});
|
|
|
|
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
|
{
|
|
// clang-format off
|
|
return concat('_', "shape",
|
|
kBlockM, kWarpM, kVectorM, kWarpPerBlockM, kThreadPerWarpM, kRepeatM, kBlockSize
|
|
);
|
|
// clang-format on
|
|
}
|
|
};
|
|
|
|
} // namespace ck_tile
|