mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 22:22:27 +00:00
This commit is contained in:
275
include/ck_tile/host/reference/reference_batched_contraction.hpp
Normal file
275
include/ck_tile/host/reference/reference_batched_contraction.hpp
Normal file
@@ -0,0 +1,275 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <thread>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Helper to apply elementwise operation with variable number of D tensors
|
||||
template <typename EDataType, typename AccDataType, typename CDEElementWise>
|
||||
struct ApplyCDEElementWise
|
||||
{
|
||||
template <typename... DValues>
|
||||
CK_TILE_HOST_DEVICE static void apply(EDataType& result,
|
||||
AccDataType sum,
|
||||
const CDEElementWise& cde_elementwise,
|
||||
DValues... d_vals)
|
||||
{
|
||||
if constexpr(sizeof...(DValues) == 0)
|
||||
{
|
||||
result = static_cast<EDataType>(sum);
|
||||
}
|
||||
else
|
||||
{
|
||||
cde_elementwise(
|
||||
result, ck_tile::type_convert<float>(sum), ck_tile::type_convert<float>(d_vals)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Helper to extract D values at a given offset using index sequence
|
||||
template <typename DDataType,
|
||||
ck_tile::index_t NumDTensor,
|
||||
typename Indices = std::make_index_sequence<NumDTensor>>
|
||||
struct ExtractDValues;
|
||||
|
||||
template <typename DDataType, ck_tile::index_t NumDTensor, std::size_t... Is>
|
||||
struct ExtractDValues<DDataType, NumDTensor, std::index_sequence<Is...>>
|
||||
{
|
||||
template <typename EDataType, typename AccDataType, typename CDEElementWise>
|
||||
CK_TILE_HOST static void
|
||||
apply_at_offsets(EDataType& result,
|
||||
AccDataType sum,
|
||||
const CDEElementWise& cde_elementwise,
|
||||
const std::array<ck_tile::HostTensor<DDataType>, NumDTensor>& ds_tensors,
|
||||
const std::array<std::size_t, NumDTensor>& d_offsets)
|
||||
{
|
||||
ApplyCDEElementWise<EDataType, AccDataType, CDEElementWise>::apply(
|
||||
result, sum, cde_elementwise, ds_tensors[Is].mData[d_offsets[Is]]...);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename DDataType,
|
||||
typename EDataType,
|
||||
typename AccDataType,
|
||||
typename CDEElementWise,
|
||||
ck_tile::index_t NumDTensor>
|
||||
|
||||
void compute_reference_batched_contraction(
|
||||
const ck_tile::HostTensor<ADataType>& a_full_dims,
|
||||
const ck_tile::HostTensor<BDataType>& b_full_dims,
|
||||
const std::array<ck_tile::HostTensor<DDataType>, NumDTensor>& ds_full_dims_host,
|
||||
ck_tile::HostTensor<EDataType>& e_full_dims_host_ref,
|
||||
ck_tile::index_t G_total,
|
||||
ck_tile::index_t M_total,
|
||||
ck_tile::index_t N_total,
|
||||
ck_tile::index_t K_total,
|
||||
const CDEElementWise& cde_elementwise,
|
||||
const std::vector<ck_tile::index_t>& G_dims,
|
||||
const std::vector<ck_tile::index_t>& M_dims,
|
||||
const std::vector<ck_tile::index_t>& N_dims,
|
||||
const std::vector<ck_tile::index_t>& K_dims)
|
||||
{
|
||||
std::cout << "Calculating reference using stride-aware indexing with parallel processing..."
|
||||
<< std::endl;
|
||||
|
||||
// Extract stride information from tensor descriptors
|
||||
const auto a_strides = a_full_dims.get_strides();
|
||||
const auto b_strides = b_full_dims.get_strides();
|
||||
const auto e_strides = e_full_dims_host_ref.get_strides();
|
||||
|
||||
// Extract D tensor strides
|
||||
std::array<std::vector<std::size_t>, NumDTensor> ds_strides;
|
||||
for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
|
||||
{
|
||||
ds_strides[d] = ds_full_dims_host[d].get_strides();
|
||||
}
|
||||
|
||||
const ck_tile::index_t num_g_dims = G_dims.size();
|
||||
const ck_tile::index_t num_m_dims = M_dims.size();
|
||||
const ck_tile::index_t num_n_dims = N_dims.size();
|
||||
const ck_tile::index_t num_k_dims = K_dims.size();
|
||||
|
||||
// Helper lambda to compute linear index from flat indices using strides
|
||||
auto compute_a_offset = [&](ck_tile::index_t g_flat,
|
||||
ck_tile::index_t m_flat,
|
||||
ck_tile::index_t k_flat) -> std::size_t {
|
||||
std::size_t offset = 0;
|
||||
|
||||
// Decode G dimensions
|
||||
ck_tile::index_t temp = g_flat;
|
||||
for(int i = num_g_dims - 1; i >= 0; --i)
|
||||
{
|
||||
offset += (temp % G_dims[i]) * a_strides[i];
|
||||
temp /= G_dims[i];
|
||||
}
|
||||
|
||||
// Decode M dimensions
|
||||
temp = m_flat;
|
||||
for(int i = num_m_dims - 1; i >= 0; --i)
|
||||
{
|
||||
offset += (temp % M_dims[i]) * a_strides[num_g_dims + i];
|
||||
temp /= M_dims[i];
|
||||
}
|
||||
|
||||
// Decode K dimensions
|
||||
temp = k_flat;
|
||||
for(int i = num_k_dims - 1; i >= 0; --i)
|
||||
{
|
||||
offset += (temp % K_dims[i]) * a_strides[num_g_dims + num_m_dims + i];
|
||||
temp /= K_dims[i];
|
||||
}
|
||||
|
||||
return offset;
|
||||
};
|
||||
|
||||
auto compute_b_offset = [&](ck_tile::index_t g_flat,
|
||||
ck_tile::index_t n_flat,
|
||||
ck_tile::index_t k_flat) -> std::size_t {
|
||||
std::size_t offset = 0;
|
||||
|
||||
// Decode G dimensions
|
||||
ck_tile::index_t temp = g_flat;
|
||||
for(int i = num_g_dims - 1; i >= 0; --i)
|
||||
{
|
||||
offset += (temp % G_dims[i]) * b_strides[i];
|
||||
temp /= G_dims[i];
|
||||
}
|
||||
|
||||
// Decode N dimensions
|
||||
temp = n_flat;
|
||||
for(int i = num_n_dims - 1; i >= 0; --i)
|
||||
{
|
||||
offset += (temp % N_dims[i]) * b_strides[num_g_dims + i];
|
||||
temp /= N_dims[i];
|
||||
}
|
||||
|
||||
// Decode K dimensions
|
||||
temp = k_flat;
|
||||
for(int i = num_k_dims - 1; i >= 0; --i)
|
||||
{
|
||||
offset += (temp % K_dims[i]) * b_strides[num_g_dims + num_n_dims + i];
|
||||
temp /= K_dims[i];
|
||||
}
|
||||
|
||||
return offset;
|
||||
};
|
||||
|
||||
auto compute_e_offset = [&](ck_tile::index_t g_flat,
|
||||
ck_tile::index_t m_flat,
|
||||
ck_tile::index_t n_flat) -> std::size_t {
|
||||
std::size_t offset = 0;
|
||||
|
||||
// Decode G dimensions
|
||||
ck_tile::index_t temp = g_flat;
|
||||
for(int i = num_g_dims - 1; i >= 0; --i)
|
||||
{
|
||||
offset += (temp % G_dims[i]) * e_strides[i];
|
||||
temp /= G_dims[i];
|
||||
}
|
||||
|
||||
// Decode M dimensions
|
||||
temp = m_flat;
|
||||
for(int i = num_m_dims - 1; i >= 0; --i)
|
||||
{
|
||||
offset += (temp % M_dims[i]) * e_strides[num_g_dims + i];
|
||||
temp /= M_dims[i];
|
||||
}
|
||||
|
||||
// Decode N dimensions
|
||||
temp = n_flat;
|
||||
for(int i = num_n_dims - 1; i >= 0; --i)
|
||||
{
|
||||
offset += (temp % N_dims[i]) * e_strides[num_g_dims + num_m_dims + i];
|
||||
temp /= N_dims[i];
|
||||
}
|
||||
|
||||
return offset;
|
||||
};
|
||||
|
||||
// Helper to compute D tensor offset (D tensors have same shape as E: [G, M, N])
|
||||
auto compute_d_offset = [&](ck_tile::index_t g_flat,
|
||||
ck_tile::index_t m_flat,
|
||||
ck_tile::index_t n_flat,
|
||||
ck_tile::index_t d_idx) -> std::size_t {
|
||||
std::size_t offset = 0;
|
||||
const auto& d_strides = ds_strides[d_idx];
|
||||
|
||||
// Decode G dimensions
|
||||
ck_tile::index_t temp = g_flat;
|
||||
for(int i = num_g_dims - 1; i >= 0; --i)
|
||||
{
|
||||
offset += (temp % G_dims[i]) * d_strides[i];
|
||||
temp /= G_dims[i];
|
||||
}
|
||||
|
||||
// Decode M dimensions
|
||||
temp = m_flat;
|
||||
for(int i = num_m_dims - 1; i >= 0; --i)
|
||||
{
|
||||
offset += (temp % M_dims[i]) * d_strides[num_g_dims + i];
|
||||
temp /= M_dims[i];
|
||||
}
|
||||
|
||||
// Decode N dimensions
|
||||
temp = n_flat;
|
||||
for(int i = num_n_dims - 1; i >= 0; --i)
|
||||
{
|
||||
offset += (temp % N_dims[i]) * d_strides[num_g_dims + num_m_dims + i];
|
||||
temp /= N_dims[i];
|
||||
}
|
||||
|
||||
return offset;
|
||||
};
|
||||
|
||||
// Parallel computation over G and M dimensions
|
||||
auto f_gm = [&](auto g_flat, auto m_flat) {
|
||||
for(ck_tile::index_t n_flat = 0; n_flat < N_total; ++n_flat)
|
||||
{
|
||||
AccDataType sum = 0;
|
||||
|
||||
// Compute dot product over K dimension using stride-aware indexing
|
||||
for(ck_tile::index_t k_flat = 0; k_flat < K_total; ++k_flat)
|
||||
{
|
||||
const std::size_t a_offset = compute_a_offset(g_flat, m_flat, k_flat);
|
||||
const std::size_t b_offset = compute_b_offset(g_flat, n_flat, k_flat);
|
||||
|
||||
auto a_val = a_full_dims.mData[a_offset];
|
||||
auto b_val = b_full_dims.mData[b_offset];
|
||||
sum += static_cast<AccDataType>(a_val) * static_cast<AccDataType>(b_val);
|
||||
}
|
||||
|
||||
// Compute output offset using strides
|
||||
const std::size_t e_offset = compute_e_offset(g_flat, m_flat, n_flat);
|
||||
|
||||
// Compute individual D tensor offsets using their respective strides
|
||||
std::array<std::size_t, NumDTensor> d_offsets;
|
||||
for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
|
||||
{
|
||||
d_offsets[d] = compute_d_offset(g_flat, m_flat, n_flat, d);
|
||||
}
|
||||
|
||||
// Apply elementwise operation with D tensors using compile-time dispatch
|
||||
EDataType result = static_cast<EDataType>(sum);
|
||||
ExtractDValues<DDataType, NumDTensor>::apply_at_offsets(
|
||||
result, sum, cde_elementwise, ds_full_dims_host, d_offsets);
|
||||
|
||||
// Store result using stride-aware indexing
|
||||
e_full_dims_host_ref.mData[e_offset] = static_cast<EDataType>(result);
|
||||
}
|
||||
};
|
||||
|
||||
// Execute parallel computation using hardware concurrency
|
||||
// Parallelize over G_total and M_total dimensions for optimal CPU utilization
|
||||
make_ParallelTensorFunctor(f_gm, G_total, M_total)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
33
include/ck_tile/host/reference/reference_batched_dropout.hpp
Normal file
33
include/ck_tile/host/reference/reference_batched_dropout.hpp
Normal file
@@ -0,0 +1,33 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include <thread>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename DataType, typename RandValOutputDataType>
|
||||
CK_TILE_HOST void reference_batched_dropout(HostTensor<DataType>& in_out_b_m_n,
|
||||
const HostTensor<RandValOutputDataType>& randval_b_m_n,
|
||||
const uint8_t& p_undrop_in_uint8_t,
|
||||
const float scale)
|
||||
{
|
||||
const int N = in_out_b_m_n.mDesc.get_lengths()[2];
|
||||
auto f = [&](auto batch, auto m) {
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
float tmp = ck_tile::type_convert<float>(in_out_b_m_n(batch, m, n)) * scale;
|
||||
in_out_b_m_n(batch, m, n) = randval_b_m_n(batch, m, n) <= p_undrop_in_uint8_t
|
||||
? ck_tile::type_convert<DataType>(tmp)
|
||||
: DataType(0);
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(
|
||||
f, randval_b_m_n.mDesc.get_lengths()[0], randval_b_m_n.mDesc.get_lengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,74 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include <thread>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename RandValOutputDataType>
|
||||
CK_TILE_HOST void
|
||||
reference_batched_dropout_randval(HostTensor<RandValOutputDataType>& randval_b_m_n,
|
||||
index_t batch,
|
||||
uint64_t drop_seed,
|
||||
uint64_t drop_offset)
|
||||
{
|
||||
const index_t nhead = randval_b_m_n.mDesc.get_lengths()[0];
|
||||
const index_t real_seqlen_q = randval_b_m_n.mDesc.get_lengths()[1];
|
||||
const index_t real_seqlen_k = randval_b_m_n.mDesc.get_lengths()[2];
|
||||
|
||||
static_assert(std::is_same_v<RandValOutputDataType, uint8_t>);
|
||||
|
||||
// BlockDropout generates random numbers by 32x32 tiles. Even when warp gemm 16x16 is used, the
|
||||
// order of values in the bigger 32x32 tile must be the same because fwd and bwd may use
|
||||
// different warp gemms (16x16 or 32x32).
|
||||
// To compute 32x32 tiles, WarpGemmMfmaF16F16F32M32N32K16SwizzleA is used. It is
|
||||
// WarpGemmAttributeMfmaImplF16F16F32M32N32K8 with SFactor = 2 (swizzling factor).
|
||||
// Matrix element to register mapping for WarpGemmAttributeMfmaImplF16F16F32M32N32K8:
|
||||
// C i: (8 * floor(GPR_num / 4) % 32) + 4 * floor(lane / 32) + (GPR_num % 4)
|
||||
// C j: (lane % 32)
|
||||
// With SFactor = 2 it becomes:
|
||||
// C i: (16 * floor(GPR_num / 8) % 32) + 8 * floor(lane / 32) + (GPR_num % 8)
|
||||
// C j: (lane % 32)
|
||||
// See ck_tile/ops/fmha/block/block_dropout.hpp for more details.
|
||||
|
||||
// The number of Philox 4x32 results required to fill 32x32 tile of 8-bit values
|
||||
constexpr index_t philox_per_tile = 64;
|
||||
constexpr index_t warp_gemm_mn = 32;
|
||||
|
||||
const index_t rows = integer_divide_ceil(real_seqlen_q, warp_gemm_mn);
|
||||
const index_t cols = integer_divide_ceil(real_seqlen_k, warp_gemm_mn);
|
||||
|
||||
auto f = [&](index_t i_h, index_t row, index_t col) {
|
||||
uint2 rowcol = make_uint2(row, col);
|
||||
for(index_t lane = 0; lane < philox_per_tile; lane++)
|
||||
{
|
||||
const uint64_t ph_head_offset = drop_offset + (batch * nhead + i_h) * philox_per_tile;
|
||||
const index_t ph_offset = lane;
|
||||
philox ph(drop_seed, ph_head_offset + ph_offset);
|
||||
|
||||
uint8_t random_uint8_t[16];
|
||||
ph.get_random_16x8(random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol));
|
||||
|
||||
for(auto r = 0; r < 16; r++)
|
||||
{
|
||||
index_t i = (16 * (r / 8) % 32) + 8 * (lane / 32) + (r % 8);
|
||||
index_t j = (lane % 32);
|
||||
index_t m = row * warp_gemm_mn + i;
|
||||
index_t n = col * warp_gemm_mn + j;
|
||||
|
||||
if(m < real_seqlen_q && n < real_seqlen_k)
|
||||
{
|
||||
randval_b_m_n(i_h, m, n) = random_uint8_t[r];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f, nhead, rows, cols)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,64 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include <thread>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename AElementOp = ck_tile::identity,
|
||||
typename BElementOp = ck_tile::identity,
|
||||
typename BinaryElementOp = ck_tile::plus<AccDataType>>
|
||||
CK_TILE_HOST void reference_batched_elementwise(const HostTensor<ADataType>& a_b_m_n,
|
||||
const HostTensor<BDataType>& b_b_m_n,
|
||||
HostTensor<CDataType>& c_b_m_n,
|
||||
const AElementOp& a_element_op = {},
|
||||
const BElementOp& b_element_op = {},
|
||||
const BinaryElementOp& binary_element_op = {})
|
||||
{
|
||||
const ck_tile::index_t N = c_b_m_n.mDesc.get_lengths()[2];
|
||||
|
||||
const bool broadcast_a_dim_b = (a_b_m_n.get_lengths()[0] == 1);
|
||||
const bool broadcast_a_dim_m = (a_b_m_n.get_lengths()[1] == 1);
|
||||
const bool broadcast_a_dim_n = (a_b_m_n.get_lengths()[2] == 1);
|
||||
|
||||
const bool broadcast_b_dim_b = (b_b_m_n.get_lengths()[0] == 1);
|
||||
const bool broadcast_b_dim_m = (b_b_m_n.get_lengths()[1] == 1);
|
||||
const bool broadcast_b_dim_n = (b_b_m_n.get_lengths()[2] == 1);
|
||||
|
||||
auto f = [&](auto batch, auto m) {
|
||||
for(ck_tile::index_t n = 0; n < N; ++n)
|
||||
{
|
||||
AccDataType v_a{};
|
||||
{
|
||||
ck_tile::index_t i_b = (broadcast_a_dim_b ? 0 : batch);
|
||||
ck_tile::index_t i_m = (broadcast_a_dim_m ? 0 : m);
|
||||
ck_tile::index_t i_n = (broadcast_a_dim_n ? 0 : n);
|
||||
|
||||
v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_b_m_n(i_b, i_m, i_n)));
|
||||
}
|
||||
|
||||
AccDataType v_b{};
|
||||
{
|
||||
ck_tile::index_t i_b = (broadcast_b_dim_b ? 0 : batch);
|
||||
ck_tile::index_t i_m = (broadcast_b_dim_m ? 0 : m);
|
||||
ck_tile::index_t i_n = (broadcast_b_dim_n ? 0 : n);
|
||||
|
||||
v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_b_m_n(i_b, i_m, i_n)));
|
||||
}
|
||||
|
||||
c_b_m_n(batch, m, n) = ck_tile::type_convert<CDataType>(binary_element_op(v_a, v_b));
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f, c_b_m_n.mDesc.get_lengths()[0], c_b_m_n.mDesc.get_lengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
} // namespace ck_tile
|
||||
90
include/ck_tile/host/reference/reference_batched_gemm.hpp
Normal file
90
include/ck_tile/host/reference/reference_batched_gemm.hpp
Normal file
@@ -0,0 +1,90 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include <thread>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename AElementOp = ck_tile::identity,
|
||||
typename BElementOp = ck_tile::identity,
|
||||
typename ACCElementOp = ck_tile::identity>
|
||||
CK_TILE_HOST void reference_batched_gemm(const HostTensor<ADataType>& a_b_m_k,
|
||||
const HostTensor<BDataType>& b_b_n_k,
|
||||
HostTensor<CDataType>& c_b_m_n,
|
||||
const AElementOp& a_element_op = {},
|
||||
const BElementOp& b_element_op = {},
|
||||
const ACCElementOp& acc_element_op = {})
|
||||
{
|
||||
const int N = b_b_n_k.mDesc.get_lengths()[1];
|
||||
const int K = b_b_n_k.mDesc.get_lengths()[2];
|
||||
|
||||
auto f = [&](auto batch, auto m) {
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
AccDataType v_acc = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
ADataType v_a = a_element_op(a_b_m_k(batch, m, k));
|
||||
BDataType v_b = b_element_op(b_b_n_k(batch, n, k));
|
||||
|
||||
v_acc += ck_tile::type_convert<AccDataType>(v_a) *
|
||||
ck_tile::type_convert<AccDataType>(v_b);
|
||||
}
|
||||
|
||||
c_b_m_n(batch, m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f, c_b_m_n.mDesc.get_lengths()[0], c_b_m_n.mDesc.get_lengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename AElementOp = ck_tile::idx_identity,
|
||||
typename BElementOp = ck_tile::idx_identity,
|
||||
typename ACCElementOp = ck_tile::idx_identity>
|
||||
CK_TILE_HOST void reference_batched_quant_gemm(const HostTensor<ADataType>& a_b_m_k,
|
||||
const HostTensor<BDataType>& b_b_n_k,
|
||||
HostTensor<CDataType>& c_b_m_n,
|
||||
const AElementOp& a_element_op = {},
|
||||
const BElementOp& b_element_op = {},
|
||||
const ACCElementOp& acc_element_op = {})
|
||||
{
|
||||
const int N = b_b_n_k.mDesc.get_lengths()[1];
|
||||
const int K = b_b_n_k.mDesc.get_lengths()[2];
|
||||
|
||||
auto f = [&](auto batch, auto m) {
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
AccDataType v_acc = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
AccDataType v_a = ck_tile::type_convert<AccDataType>(
|
||||
a_element_op(std::make_tuple(batch, m, k), a_b_m_k(batch, m, k)));
|
||||
AccDataType v_b = ck_tile::type_convert<AccDataType>(
|
||||
b_element_op(std::make_tuple(batch, n, k), b_b_n_k(batch, n, k)));
|
||||
|
||||
v_acc += v_a * v_b;
|
||||
}
|
||||
|
||||
c_b_m_n(batch, m, n) = ck_tile::type_convert<CDataType>(
|
||||
acc_element_op(std::make_tuple(batch, m, n), v_acc));
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f, c_b_m_n.mDesc.get_lengths()[0], c_b_m_n.mDesc.get_lengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
} // namespace ck_tile
|
||||
32
include/ck_tile/host/reference/reference_batched_masking.hpp
Normal file
32
include/ck_tile/host/reference/reference_batched_masking.hpp
Normal file
@@ -0,0 +1,32 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include <thread>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename CDataType, typename MaskingType>
|
||||
CK_TILE_HOST void reference_batched_masking(HostTensor<CDataType>& c_b_m_n, const MaskingType& mask)
|
||||
{
|
||||
const int M = c_b_m_n.mDesc.get_lengths()[1];
|
||||
const int N = c_b_m_n.mDesc.get_lengths()[2];
|
||||
|
||||
auto f = [&](auto batch) {
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
if(mask.IsOutOfSinkBound(m, n))
|
||||
c_b_m_n(batch, m, n) = -ck_tile::numeric<CDataType>::infinity();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f,
|
||||
c_b_m_n.mDesc.get_lengths()[0])(std::thread::hardware_concurrency());
|
||||
}
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,61 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include <thread>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename InDataType,
|
||||
typename ScaleDataType,
|
||||
typename OutDataType,
|
||||
typename ComputeDataType>
|
||||
CK_TILE_HOST HostTensor<OutDataType>
|
||||
reference_batched_mx_descale(const HostTensor<InDataType>& a_b_m_k,
|
||||
const HostTensor<ScaleDataType>& scales_b_m_ks,
|
||||
const std::size_t scale_granularity)
|
||||
{
|
||||
const std::size_t B = a_b_m_k.get_length(0);
|
||||
const std::size_t M = a_b_m_k.get_length(1);
|
||||
const std::size_t K = a_b_m_k.get_length(2);
|
||||
|
||||
HostTensor<ComputeDataType> a_b_m_k_scaled(a_b_m_k.get_lengths());
|
||||
|
||||
auto f = [&](auto batch) {
|
||||
constexpr index_t packed_size = ck_tile::numeric_traits<InDataType>::PackedSize;
|
||||
|
||||
for(std::size_t m = 0; m < M; ++m)
|
||||
{
|
||||
for(std::size_t k = 0; k < K; k += packed_size)
|
||||
{
|
||||
const auto scale = ck_tile::type_convert<ComputeDataType>(
|
||||
scales_b_m_ks(batch, m, k / scale_granularity));
|
||||
|
||||
if constexpr(std::is_same_v<InDataType, pk_fp4_t>)
|
||||
{
|
||||
auto a_f4x2 = a_b_m_k(batch, m, k);
|
||||
auto a_f4_lo = ck_tile::type_convert<ComputeDataType>(
|
||||
a_f4x2.template unpack<>(number<0>{}));
|
||||
auto a_f4_hi = ck_tile::type_convert<ComputeDataType>(
|
||||
a_f4x2.template unpack<>(number<1>{}));
|
||||
|
||||
a_b_m_k_scaled(batch, m, k) = a_f4_lo * scale;
|
||||
a_b_m_k_scaled(batch, m, k + 1) = a_f4_hi * scale;
|
||||
}
|
||||
else
|
||||
{
|
||||
a_b_m_k_scaled(batch, m, k) =
|
||||
ck_tile::type_convert<ComputeDataType>(a_b_m_k(batch, m, k)) * scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
make_ParallelTensorFunctor(f, B)(std::thread::hardware_concurrency());
|
||||
|
||||
return a_b_m_k_scaled;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,73 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
#include <cassert>
|
||||
#include <thread>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename DataType, typename ComputeDataType = float>
|
||||
CK_TILE_HOST void reference_batched_rotary_position_embedding(const HostTensor<DataType>& input_bsd,
|
||||
const HostTensor<DataType>& cos_sd,
|
||||
const HostTensor<DataType>& sin_sd,
|
||||
bool interleaved,
|
||||
HostTensor<DataType>& output_bsd,
|
||||
bool use_1_row_sin_cos = false)
|
||||
{
|
||||
assert(cos_sd.get_num_of_dimension() == 2 && sin_sd.get_num_of_dimension() == 2);
|
||||
assert(cos_sd.get_length(0) == sin_sd.get_length(0) &&
|
||||
cos_sd.get_length(1) == sin_sd.get_length(1));
|
||||
|
||||
const index_t rotary_dim = cos_sd.get_length(1) * 2;
|
||||
assert(static_cast<std::size_t>(rotary_dim) <= input_bsd.get_length(2));
|
||||
|
||||
output_bsd.ForEach([&](auto& self, auto i) {
|
||||
const index_t i_d = i[2];
|
||||
if(rotary_dim <= i_d)
|
||||
{
|
||||
self(i) = input_bsd(i);
|
||||
return;
|
||||
}
|
||||
assert(i_d < rotary_dim);
|
||||
|
||||
const index_t i_s = i[1];
|
||||
const index_t i_s_cos_sin = (use_1_row_sin_cos ? 0 : i_s);
|
||||
|
||||
const ComputeDataType cos = type_convert<ComputeDataType>(
|
||||
interleaved ? cos_sd(i_s_cos_sin, i_d / 2)
|
||||
: cos_sd(i_s_cos_sin, i_d % cos_sd.get_length(1)));
|
||||
const ComputeDataType sin = type_convert<ComputeDataType>(
|
||||
interleaved ? sin_sd(i_s_cos_sin, i_d / 2)
|
||||
: sin_sd(i_s_cos_sin, i_d % sin_sd.get_length(1)));
|
||||
|
||||
const ComputeDataType half_rotated_input = [&] {
|
||||
const index_t i_b = i[0];
|
||||
|
||||
if(interleaved)
|
||||
{
|
||||
const bool is_even = (i_d % 2 == 0);
|
||||
const index_t pos = i_d + (is_even ? 1 : -1);
|
||||
const ComputeDataType sign = (is_even ? -1 : 1);
|
||||
return sign * type_convert<ComputeDataType>(input_bsd(i_b, i_s, pos));
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t half_rdim = (rotary_dim / 2);
|
||||
const index_t pos = (i_d + half_rdim) % rotary_dim;
|
||||
const ComputeDataType sign = (pos < half_rdim ? 1 : -1);
|
||||
return sign * type_convert<ComputeDataType>(input_bsd(i_b, i_s, pos));
|
||||
}
|
||||
}();
|
||||
ComputeDataType result =
|
||||
type_convert<ComputeDataType>(input_bsd(i)) * cos + half_rotated_input * sin;
|
||||
|
||||
self(i) = type_convert<DataType>(result);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
71
include/ck_tile/host/reference/reference_batched_softmax.hpp
Normal file
71
include/ck_tile/host/reference/reference_batched_softmax.hpp
Normal file
@@ -0,0 +1,71 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include <thread>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType,
|
||||
typename CompDataType,
|
||||
typename BDataType,
|
||||
typename CompElementOp = ck_tile::identity>
|
||||
CK_TILE_HOST void reference_batched_softmax(
|
||||
const HostTensor<ADataType>& a_b_m_n,
|
||||
HostTensor<BDataType>& b_b_m_n,
|
||||
const CompElementOp& comp_element_op = {},
|
||||
std::optional<std::reference_wrapper<HostTensor<CompDataType>>> lse_b_m = std::nullopt)
|
||||
{
|
||||
const int N = a_b_m_n.mDesc.get_lengths()[2];
|
||||
|
||||
auto f = [&](auto batch, auto m) {
|
||||
CompDataType v_max = -ck_tile::numeric<CompDataType>::infinity();
|
||||
|
||||
// max
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
const CompDataType v_a = ck_tile::type_convert<CompDataType>(a_b_m_n(batch, m, n));
|
||||
|
||||
v_max = v_max < v_a ? v_a : v_max;
|
||||
}
|
||||
|
||||
CompDataType v_exp_sum = 0;
|
||||
// validate v_max if all the elements within a row are -INF
|
||||
if(std::isinf(v_max) && v_max < 0)
|
||||
{
|
||||
v_max = ck_tile::type_convert<CompDataType>(0.f);
|
||||
}
|
||||
|
||||
// sum
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
const CompDataType v_a = ck_tile::type_convert<CompDataType>(a_b_m_n(batch, m, n));
|
||||
|
||||
v_exp_sum += ck_tile::exp(v_a - v_max);
|
||||
}
|
||||
|
||||
// if sum is zero(masked), or nan/inf(other computation error), don't do divide
|
||||
CompDataType inv_sum = (v_exp_sum == 0.f ? 1.f : 1.f / v_exp_sum);
|
||||
|
||||
// elementwise
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
const CompDataType v_a = ck_tile::type_convert<CompDataType>(a_b_m_n(batch, m, n));
|
||||
const CompDataType v_b = ck_tile::exp(v_a - v_max) * inv_sum;
|
||||
|
||||
b_b_m_n(batch, m, n) = ck_tile::type_convert<BDataType>(comp_element_op(v_b));
|
||||
}
|
||||
// lse
|
||||
if(lse_b_m)
|
||||
{
|
||||
lse_b_m->get()(batch, m) = v_max + ck_tile::log(v_exp_sum);
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f, b_b_m_n.mDesc.get_lengths()[0], b_b_m_n.mDesc.get_lengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,59 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include <thread>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Type>
|
||||
CK_TILE_HOST void reference_batched_transpose(const HostTensor<Type>& x,
|
||||
HostTensor<Type>& y,
|
||||
std::string layout_in = "NCHW",
|
||||
std::string layout_out = "NHWC")
|
||||
{
|
||||
const int N = x.mDesc.get_lengths()[0];
|
||||
|
||||
auto f = [&](auto batch) {
|
||||
if(layout_in == "NCHW" && layout_out == "NHWC")
|
||||
{
|
||||
const int C = x.mDesc.get_lengths()[1];
|
||||
const int H = x.mDesc.get_lengths()[2];
|
||||
const int W = x.mDesc.get_lengths()[3];
|
||||
for(int c = 0; c < C; ++c)
|
||||
{
|
||||
for(int h = 0; h < H; ++h)
|
||||
{
|
||||
for(int w = 0; w < W; ++w)
|
||||
{
|
||||
Type v_x = x(batch, c, h, w);
|
||||
y(batch, h, w, c) = v_x;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else if(layout_in == "NHWC" && layout_out == "NCHW")
|
||||
{
|
||||
const int H = x.mDesc.get_lengths()[1];
|
||||
const int W = x.mDesc.get_lengths()[2];
|
||||
const int C = x.mDesc.get_lengths()[3];
|
||||
for(int h = 0; h < H; ++h)
|
||||
{
|
||||
for(int w = 0; w < W; ++w)
|
||||
{
|
||||
for(int c = 0; c < C; ++c)
|
||||
{
|
||||
Type v_x = x(batch, h, w, c);
|
||||
y(batch, c, h, w) = v_x;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f, N)(std::thread::hardware_concurrency());
|
||||
}
|
||||
} // namespace ck_tile
|
||||
156
include/ck_tile/host/reference/reference_blocked_attention.hpp
Normal file
156
include/ck_tile/host/reference/reference_blocked_attention.hpp
Normal file
@@ -0,0 +1,156 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename AccT, typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr AccT to_acc(T value)
|
||||
{
|
||||
if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
|
||||
{
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
return static_cast<AccT>(value);
|
||||
#else
|
||||
return static_cast<AccT>(
|
||||
ck_tile::bf16_to_float_raw(ck_tile::bit_cast<ck_tile::bf16_raw_t>(value)));
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
return static_cast<AccT>(value);
|
||||
}
|
||||
}
|
||||
|
||||
// Reference implementation: blocked attention (for sparse attention tests).
|
||||
template <typename T, typename MaskT, typename AccT = float>
|
||||
void reference_blocked_attention(
|
||||
const HostTensor<T>& q, // [B, H, S_q, D]
|
||||
const HostTensor<T>& k, // [B, H, S_k, D]
|
||||
const HostTensor<T>& v, // [B, H, S_k, D_v]
|
||||
const HostTensor<MaskT>& block_relation, // [B, H, Q_blocks, K_blocks]
|
||||
HostTensor<T>& output, // [B, H, S_q, D_v]
|
||||
index_t BLKQ,
|
||||
index_t BLKK,
|
||||
AccT scale)
|
||||
{
|
||||
auto q_lengths = q.get_lengths();
|
||||
index_t batch = q_lengths[0];
|
||||
index_t nhead = q_lengths[1];
|
||||
index_t seqlen_q = q_lengths[2];
|
||||
index_t hdim = q_lengths[3];
|
||||
|
||||
auto v_lengths = v.get_lengths();
|
||||
index_t seqlen_k = v_lengths[2];
|
||||
index_t hdim_v = v_lengths[3];
|
||||
|
||||
index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ;
|
||||
index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK;
|
||||
|
||||
for(index_t b = 0; b < batch; ++b)
|
||||
{
|
||||
for(index_t h = 0; h < nhead; ++h)
|
||||
{
|
||||
for(index_t qb = 0; qb < num_q_blocks; ++qb)
|
||||
{
|
||||
index_t q_start = qb * BLKQ;
|
||||
if(q_start >= seqlen_q)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
index_t q_end = std::min<index_t>(q_start + BLKQ, seqlen_q);
|
||||
|
||||
std::vector<index_t> relevant_k_indices;
|
||||
for(index_t kb = 0; kb < num_k_blocks; ++kb)
|
||||
{
|
||||
// Treat block_relation as boolean; >0.5 marks an active block.
|
||||
if(static_cast<float>(block_relation(b, h, qb, kb)) > 0.5f)
|
||||
{
|
||||
relevant_k_indices.push_back(kb);
|
||||
}
|
||||
}
|
||||
|
||||
if(relevant_k_indices.empty())
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
for(index_t sq = q_start; sq < q_end; ++sq)
|
||||
{
|
||||
std::vector<AccT> scores;
|
||||
AccT max_score = -std::numeric_limits<AccT>::infinity();
|
||||
|
||||
for(auto kb : relevant_k_indices)
|
||||
{
|
||||
index_t k_start = kb * BLKK;
|
||||
if(k_start >= seqlen_k)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
index_t k_end = std::min<index_t>(k_start + BLKK, seqlen_k);
|
||||
|
||||
for(index_t sk = k_start; sk < k_end; ++sk)
|
||||
{
|
||||
AccT score = 0.0f;
|
||||
for(index_t d = 0; d < hdim; ++d)
|
||||
{
|
||||
score +=
|
||||
to_acc<AccT>(q(b, h, sq, d)) * to_acc<AccT>(k(b, h, sk, d));
|
||||
}
|
||||
score = score * scale;
|
||||
scores.push_back(score);
|
||||
max_score = std::max(max_score, score);
|
||||
}
|
||||
}
|
||||
|
||||
AccT sum_exp = 0.0f;
|
||||
for(auto& s : scores)
|
||||
{
|
||||
s = std::exp(s - max_score);
|
||||
sum_exp += s;
|
||||
}
|
||||
for(auto& s : scores)
|
||||
{
|
||||
s /= sum_exp;
|
||||
}
|
||||
|
||||
for(index_t dv = 0; dv < hdim_v; ++dv)
|
||||
{
|
||||
AccT out_val = 0.0f;
|
||||
size_t score_idx = 0;
|
||||
|
||||
for(auto kb : relevant_k_indices)
|
||||
{
|
||||
index_t k_start = kb * BLKK;
|
||||
if(k_start >= seqlen_k)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
index_t k_end = std::min<index_t>(k_start + BLKK, seqlen_k);
|
||||
|
||||
for(index_t sk = k_start; sk < k_end; ++sk)
|
||||
{
|
||||
out_val += scores[score_idx] * to_acc<AccT>(v(b, h, sk, dv));
|
||||
score_idx++;
|
||||
}
|
||||
}
|
||||
|
||||
output(b, h, sq, dv) = static_cast<T>(out_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
47
include/ck_tile/host/reference/reference_elementwise.hpp
Normal file
47
include/ck_tile/host/reference/reference_elementwise.hpp
Normal file
@@ -0,0 +1,47 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include <thread>
|
||||
|
||||
namespace ck_tile {
|
||||
template <typename ADataType, typename BDataType, typename ComputeDataType, typename ElementOp>
|
||||
CK_TILE_HOST void reference_unary_elementwise(const HostTensor<ADataType>& a,
|
||||
HostTensor<BDataType>& b,
|
||||
ElementOp element_op)
|
||||
{
|
||||
// TODO: imeplement gpu version reference function
|
||||
auto f = [&](auto i) {
|
||||
auto v_a = type_convert<ComputeDataType>(a.mData[i]);
|
||||
auto v_b = element_op(v_a);
|
||||
b.mData[i] = ck_tile::type_convert<BDataType>(v_b);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f, b.get_element_space_size())(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename ComputeDataType,
|
||||
typename ElementOp>
|
||||
CK_TILE_HOST void reference_binary_elementwise(const HostTensor<ADataType>& a,
|
||||
const HostTensor<BDataType>& b,
|
||||
HostTensor<CDataType>& c,
|
||||
ElementOp element_op)
|
||||
{
|
||||
// TODO: imeplement gpu version reference function
|
||||
auto f = [&](auto i) {
|
||||
auto v_a = type_convert<ComputeDataType>(a.mData[i]);
|
||||
auto v_b = type_convert<ComputeDataType>(b.mData[i]);
|
||||
auto v_c = element_op(v_a, v_b);
|
||||
c.mData[i] = ck_tile::type_convert<CDataType>(v_c);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f, c.get_element_space_size())(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
205
include/ck_tile/host/reference/reference_fused_moe.hpp
Normal file
205
include/ck_tile/host/reference/reference_fused_moe.hpp
Normal file
@@ -0,0 +1,205 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
// [indexing implementation-1]
|
||||
// using M_a as constexpr block_size to partition all tokens into different slices
|
||||
// each slice map to one expert, and one expert can have multiple slices
|
||||
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
|
||||
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
|
||||
// tok-0 tok-1 tok-2 tok-3 tok-4
|
||||
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float
|
||||
// number)
|
||||
//
|
||||
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
|
||||
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
|
||||
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
|
||||
//
|
||||
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
|
||||
// max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated)
|
||||
// * this could be larger than actual, since actual tokens are on GPU
|
||||
//
|
||||
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6,
|
||||
// 0, 1, 2, 5]
|
||||
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4
|
||||
// -|- exp-5 -|
|
||||
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *,
|
||||
// c, f, i, o]
|
||||
//
|
||||
// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
|
||||
//
|
||||
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
|
||||
// * length is (max_num_tokens_padded + block_size - 1) / block_size
|
||||
///
|
||||
// num_tokens_post_padded_ptr : [28]
|
||||
// num_sorted_tiles_ptr : [7]
|
||||
|
||||
template <typename AccDataType, // you only need to explcitly set this one
|
||||
typename Activation, // ck_tile::element_wise::Gelu
|
||||
typename ADataType,
|
||||
typename GDataType,
|
||||
typename DDataType,
|
||||
typename ODataType,
|
||||
typename AScaleDataType,
|
||||
typename GScaleDataType,
|
||||
typename DScaleDataType,
|
||||
typename YSmoothScaleDataType,
|
||||
typename TopkWeightDataType,
|
||||
typename IndexDataType>
|
||||
void reference_fused_moe(
|
||||
const ck_tile::HostTensor<ADataType>& a_host, // [tokens, hidden_size]
|
||||
const ck_tile::HostTensor<GDataType>& g_host, // [experts, interme_size_0, hidden_size]
|
||||
const ck_tile::HostTensor<DDataType>& d_host, // [experts, hidden_size, interme_size_1]
|
||||
const ck_tile::HostTensor<AScaleDataType>& sa_host, // [tokens, 1],
|
||||
const ck_tile::HostTensor<GScaleDataType>& sg_host, // [experts, 1, interme_size_0]
|
||||
const ck_tile::HostTensor<DScaleDataType>& sd_host, // [experts, 1, hidden_size],
|
||||
const ck_tile::HostTensor<YSmoothScaleDataType>& sy_host, // [experts, 1, interme_size_0]
|
||||
ck_tile::HostTensor<ODataType>& o_host, // [tokens, hidden_size]
|
||||
const ck_tile::HostTensor<IndexDataType>& sorted_token_ids_host, // [max_num_tokens_padded]
|
||||
const ck_tile::HostTensor<TopkWeightDataType>& sorted_weight_host, // [max_num_tokens_padded]
|
||||
const ck_tile::HostTensor<IndexDataType>&
|
||||
sorted_expert_ids_host, // [(max_num_tokens_padded + block_size - 1) / block_size]
|
||||
const ck_tile::HostTensor<IndexDataType>& num_sorted_tiles_host, // [1]
|
||||
|
||||
const ck_tile::HostTensor<IndexDataType>&
|
||||
token_ids_host, // [tokens, topk] --> ugly!!! remove in the future
|
||||
|
||||
ck_tile::index_t block_m,
|
||||
ck_tile::index_t tokens,
|
||||
ck_tile::index_t experts,
|
||||
ck_tile::index_t hidden_size,
|
||||
ck_tile::index_t intermediate_size, // this size is for gate/up/down
|
||||
ck_tile::index_t topk,
|
||||
ck_tile::index_t gate_only)
|
||||
{
|
||||
assert(sorted_token_ids_host.get_num_of_dimension() == 1);
|
||||
assert(sorted_weight_host.get_num_of_dimension() == 1);
|
||||
assert(sorted_expert_ids_host.get_num_of_dimension() == 1);
|
||||
assert(num_sorted_tiles_host.get_element_size() == 1);
|
||||
ck_tile::index_t num_sorted_tiles = num_sorted_tiles_host.mData[0] / block_m;
|
||||
ck_tile::index_t intermediate_size_0 = intermediate_size * (gate_only ? 1 : 2);
|
||||
ck_tile::index_t intermediate_size_1 = intermediate_size;
|
||||
|
||||
ck_tile::HostTensor<AccDataType> out_topk_tokens({tokens, topk, hidden_size});
|
||||
|
||||
int max_num_tokens_padded = topk * tokens + experts * block_m - topk;
|
||||
// assert();
|
||||
auto f = [&](auto i_flatten) {
|
||||
ck_tile::index_t i_tile = i_flatten / block_m;
|
||||
if(i_tile >= num_sorted_tiles)
|
||||
return;
|
||||
ck_tile::index_t i_expert = sorted_expert_ids_host.mData[i_tile];
|
||||
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
ck_tile::index_t i_token = sorted_token_ids_host.mData[i_flatten];
|
||||
ck_tile::index_t i_topk = i_token >> 24;
|
||||
i_token &= 0xffffff;
|
||||
if(i_token >= tokens)
|
||||
return;
|
||||
(void)token_ids_host;
|
||||
#else
|
||||
// TODO: better remove this in the future, or modify the token_id value
|
||||
auto get_topk_id = [&](ck_tile::index_t token_id_, ck_tile::index_t expert_id_) {
|
||||
for(ck_tile::index_t i_ = 0; i_ < topk; i_++)
|
||||
{
|
||||
if(token_ids_host(token_id_, i_) == expert_id_)
|
||||
return i_;
|
||||
}
|
||||
throw std::runtime_error("not correct token/expert pair\n");
|
||||
return -1; // TODO: not correct!!
|
||||
};
|
||||
ck_tile::index_t i_token = sorted_token_ids_host.mData[i_flatten];
|
||||
if(i_token >= tokens)
|
||||
return;
|
||||
ck_tile::index_t i_topk = get_topk_id(i_token, i_expert); // TODO: ugly
|
||||
#endif
|
||||
auto weight = sorted_weight_host.mData[i_flatten];
|
||||
|
||||
ck_tile::HostTensor<AccDataType> acc_0({1, intermediate_size_0});
|
||||
// first gemm
|
||||
for(ck_tile::index_t i_n = 0; i_n < intermediate_size_0; i_n++)
|
||||
{
|
||||
AccDataType acc = static_cast<AccDataType>(0);
|
||||
for(ck_tile::index_t i_k = 0; i_k < hidden_size; i_k++)
|
||||
{
|
||||
acc += type_convert<AccDataType>(a_host(i_token, i_k)) *
|
||||
type_convert<AccDataType>(g_host(i_expert, i_n, i_k));
|
||||
}
|
||||
acc_0(0, i_n) = acc;
|
||||
// printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, acc);
|
||||
}
|
||||
|
||||
ck_tile::HostTensor<AccDataType> y({1, intermediate_size_1});
|
||||
if(gate_only)
|
||||
{
|
||||
if(intermediate_size_1 != intermediate_size_0)
|
||||
throw std::runtime_error(
|
||||
"intermediate_size not correct, 0:" + std::to_string(intermediate_size_0) +
|
||||
", 1:" + std::to_string(intermediate_size_1));
|
||||
for(ck_tile::index_t i_n = 0; i_n < intermediate_size_1; i_n++)
|
||||
{
|
||||
Activation{}(y(0, i_n), acc_0(0, i_n));
|
||||
// printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, y(0, i_n));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(intermediate_size_1 * 2 != intermediate_size_0)
|
||||
throw std::runtime_error(
|
||||
"intermediate_size not correct, 0:" + std::to_string(intermediate_size_0) +
|
||||
", 1:" + std::to_string(intermediate_size_1));
|
||||
for(ck_tile::index_t i_n = 0; i_n < intermediate_size_1; i_n++)
|
||||
{
|
||||
AccDataType tmp;
|
||||
Activation{}(tmp, acc_0(0, i_n));
|
||||
y(0, i_n) = tmp * acc_0(0, i_n + intermediate_size_1); // TODO: elementwise mul
|
||||
}
|
||||
}
|
||||
|
||||
// second gemm, loop along gemm-n
|
||||
ck_tile::HostTensor<AccDataType> acc_1({1, hidden_size});
|
||||
for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++)
|
||||
{
|
||||
AccDataType acc = static_cast<AccDataType>(0);
|
||||
for(ck_tile::index_t i_k = 0; i_k < intermediate_size_1; i_k++)
|
||||
{
|
||||
acc += y(0, i_k) * type_convert<AccDataType>(d_host(i_expert, i_n, i_k));
|
||||
}
|
||||
acc_1(0, i_n) = acc * weight; // multiple weight here
|
||||
}
|
||||
|
||||
for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++)
|
||||
{
|
||||
out_topk_tokens(i_token, i_topk, i_n) = acc_1(0, i_n);
|
||||
}
|
||||
};
|
||||
|
||||
// make_ParallelTensorFunctor(f, max_num_tokens_padded)(std::thread::hardware_concurrency());
|
||||
make_ParallelTensorFunctor(f, max_num_tokens_padded)(1);
|
||||
|
||||
// reduce
|
||||
auto r = [&](auto i_token) {
|
||||
for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++)
|
||||
{
|
||||
AccDataType acc = type_convert<AccDataType>(0);
|
||||
for(ck_tile::index_t i_topk = 0; i_topk < topk; i_topk++)
|
||||
{
|
||||
acc += out_topk_tokens(i_token, i_topk, i_n);
|
||||
}
|
||||
o_host(i_token, i_n) = type_convert<ODataType>(acc);
|
||||
}
|
||||
};
|
||||
make_ParallelTensorFunctor(r, tokens)(std::thread::hardware_concurrency());
|
||||
|
||||
(void)num_sorted_tiles_host;
|
||||
(void)sa_host;
|
||||
(void)sg_host;
|
||||
(void)sd_host;
|
||||
(void)sy_host;
|
||||
}
|
||||
} // namespace ck_tile
|
||||
1081
include/ck_tile/host/reference/reference_gemm.hpp
Normal file
1081
include/ck_tile/host/reference/reference_gemm.hpp
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,228 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cinttypes>
|
||||
#include <cstdlib>
|
||||
#include <thread>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType>
|
||||
CK_TILE_HOST void reference_grouped_conv_bwd_data(HostTensor<InDataType>& input,
|
||||
const HostTensor<WeiDataType>& weight,
|
||||
const HostTensor<OutDataType>& output,
|
||||
std::vector<ck_tile::long_index_t> conv_strides,
|
||||
std::vector<ck_tile::long_index_t> conv_dilations,
|
||||
std::vector<ck_tile::long_index_t> in_left_pads,
|
||||
std::vector<ck_tile::long_index_t>)
|
||||
{
|
||||
if(!(input.get_num_of_dimension() == NDimSpatial + 3 &&
|
||||
weight.get_num_of_dimension() == NDimSpatial + 3 &&
|
||||
output.get_num_of_dimension() == NDimSpatial + 3))
|
||||
{
|
||||
|
||||
printf("%" PRIu64 " %" PRIu64 " %" PRIu64,
|
||||
input.get_num_of_dimension(),
|
||||
weight.get_num_of_dimension(),
|
||||
output.get_num_of_dimension());
|
||||
|
||||
throw std::runtime_error("wrong! inconsistent dimension");
|
||||
}
|
||||
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
auto func = [&](auto g, auto n, auto c, auto wi) {
|
||||
std::size_t K = weight.get_lengths()[1];
|
||||
std::size_t X = weight.get_lengths()[3];
|
||||
|
||||
std::size_t Wo = output.get_lengths()[3];
|
||||
float v_acc = 0;
|
||||
|
||||
for(std::size_t x = 0; x < X; ++x)
|
||||
{
|
||||
auto w_tmp = static_cast<ck_tile::long_index_t>(wi) +
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[0]) -
|
||||
static_cast<ck_tile::long_index_t>(x * conv_dilations[0]);
|
||||
|
||||
if(w_tmp % conv_strides[0] == 0)
|
||||
{
|
||||
auto wo = static_cast<ck_tile::long_index_t>(w_tmp) /
|
||||
static_cast<ck_tile::long_index_t>(conv_strides[0]);
|
||||
|
||||
if(wo >= 0 && ck_tile::type_convert<std::size_t>(wo) < Wo)
|
||||
{
|
||||
for(std::size_t k = 0; k < K; ++k)
|
||||
{
|
||||
OutDataType v_out = output(g, n, k, wo);
|
||||
WeiDataType v_wei = weight(g, k, c, x);
|
||||
v_acc += ck_tile::type_convert<float>(v_out) *
|
||||
ck_tile::type_convert<float>(v_wei);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
InDataType v_acc_converted = ck_tile::type_convert<InDataType>(v_acc);
|
||||
input(g, n, c, wi) = v_acc_converted;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func,
|
||||
input.get_lengths()[0],
|
||||
input.get_lengths()[1],
|
||||
input.get_lengths()[2],
|
||||
input.get_lengths()[3])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
auto func = [&](auto g, auto n, auto c, auto hi, auto wi) {
|
||||
std::size_t K = weight.get_lengths()[1];
|
||||
std::size_t Y = weight.get_lengths()[3];
|
||||
std::size_t X = weight.get_lengths()[4];
|
||||
|
||||
std::size_t Ho = output.get_lengths()[3];
|
||||
std::size_t Wo = output.get_lengths()[4];
|
||||
|
||||
float v_acc = 0;
|
||||
|
||||
for(std::size_t y = 0; y < Y; ++y)
|
||||
{
|
||||
auto h_tmp = static_cast<ck_tile::long_index_t>(hi) +
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[0]) -
|
||||
static_cast<ck_tile::long_index_t>(y * conv_dilations[0]);
|
||||
if(h_tmp % conv_strides[0] == 0)
|
||||
{
|
||||
auto ho = static_cast<ck_tile::long_index_t>(h_tmp) /
|
||||
static_cast<ck_tile::long_index_t>(conv_strides[0]);
|
||||
if(ho >= 0 && ck_tile::type_convert<std::size_t>(ho) < Ho)
|
||||
{
|
||||
for(std::size_t x = 0; x < X; ++x)
|
||||
{
|
||||
auto w_tmp = static_cast<ck_tile::long_index_t>(wi) +
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[1]) -
|
||||
static_cast<ck_tile::long_index_t>(x * conv_dilations[1]);
|
||||
if(w_tmp % conv_strides[1] == 0)
|
||||
{
|
||||
auto wo = static_cast<ck_tile::long_index_t>(w_tmp) /
|
||||
static_cast<ck_tile::long_index_t>(conv_strides[1]);
|
||||
|
||||
if(wo >= 0 && ck_tile::type_convert<std::size_t>(wo) < Wo)
|
||||
{
|
||||
for(std::size_t k = 0; k < K; ++k)
|
||||
{
|
||||
OutDataType v_out = output(g, n, k, ho, wo);
|
||||
WeiDataType v_wei = weight(g, k, c, y, x);
|
||||
v_acc += ck_tile::type_convert<float>(v_out) *
|
||||
ck_tile::type_convert<float>(v_wei);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
InDataType v_acc_converted = ck_tile::type_convert<InDataType>(v_acc);
|
||||
input(g, n, c, hi, wi) = v_acc_converted;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func,
|
||||
input.get_lengths()[0],
|
||||
input.get_lengths()[1],
|
||||
input.get_lengths()[2],
|
||||
input.get_lengths()[3],
|
||||
input.get_lengths()[4])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
auto func = [&](auto g, auto n, auto c, auto di, auto hi, auto wi) {
|
||||
std::size_t K = weight.get_lengths()[1];
|
||||
std::size_t Z = weight.get_lengths()[3];
|
||||
std::size_t Y = weight.get_lengths()[4];
|
||||
std::size_t X = weight.get_lengths()[5];
|
||||
|
||||
std::size_t Do = output.get_lengths()[3];
|
||||
std::size_t Ho = output.get_lengths()[4];
|
||||
std::size_t Wo = output.get_lengths()[5];
|
||||
|
||||
float v_acc = 0;
|
||||
|
||||
for(std::size_t z = 0; z < Z; ++z)
|
||||
{
|
||||
auto d_tmp = static_cast<ck_tile::long_index_t>(di) +
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[0]) -
|
||||
static_cast<ck_tile::long_index_t>(z * conv_dilations[0]);
|
||||
if(d_tmp % conv_strides[0] == 0)
|
||||
{
|
||||
auto do_ = static_cast<ck_tile::long_index_t>(d_tmp) /
|
||||
static_cast<ck_tile::long_index_t>(conv_strides[0]);
|
||||
if(do_ >= 0 && ck_tile::type_convert<std::size_t>(do_) < Do)
|
||||
{
|
||||
for(std::size_t y = 0; y < Y; ++y)
|
||||
{
|
||||
auto h_tmp = static_cast<ck_tile::long_index_t>(hi) +
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[1]) -
|
||||
static_cast<ck_tile::long_index_t>(y * conv_dilations[1]);
|
||||
if(h_tmp % conv_strides[1] == 0)
|
||||
{
|
||||
auto ho = static_cast<ck_tile::long_index_t>(h_tmp) /
|
||||
static_cast<ck_tile::long_index_t>(conv_strides[1]);
|
||||
if(ho >= 0 && ck_tile::type_convert<std::size_t>(ho) < Ho)
|
||||
{
|
||||
for(std::size_t x = 0; x < X; ++x)
|
||||
{
|
||||
auto w_tmp =
|
||||
static_cast<ck_tile::long_index_t>(wi) +
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[2]) -
|
||||
static_cast<ck_tile::long_index_t>(x *
|
||||
conv_dilations[2]);
|
||||
|
||||
if(w_tmp % conv_strides[2] == 0)
|
||||
{
|
||||
auto wo =
|
||||
static_cast<ck_tile::long_index_t>(w_tmp) /
|
||||
static_cast<ck_tile::long_index_t>(conv_strides[2]);
|
||||
if(wo >= 0 &&
|
||||
ck_tile::type_convert<std::size_t>(wo) < Wo)
|
||||
{
|
||||
for(std::size_t k = 0; k < K; ++k)
|
||||
{
|
||||
OutDataType v_out =
|
||||
output(g, n, k, do_, ho, wo);
|
||||
WeiDataType v_wei = weight(g, k, c, z, y, x);
|
||||
v_acc += ck_tile::type_convert<float>(v_out) *
|
||||
ck_tile::type_convert<float>(v_wei);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
InDataType v_acc_converted = ck_tile::type_convert<InDataType>(v_acc);
|
||||
input(g, n, c, di, hi, wi) = v_acc_converted;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func,
|
||||
input.get_lengths()[0],
|
||||
input.get_lengths()[1],
|
||||
input.get_lengths()[2],
|
||||
input.get_lengths()[3],
|
||||
input.get_lengths()[4],
|
||||
input.get_lengths()[5])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Ref_conv_bwd_data: number of dimensions must be between 1 and 3.");
|
||||
}
|
||||
}
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,167 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <thread>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType>
|
||||
CK_TILE_HOST void
|
||||
reference_grouped_conv_bwd_weight(const HostTensor<InDataType>& input,
|
||||
HostTensor<WeiDataType>& weight,
|
||||
const HostTensor<OutDataType>& output,
|
||||
std::vector<ck_tile::long_index_t> conv_strides,
|
||||
std::vector<ck_tile::long_index_t> conv_dilations,
|
||||
std::vector<ck_tile::long_index_t> in_left_pads,
|
||||
std::vector<ck_tile::long_index_t>)
|
||||
{
|
||||
if(!(input.get_num_of_dimension() == NDimSpatial + 3 &&
|
||||
weight.get_num_of_dimension() == NDimSpatial + 3 &&
|
||||
output.get_num_of_dimension() == NDimSpatial + 3))
|
||||
{
|
||||
throw std::runtime_error("wrong! inconsistent dimension");
|
||||
}
|
||||
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
auto func = [&](auto g, auto k, auto c, auto x) {
|
||||
float v_acc = 0;
|
||||
|
||||
for(std::size_t n = 0; n < output.get_lengths()[1]; ++n)
|
||||
{
|
||||
for(std::size_t wo = 0; wo < output.get_lengths()[3]; ++wo)
|
||||
{
|
||||
auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[0]) +
|
||||
static_cast<ck_tile::long_index_t>(x * conv_dilations[0]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
|
||||
|
||||
if(wi >= 0 && ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[3])
|
||||
{
|
||||
InDataType v_in = input(g, n, c, wi);
|
||||
OutDataType v_out = output(g, n, k, wo);
|
||||
v_acc += ck_tile::type_convert<float>(v_out) *
|
||||
ck_tile::type_convert<float>(v_in);
|
||||
}
|
||||
}
|
||||
}
|
||||
OutDataType v_acc_converted = ck_tile::type_convert<WeiDataType>(v_acc);
|
||||
weight(g, k, c, x) = v_acc_converted;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func,
|
||||
weight.get_lengths()[0],
|
||||
weight.get_lengths()[1],
|
||||
weight.get_lengths()[2],
|
||||
weight.get_lengths()[3])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
auto func = [&](auto g, auto k, auto c, auto y, auto x) {
|
||||
float v_acc = 0;
|
||||
|
||||
for(std::size_t n = 0; n < output.get_lengths()[1]; ++n)
|
||||
{
|
||||
for(std::size_t ho = 0; ho < output.get_lengths()[3]; ++ho)
|
||||
{
|
||||
auto hi = static_cast<ck_tile::long_index_t>(ho * conv_strides[0]) +
|
||||
static_cast<ck_tile::long_index_t>(y * conv_dilations[0]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
|
||||
|
||||
for(std::size_t wo = 0; wo < output.get_lengths()[4]; ++wo)
|
||||
{
|
||||
auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[1]) +
|
||||
static_cast<ck_tile::long_index_t>(x * conv_dilations[1]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[1]);
|
||||
|
||||
if(hi >= 0 &&
|
||||
ck_tile::type_convert<std::size_t>(hi) < input.get_lengths()[3] &&
|
||||
wi >= 0 &&
|
||||
ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[4])
|
||||
{
|
||||
InDataType v_in = input(g, n, c, hi, wi);
|
||||
OutDataType v_out = output(g, n, k, ho, wo);
|
||||
|
||||
v_acc += ck_tile::type_convert<float>(v_out) *
|
||||
ck_tile::type_convert<float>(v_in);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
WeiDataType v_acc_converted = ck_tile::type_convert<WeiDataType>(v_acc);
|
||||
weight(g, k, c, y, x) = v_acc_converted;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func,
|
||||
weight.get_lengths()[0],
|
||||
weight.get_lengths()[1],
|
||||
weight.get_lengths()[2],
|
||||
weight.get_lengths()[3],
|
||||
weight.get_lengths()[4])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
auto func = [&](auto g, auto k, auto c, auto z, auto y, auto x) {
|
||||
float v_acc = 0;
|
||||
|
||||
for(std::size_t n = 0; n < output.get_lengths()[1]; ++n)
|
||||
{
|
||||
for(std::size_t do_ = 0; do_ < output.get_lengths()[3]; ++do_)
|
||||
{
|
||||
auto di = static_cast<ck_tile::long_index_t>(do_ * conv_strides[0]) +
|
||||
static_cast<ck_tile::long_index_t>(z * conv_dilations[0]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
|
||||
for(std::size_t ho = 0; ho < output.get_lengths()[4]; ++ho)
|
||||
{
|
||||
auto hi = static_cast<ck_tile::long_index_t>(ho * conv_strides[1]) +
|
||||
static_cast<ck_tile::long_index_t>(y * conv_dilations[1]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[1]);
|
||||
for(std::size_t wo = 0; wo < output.get_lengths()[5]; ++wo)
|
||||
{
|
||||
auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[2]) +
|
||||
static_cast<ck_tile::long_index_t>(x * conv_dilations[2]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[2]);
|
||||
if(di >= 0 &&
|
||||
ck_tile::type_convert<std::size_t>(di) < input.get_lengths()[3] &&
|
||||
hi >= 0 &&
|
||||
ck_tile::type_convert<std::size_t>(hi) < input.get_lengths()[4] &&
|
||||
wi >= 0 &&
|
||||
ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[5])
|
||||
{
|
||||
InDataType v_in = input(g, n, c, di, hi, wi);
|
||||
OutDataType v_out = output(g, n, k, do_, ho, wo);
|
||||
|
||||
v_acc += ck_tile::type_convert<float>(v_out) *
|
||||
ck_tile::type_convert<float>(v_in);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
WeiDataType v_acc_converted = ck_tile::type_convert<WeiDataType>(v_acc);
|
||||
weight(g, k, c, z, y, x) = v_acc_converted;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func,
|
||||
weight.get_lengths()[0],
|
||||
weight.get_lengths()[1],
|
||||
weight.get_lengths()[2],
|
||||
weight.get_lengths()[3],
|
||||
weight.get_lengths()[4],
|
||||
weight.get_lengths()[5])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Ref_conv_bwd_weight: number of dimensions must be between 1 and 3.");
|
||||
}
|
||||
}
|
||||
} // namespace ck_tile
|
||||
182
include/ck_tile/host/reference/reference_grouped_conv_fwd.hpp
Normal file
182
include/ck_tile/host/reference/reference_grouped_conv_fwd.hpp
Normal file
@@ -0,0 +1,182 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <thread>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename Elfunc = ck_tile::element_wise::PassThrough,
|
||||
typename Tuple = ck_tile::tuple<>>
|
||||
CK_TILE_HOST void reference_grouped_conv_fwd(const HostTensor<InDataType>& input,
|
||||
const HostTensor<WeiDataType>& weight,
|
||||
HostTensor<OutDataType>& output,
|
||||
std::vector<ck_tile::long_index_t> conv_strides,
|
||||
std::vector<ck_tile::long_index_t> conv_dilations,
|
||||
std::vector<ck_tile::long_index_t> in_left_pads,
|
||||
std::vector<ck_tile::long_index_t>,
|
||||
Elfunc elfunc = Elfunc{},
|
||||
Tuple ds = {})
|
||||
{
|
||||
if(!(input.get_num_of_dimension() == NDimSpatial + 3 &&
|
||||
weight.get_num_of_dimension() == NDimSpatial + 3 &&
|
||||
output.get_num_of_dimension() == NDimSpatial + 3))
|
||||
{
|
||||
throw std::runtime_error("wrong! inconsistent dimension");
|
||||
}
|
||||
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
auto func = [&](auto g, auto n, auto k, auto wo) {
|
||||
float v_acc = 0;
|
||||
|
||||
for(std::size_t c = 0; c < weight.get_lengths()[2]; ++c)
|
||||
{
|
||||
for(std::size_t x = 0; x < weight.get_lengths()[3]; ++x)
|
||||
{
|
||||
auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[0]) +
|
||||
static_cast<ck_tile::long_index_t>(x * conv_dilations[0]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
|
||||
|
||||
if(wi >= 0 && ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[3])
|
||||
{
|
||||
InDataType v_in = input(g, n, c, wi);
|
||||
WeiDataType v_wei = weight(g, k, c, x);
|
||||
v_acc += ck_tile::type_convert<float>(v_in) *
|
||||
ck_tile::type_convert<float>(v_wei);
|
||||
}
|
||||
}
|
||||
}
|
||||
if constexpr(Tuple::size() > 0)
|
||||
elfunc(v_acc, v_acc, ds.at(ck_tile::number<0>{})(g, n, k, wo));
|
||||
else
|
||||
elfunc(v_acc, v_acc);
|
||||
OutDataType v_acc_out = ck_tile::type_convert<OutDataType>(v_acc);
|
||||
output(g, n, k, wo) = v_acc_out;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func,
|
||||
output.get_lengths()[0],
|
||||
output.get_lengths()[1],
|
||||
output.get_lengths()[2],
|
||||
output.get_lengths()[3])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
auto func = [&](auto g, auto n, auto k, auto ho, auto wo) {
|
||||
float v_acc = 0;
|
||||
|
||||
for(std::size_t c = 0; c < weight.get_lengths()[2]; ++c)
|
||||
{
|
||||
for(std::size_t y = 0; y < weight.get_lengths()[3]; ++y)
|
||||
{
|
||||
auto hi = static_cast<ck_tile::long_index_t>(ho * conv_strides[0]) +
|
||||
static_cast<ck_tile::long_index_t>(y * conv_dilations[0]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
|
||||
|
||||
for(std::size_t x = 0; x < weight.get_lengths()[4]; ++x)
|
||||
{
|
||||
auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[1]) +
|
||||
static_cast<ck_tile::long_index_t>(x * conv_dilations[1]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[1]);
|
||||
|
||||
if(hi >= 0 &&
|
||||
ck_tile::type_convert<std::size_t>(hi) < input.get_lengths()[3] &&
|
||||
wi >= 0 &&
|
||||
ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[4])
|
||||
{
|
||||
InDataType v_in = input(g, n, c, hi, wi);
|
||||
WeiDataType v_wei = weight(g, k, c, y, x);
|
||||
|
||||
v_acc += ck_tile::type_convert<float>(v_in) *
|
||||
ck_tile::type_convert<float>(v_wei);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if constexpr(Tuple::size() > 0)
|
||||
elfunc(v_acc, v_acc, ds.at(ck_tile::number<0>{})(g, n, k, ho, wo));
|
||||
else
|
||||
elfunc(v_acc, v_acc);
|
||||
OutDataType v_acc_out = ck_tile::type_convert<OutDataType>(v_acc);
|
||||
output(g, n, k, ho, wo) = v_acc_out;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func,
|
||||
output.get_lengths()[0],
|
||||
output.get_lengths()[1],
|
||||
output.get_lengths()[2],
|
||||
output.get_lengths()[3],
|
||||
output.get_lengths()[4])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
auto func = [&](auto g, auto n, auto k, auto d_o, auto ho, auto wo) {
|
||||
float v_acc = 0;
|
||||
|
||||
for(std::size_t c = 0; c < weight.get_lengths()[2]; ++c)
|
||||
{
|
||||
for(std::size_t z = 0; z < weight.get_lengths()[3]; ++z)
|
||||
{
|
||||
auto di = static_cast<ck_tile::long_index_t>(d_o * conv_strides[0]) +
|
||||
static_cast<ck_tile::long_index_t>(z * conv_dilations[0]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
|
||||
for(std::size_t y = 0; y < weight.get_lengths()[4]; ++y)
|
||||
{
|
||||
auto hi = static_cast<ck_tile::long_index_t>(ho * conv_strides[1]) +
|
||||
static_cast<ck_tile::long_index_t>(y * conv_dilations[1]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[1]);
|
||||
for(std::size_t x = 0; x < weight.get_lengths()[5]; ++x)
|
||||
{
|
||||
auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[2]) +
|
||||
static_cast<ck_tile::long_index_t>(x * conv_dilations[2]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[2]);
|
||||
if(di >= 0 &&
|
||||
ck_tile::type_convert<std::size_t>(di) < input.get_lengths()[3] &&
|
||||
hi >= 0 &&
|
||||
ck_tile::type_convert<std::size_t>(hi) < input.get_lengths()[4] &&
|
||||
wi >= 0 &&
|
||||
ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[5])
|
||||
{
|
||||
InDataType v_in = input(g, n, c, di, hi, wi);
|
||||
WeiDataType v_wei = weight(g, k, c, z, y, x);
|
||||
|
||||
v_acc += ck_tile::type_convert<float>(v_in) *
|
||||
ck_tile::type_convert<float>(v_wei);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if constexpr(Tuple::size() > 0)
|
||||
elfunc(v_acc, v_acc, ds.at(ck_tile::number<0>{})(g, n, k, d_o, ho, wo));
|
||||
else
|
||||
elfunc(v_acc, v_acc);
|
||||
OutDataType v_acc_out = ck_tile::type_convert<OutDataType>(v_acc);
|
||||
output(g, n, k, d_o, ho, wo) = v_acc_out;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func,
|
||||
output.get_lengths()[0],
|
||||
output.get_lengths()[1],
|
||||
output.get_lengths()[2],
|
||||
output.get_lengths()[3],
|
||||
output.get_lengths()[4],
|
||||
output.get_lengths()[5])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Ref_Conv_fwd: number of dimensions must be between 1 and 3.");
|
||||
}
|
||||
}
|
||||
} // namespace ck_tile
|
||||
133
include/ck_tile/host/reference/reference_im2col.hpp
Normal file
133
include/ck_tile/host/reference/reference_im2col.hpp
Normal file
@@ -0,0 +1,133 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include <thread>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename InDataType, typename OutDataType, index_t NDimSpatial>
|
||||
CK_TILE_HOST void reference_im2col(const HostTensor<InDataType>& in_host,
|
||||
HostTensor<OutDataType>& out_host,
|
||||
const ck_tile::conv::ConvParam& conv_params)
|
||||
{
|
||||
const long_index_t G = in_host.get_lengths()[0];
|
||||
const long_index_t N = in_host.get_lengths()[1];
|
||||
const long_index_t C = in_host.get_lengths()[2];
|
||||
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
const long_index_t Wo = conv_params.output_spatial_lengths_[0];
|
||||
auto func = [&](auto g, auto n, auto wo) {
|
||||
long_index_t row = n * Wo + wo;
|
||||
long_index_t column = 0;
|
||||
|
||||
for(long_index_t x = 0; x < conv_params.filter_spatial_lengths_[0]; ++x)
|
||||
{
|
||||
auto wi = static_cast<long_index_t>(wo * conv_params.conv_filter_strides_[0]) +
|
||||
static_cast<long_index_t>(x * conv_params.conv_filter_dilations_[0]) -
|
||||
static_cast<long_index_t>(conv_params.input_left_pads_[0]);
|
||||
|
||||
for(long_index_t c = 0; c < C; ++c)
|
||||
{
|
||||
if(wi >= 0 && type_convert<std::size_t>(wi) < in_host.get_lengths()[3])
|
||||
{
|
||||
InDataType v_in = in_host(g, n, c, wi);
|
||||
out_host(g, row, column) = type_convert<OutDataType>(v_in);
|
||||
}
|
||||
column++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func, G, N, Wo)(std::thread::hardware_concurrency());
|
||||
}
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
const long_index_t Ho = conv_params.output_spatial_lengths_[0];
|
||||
const long_index_t Wo = conv_params.output_spatial_lengths_[1];
|
||||
|
||||
auto func = [&](auto g, auto n, auto ho, auto wo) {
|
||||
long_index_t row = n * Ho * Wo + ho * Wo + wo;
|
||||
long_index_t column = 0;
|
||||
|
||||
for(long_index_t y = 0; y < conv_params.filter_spatial_lengths_[0]; ++y)
|
||||
{
|
||||
auto hi = static_cast<long_index_t>(ho * conv_params.conv_filter_strides_[0]) +
|
||||
static_cast<long_index_t>(y * conv_params.conv_filter_dilations_[0]) -
|
||||
static_cast<long_index_t>(conv_params.input_left_pads_[0]);
|
||||
|
||||
for(long_index_t x = 0; x < conv_params.filter_spatial_lengths_[1]; ++x)
|
||||
{
|
||||
auto wi = static_cast<long_index_t>(wo * conv_params.conv_filter_strides_[1]) +
|
||||
static_cast<long_index_t>(x * conv_params.conv_filter_dilations_[1]) -
|
||||
static_cast<long_index_t>(conv_params.input_left_pads_[1]);
|
||||
|
||||
for(long_index_t c = 0; c < C; ++c)
|
||||
{
|
||||
|
||||
if(hi >= 0 && type_convert<std::size_t>(hi) < in_host.get_lengths()[3] &&
|
||||
wi >= 0 && type_convert<std::size_t>(wi) < in_host.get_lengths()[4])
|
||||
{
|
||||
InDataType v_in = in_host(g, n, c, hi, wi);
|
||||
out_host(g, row, column) = type_convert<OutDataType>(v_in);
|
||||
}
|
||||
column++;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func, G, N, Ho, Wo)(std::thread::hardware_concurrency());
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
const long_index_t Do = conv_params.output_spatial_lengths_[0];
|
||||
const long_index_t Ho = conv_params.output_spatial_lengths_[1];
|
||||
const long_index_t Wo = conv_params.output_spatial_lengths_[2];
|
||||
|
||||
auto func = [&](auto g, auto n, auto d_o, auto ho, auto wo) {
|
||||
long_index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo;
|
||||
long_index_t column = 0;
|
||||
|
||||
for(long_index_t z = 0; z < conv_params.filter_spatial_lengths_[0]; ++z)
|
||||
{
|
||||
auto di = static_cast<long_index_t>(d_o * conv_params.conv_filter_strides_[0]) +
|
||||
static_cast<long_index_t>(z * conv_params.conv_filter_dilations_[0]) -
|
||||
static_cast<long_index_t>(conv_params.input_left_pads_[0]);
|
||||
for(long_index_t y = 0; y < conv_params.filter_spatial_lengths_[1]; ++y)
|
||||
{
|
||||
auto hi = static_cast<long_index_t>(ho * conv_params.conv_filter_strides_[1]) +
|
||||
static_cast<long_index_t>(y * conv_params.conv_filter_dilations_[1]) -
|
||||
static_cast<long_index_t>(conv_params.input_left_pads_[1]);
|
||||
for(long_index_t x = 0; x < conv_params.filter_spatial_lengths_[2]; ++x)
|
||||
{
|
||||
auto wi =
|
||||
static_cast<long_index_t>(wo * conv_params.conv_filter_strides_[2]) +
|
||||
static_cast<long_index_t>(x * conv_params.conv_filter_dilations_[2]) -
|
||||
static_cast<long_index_t>(conv_params.input_left_pads_[2]);
|
||||
for(long_index_t c = 0; c < C; ++c)
|
||||
{
|
||||
if(di >= 0 &&
|
||||
type_convert<std::size_t>(di) < in_host.get_lengths()[3] &&
|
||||
hi >= 0 &&
|
||||
type_convert<std::size_t>(hi) < in_host.get_lengths()[4] &&
|
||||
wi >= 0 && type_convert<std::size_t>(wi) < in_host.get_lengths()[5])
|
||||
{
|
||||
InDataType v_in = in_host(g, n, c, di, hi, wi);
|
||||
out_host(g, row, column) = type_convert<OutDataType>(v_in);
|
||||
}
|
||||
column++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func, G, N, Do, Ho, Wo)(std::thread::hardware_concurrency());
|
||||
}
|
||||
}
|
||||
} // namespace ck_tile
|
||||
96
include/ck_tile/host/reference/reference_layernorm2d_fwd.hpp
Normal file
96
include/ck_tile/host/reference/reference_layernorm2d_fwd.hpp
Normal file
@@ -0,0 +1,96 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Note: for simplicity, each functor only care about single M
|
||||
struct reference_layernorm2d_default_epilogue
|
||||
{
|
||||
template <typename OutDataType, typename AccDataType>
|
||||
void operator()(int m, HostTensor<OutDataType>& o, const HostTensor<AccDataType>& acc)
|
||||
{
|
||||
const int N = acc.mDesc.get_lengths()[1];
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
o(m, n) = ck_tile::type_convert<OutDataType>(acc(m, n));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename OutDataType, typename AccDataType>
|
||||
auto operator()(int m, const HostTensor<AccDataType>& acc)
|
||||
{
|
||||
HostTensor<OutDataType> o(acc.get_lengths(), acc.get_strides());
|
||||
operator()(m, o, acc);
|
||||
return o;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename XDataType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename ComputeDataType,
|
||||
typename YDataType,
|
||||
typename MeanDataType,
|
||||
typename InvStdDataType,
|
||||
typename Epilogue = reference_layernorm2d_default_epilogue>
|
||||
void reference_layernorm2d_fwd(const HostTensor<XDataType>& x_m_n,
|
||||
const HostTensor<GammaDataType>& gamma_n,
|
||||
const HostTensor<BetaDataType>& beta_n,
|
||||
HostTensor<YDataType>& y_m_n,
|
||||
HostTensor<MeanDataType>& mean_m,
|
||||
HostTensor<InvStdDataType>& invStd_m,
|
||||
ComputeDataType epsilon,
|
||||
Epilogue epilogue_functor = {})
|
||||
{
|
||||
auto layernorm2d_fwd_func = [&](auto m) {
|
||||
const int N = x_m_n.mDesc.get_lengths()[1];
|
||||
|
||||
int count = 0;
|
||||
ComputeDataType mean = 0;
|
||||
ComputeDataType variance = 0;
|
||||
ComputeDataType divisor = 0;
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
++count;
|
||||
ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
|
||||
ComputeDataType delta = x - mean;
|
||||
mean += delta / count;
|
||||
ComputeDataType delta2 = x - mean;
|
||||
variance += delta * delta2;
|
||||
}
|
||||
|
||||
// actual variance
|
||||
variance = variance / count;
|
||||
divisor = ck_tile::type_convert<ComputeDataType>(1) / ck_tile::sqrt(variance + epsilon);
|
||||
|
||||
if constexpr(!std::is_same_v<MeanDataType, ck_tile::null_type>)
|
||||
mean_m(m) = ck_tile::type_convert<MeanDataType>(mean);
|
||||
|
||||
if constexpr(!std::is_same_v<InvStdDataType, ck_tile::null_type>)
|
||||
invStd_m(m) = ck_tile::type_convert<InvStdDataType>(divisor);
|
||||
|
||||
HostTensor<ComputeDataType> acc(x_m_n.get_lengths(), x_m_n.get_strides());
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
|
||||
ComputeDataType gamma = ck_tile::type_convert<ComputeDataType>(gamma_n(n));
|
||||
ComputeDataType beta = ck_tile::type_convert<ComputeDataType>(beta_n(n));
|
||||
auto a_ = (x - mean) * divisor;
|
||||
a_ = a_ * gamma + beta;
|
||||
|
||||
acc(m, n) = a_;
|
||||
}
|
||||
|
||||
epilogue_functor(m, y_m_n, acc);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(layernorm2d_fwd_func,
|
||||
mean_m.mDesc.get_lengths()[0])(std::thread::hardware_concurrency());
|
||||
}
|
||||
} // namespace ck_tile
|
||||
318
include/ck_tile/host/reference/reference_moe_gemm.hpp
Normal file
318
include/ck_tile/host/reference/reference_moe_gemm.hpp
Normal file
@@ -0,0 +1,318 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <thread>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC,
|
||||
int MoeGemmKind = 0, // 0: gemm1_gate_only, 1: gemm1_gate_up, 2: gemm2, 3:gemm1_split_k
|
||||
typename ActivationOp = identity>
|
||||
__global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
|
||||
const ck_tile::index_t* p_sorted_expert_ids_,
|
||||
const ck_tile::index_t* p_max_token_id_,
|
||||
const ADataType* A,
|
||||
const BDataType* B,
|
||||
CDataType* C,
|
||||
const AccDataType* expert_weight_ptr,
|
||||
ck_tile::index_t Num_tokens,
|
||||
ck_tile::index_t TokensPerBlock,
|
||||
ck_tile::index_t TopK,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t strideA,
|
||||
ck_tile::index_t strideB,
|
||||
ck_tile::index_t strideC,
|
||||
index_t scale_granularity_m,
|
||||
index_t scale_granularity_n,
|
||||
index_t scale_granularity_k,
|
||||
float* scale_A_ptr,
|
||||
float* scale_B_ptr,
|
||||
float* expert_bias_ptr)
|
||||
{
|
||||
constexpr auto is_split_k = MoeGemmKind == 3;
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int problem_N = MoeGemmKind == 1 ? N / 2 : N;
|
||||
int row = idx / problem_N; // Compute row index
|
||||
int col = idx % problem_N; // Compute column index
|
||||
|
||||
index_t gather_token_id = 0;
|
||||
index_t scatter_token_id = 0;
|
||||
index_t expert_id = 0;
|
||||
|
||||
if(row < p_max_token_id_[0])
|
||||
{
|
||||
expert_id = p_sorted_expert_ids_[row / TokensPerBlock];
|
||||
gather_token_id = p_sorted_token_ids_[row] & 0xff'ffff;
|
||||
scatter_token_id = p_sorted_token_ids_[row] & 0xff'ffff;
|
||||
if(gather_token_id >= Num_tokens)
|
||||
{
|
||||
return;
|
||||
}
|
||||
if(MoeGemmKind == 2)
|
||||
{
|
||||
gather_token_id = gather_token_id * TopK + (p_sorted_token_ids_[row] >> 24);
|
||||
}
|
||||
else
|
||||
{
|
||||
scatter_token_id = scatter_token_id * TopK + (p_sorted_token_ids_[row] >> 24);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
if(row < M)
|
||||
{
|
||||
AccDataType acc = 0.0;
|
||||
AccDataType acc_up = 0.0;
|
||||
|
||||
AccDataType acc_temp = 0.0;
|
||||
AccDataType acc_up_temp = 0.0;
|
||||
|
||||
float scale_A = 0;
|
||||
float scale_B = 0;
|
||||
float scale_B_up = 0;
|
||||
|
||||
index_t scale_A_stride = (M + scale_granularity_m - 1) / scale_granularity_m;
|
||||
index_t scale_B_stride = (N + scale_granularity_n - 1) / scale_granularity_n;
|
||||
index_t scale_B_expert_stride = scale_B_stride * K / scale_granularity_k;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
if(k % scale_granularity_k == 0)
|
||||
{
|
||||
// update acc
|
||||
acc += acc_temp * scale_A * scale_B;
|
||||
acc_up += acc_up_temp * scale_A * scale_B_up;
|
||||
// reset acc temp
|
||||
acc_temp = 0.0;
|
||||
acc_up_temp = 0.0;
|
||||
// update scale factors
|
||||
scale_A = scale_A_ptr[(gather_token_id / scale_granularity_m) +
|
||||
(k / scale_granularity_k) * scale_A_stride];
|
||||
scale_B =
|
||||
scale_B_ptr[expert_id * scale_B_expert_stride + col / scale_granularity_n +
|
||||
(k / scale_granularity_k) * scale_B_stride];
|
||||
if constexpr(MoeGemmKind == 1)
|
||||
scale_B_up = scale_B_ptr[expert_id * scale_B_expert_stride +
|
||||
(col + problem_N) / scale_granularity_n +
|
||||
(k / scale_granularity_k) * scale_B_stride];
|
||||
}
|
||||
|
||||
constexpr index_t packed_size_a = ck_tile::numeric_traits<ADataType>::PackedSize;
|
||||
constexpr index_t packed_size_b = ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
// Adjust indexing based on matrix layout
|
||||
int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
|
||||
? gather_token_id * strideA + k
|
||||
: k * strideA + gather_token_id;
|
||||
|
||||
long b_index =
|
||||
long(expert_id) * N * K +
|
||||
((std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>) ? col * strideB + k
|
||||
: k * strideB + col);
|
||||
long b_index_up;
|
||||
if constexpr(MoeGemmKind == 1)
|
||||
b_index_up = long(expert_id) * N * K +
|
||||
((std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
|
||||
? (col + problem_N) * strideB + k
|
||||
: k * strideB + col + problem_N);
|
||||
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
AccDataType v_b_up;
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
|
||||
if(k % 2 == 1)
|
||||
v_a = fp32_val.hi;
|
||||
else
|
||||
v_a = fp32_val.lo;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a]);
|
||||
if(k % 2 == 1)
|
||||
v_a = fp32_val.hi;
|
||||
else
|
||||
v_a = fp32_val.lo;
|
||||
}
|
||||
else
|
||||
{
|
||||
v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
|
||||
}
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
|
||||
if(k % 2 == 1)
|
||||
v_b = fp32_val.hi;
|
||||
else
|
||||
v_b = fp32_val.lo;
|
||||
if constexpr(MoeGemmKind == 1)
|
||||
{
|
||||
const fp32x2_t fp32_val_up =
|
||||
pk_int4_t_to_fp32x2_t(B[b_index_up / packed_size_b]);
|
||||
if(k % 2 == 1)
|
||||
v_b_up = fp32_val_up.hi;
|
||||
else
|
||||
v_b_up = fp32_val_up.lo;
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b], 1.0f);
|
||||
if(k % 2 == 1)
|
||||
v_b = fp32_val.hi;
|
||||
else
|
||||
v_b = fp32_val.lo;
|
||||
if constexpr(MoeGemmKind == 1)
|
||||
{
|
||||
const fp32x2_t fp32_val_up =
|
||||
pk_fp4_to_fp32x2(B[b_index_up / packed_size_b], 1.0f);
|
||||
if(k % 2 == 1)
|
||||
v_b_up = fp32_val_up.hi;
|
||||
else
|
||||
v_b_up = fp32_val_up.lo;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
|
||||
if constexpr(MoeGemmKind == 1)
|
||||
v_b_up = ck_tile::type_convert<AccDataType>(B[b_index_up]);
|
||||
}
|
||||
acc_temp += v_a * v_b;
|
||||
if constexpr(MoeGemmKind == 1)
|
||||
acc_up_temp += v_a * v_b_up;
|
||||
}
|
||||
|
||||
acc += acc_temp * scale_A * scale_B;
|
||||
acc_up += acc_up_temp * scale_A * scale_B_up;
|
||||
|
||||
float bias = 0.f, bias_up = 0.f;
|
||||
if(expert_bias_ptr != nullptr && !is_split_k)
|
||||
{
|
||||
bias = expert_bias_ptr[expert_id * N + col];
|
||||
if constexpr(MoeGemmKind == 1)
|
||||
bias_up = expert_bias_ptr[expert_id * N + col + problem_N];
|
||||
}
|
||||
|
||||
int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
|
||||
? scatter_token_id * strideC + col
|
||||
: col * strideC + scatter_token_id;
|
||||
if constexpr(MoeGemmKind < 2)
|
||||
{
|
||||
C[c_index] = ck_tile::type_convert<CDataType>(
|
||||
ActivationOp{}(acc + bias, MoeGemmKind == 1 ? acc_up + bias_up : 1));
|
||||
}
|
||||
else
|
||||
{
|
||||
// moe gemm2 don't use activation.
|
||||
auto weight =
|
||||
is_split_k ? ck_tile::type_convert<AccDataType>(1.0f) : expert_weight_ptr[row];
|
||||
CDataType res = ck_tile::type_convert<CDataType>((acc + bias) * weight);
|
||||
|
||||
thread_buffer<CDataType, 2> add_v = 0;
|
||||
if(c_index % 2)
|
||||
{
|
||||
// result is the second value of fp16 pair.
|
||||
add_v.template get_as<CDataType>()[1] = res;
|
||||
}
|
||||
else
|
||||
{
|
||||
// result is the first value of fp16 pair.
|
||||
add_v.template get_as<CDataType>()[0] = res;
|
||||
}
|
||||
// mask last bit to make sure atomicAdd pointer is aligned of DWORD.
|
||||
atomic_add_g<CDataType, 2>(reinterpret_cast<CDataType*>(C + (c_index & 0xffff'fffe)),
|
||||
add_v);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC,
|
||||
int MoeGemmKind = 0, // 0: gemm1_gate_only, 1: gemm1_gate_up, 2: gemm2, 3:gemm1_split_k
|
||||
typename ActivationOp = identity>
|
||||
void reference_moe_gemm_gpu(const index_t* p_sorted_token_ids_,
|
||||
const index_t* p_sorted_expert_ids_,
|
||||
const index_t* p_max_token_id_,
|
||||
const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
CDataType* c_ptr,
|
||||
const AccDataType* expert_weight_ptr,
|
||||
index_t Num_tokens,
|
||||
index_t TokensPerBlock,
|
||||
index_t TopK,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t stride_a,
|
||||
index_t stride_b,
|
||||
index_t stride_c,
|
||||
index_t scale_granularity_m,
|
||||
index_t scale_granularity_n,
|
||||
index_t scale_granularity_k,
|
||||
float* scale_A_ptr,
|
||||
float* scale_B_ptr,
|
||||
float* exp_bias = nullptr)
|
||||
{
|
||||
int problem_N = MoeGemmKind == 1 ? N / 2 : N;
|
||||
int totalElements = M * problem_N;
|
||||
int numThreadsPerBlock = 256; // Common choice for threads per block
|
||||
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
|
||||
|
||||
moe_gemm_kernel<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
LayoutA,
|
||||
LayoutB,
|
||||
LayoutC,
|
||||
MoeGemmKind,
|
||||
ActivationOp><<<numBlocks, numThreadsPerBlock>>>(p_sorted_token_ids_,
|
||||
p_sorted_expert_ids_,
|
||||
p_max_token_id_,
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
expert_weight_ptr,
|
||||
Num_tokens,
|
||||
TokensPerBlock,
|
||||
TopK,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
scale_granularity_m,
|
||||
scale_granularity_n,
|
||||
scale_granularity_k,
|
||||
scale_A_ptr,
|
||||
scale_B_ptr,
|
||||
exp_bias);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
121
include/ck_tile/host/reference/reference_moe_sorting.hpp
Normal file
121
include/ck_tile/host/reference/reference_moe_sorting.hpp
Normal file
@@ -0,0 +1,121 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
#define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \
|
||||
static_cast<uint32_t>(((token_id_) & 0x00ffffff) | (((topk_id_) & 0xff) << 24))
|
||||
|
||||
template <typename WeightType, typename IndexType = index_t>
|
||||
CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
|
||||
const HostTensor<WeightType>& weights,
|
||||
const HostTensor<IndexType>& local_expert_mask,
|
||||
HostTensor<IndexType>& p_sorted_token_ids,
|
||||
HostTensor<WeightType>& sorted_weight,
|
||||
HostTensor<IndexType>& sorted_expert_ids,
|
||||
index_t& unit_cnt,
|
||||
const index_t experts,
|
||||
const index_t unit_size,
|
||||
const index_t tokens,
|
||||
bool local_expert_masking,
|
||||
bool skip_experts_with_zero_token = true)
|
||||
{
|
||||
// note: if tokens is smaller than topk_ids.mDesc.get_lengths()[0], indicating local_token case
|
||||
const index_t num_token = tokens; // topk_ids.mDesc.get_lengths()[0];
|
||||
const index_t topk = topk_ids.mDesc.get_lengths()[1];
|
||||
// allocate a temp buffer, and fill the value with [number_token|topk]
|
||||
std::vector<std::vector<IndexType>> expert_tokens(
|
||||
experts,
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
std::vector<IndexType>(unit_size, MOE_SORTING_MOCK_ID(num_token, topk)));
|
||||
#else
|
||||
std::vector<IndexType>(unit_size, num_token));
|
||||
#endif
|
||||
std::vector<std::vector<WeightType>> expert_token_weights(
|
||||
experts, std::vector<WeightType>(unit_size, 0));
|
||||
// count number of unit-size slices in this expert
|
||||
std::vector<IndexType> expert_slices(experts, 1);
|
||||
// count the tokens used in this expert
|
||||
std::vector<IndexType> expert_slice_idxs(experts, 0);
|
||||
// TODO: above 2 buffer seems duplicated
|
||||
|
||||
for(index_t t = 0; t < num_token; t++)
|
||||
{
|
||||
for(index_t k = 0; k < topk; k++)
|
||||
{
|
||||
IndexType e = topk_ids(t, k);
|
||||
WeightType w = weights(t, k);
|
||||
index_t idx = expert_slice_idxs[e];
|
||||
if(idx > expert_slices[e] * unit_size - 1)
|
||||
{
|
||||
expert_slices[e]++;
|
||||
index_t new_size = expert_slices[e] * unit_size;
|
||||
expert_tokens[e].resize(new_size);
|
||||
expert_token_weights[e].resize(new_size);
|
||||
for(index_t i = (expert_slices[e] - 1) * unit_size; i < new_size; i++)
|
||||
{
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
expert_tokens[e][i] = MOE_SORTING_MOCK_ID(num_token, topk);
|
||||
#else
|
||||
expert_tokens[e][i] = num_token;
|
||||
#endif
|
||||
expert_token_weights[e][i] = 0;
|
||||
}
|
||||
}
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
expert_tokens[e][idx] = MOE_SORTING_MOCK_ID(t, k);
|
||||
#else
|
||||
expert_tokens[e][idx] = t;
|
||||
#endif
|
||||
expert_token_weights[e][idx] = w;
|
||||
expert_slice_idxs[e]++;
|
||||
}
|
||||
}
|
||||
|
||||
IndexType* out_tokens = p_sorted_token_ids.data();
|
||||
WeightType* out_weights = sorted_weight.data();
|
||||
IndexType* out_expert_id = sorted_expert_ids.data();
|
||||
int curr_expert_id = 0;
|
||||
for(index_t e = 0; e < experts; e++)
|
||||
{
|
||||
if(local_expert_masking)
|
||||
{
|
||||
if(local_expert_mask(e) == 0)
|
||||
continue;
|
||||
}
|
||||
if(skip_experts_with_zero_token)
|
||||
{
|
||||
if(expert_slice_idxs[e] == 0)
|
||||
{
|
||||
curr_expert_id++;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
memcpy(out_tokens, expert_tokens[e].data(), sizeof(index_t) * expert_slices[e] * unit_size);
|
||||
out_tokens += expert_slices[e] * unit_size;
|
||||
memcpy(out_weights,
|
||||
expert_token_weights[e].data(),
|
||||
sizeof(WeightType) * expert_slices[e] * unit_size);
|
||||
out_weights += expert_slices[e] * unit_size;
|
||||
|
||||
for(index_t s = 0; s < expert_slices[e]; s++)
|
||||
{
|
||||
out_expert_id[s] = curr_expert_id;
|
||||
unit_cnt++;
|
||||
}
|
||||
out_expert_id += expert_slices[e];
|
||||
curr_expert_id++;
|
||||
}
|
||||
unit_cnt *= unit_size;
|
||||
return;
|
||||
}
|
||||
|
||||
#undef MOE_SORTING_MOCK_ID
|
||||
|
||||
} // namespace ck_tile
|
||||
76
include/ck_tile/host/reference/reference_permute.hpp
Normal file
76
include/ck_tile/host/reference/reference_permute.hpp
Normal file
@@ -0,0 +1,76 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include <thread>
|
||||
#include <numeric>
|
||||
#include <functional>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/*
|
||||
this will do permute + contiguous like functionality in pytorch
|
||||
*/
|
||||
template <typename DataType>
|
||||
CK_TILE_HOST void
|
||||
reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::vector<index_t> perm)
|
||||
{
|
||||
const auto x_len = x.mDesc.get_lengths();
|
||||
const auto y_len = y.mDesc.get_lengths();
|
||||
assert(x_len.size() == y_len.size());
|
||||
index_t rank = x_len.size();
|
||||
const auto x_elm = std::accumulate(x_len.begin(), x_len.end(), 1, std::multiplies<index_t>());
|
||||
const auto y_elm = std::accumulate(y_len.begin(), y_len.end(), 1, std::multiplies<index_t>());
|
||||
assert(x_elm == y_elm);
|
||||
(void)y_elm;
|
||||
|
||||
auto f = [&](auto i_element) {
|
||||
std::vector<size_t> y_coord = [&]() {
|
||||
std::vector<size_t> tmp(rank, 0);
|
||||
size_t r = i_element;
|
||||
for(index_t i = rank - 1; i >= 0; i--)
|
||||
{
|
||||
tmp[i] = r % y_len[i];
|
||||
r = r / y_len[i];
|
||||
}
|
||||
return tmp;
|
||||
}();
|
||||
|
||||
std::vector<size_t> x_coord = [&]() {
|
||||
std::vector<size_t> tmp(rank, 0);
|
||||
for(index_t i = 0; i < rank; i++)
|
||||
{
|
||||
tmp[perm[i]] = y_coord[i];
|
||||
}
|
||||
return tmp;
|
||||
}();
|
||||
|
||||
// do permute
|
||||
y(y_coord) = x(x_coord);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f, x_elm)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
CK_TILE_HOST auto reference_permute(const HostTensor<DataType>& x, std::vector<index_t> perm)
|
||||
{
|
||||
auto x_shape = x.get_lengths();
|
||||
ck_tile::index_t rank = perm.size();
|
||||
std::vector<ck_tile::index_t> y_shape = [&]() {
|
||||
std::vector<ck_tile::index_t> tmp(rank, 0);
|
||||
for(int i = 0; i < static_cast<int>(rank); i++)
|
||||
{
|
||||
tmp[i] = x_shape[perm[i]];
|
||||
}
|
||||
return tmp;
|
||||
}();
|
||||
|
||||
HostTensor<DataType> y(y_shape);
|
||||
reference_permute(x, y, perm);
|
||||
return y;
|
||||
}
|
||||
} // namespace ck_tile
|
||||
198
include/ck_tile/host/reference/reference_pool.hpp
Normal file
198
include/ck_tile/host/reference/reference_pool.hpp
Normal file
@@ -0,0 +1,198 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include "ck_tile/ops/pooling/kernel/pool_kernel.hpp"
|
||||
#include <thread>
|
||||
#include <cmath>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename InDataType,
|
||||
typename ComputeDataType,
|
||||
typename OutDataType,
|
||||
typename IndexDataType,
|
||||
typename ReduceOp,
|
||||
typename TensorShape,
|
||||
typename WindowShape,
|
||||
bool OutputIndex = false>
|
||||
CK_TILE_HOST void reference_pool2d(const HostTensor<InDataType>& input,
|
||||
HostTensor<OutDataType>& output,
|
||||
HostTensor<IndexDataType>& output_index,
|
||||
PoolKernelArgs<TensorShape, WindowShape> kargs,
|
||||
ReduceOp reduce_op)
|
||||
{
|
||||
const ck_tile::index_t N = kargs.input_shape.at(ck_tile::number<0>{});
|
||||
const ck_tile::index_t H = kargs.input_shape.at(ck_tile::number<1>{});
|
||||
const ck_tile::index_t W = kargs.input_shape.at(ck_tile::number<2>{});
|
||||
const ck_tile::index_t C = kargs.input_shape.at(ck_tile::number<3>{});
|
||||
|
||||
const ck_tile::index_t Ho = kargs.output_shape.at(ck_tile::number<1>{});
|
||||
const ck_tile::index_t Wo = kargs.output_shape.at(ck_tile::number<2>{});
|
||||
|
||||
const ck_tile::index_t Y = kargs.window_lengths.at(ck_tile::number<0>{});
|
||||
const ck_tile::index_t X = kargs.window_lengths.at(ck_tile::number<1>{});
|
||||
|
||||
const ck_tile::index_t Sy = kargs.window_strides.at(ck_tile::number<0>{});
|
||||
const ck_tile::index_t Sx = kargs.window_strides.at(ck_tile::number<1>{});
|
||||
|
||||
const ck_tile::index_t Dy = kargs.window_dilations.at(ck_tile::number<0>{});
|
||||
const ck_tile::index_t Dx = kargs.window_dilations.at(ck_tile::number<1>{});
|
||||
|
||||
const ck_tile::index_t LeftPy = kargs.input_left_pads.at(ck_tile::number<0>{});
|
||||
const ck_tile::index_t LeftPx = kargs.input_left_pads.at(ck_tile::number<1>{});
|
||||
// Right padding is handled implicitly by bounds checking
|
||||
|
||||
auto f = [&](auto n, auto ho, auto wo, auto c) {
|
||||
ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
|
||||
|
||||
IndexDataType current_index = 0; // Declare outside if constexpr for efficiency
|
||||
|
||||
for(ck_tile::index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
// Calculate input height index with stride, dilation, and padding
|
||||
ck_tile::index_t hi = ho * Sy + y * Dy - LeftPy;
|
||||
|
||||
for(ck_tile::index_t x = 0; x < X; ++x)
|
||||
{
|
||||
// Calculate input width index with stride, dilation, and padding
|
||||
ck_tile::index_t wi = wo * Sx + x * Dx - LeftPx;
|
||||
|
||||
if(hi >= 0 && hi < H && wi >= 0 && wi < W)
|
||||
{
|
||||
const ComputeDataType v_in = type_convert<ComputeDataType>(input(n, hi, wi, c));
|
||||
|
||||
if constexpr(OutputIndex)
|
||||
{
|
||||
IndexDataType flat_index = input.GetOffsetFromMultiIndex(n, hi, wi, c);
|
||||
bool changed = false;
|
||||
v_acc = reduce_op(v_acc, v_in, changed);
|
||||
if(changed)
|
||||
{
|
||||
current_index = flat_index;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
v_acc = reduce_op(v_acc, v_in);
|
||||
}
|
||||
}
|
||||
// For positions outside bounds, we implicitly use identity value
|
||||
}
|
||||
}
|
||||
|
||||
output(n, ho, wo, c) = ck_tile::type_convert<OutDataType>(v_acc);
|
||||
|
||||
if constexpr(OutputIndex)
|
||||
{
|
||||
output_index(n, ho, wo, c) = current_index;
|
||||
}
|
||||
};
|
||||
|
||||
// Parallelize over all output dimensions
|
||||
make_ParallelTensorFunctor(f, N, Ho, Wo, C)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
template <typename InDataType,
|
||||
typename ComputeDataType,
|
||||
typename OutDataType,
|
||||
typename IndexDataType,
|
||||
typename ReduceOp,
|
||||
typename TensorShape,
|
||||
typename WindowShape,
|
||||
bool OutputIndex = false>
|
||||
CK_TILE_HOST void reference_pool3d(const HostTensor<InDataType>& input,
|
||||
HostTensor<OutDataType>& output,
|
||||
HostTensor<IndexDataType>& output_index,
|
||||
PoolKernelArgs<TensorShape, WindowShape> kargs,
|
||||
ReduceOp reduce_op)
|
||||
{
|
||||
const ck_tile::index_t N = kargs.input_shape.at(ck_tile::number<0>{});
|
||||
const ck_tile::index_t D = kargs.input_shape.at(ck_tile::number<1>{});
|
||||
const ck_tile::index_t H = kargs.input_shape.at(ck_tile::number<2>{});
|
||||
const ck_tile::index_t W = kargs.input_shape.at(ck_tile::number<3>{});
|
||||
const ck_tile::index_t C = kargs.input_shape.at(ck_tile::number<4>{});
|
||||
|
||||
const ck_tile::index_t Do = kargs.output_shape.at(ck_tile::number<1>{});
|
||||
const ck_tile::index_t Ho = kargs.output_shape.at(ck_tile::number<2>{});
|
||||
const ck_tile::index_t Wo = kargs.output_shape.at(ck_tile::number<3>{});
|
||||
|
||||
const ck_tile::index_t Z = kargs.window_lengths.at(ck_tile::number<0>{});
|
||||
const ck_tile::index_t Y = kargs.window_lengths.at(ck_tile::number<1>{});
|
||||
const ck_tile::index_t X = kargs.window_lengths.at(ck_tile::number<2>{});
|
||||
|
||||
const ck_tile::index_t Sz = kargs.window_strides.at(ck_tile::number<0>{});
|
||||
const ck_tile::index_t Sy = kargs.window_strides.at(ck_tile::number<1>{});
|
||||
const ck_tile::index_t Sx = kargs.window_strides.at(ck_tile::number<2>{});
|
||||
|
||||
const ck_tile::index_t Dz = kargs.window_dilations.at(ck_tile::number<0>{});
|
||||
const ck_tile::index_t Dy = kargs.window_dilations.at(ck_tile::number<1>{});
|
||||
const ck_tile::index_t Dx = kargs.window_dilations.at(ck_tile::number<2>{});
|
||||
|
||||
const ck_tile::index_t LeftPz = kargs.input_left_pads.at(ck_tile::number<0>{});
|
||||
const ck_tile::index_t LeftPy = kargs.input_left_pads.at(ck_tile::number<1>{});
|
||||
const ck_tile::index_t LeftPx = kargs.input_left_pads.at(ck_tile::number<2>{});
|
||||
// Right padding is handled implicitly by bounds checking
|
||||
|
||||
auto f = [&](auto n, auto do_, auto ho, auto wo, auto c) {
|
||||
ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
|
||||
|
||||
IndexDataType current_index = 0; // Declare outside if constexpr for efficiency
|
||||
|
||||
for(ck_tile::index_t z = 0; z < Z; ++z)
|
||||
{
|
||||
// Calculate input depth index with stride, dilation, and padding
|
||||
ck_tile::index_t di = do_ * Sz + z * Dz - LeftPz;
|
||||
|
||||
for(ck_tile::index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
// Calculate input height index with stride, dilation, and padding
|
||||
ck_tile::index_t hi = ho * Sy + y * Dy - LeftPy;
|
||||
|
||||
for(ck_tile::index_t x = 0; x < X; ++x)
|
||||
{
|
||||
// Calculate input width index with stride, dilation, and padding
|
||||
ck_tile::index_t wi = wo * Sx + x * Dx - LeftPx;
|
||||
|
||||
if(di >= 0 && di < D && hi >= 0 && hi < H && wi >= 0 && wi < W)
|
||||
{
|
||||
const ComputeDataType v_in =
|
||||
type_convert<ComputeDataType>(input(n, di, hi, wi, c));
|
||||
|
||||
if constexpr(OutputIndex)
|
||||
{
|
||||
IndexDataType flat_index =
|
||||
input.GetOffsetFromMultiIndex(n, di, hi, wi, c);
|
||||
bool changed = false;
|
||||
v_acc = reduce_op(v_acc, v_in, changed);
|
||||
if(changed)
|
||||
{
|
||||
current_index = flat_index;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
v_acc = reduce_op(v_acc, v_in);
|
||||
}
|
||||
}
|
||||
// For positions outside bounds, we implicitly use identity value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
output(n, do_, ho, wo, c) = ck_tile::type_convert<OutDataType>(v_acc);
|
||||
|
||||
if constexpr(OutputIndex)
|
||||
{
|
||||
|
||||
output_index(n, do_, ho, wo, c) = current_index;
|
||||
}
|
||||
};
|
||||
|
||||
// Parallelize over all output dimensions
|
||||
make_ParallelTensorFunctor(f, N, Do, Ho, Wo, C)(std::thread::hardware_concurrency());
|
||||
}
|
||||
} // namespace ck_tile
|
||||
341
include/ck_tile/host/reference/reference_reduce.hpp
Normal file
341
include/ck_tile/host/reference/reference_reduce.hpp
Normal file
@@ -0,0 +1,341 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
#include <thread>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename XDataType, typename ComputeDataType, typename YDataType, typename ReduceOp>
|
||||
CK_TILE_HOST void
|
||||
reference_reduce(const HostTensor<XDataType>& x_m_n, HostTensor<YDataType>& y_m, ReduceOp reduce_op)
|
||||
{
|
||||
auto f = [&](auto m) {
|
||||
const int N = x_m_n.mDesc.get_lengths()[1];
|
||||
|
||||
ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
const ComputeDataType v_a = type_convert<ComputeDataType>(x_m_n(m, n));
|
||||
|
||||
v_acc = reduce_op(v_acc, v_a);
|
||||
}
|
||||
|
||||
y_m(m) = ck_tile::type_convert<YDataType>(v_acc);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f, y_m.mDesc.get_lengths()[0])(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
// Generic reference reduce for arbitrary dimensions
|
||||
template <
|
||||
typename XDataType,
|
||||
typename ComputeDataType,
|
||||
typename YDataType,
|
||||
typename ReduceOp,
|
||||
typename KeptDim, // Expected type: ck_tile::sequence<...> containing dimension indices to keep
|
||||
typename ReduceDims> // Expected type: ck_tile::sequence<...> containing dimension indices to
|
||||
// reduce
|
||||
CK_TILE_HOST void reference_reduce(const HostTensor<XDataType>& x_tensor,
|
||||
HostTensor<YDataType>& y_tensor,
|
||||
ReduceOp reduce_op,
|
||||
KeptDim kept_dim,
|
||||
ReduceDims reduce_dims)
|
||||
{
|
||||
const auto& x_lengths = x_tensor.mDesc.get_lengths();
|
||||
|
||||
// Calculate total kept elements (product of all kept dimension lengths)
|
||||
index_t total_kept_elements = 1;
|
||||
static_for<0, kept_dim.size(), 1>{}(
|
||||
[&](auto i) { total_kept_elements *= x_lengths[kept_dim.at(i)]; });
|
||||
|
||||
// Calculate total reduce elements (product of all reduce dimension lengths)
|
||||
index_t total_reduce_elements = 1;
|
||||
static_for<0, reduce_dims.size(), 1>{}(
|
||||
[&](auto i) { total_reduce_elements *= x_lengths[reduce_dims.at(i)]; });
|
||||
|
||||
auto f = [&](auto linear_kept_idx) {
|
||||
ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
|
||||
|
||||
// Convert linear kept index to multi-dimensional kept indices
|
||||
std::vector<index_t> kept_indices(kept_dim.size());
|
||||
index_t temp_kept = linear_kept_idx;
|
||||
static_for<0, kept_dim.size(), 1>{}([&](auto i) {
|
||||
constexpr auto dim_idx = kept_dim.size() - 1 - i;
|
||||
constexpr auto dim = kept_dim.at(dim_idx);
|
||||
const auto len = x_lengths[dim];
|
||||
kept_indices[dim_idx] = temp_kept % len;
|
||||
temp_kept /= len;
|
||||
});
|
||||
|
||||
for(index_t reduce_idx = 0; reduce_idx < total_reduce_elements; ++reduce_idx)
|
||||
{
|
||||
// Convert linear reduce index to multi-dimensional reduce indices
|
||||
std::vector<index_t> reduce_indices(reduce_dims.size());
|
||||
index_t temp_reduce = reduce_idx;
|
||||
static_for<0, reduce_dims.size(), 1>{}([&](auto i) {
|
||||
constexpr auto dim_idx = reduce_dims.size() - 1 - i;
|
||||
constexpr auto dim = reduce_dims.at(dim_idx);
|
||||
const auto len = x_lengths[dim];
|
||||
reduce_indices[dim_idx] = temp_reduce % len;
|
||||
temp_reduce /= len;
|
||||
});
|
||||
|
||||
// Build full input tensor indices by combining kept and reduce indices
|
||||
std::vector<std::size_t> full_indices(x_lengths.size(), 0);
|
||||
static_for<0, kept_dim.size(), 1>{}(
|
||||
[&](auto i) { full_indices[kept_dim.at(i)] = kept_indices[i]; });
|
||||
static_for<0, reduce_dims.size(), 1>{}(
|
||||
[&](auto i) { full_indices[reduce_dims.at(i)] = reduce_indices[i]; });
|
||||
|
||||
// Access input tensor element
|
||||
const auto v_a = type_convert<ComputeDataType>(x_tensor(full_indices));
|
||||
|
||||
v_acc = reduce_op(v_acc, v_a);
|
||||
}
|
||||
|
||||
// Calculate output tensor index using kept indices
|
||||
// The output tensor has the same structure as the kept dimensions
|
||||
std::vector<std::size_t> y_indices(kept_dim.size());
|
||||
static_for<0, kept_dim.size(), 1>{}([&](auto i) { y_indices[i] = kept_indices[i]; });
|
||||
|
||||
y_tensor(y_indices) = type_convert<YDataType>(v_acc);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f, total_kept_elements)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
template <typename XDataType,
|
||||
typename ComputeDataType,
|
||||
typename YDataType,
|
||||
typename YRefTuple,
|
||||
typename ReduceOps, // Expected type: ck_tile::tuple<...> containing reduce operations
|
||||
typename KeptDim, // Expected type: ck_tile::sequence<...> containing dimension indices to
|
||||
// keep
|
||||
typename ReduceDims, // Expected type: ck_tile::sequence<...> containing dimension indices
|
||||
// to reduce
|
||||
typename ElementWiseOps,
|
||||
typename AccElementWiseOps>
|
||||
CK_TILE_HOST void reference_multiple_reduce(const HostTensor<XDataType>& x_tensor,
|
||||
YRefTuple& y_tensor_tuple,
|
||||
ReduceOps reduce_ops,
|
||||
KeptDim kept_dim,
|
||||
ReduceDims reduce_dims,
|
||||
ElementWiseOps elementwise_ops,
|
||||
AccElementWiseOps accumulator_ops)
|
||||
{
|
||||
const auto& x_lengths = x_tensor.mDesc.get_lengths();
|
||||
|
||||
// Calculate total kept elements (product of all kept dimension lengths)
|
||||
index_t total_kept_elements = 1;
|
||||
static_for<0, kept_dim.size(), 1>{}(
|
||||
[&](auto i) { total_kept_elements *= x_lengths[kept_dim.at(i)]; });
|
||||
|
||||
// Calculate total reduce elements (product of all reduce dimension lengths)
|
||||
index_t total_reduce_elements = 1;
|
||||
static_for<0, reduce_dims.size(), 1>{}(
|
||||
[&](auto i) { total_reduce_elements *= x_lengths[reduce_dims.at(i)]; });
|
||||
|
||||
auto f = [&](auto linear_kept_idx) {
|
||||
// Initialize accumulators for each reduction operation
|
||||
auto v_acc_tuple = ck_tile::generate_tuple(
|
||||
[&](auto i) {
|
||||
return reduce_ops.template at<i>().template GetIdentityValue<ComputeDataType>();
|
||||
},
|
||||
number<reduce_ops.size()>{});
|
||||
|
||||
// Convert linear kept index to multi-dimensional kept indices
|
||||
std::vector<index_t> kept_indices(kept_dim.size());
|
||||
index_t temp_kept = linear_kept_idx;
|
||||
static_for<0, kept_dim.size(), 1>{}([&](auto i) {
|
||||
constexpr auto dim_idx = kept_dim.size() - 1 - i;
|
||||
constexpr auto dim = kept_dim.at(dim_idx);
|
||||
const auto len = x_lengths[dim];
|
||||
kept_indices[dim_idx] = temp_kept % len;
|
||||
temp_kept /= len;
|
||||
});
|
||||
|
||||
for(index_t reduce_idx = 0; reduce_idx < total_reduce_elements; ++reduce_idx)
|
||||
{
|
||||
// Convert linear reduce index to multi-dimensional reduce indices
|
||||
std::vector<index_t> reduce_indices(reduce_dims.size());
|
||||
index_t temp_reduce = reduce_idx;
|
||||
static_for<0, reduce_dims.size(), 1>{}([&](auto i) {
|
||||
constexpr auto dim_idx = reduce_dims.size() - 1 - i;
|
||||
constexpr auto dim = reduce_dims.at(dim_idx);
|
||||
const auto len = x_lengths[dim];
|
||||
reduce_indices[dim_idx] = temp_reduce % len;
|
||||
temp_reduce /= len;
|
||||
});
|
||||
|
||||
// Build full input tensor indices by combining kept and reduce indices
|
||||
std::vector<std::size_t> full_indices(x_lengths.size(), 0);
|
||||
static_for<0, kept_dim.size(), 1>{}(
|
||||
[&](auto i) { full_indices[kept_dim.at(i)] = kept_indices[i]; });
|
||||
static_for<0, reduce_dims.size(), 1>{}(
|
||||
[&](auto i) { full_indices[reduce_dims.at(i)] = reduce_indices[i]; });
|
||||
|
||||
// Access input tensor element
|
||||
auto v_a = type_convert<ComputeDataType>(x_tensor(full_indices));
|
||||
|
||||
// Apply each reduction operation
|
||||
static_for<0, reduce_ops.size(), 1>{}([&](auto i) {
|
||||
// Apply element-wise operation before reduction
|
||||
elementwise_ops.at(i)(v_a, v_a);
|
||||
|
||||
v_acc_tuple.template at<i>() =
|
||||
reduce_ops.template at<i>()(v_acc_tuple.template at<i>(), v_a);
|
||||
});
|
||||
}
|
||||
|
||||
static_for<0, reduce_ops.size(), 1>{}([&](auto i) {
|
||||
// Apply accumulator element-wise operation after reduction
|
||||
accumulator_ops.at(i)(v_acc_tuple.template at<i>(), v_acc_tuple.template at<i>());
|
||||
});
|
||||
|
||||
// Calculate output tensor index using kept indices
|
||||
// The output tensor has the same structure as the kept dimensions
|
||||
std::vector<std::size_t> y_indices(kept_dim.size());
|
||||
static_for<0, kept_dim.size(), 1>{}([&](auto i) { y_indices[i] = kept_indices[i]; });
|
||||
|
||||
// Store results for each reduction operation in the output tensor
|
||||
static_for<0, reduce_ops.size(), 1>{}([&](auto i) {
|
||||
y_tensor_tuple.template at<i>()(y_indices) =
|
||||
type_convert<YDataType>(v_acc_tuple.template at<i>());
|
||||
});
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f, total_kept_elements)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
template <typename XDataType,
|
||||
typename ComputeDataType,
|
||||
typename YDataType,
|
||||
typename YRefTuple,
|
||||
typename ReduceOps, // Expected type: ck_tile::tuple<...> containing reduce operations
|
||||
typename KeptDim, // Expected type: ck_tile::sequence<...> containing dimension indices to
|
||||
// keep
|
||||
typename ReduceDims, // Expected type: ck_tile::sequence<...> containing dimension indices
|
||||
// to reduce
|
||||
typename ElementWiseOps,
|
||||
typename AccElementWiseOps,
|
||||
typename InterBlockReduceOps>
|
||||
CK_TILE_HOST void reference_multiple_reduce_multiblock(const HostTensor<XDataType>& x_tensor,
|
||||
YRefTuple& y_tensor_tuple,
|
||||
ReduceOps reduce_ops,
|
||||
KeptDim kept_dim,
|
||||
ReduceDims reduce_dims,
|
||||
ElementWiseOps elementwise_ops,
|
||||
AccElementWiseOps accumulator_ops,
|
||||
InterBlockReduceOps inter_block_reduce_ops,
|
||||
ck_tile::index_t num_blocks)
|
||||
{
|
||||
const auto& x_lengths = x_tensor.mDesc.get_lengths();
|
||||
|
||||
// Calculate total kept elements (product of all kept dimension lengths)
|
||||
index_t total_kept_elements = 1;
|
||||
static_for<0, kept_dim.size(), 1>{}(
|
||||
[&](auto i) { total_kept_elements *= x_lengths[kept_dim.at(i)]; });
|
||||
|
||||
// Calculate total reduce elements (product of all reduce dimension lengths)
|
||||
index_t total_reduce_elements = 1;
|
||||
static_for<0, reduce_dims.size(), 1>{}(
|
||||
[&](auto i) { total_reduce_elements *= x_lengths[reduce_dims.at(i)]; });
|
||||
|
||||
// Initialize output tensors
|
||||
static_for<0, reduce_ops.size(), 1>{}([&](auto i) {
|
||||
auto& y_tensor = y_tensor_tuple.template at<i>();
|
||||
for(auto& val : y_tensor.mData)
|
||||
{
|
||||
val = inter_block_reduce_ops.template at<i>().template GetIdentityValue<YDataType>();
|
||||
}
|
||||
});
|
||||
|
||||
auto f = [&](auto linear_kept_idx) {
|
||||
// Convert linear kept index to multi-dimensional kept indices
|
||||
std::vector<index_t> kept_indices(kept_dim.size());
|
||||
index_t temp_kept = linear_kept_idx;
|
||||
static_for<0, kept_dim.size(), 1>{}([&](auto i) {
|
||||
constexpr auto dim_idx = kept_dim.size() - 1 - i;
|
||||
constexpr auto dim = kept_dim.at(dim_idx);
|
||||
const auto len = x_lengths[dim];
|
||||
kept_indices[dim_idx] = temp_kept % len;
|
||||
temp_kept /= len;
|
||||
});
|
||||
|
||||
// Calculate output tensor index using kept indices
|
||||
std::vector<std::size_t> y_indices(kept_dim.size());
|
||||
static_for<0, kept_dim.size(), 1>{}([&](auto i) { y_indices[i] = kept_indices[i]; });
|
||||
|
||||
const auto max_element_per_block = (total_reduce_elements + num_blocks - 1) / num_blocks;
|
||||
|
||||
for(index_t block_id = 0; block_id < num_blocks; ++block_id)
|
||||
{
|
||||
// Initialize accumulators for each reduction operation for the current block
|
||||
auto v_acc_tuple = ck_tile::generate_tuple(
|
||||
[&](auto i) {
|
||||
return reduce_ops.template at<i>().template GetIdentityValue<ComputeDataType>();
|
||||
},
|
||||
number<reduce_ops.size()>{});
|
||||
|
||||
const index_t element_offset = block_id * max_element_per_block;
|
||||
const index_t element_end =
|
||||
std::min(element_offset + max_element_per_block, total_reduce_elements);
|
||||
|
||||
for(index_t linear_reduce_idx = element_offset; linear_reduce_idx < element_end;
|
||||
++linear_reduce_idx)
|
||||
{
|
||||
// Convert linear reduce index to multi-dimensional reduce indices
|
||||
std::vector<index_t> reduce_indices(reduce_dims.size());
|
||||
index_t temp_reduce = linear_reduce_idx;
|
||||
static_for<0, reduce_dims.size(), 1>{}([&](auto i) {
|
||||
constexpr auto dim_idx = reduce_dims.size() - 1 - i;
|
||||
constexpr auto dim = reduce_dims.at(dim_idx);
|
||||
const auto len = x_lengths[dim];
|
||||
reduce_indices[dim_idx] = temp_reduce % len;
|
||||
temp_reduce /= len;
|
||||
});
|
||||
|
||||
// Build full input tensor indices by combining kept and reduce indices
|
||||
std::vector<std::size_t> full_indices(x_lengths.size(), 0);
|
||||
static_for<0, kept_dim.size(), 1>{}(
|
||||
[&](auto i) { full_indices[kept_dim.at(i)] = kept_indices[i]; });
|
||||
static_for<0, reduce_dims.size(), 1>{}(
|
||||
[&](auto i) { full_indices[reduce_dims.at(i)] = reduce_indices[i]; });
|
||||
|
||||
// Access input tensor element
|
||||
const auto v_a_in = type_convert<ComputeDataType>(x_tensor(full_indices));
|
||||
|
||||
// Apply each reduction operation
|
||||
static_for<0, reduce_ops.size(), 1>{}([&](auto i) {
|
||||
auto v_a = v_a_in;
|
||||
// Apply element-wise operation before reduction
|
||||
elementwise_ops.at(i)(v_a, v_a);
|
||||
|
||||
v_acc_tuple.template at<i>() =
|
||||
reduce_ops.template at<i>()(v_acc_tuple.template at<i>(), v_a);
|
||||
});
|
||||
}
|
||||
|
||||
static_for<0, reduce_ops.size(), 1>{}([&](auto i) {
|
||||
// Apply accumulator element-wise operation after reduction
|
||||
accumulator_ops.at(i)(v_acc_tuple.template at<i>(), v_acc_tuple.template at<i>());
|
||||
|
||||
// Update the output tensor with the partial result from this block
|
||||
auto& y_tensor = y_tensor_tuple.template at<i>();
|
||||
auto& y_val = y_tensor(y_indices);
|
||||
y_val = inter_block_reduce_ops.template at<i>()(
|
||||
y_val, type_convert<YDataType>(v_acc_tuple.template at<i>()));
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f, total_kept_elements)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
114
include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp
Normal file
114
include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp
Normal file
@@ -0,0 +1,114 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Note: for simplicity, each functor only care about single M
|
||||
struct reference_rmsnorm2d_default_epilogue
|
||||
{
|
||||
template <typename OutDataType, typename AccDataType>
|
||||
void operator()(int m, HostTensor<OutDataType>& o, const HostTensor<AccDataType>& acc)
|
||||
{
|
||||
const int N = acc.mDesc.get_lengths()[1];
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
o(m, n) = ck_tile::type_convert<OutDataType>(acc(m, n));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename OutDataType, typename AccDataType>
|
||||
auto operator()(int m, const HostTensor<AccDataType>& acc)
|
||||
{
|
||||
HostTensor<OutDataType> o(acc.get_lengths(), acc.get_strides());
|
||||
operator()(m, o, acc);
|
||||
return o;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename XDataType,
|
||||
typename GammaDataType,
|
||||
typename ComputeDataType,
|
||||
typename YDataType,
|
||||
typename InvRmsDataType,
|
||||
typename UnquantYDataType,
|
||||
typename Epilogue = reference_rmsnorm2d_default_epilogue>
|
||||
void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
|
||||
const HostTensor<GammaDataType>& gamma_n,
|
||||
HostTensor<YDataType>& y_m_n,
|
||||
HostTensor<InvRmsDataType>& invRms_m,
|
||||
HostTensor<UnquantYDataType>& unquant_y_m_n,
|
||||
ComputeDataType epsilon,
|
||||
Epilogue epilogue_functor = {},
|
||||
const int use_model_sensitive_rmsnorm =
|
||||
static_cast<int>(Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL))
|
||||
{
|
||||
auto rmsnorm2d_fwd_func = [&](auto m) {
|
||||
const int N = x_m_n.mDesc.get_lengths()[1];
|
||||
|
||||
ComputeDataType mean_square = 0;
|
||||
ComputeDataType divisor = 0;
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
|
||||
mean_square += x * x;
|
||||
}
|
||||
|
||||
mean_square = mean_square / N;
|
||||
divisor = ck_tile::type_convert<ComputeDataType>(1) / ck_tile::sqrt(mean_square + epsilon);
|
||||
|
||||
if constexpr(!std::is_same_v<InvRmsDataType, ck_tile::null_type>)
|
||||
invRms_m(m) = ck_tile::type_convert<InvRmsDataType>(divisor);
|
||||
|
||||
HostTensor<ComputeDataType> acc(x_m_n.get_lengths(), x_m_n.get_strides());
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
|
||||
ComputeDataType gamma = ck_tile::type_convert<ComputeDataType>(gamma_n(n));
|
||||
if(use_model_sensitive_rmsnorm ==
|
||||
static_cast<int>(
|
||||
Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL)) // 0: for no specific model
|
||||
{
|
||||
acc(m, n) = x * divisor * gamma;
|
||||
}
|
||||
else if(use_model_sensitive_rmsnorm ==
|
||||
static_cast<int>(Rmsnorm2dSensitiveEnum::T5_MODEL_LIKE)) // 1: for T5-like model
|
||||
{
|
||||
if constexpr(std::is_same_v<XDataType, ck_tile::bf16_t>)
|
||||
{
|
||||
const auto tmp0 = float_to_bf16<bf16_rounding_mode::standard>(x * divisor);
|
||||
const auto tmp1 = float_to_bf16<bf16_rounding_mode::standard>(
|
||||
type_convert<ComputeDataType>(tmp0) * gamma);
|
||||
const auto rmsn_ = type_convert<ComputeDataType>(tmp1);
|
||||
acc(m, n) = rmsn_;
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto tmp = type_convert<XDataType>(x * divisor);
|
||||
const auto rmsn_ = type_convert<ComputeDataType>(tmp) * gamma;
|
||||
acc(m, n) = rmsn_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(!std::is_same_v<UnquantYDataType, ck_tile::null_type>)
|
||||
{
|
||||
epilogue_functor(m, unquant_y_m_n, y_m_n, acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
epilogue_functor(m, y_m_n, acc);
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(rmsnorm2d_fwd_func, invRms_m.mDesc.get_lengths()[0])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,33 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include <thread>
|
||||
|
||||
namespace ck_tile {
|
||||
template <typename XDataType, typename ScaleDataType, typename QXDataType>
|
||||
CK_TILE_HOST void reference_rowwise_quantization2d(const HostTensor<XDataType>& x_m_n,
|
||||
const HostTensor<ScaleDataType>& scale_m,
|
||||
HostTensor<QXDataType>& qx_m_n)
|
||||
{
|
||||
auto f = [&](auto m) {
|
||||
const int N = x_m_n.mDesc.get_lengths()[1];
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
auto v_x = x_m_n(m, n);
|
||||
// scale = amax / 127 for int8
|
||||
auto v_scale = type_convert<XDataType>(scale_m(m));
|
||||
auto v_qx = v_x / v_scale;
|
||||
qx_m_n(m, n) = type_convert<QXDataType>(saturates<QXDataType>{}(v_qx));
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f,
|
||||
scale_m.mDesc.get_lengths()[0])(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
89
include/ck_tile/host/reference/reference_softmax.hpp
Normal file
89
include/ck_tile/host/reference/reference_softmax.hpp
Normal file
@@ -0,0 +1,89 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include <thread>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename InputType, typename ComputeType, typename OutputType = ComputeType>
|
||||
CK_TILE_HOST void
|
||||
reference_softmax(const HostTensor<InputType>& x, HostTensor<OutputType>& y, index_t dim = -1)
|
||||
{
|
||||
index_t rank = x.get_num_of_dimension();
|
||||
assert(static_cast<std::size_t>(rank) == y.get_num_of_dimension());
|
||||
assert(dim == -1 || dim < rank);
|
||||
|
||||
index_t target_dim = dim == -1 ? (rank - 1) : dim;
|
||||
index_t softmax_len = x.get_length(target_dim);
|
||||
index_t n_parallel = x.get_element_size() / softmax_len;
|
||||
auto x_len = x.get_lengths();
|
||||
|
||||
auto f = [&](auto i_element) {
|
||||
std::vector<size_t> coord = [&]() {
|
||||
std::vector<size_t> t_(rank, 0);
|
||||
size_t r = i_element;
|
||||
for(index_t i = rank - 1; i >= 0; i--)
|
||||
{
|
||||
if(i == target_dim)
|
||||
continue;
|
||||
t_[i] = r % x_len[i];
|
||||
r = r / x_len[i];
|
||||
}
|
||||
return t_;
|
||||
}();
|
||||
|
||||
ComputeType v_max = -ck_tile::numeric<ComputeType>::infinity();
|
||||
|
||||
// compute max
|
||||
for(auto idx = 0; idx < softmax_len; idx++)
|
||||
{
|
||||
auto c_ = coord;
|
||||
c_[target_dim] = idx;
|
||||
const ComputeType v_x = ck_tile::type_convert<ComputeType>(x(c_));
|
||||
v_max = v_max < v_x ? v_x : v_max;
|
||||
}
|
||||
|
||||
ComputeType v_exp_sum = static_cast<ComputeType>(0);
|
||||
|
||||
// sum
|
||||
for(auto idx = 0; idx < softmax_len; idx++)
|
||||
{
|
||||
auto c_ = coord;
|
||||
c_[target_dim] = idx;
|
||||
|
||||
const ComputeType v_x = ck_tile::type_convert<ComputeType>(x(c_));
|
||||
|
||||
v_exp_sum += ck_tile::exp(v_x - v_max);
|
||||
}
|
||||
|
||||
// elementwise
|
||||
for(auto idx = 0; idx < softmax_len; idx++)
|
||||
{
|
||||
auto c_ = coord;
|
||||
c_[target_dim] = idx;
|
||||
|
||||
const ComputeType v_x = ck_tile::type_convert<ComputeType>(x(c_));
|
||||
|
||||
auto out = ck_tile::exp(v_x - v_max) / v_exp_sum;
|
||||
|
||||
y(c_) = ck_tile::type_convert<OutputType>(out);
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f, n_parallel)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
template <typename InputType, typename ComputeType, typename OutputType = ComputeType>
|
||||
CK_TILE_HOST auto reference_softmax(const HostTensor<InputType>& x, index_t dim = -1)
|
||||
{
|
||||
HostTensor<OutputType> y(x.get_lengths(), x.get_strides());
|
||||
|
||||
reference_softmax<InputType, ComputeType, OutputType>(x, y, dim);
|
||||
|
||||
return y;
|
||||
}
|
||||
} // namespace ck_tile
|
||||
125
include/ck_tile/host/reference/reference_topk.hpp
Normal file
125
include/ck_tile/host/reference/reference_topk.hpp
Normal file
@@ -0,0 +1,125 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include <thread>
|
||||
#include <numeric>
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/*
|
||||
similiar to torch.topk()
|
||||
x (Tensor) – the input tensor.
|
||||
k (int) – the k in “top-k”
|
||||
dim (int, optional) – the dimension to sort along
|
||||
largest (bool, optional) – largest or smallest elements
|
||||
sorted (bool, optional) – elements in sorted order or not
|
||||
|
||||
output:
|
||||
y_values
|
||||
y_indices
|
||||
|
||||
https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/TopKImpl.h
|
||||
*/
|
||||
template <typename DataType, typename IndexType = index_t>
|
||||
CK_TILE_HOST void reference_topk(const HostTensor<DataType>& x,
|
||||
HostTensor<DataType>& y_values,
|
||||
HostTensor<IndexType>& y_indices,
|
||||
index_t k,
|
||||
index_t dim = -1,
|
||||
bool largest = true,
|
||||
bool sorted = true)
|
||||
{
|
||||
// rank must be the same
|
||||
index_t rank = x.get_num_of_dimension();
|
||||
assert(static_cast<std::size_t>(rank) == y_values.get_num_of_dimension());
|
||||
assert(static_cast<size_t>(rank) == y_indices.get_num_of_dimension());
|
||||
assert(dim == -1 || dim < rank);
|
||||
|
||||
index_t topk_dim = dim == -1 ? (rank - 1) : dim;
|
||||
index_t topk_src_len = x.get_length(topk_dim);
|
||||
auto x_len = x.get_lengths();
|
||||
|
||||
assert(k <= topk_src_len);
|
||||
assert(static_cast<size_t>(k) == y_values.get_length(topk_dim) &&
|
||||
static_cast<size_t>(k) == y_indices.get_length(topk_dim));
|
||||
|
||||
index_t n_parallel = x.get_element_size() / topk_src_len;
|
||||
|
||||
// clang-format off
|
||||
auto f = [&](auto i_element) {
|
||||
std::vector<size_t> topk_coord = [&](){
|
||||
std::vector<size_t> t_(rank, 0);
|
||||
size_t r = i_element;
|
||||
for(index_t i = rank - 1; i >= 0; i--) {
|
||||
if(i == topk_dim) continue; // topk dim should be zero
|
||||
t_[i] = r % x_len[i]; r = r / x_len[i];
|
||||
}
|
||||
return t_;
|
||||
}();
|
||||
|
||||
using elem_t = std::pair<DataType, IndexType>;
|
||||
std::vector<elem_t> q = [&](){
|
||||
std::vector<elem_t> t_(topk_src_len);
|
||||
for(index_t i = 0; i < topk_src_len; i++) {
|
||||
auto c_ = topk_coord; c_[topk_dim] = i;
|
||||
t_[i].first = x(c_); t_[i].second = i;
|
||||
}
|
||||
return t_;
|
||||
}();
|
||||
|
||||
// run topk
|
||||
if(largest) {
|
||||
std::nth_element(q.begin(), q.begin() + k - 1, q.end(),
|
||||
[](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first > rhs.first; });
|
||||
if(sorted) {
|
||||
std::sort(q.begin(), q.begin() + k - 1,
|
||||
[](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first > rhs.first; });
|
||||
}
|
||||
} else {
|
||||
std::nth_element(q.begin(), q.begin() + k - 1, q.end(),
|
||||
[](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first < rhs.first; });
|
||||
if(sorted) {
|
||||
std::sort(q.begin(), q.begin() + k - 1,
|
||||
[](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first < rhs.first; });
|
||||
}
|
||||
}
|
||||
|
||||
// write out
|
||||
for(index_t i = 0; i < k; i++) {
|
||||
auto c_ = topk_coord; c_[topk_dim] = i;
|
||||
y_values(c_) = q[i].first; y_indices(c_) = q[i].second;
|
||||
}
|
||||
};
|
||||
// clang-format on
|
||||
|
||||
make_ParallelTensorFunctor(f, n_parallel)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
// TODO: if using this method, the return tensor would be dense(no stride)
|
||||
template <typename DataType, typename IndexType = index_t>
|
||||
CK_TILE_HOST auto reference_topk(const HostTensor<DataType>& x,
|
||||
index_t k,
|
||||
index_t dim = -1,
|
||||
bool largest = true,
|
||||
bool sorted = true)
|
||||
{
|
||||
auto lens = x.get_lengths();
|
||||
index_t target_dim = (dim == -1) ? (lens.size() - 1) : dim;
|
||||
assert(target_dim < lens.size());
|
||||
assert(k <= lens[target_dim]);
|
||||
lens[target_dim] = k;
|
||||
HostTensor<DataType> y_values(lens);
|
||||
HostTensor<IndexType> y_indices(lens);
|
||||
|
||||
reference_topk<DataType, IndexType>(x, y_values, y_indices, k, dim, largest, sorted);
|
||||
|
||||
return ck_tile::make_tuple(y_values, y_indices);
|
||||
}
|
||||
} // namespace ck_tile
|
||||
33
include/ck_tile/host/reference/reference_transpose.hpp
Normal file
33
include/ck_tile/host/reference/reference_transpose.hpp
Normal file
@@ -0,0 +1,33 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include <thread>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType, typename BDataType>
|
||||
void reference_transpose_elementwise(const HostTensor<ADataType>& a, HostTensor<BDataType>& b)
|
||||
{
|
||||
ck_tile::index_t M = static_cast<ck_tile::index_t>(a.mDesc.get_lengths()[0]);
|
||||
ck_tile::index_t N = static_cast<ck_tile::index_t>(a.mDesc.get_lengths()[1]);
|
||||
|
||||
// Ensure the b tensor is sized correctly for N x M
|
||||
if(static_cast<ck_tile::index_t>(b.mDesc.get_lengths()[0]) != N ||
|
||||
static_cast<ck_tile::index_t>(b.mDesc.get_lengths()[1]) != M)
|
||||
{
|
||||
throw std::runtime_error("Output tensor b has incorrect dimensions for transpose.");
|
||||
}
|
||||
|
||||
auto f = [&](auto i, auto j) {
|
||||
auto v_a = a(i, j);
|
||||
b(j, i) = ck_tile::type_convert<BDataType>(v_a);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f, M, N)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user