mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
Add a sample to do element_wise add for 3D inputs
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
set(EXAMPLE_REDUCE "add")
|
||||
set(EXAMPLE_REDUCE "add_3D")
|
||||
# not using add_example_executable() to add this target, since we don't want this to have
|
||||
# to be included in "make all/install/check"
|
||||
message("adding example ${EXAMPLE_REDUCE}")
|
||||
|
||||
add_executable(${EXAMPLE_REDUCE} EXCLUDE_FROM_ALL add.cpp)
|
||||
add_executable(${EXAMPLE_REDUCE} EXCLUDE_FROM_ALL add_3D.cpp)
|
||||
target_include_directories(${EXAMPLE_REDUCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
set(EXAMPLE_REDUCE_COMPILE_OPTIONS)
|
||||
|
||||
|
||||
117
example/ck_tile/99_toy_example/01_add/add_3D.cpp
Normal file
117
example/ck_tile/99_toy_example/01_add/add_3D.cpp
Normal file
@@ -0,0 +1,117 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "reference_add_3D.hpp"
|
||||
#include "add_3D.hpp"
|
||||
#include <cstring>
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("b", "4", "b dimension")
|
||||
.insert("m", "10240", "m dimension")
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("prec", "fp16", "precision")
|
||||
.insert("warmup", "200", "cold iter")
|
||||
.insert("repeat", "1000", "hot iter");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
using XDataType = DataType;
|
||||
using ComputeDataType = float;
|
||||
using YDataType = DataType;
|
||||
|
||||
ck_tile::index_t b = arg_parser.get_int("b");
|
||||
ck_tile::index_t m = arg_parser.get_int("m");
|
||||
ck_tile::index_t n = arg_parser.get_int("n");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
|
||||
ck_tile::HostTensor<XDataType> x_host_a({b, m, n});
|
||||
ck_tile::HostTensor<XDataType> x_host_b({b, m, n});
|
||||
|
||||
ck_tile::HostTensor<YDataType> y_host_ref({b, m, n});
|
||||
ck_tile::HostTensor<YDataType> y_host_dev({b, m, n});
|
||||
|
||||
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(x_host_a);
|
||||
ck_tile::FillUniformDistribution<XDataType>{-5.f, 5.f}(x_host_b);
|
||||
|
||||
ck_tile::DeviceMem x_buf_a(x_host_a.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem x_buf_b(x_host_b.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
|
||||
|
||||
x_buf_a.ToDevice(x_host_a.data());
|
||||
x_buf_b.ToDevice(x_host_b.data());
|
||||
|
||||
using BlockWarps =
|
||||
ck_tile::sequence<1, 8>; // number of concurrent warps in one block (if 8 warps * 64 threads
|
||||
// per warp, 512 threads in one block are NEEDED)
|
||||
using BlockTile =
|
||||
ck_tile::sequence<1, 4096>; // shape of one blockTile (elements covered by one block)
|
||||
using WarpTile = ck_tile::sequence<1, 512>; // shape of one warpTile (elements covered by one
|
||||
// warp (64 threads))
|
||||
using Vector = ck_tile::sequence<1, 8>; // shape of one vector (elements covered by one thread)
|
||||
|
||||
constexpr ck_tile::index_t kBlockSize =
|
||||
512; // number of blockWarps * number of threads per warp
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
ck_tile::index_t kGridSize = (b * m / BlockTile::at(ck_tile::number<0>{}));
|
||||
std::cout << "block x-size = " << BlockTile::at(ck_tile::number<0>{}) << std::endl;
|
||||
std::cout << "grid size " << kGridSize << std::endl;
|
||||
|
||||
using Shape = ck_tile::AddShape<BlockWarps, BlockTile, WarpTile, Vector>;
|
||||
using Porblem = ck_tile::AddProblem<XDataType, ComputeDataType, YDataType, Shape>;
|
||||
|
||||
using Kernel = ck_tile::Add<Porblem>;
|
||||
|
||||
float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
|
||||
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
|
||||
Kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
static_cast<XDataType*>(x_buf_a.GetDeviceBuffer()),
|
||||
static_cast<XDataType*>(x_buf_b.GetDeviceBuffer()),
|
||||
static_cast<YDataType*>(y_buf.GetDeviceBuffer()),
|
||||
b,
|
||||
m,
|
||||
n));
|
||||
|
||||
std::size_t num_btype = 2 * sizeof(XDataType) * b * m * n + sizeof(YDataType) * b * m * n;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
ck_tile::reference_add<XDataType, YDataType>(x_host_a, x_host_b, y_host_ref);
|
||||
y_buf.FromDevice(y_host_dev.mData.data());
|
||||
pass = ck_tile::check_err(y_host_dev, y_host_ref);
|
||||
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
}
|
||||
177
example/ck_tile/99_toy_example/01_add/add_3D.hpp
Normal file
177
example/ck_tile/99_toy_example/01_add/add_3D.hpp
Normal file
@@ -0,0 +1,177 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BlockWarps, // num warps along seq<M, N>
|
||||
typename BlockTile, // block size, seq<M, N>
|
||||
typename WarpTile, // warp size, seq<M, N>
|
||||
typename Vector> // contiguous pixels(vector size) along seq<M, N>
|
||||
struct AddShape
|
||||
{
|
||||
static constexpr index_t Block_M = BlockTile::at(number<0>{}); // elements along M in one Block
|
||||
static constexpr index_t Block_N = BlockTile::at(number<1>{}); // elements along N in one Block
|
||||
|
||||
static constexpr index_t Warp_M = WarpTile::at(number<0>{}); // elements along M in one Warp
|
||||
static constexpr index_t Warp_N = WarpTile::at(number<1>{}); // elements along N in one Warp
|
||||
|
||||
static constexpr index_t Vector_M = Vector::at(number<0>{}); // elements along M in one Vector
|
||||
static constexpr index_t Vector_N = Vector::at(number<1>{}); // elements along N in one Vector
|
||||
|
||||
static constexpr index_t WarpPerBlock_M =
|
||||
BlockWarps::at(number<0>{}); // num concurrent warps along M
|
||||
static constexpr index_t WarpPerBlock_N =
|
||||
BlockWarps::at(number<1>{}); // num concurrent warps along N
|
||||
|
||||
static constexpr index_t ThreadPerWarp_M =
|
||||
Warp_M /
|
||||
Vector_M; // num threads along M in one Warp (ThreadPerWarp_M * ThreadPerWarp_N must be 64)
|
||||
static constexpr index_t ThreadPerWarp_N =
|
||||
Warp_N /
|
||||
Vector_N; // num threads along N in one Warp (ThreadPerWarp_M * ThreadPerWarp_N must be 64)
|
||||
|
||||
static constexpr index_t Repeat_M =
|
||||
Block_M /
|
||||
(WarpPerBlock_M *
|
||||
Warp_M); // num of time a warp iterates along M to ensure the entire block is covered
|
||||
static constexpr index_t Repeat_N =
|
||||
Block_N /
|
||||
(WarpPerBlock_N *
|
||||
Warp_N); // num of time a warp iterates along N to ensure the entire block is covered
|
||||
|
||||
static constexpr index_t BlockSize =
|
||||
warpSize *
|
||||
reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{}); // num of threads in one block
|
||||
};
|
||||
|
||||
template <typename XDataType_, typename ComputeDataType_, typename YDataType_, typename BlockShape_>
|
||||
struct AddProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>; // data type of input tensor
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>; // data type of compute tensor
|
||||
using YDataType = remove_cvref_t<YDataType_>; // data type of output tensor
|
||||
using BlockShape = remove_cvref_t<BlockShape_>; // block shapes and sizes
|
||||
};
|
||||
|
||||
struct AddDefaultPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<S::Repeat_M,
|
||||
S::WarpPerBlock_M,
|
||||
S::ThreadPerWarp_M,
|
||||
S::Vector_M>, // how many sub division is a block divided in
|
||||
sequence<S::Repeat_N,
|
||||
S::WarpPerBlock_N,
|
||||
S::ThreadPerWarp_N,
|
||||
S::Vector_N>>, // how many sub division is a block divided in
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>, // What are the shapes of those sub divisions
|
||||
tuple<sequence<1, 1>, sequence<2, 2>>, // What are the shapes of those sub divisions
|
||||
sequence<1, 1, 2, 2>, // How much data does a thread work on and how many iterations
|
||||
// of warps are there
|
||||
sequence<0, 3, 0, 3>>{}); // How much data does a thread work on and how many
|
||||
// iterations of warps are there
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = AddDefaultPolicy>
|
||||
struct Add
|
||||
{
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
|
||||
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
|
||||
|
||||
CK_TILE_DEVICE void operator()(
|
||||
const XDataType* p_x_a, const XDataType* p_x_b, YDataType* p_y, index_t B, index_t M, index_t N) const
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
// Create flattened 2D view by combining B and M dimensions
|
||||
const index_t M_flattened = B * M;
|
||||
|
||||
const auto x_m_n_a = make_naive_tensor_view<address_space_enum::global,
|
||||
memory_operation_enum::set,
|
||||
amd_buffer_coherence_enum::slc>(
|
||||
p_x_a,
|
||||
make_tuple(M_flattened, N),
|
||||
make_tuple(N, 1),
|
||||
number<S::Vector_N>{},
|
||||
number<1>{}); // raw data, shape of tensor, stride of tensor, lastGarunteedVectorLength,
|
||||
// lastGarunteedVectorStride
|
||||
|
||||
const auto x_m_n_b = make_naive_tensor_view<address_space_enum::global,
|
||||
memory_operation_enum::set,
|
||||
amd_buffer_coherence_enum::slc>(
|
||||
p_x_b,
|
||||
make_tuple(M_flattened, N),
|
||||
make_tuple(N, 1),
|
||||
number<S::Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
const auto y_m_n = make_naive_tensor_view<address_space_enum::global,
|
||||
memory_operation_enum::set,
|
||||
amd_buffer_coherence_enum::slc>(
|
||||
p_y,
|
||||
make_tuple(M_flattened, N),
|
||||
make_tuple(N, 1),
|
||||
number<S::Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
const auto iM = get_block_id() * S::Block_M; // origin of the block along
|
||||
|
||||
auto x_window_a = make_tile_window(x_m_n_a,
|
||||
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
|
||||
{iM, 0},
|
||||
Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
|
||||
auto x_window_b = make_tile_window(x_m_n_b,
|
||||
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
|
||||
{iM, 0},
|
||||
Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
|
||||
auto y_window = make_tile_window(y_m_n,
|
||||
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
|
||||
{iM, 0},
|
||||
Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
|
||||
index_t num_n_tile_iteration =
|
||||
__builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_N));
|
||||
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
const auto xa = load_tile(x_window_a);
|
||||
const auto xb = load_tile(x_window_b);
|
||||
auto y_compute = load_tile(y_window);
|
||||
|
||||
constexpr auto spans = decltype(xa)::get_distributed_spans();
|
||||
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1);
|
||||
const auto x = ck_tile::type_convert<ComputeDataType>(xa[i_j_idx]);
|
||||
const auto y = ck_tile::type_convert<ComputeDataType>(xb[i_j_idx]);
|
||||
y_compute(i_j_idx) = x + y;
|
||||
});
|
||||
});
|
||||
|
||||
store_tile(y_window, cast_tile<YDataType>(y_compute));
|
||||
move_tile_window(x_window_a, {0, S::Block_N});
|
||||
move_tile_window(x_window_b, {0, S::Block_N});
|
||||
move_tile_window(y_window, {0, S::Block_N});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,206 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BlockWarps, // num warps along seq<M, N>
|
||||
typename BlockTile, // block size, seq<M, N>
|
||||
typename WarpTile, // warp size, seq<M, N>
|
||||
typename Vector> // contiguous pixels(vector size) along seq<M, N>
|
||||
struct AddShape
|
||||
{
|
||||
static constexpr index_t Block_M = BlockTile::at(number<0>{}); // elements along M in one Block
|
||||
static constexpr index_t Block_N = BlockTile::at(number<1>{}); // elements along N in one Block
|
||||
|
||||
static constexpr index_t Warp_M = WarpTile::at(number<0>{}); // elements along M in one Warp
|
||||
static constexpr index_t Warp_N = WarpTile::at(number<1>{}); // elements along N in one Warp
|
||||
|
||||
static constexpr index_t Vector_M = Vector::at(number<0>{}); // elements along M in one Vector
|
||||
static constexpr index_t Vector_N = Vector::at(number<1>{}); // elements along N in one Vector
|
||||
|
||||
static constexpr index_t WarpPerBlock_M =
|
||||
BlockWarps::at(number<0>{}); // num concurrent warps along M
|
||||
static constexpr index_t WarpPerBlock_N =
|
||||
BlockWarps::at(number<1>{}); // num concurrent warps along N
|
||||
|
||||
static constexpr index_t ThreadPerWarp_M =
|
||||
Warp_M /
|
||||
Vector_M; // num threads along M in one Warp (ThreadPerWarp_M * ThreadPerWarp_N must be 64)
|
||||
static constexpr index_t ThreadPerWarp_N =
|
||||
Warp_N /
|
||||
Vector_N; // num threads along N in one Warp (ThreadPerWarp_M * ThreadPerWarp_N must be 64)
|
||||
|
||||
static constexpr index_t Repeat_M =
|
||||
Block_M /
|
||||
(WarpPerBlock_M *
|
||||
Warp_M); // num of time a warp iterates along M to ensure the entire block is covered
|
||||
static constexpr index_t Repeat_N =
|
||||
Block_N /
|
||||
(WarpPerBlock_N *
|
||||
Warp_N); // num of time a warp iterates along N to ensure the entire block is covered
|
||||
|
||||
static constexpr index_t BlockSize =
|
||||
warpSize *
|
||||
reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{}); // num of threads in one block
|
||||
};
|
||||
|
||||
template <typename XDataType_, typename ComputeDataType_, typename YDataType_, typename BlockShape_>
|
||||
struct AddProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>; // data type of input tensor
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>; // data type of compute tensor
|
||||
using YDataType = remove_cvref_t<YDataType_>; // data type of output tensor
|
||||
using BlockShape = remove_cvref_t<BlockShape_>; // block shapes and sizes
|
||||
};
|
||||
|
||||
struct AddDefaultPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<S::Repeat_M,
|
||||
S::WarpPerBlock_M,
|
||||
S::ThreadPerWarp_M,
|
||||
S::Vector_M>, // how many sub division is a block divided in
|
||||
sequence<S::Repeat_N,
|
||||
S::WarpPerBlock_N,
|
||||
S::ThreadPerWarp_N,
|
||||
S::Vector_N>>, // how many sub division is a block divided in
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>, // What are the shapes of those sub divisions
|
||||
tuple<sequence<1, 1>, sequence<2, 2>>, // What are the shapes of those sub divisions
|
||||
sequence<1, 1, 2, 2>, // How much data does a thread work on and how many iterations
|
||||
// of warps are there
|
||||
sequence<0, 3, 0, 3>>{}); // How much data does a thread work on and how many
|
||||
// iterations of warps are there
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = AddDefaultPolicy>
|
||||
struct Add
|
||||
{
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
|
||||
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
|
||||
|
||||
CK_TILE_DEVICE void operator()(
|
||||
const XDataType* p_x_a, const XDataType* p_x_b, YDataType* p_y, index_t B, index_t M, index_t N) const
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
// Create 3D tensor views first
|
||||
const auto x_b_m_n_a = make_naive_tensor_view<address_space_enum::global,
|
||||
memory_operation_enum::set,
|
||||
amd_buffer_coherence_enum::slc>(
|
||||
p_x_a,
|
||||
make_tuple(B, M, N),
|
||||
make_tuple(M * N, N, 1),
|
||||
number<S::Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
const auto x_b_m_n_b = make_naive_tensor_view<address_space_enum::global,
|
||||
memory_operation_enum::set,
|
||||
amd_buffer_coherence_enum::slc>(
|
||||
p_x_b,
|
||||
make_tuple(B, M, N),
|
||||
make_tuple(M * N, N, 1),
|
||||
number<S::Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
const auto y_b_m_n = make_naive_tensor_view<address_space_enum::global,
|
||||
memory_operation_enum::set,
|
||||
amd_buffer_coherence_enum::slc>(
|
||||
p_y,
|
||||
make_tuple(B, M, N),
|
||||
make_tuple(M * N, N, 1),
|
||||
number<S::Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
// Now transform the 3D tensor views to 2D using make_merge_transform
|
||||
// This merges the B and M dimensions
|
||||
const auto x_m_n_a = transform_tensor_descriptor(
|
||||
x_b_m_n_a,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(number<B>{}, number<M>{})),
|
||||
make_pass_through_transform(number<N>{})
|
||||
),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
const auto x_m_n_b = transform_tensor_descriptor(
|
||||
x_b_m_n_b,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(number<B>{}, number<M>{})),
|
||||
make_pass_through_transform(number<N>{})
|
||||
),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
const auto y_m_n = transform_tensor_descriptor(
|
||||
y_b_m_n,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(number<B>{}, number<M>{})),
|
||||
make_pass_through_transform(number<N>{})
|
||||
),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
|
||||
|
||||
const auto iM = get_block_id() * S::Block_M; // origin of the block along
|
||||
|
||||
auto x_window_a = make_tile_window(x_m_n_a,
|
||||
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
|
||||
{iM, 0},
|
||||
Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
|
||||
auto x_window_b = make_tile_window(x_m_n_b,
|
||||
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
|
||||
{iM, 0},
|
||||
Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
|
||||
auto y_window = make_tile_window(y_m_n,
|
||||
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
|
||||
{iM, 0},
|
||||
Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
|
||||
index_t num_n_tile_iteration =
|
||||
__builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_N));
|
||||
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
|
||||
{
|
||||
const auto xa = load_tile(x_window_a);
|
||||
const auto xb = load_tile(x_window_b);
|
||||
auto y_compute = load_tile(y_window);
|
||||
|
||||
constexpr auto spans = decltype(xa)::get_distributed_spans();
|
||||
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1);
|
||||
const auto x = ck_tile::type_convert<ComputeDataType>(xa[i_j_idx]);
|
||||
const auto y = ck_tile::type_convert<ComputeDataType>(xb[i_j_idx]);
|
||||
y_compute(i_j_idx) = x + y;
|
||||
});
|
||||
});
|
||||
|
||||
store_tile(y_window, cast_tile<YDataType>(y_compute));
|
||||
move_tile_window(x_window_a, {0, S::Block_N});
|
||||
move_tile_window(x_window_b, {0, S::Block_N});
|
||||
move_tile_window(y_window, {0, S::Block_N});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
47
example/ck_tile/99_toy_example/01_add/reference_add_3D.hpp
Normal file
47
example/ck_tile/99_toy_example/01_add/reference_add_3D.hpp
Normal file
@@ -0,0 +1,47 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include <thread>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename XDataType, typename YDataType>
|
||||
CK_TILE_HOST void reference_add(const HostTensor<XDataType>& xa_b_m_n,
|
||||
const HostTensor<XDataType>& xb_b_m_n,
|
||||
HostTensor<YDataType>& y_b_m_n)
|
||||
{
|
||||
auto f = [&](auto bm_idx) {
|
||||
|
||||
// Calculate batch and m indices
|
||||
// const int B = xa_b_m_n.mDesc.get_lengths()[0];
|
||||
const int M = xa_b_m_n.mDesc.get_lengths()[1];
|
||||
const int N = xa_b_m_n.mDesc.get_lengths()[2];
|
||||
|
||||
// Convert flat bm_idx to separate b and m indices
|
||||
const int b = bm_idx / M;
|
||||
const int m = bm_idx % M;
|
||||
|
||||
// Process each element in the N dimension
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
y_b_m_n(b, m, n) = ck_tile::type_convert<YDataType>(xa_b_m_n(b, m, n)) +
|
||||
ck_tile::type_convert<YDataType>(xb_b_m_n(b, m, n));
|
||||
}
|
||||
|
||||
|
||||
};
|
||||
|
||||
// Get total elements to process in the B and M dimensions
|
||||
const int total_bm = y_b_m_n.mDesc.get_lengths()[0] * y_b_m_n.mDesc.get_lengths()[1];
|
||||
|
||||
// Parallelize computation across the flattened B×M space
|
||||
make_ParallelTensorFunctor(f, total_bm)(std::thread::hardware_concurrency());
|
||||
|
||||
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user