mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
* Fix a typo * Use std::variant to call run_gemm_example_with_layouts with the available layout variant combinations * Use a unified run_gemm_example_prec_type for basic gemm and universal gemm * Factor out run_gemm_example_prec_type * Refactor argument parsing in gemm_splitk_two_stage_reduce.cpp * Parse arguments outside of create_args * Move the gemm operators to separate structs to facilitate their reuse * Move the invokers to separate files to facilitate their reuse * Rename the invoker files for consistency with the examples that use them * Add fp32 support to the elementwise examples, and produce an error message for unsupported types * Get rid of four unused variables * Make two variables const * Add support for different input-output type combinations in elementwise examples * Test support for different input and output types in elementwise examples * Add support for different operations in the elementwise unary tests * Add support for UnaryConvert in the elementwise unary tests * Add support for bf16 in elementwise examples, excluding unsupported type combinations * Make some operator parameters const in ElementWiseKernel * Remove some unnecessary include statements * Implement a two-stage GEMM that does a type conversion in the second stage using the elementwise kernel * Clear workspace instead of output when flushing the cache in SplitKTwoStageInvoker::gemm * Fix formatting issues reported by clang * Add back CK_TILE_USE_WMMA related changes * Use the right prec type for bf16 in the universal GEMM and two stage split K examples * Add some brackets * Add some brackets * Separate the clearing of the GEMM output memory from the cache flushing in the universal GEMM example * Separate the clearing of the GEMM output memory from the cache flushing in the split K two stage example * Fix formatting * No need to call SetZero on ws_m_n_dev_buf here, as clear_gemm_output now does this as part of the kernel preprocessing * Add fp16 data type to splitk two stage example * Add preprocessing with optional cache flushing and clearing of output for k_batch > 1 to the basic GEMM example
126 lines
5.3 KiB
C++
126 lines
5.3 KiB
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#pragma once
|
|
|
|
#include "ck_tile/core.hpp"
|
|
#include "ck_tile/ops/common.hpp"
|
|
#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp"
|
|
#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_default_policy.hpp"
|
|
namespace ck_tile {
|
|
|
|
template <typename Problem_, typename Policy_>
|
|
struct ElementWiseKernel
|
|
{
|
|
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
|
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
|
|
|
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
|
|
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
|
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
|
|
using ElementWiseOperation = ck_tile::remove_cvref_t<typename Problem::ElementWiseOperation>;
|
|
|
|
static constexpr index_t kBlockSize = Problem::BlockShape::kBlockSize;
|
|
|
|
template <typename... XDataType, typename Dims>
|
|
CK_TILE_DEVICE void operator()(const Dims lens,
|
|
const Dims input_strides,
|
|
const Dims output_strides,
|
|
const tuple<XDataType...>& input_tensors,
|
|
YDataType* p_y) const
|
|
{
|
|
using S = typename Problem::BlockShape;
|
|
|
|
// Setup block-level coordinates and transforms
|
|
const index_t iM = get_block_id() * S::kBlockM;
|
|
const auto merge_transform = make_merge_transform(lens);
|
|
|
|
// Load all input tiles into registers.
|
|
// The lambda structure here is intended to minimize the lifetime
|
|
// of intermediate objects (views, windows) used for loading.
|
|
const auto x_tiles = ck_tile::generate_tuple(
|
|
[&](auto i) {
|
|
const auto tensor_view = make_naive_tensor_view<address_space_enum::global>(
|
|
input_tensors.get(i), lens, input_strides, number<S::kVectorM>{}, number<1>{});
|
|
|
|
const auto transformed_tensor = pad_tensor_view(
|
|
transform_tensor_view(tensor_view,
|
|
ck_tile::make_tuple(merge_transform),
|
|
ck_tile::make_tuple(make_index_sequence<Dims::size()>{}),
|
|
ck_tile::make_tuple(sequence<0>{})),
|
|
ck_tile::make_tuple(number<S::kBlockM>{}),
|
|
sequence<Problem::kPad>{});
|
|
|
|
const auto x_window =
|
|
make_tile_window(transformed_tensor,
|
|
ck_tile::make_tuple(number<S::kBlockM>{}),
|
|
{iM},
|
|
Policy::template MakeXBlockTileDistribution<Problem>());
|
|
|
|
return load_tile(x_window);
|
|
},
|
|
number<sizeof...(XDataType)>{});
|
|
|
|
// Setup output tile in registers.
|
|
const auto& x_tile0 = x_tiles.get(number<0>{});
|
|
auto y_tile = make_static_distributed_tensor<YDataType>(x_tile0.get_tile_distribution());
|
|
|
|
// Perform element-wise computation.
|
|
const auto spans = x_tile0.get_distributed_spans();
|
|
sweep_tile_span(spans[number<0>{}], [&](auto idx) {
|
|
const auto tile_idx = make_tuple(idx);
|
|
apply(
|
|
[&](auto&&... tiles) {
|
|
ElementWiseOperation{}(y_tile(tile_idx),
|
|
type_convert<ComputeDataType>(tiles[tile_idx])...);
|
|
},
|
|
x_tiles);
|
|
});
|
|
|
|
// Setup output window and store the result tile.
|
|
const auto y_m_n = make_naive_tensor_view<address_space_enum::global>(
|
|
p_y, lens, output_strides, number<S::kVectorM>{});
|
|
|
|
const auto transformed_y_m_n = pad_tensor_view(
|
|
transform_tensor_view(y_m_n,
|
|
ck_tile::make_tuple(merge_transform),
|
|
ck_tile::make_tuple(make_index_sequence<Dims::size()>{}),
|
|
ck_tile::make_tuple(sequence<0>{})),
|
|
ck_tile::make_tuple(number<S::kBlockM>{}),
|
|
sequence<Problem::kPad>{});
|
|
|
|
auto y_window = make_tile_window(transformed_y_m_n,
|
|
make_tuple(number<S::kBlockM>{}),
|
|
{iM},
|
|
y_tile.get_tile_distribution());
|
|
|
|
store_tile(y_window, cast_tile<YDataType>(y_tile));
|
|
}
|
|
|
|
template <typename... Ints>
|
|
CK_TILE_HOST static bool IsSupportedArgument(const ck_tile::tuple<Ints...>& input_sizes)
|
|
{
|
|
int total_elements = 1;
|
|
const auto kVectorM = Problem_::BlockShape::kVectorM;
|
|
|
|
apply([&](auto&&... args) { ((total_elements *= args), ...); }, input_sizes);
|
|
|
|
if((total_elements % kVectorM) != 0)
|
|
{
|
|
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
|
{
|
|
CK_TILE_ERROR("Conditions not met: total number of input elements (",
|
|
total_elements,
|
|
") should be multiple of the vectorization size (",
|
|
kVectorM,
|
|
")");
|
|
}
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
};
|
|
|
|
} // namespace ck_tile
|