This commit is contained in:
Ding, Yi
2026-03-11 23:03:20 -04:00
commit e6cd3f1e3f
6330 changed files with 1132789 additions and 0 deletions

View File

@@ -0,0 +1,275 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <cstdlib>
#include <thread>
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace ck_tile {
// Helper to apply elementwise operation with variable number of D tensors
template <typename EDataType, typename AccDataType, typename CDEElementWise>
struct ApplyCDEElementWise
{
template <typename... DValues>
CK_TILE_HOST_DEVICE static void apply(EDataType& result,
AccDataType sum,
const CDEElementWise& cde_elementwise,
DValues... d_vals)
{
if constexpr(sizeof...(DValues) == 0)
{
result = static_cast<EDataType>(sum);
}
else
{
cde_elementwise(
result, ck_tile::type_convert<float>(sum), ck_tile::type_convert<float>(d_vals)...);
}
}
};
// Helper to extract D values at a given offset using index sequence
template <typename DDataType,
ck_tile::index_t NumDTensor,
typename Indices = std::make_index_sequence<NumDTensor>>
struct ExtractDValues;
template <typename DDataType, ck_tile::index_t NumDTensor, std::size_t... Is>
struct ExtractDValues<DDataType, NumDTensor, std::index_sequence<Is...>>
{
template <typename EDataType, typename AccDataType, typename CDEElementWise>
CK_TILE_HOST static void
apply_at_offsets(EDataType& result,
AccDataType sum,
const CDEElementWise& cde_elementwise,
const std::array<ck_tile::HostTensor<DDataType>, NumDTensor>& ds_tensors,
const std::array<std::size_t, NumDTensor>& d_offsets)
{
ApplyCDEElementWise<EDataType, AccDataType, CDEElementWise>::apply(
result, sum, cde_elementwise, ds_tensors[Is].mData[d_offsets[Is]]...);
}
};
template <typename ADataType,
typename BDataType,
typename DDataType,
typename EDataType,
typename AccDataType,
typename CDEElementWise,
ck_tile::index_t NumDTensor>
void compute_reference_batched_contraction(
const ck_tile::HostTensor<ADataType>& a_full_dims,
const ck_tile::HostTensor<BDataType>& b_full_dims,
const std::array<ck_tile::HostTensor<DDataType>, NumDTensor>& ds_full_dims_host,
ck_tile::HostTensor<EDataType>& e_full_dims_host_ref,
ck_tile::index_t G_total,
ck_tile::index_t M_total,
ck_tile::index_t N_total,
ck_tile::index_t K_total,
const CDEElementWise& cde_elementwise,
const std::vector<ck_tile::index_t>& G_dims,
const std::vector<ck_tile::index_t>& M_dims,
const std::vector<ck_tile::index_t>& N_dims,
const std::vector<ck_tile::index_t>& K_dims)
{
std::cout << "Calculating reference using stride-aware indexing with parallel processing..."
<< std::endl;
// Extract stride information from tensor descriptors
const auto a_strides = a_full_dims.get_strides();
const auto b_strides = b_full_dims.get_strides();
const auto e_strides = e_full_dims_host_ref.get_strides();
// Extract D tensor strides
std::array<std::vector<std::size_t>, NumDTensor> ds_strides;
for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
{
ds_strides[d] = ds_full_dims_host[d].get_strides();
}
const ck_tile::index_t num_g_dims = G_dims.size();
const ck_tile::index_t num_m_dims = M_dims.size();
const ck_tile::index_t num_n_dims = N_dims.size();
const ck_tile::index_t num_k_dims = K_dims.size();
// Helper lambda to compute linear index from flat indices using strides
auto compute_a_offset = [&](ck_tile::index_t g_flat,
ck_tile::index_t m_flat,
ck_tile::index_t k_flat) -> std::size_t {
std::size_t offset = 0;
// Decode G dimensions
ck_tile::index_t temp = g_flat;
for(int i = num_g_dims - 1; i >= 0; --i)
{
offset += (temp % G_dims[i]) * a_strides[i];
temp /= G_dims[i];
}
// Decode M dimensions
temp = m_flat;
for(int i = num_m_dims - 1; i >= 0; --i)
{
offset += (temp % M_dims[i]) * a_strides[num_g_dims + i];
temp /= M_dims[i];
}
// Decode K dimensions
temp = k_flat;
for(int i = num_k_dims - 1; i >= 0; --i)
{
offset += (temp % K_dims[i]) * a_strides[num_g_dims + num_m_dims + i];
temp /= K_dims[i];
}
return offset;
};
auto compute_b_offset = [&](ck_tile::index_t g_flat,
ck_tile::index_t n_flat,
ck_tile::index_t k_flat) -> std::size_t {
std::size_t offset = 0;
// Decode G dimensions
ck_tile::index_t temp = g_flat;
for(int i = num_g_dims - 1; i >= 0; --i)
{
offset += (temp % G_dims[i]) * b_strides[i];
temp /= G_dims[i];
}
// Decode N dimensions
temp = n_flat;
for(int i = num_n_dims - 1; i >= 0; --i)
{
offset += (temp % N_dims[i]) * b_strides[num_g_dims + i];
temp /= N_dims[i];
}
// Decode K dimensions
temp = k_flat;
for(int i = num_k_dims - 1; i >= 0; --i)
{
offset += (temp % K_dims[i]) * b_strides[num_g_dims + num_n_dims + i];
temp /= K_dims[i];
}
return offset;
};
auto compute_e_offset = [&](ck_tile::index_t g_flat,
ck_tile::index_t m_flat,
ck_tile::index_t n_flat) -> std::size_t {
std::size_t offset = 0;
// Decode G dimensions
ck_tile::index_t temp = g_flat;
for(int i = num_g_dims - 1; i >= 0; --i)
{
offset += (temp % G_dims[i]) * e_strides[i];
temp /= G_dims[i];
}
// Decode M dimensions
temp = m_flat;
for(int i = num_m_dims - 1; i >= 0; --i)
{
offset += (temp % M_dims[i]) * e_strides[num_g_dims + i];
temp /= M_dims[i];
}
// Decode N dimensions
temp = n_flat;
for(int i = num_n_dims - 1; i >= 0; --i)
{
offset += (temp % N_dims[i]) * e_strides[num_g_dims + num_m_dims + i];
temp /= N_dims[i];
}
return offset;
};
// Helper to compute D tensor offset (D tensors have same shape as E: [G, M, N])
auto compute_d_offset = [&](ck_tile::index_t g_flat,
ck_tile::index_t m_flat,
ck_tile::index_t n_flat,
ck_tile::index_t d_idx) -> std::size_t {
std::size_t offset = 0;
const auto& d_strides = ds_strides[d_idx];
// Decode G dimensions
ck_tile::index_t temp = g_flat;
for(int i = num_g_dims - 1; i >= 0; --i)
{
offset += (temp % G_dims[i]) * d_strides[i];
temp /= G_dims[i];
}
// Decode M dimensions
temp = m_flat;
for(int i = num_m_dims - 1; i >= 0; --i)
{
offset += (temp % M_dims[i]) * d_strides[num_g_dims + i];
temp /= M_dims[i];
}
// Decode N dimensions
temp = n_flat;
for(int i = num_n_dims - 1; i >= 0; --i)
{
offset += (temp % N_dims[i]) * d_strides[num_g_dims + num_m_dims + i];
temp /= N_dims[i];
}
return offset;
};
// Parallel computation over G and M dimensions
auto f_gm = [&](auto g_flat, auto m_flat) {
for(ck_tile::index_t n_flat = 0; n_flat < N_total; ++n_flat)
{
AccDataType sum = 0;
// Compute dot product over K dimension using stride-aware indexing
for(ck_tile::index_t k_flat = 0; k_flat < K_total; ++k_flat)
{
const std::size_t a_offset = compute_a_offset(g_flat, m_flat, k_flat);
const std::size_t b_offset = compute_b_offset(g_flat, n_flat, k_flat);
auto a_val = a_full_dims.mData[a_offset];
auto b_val = b_full_dims.mData[b_offset];
sum += static_cast<AccDataType>(a_val) * static_cast<AccDataType>(b_val);
}
// Compute output offset using strides
const std::size_t e_offset = compute_e_offset(g_flat, m_flat, n_flat);
// Compute individual D tensor offsets using their respective strides
std::array<std::size_t, NumDTensor> d_offsets;
for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
{
d_offsets[d] = compute_d_offset(g_flat, m_flat, n_flat, d);
}
// Apply elementwise operation with D tensors using compile-time dispatch
EDataType result = static_cast<EDataType>(sum);
ExtractDValues<DDataType, NumDTensor>::apply_at_offsets(
result, sum, cde_elementwise, ds_full_dims_host, d_offsets);
// Store result using stride-aware indexing
e_full_dims_host_ref.mData[e_offset] = static_cast<EDataType>(result);
}
};
// Execute parallel computation using hardware concurrency
// Parallelize over G_total and M_total dimensions for optimal CPU utilization
make_ParallelTensorFunctor(f_gm, G_total, M_total)(std::thread::hardware_concurrency());
}
} // namespace ck_tile

View File

@@ -0,0 +1,33 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace ck_tile {
template <typename DataType, typename RandValOutputDataType>
CK_TILE_HOST void reference_batched_dropout(HostTensor<DataType>& in_out_b_m_n,
const HostTensor<RandValOutputDataType>& randval_b_m_n,
const uint8_t& p_undrop_in_uint8_t,
const float scale)
{
const int N = in_out_b_m_n.mDesc.get_lengths()[2];
auto f = [&](auto batch, auto m) {
for(int n = 0; n < N; ++n)
{
float tmp = ck_tile::type_convert<float>(in_out_b_m_n(batch, m, n)) * scale;
in_out_b_m_n(batch, m, n) = randval_b_m_n(batch, m, n) <= p_undrop_in_uint8_t
? ck_tile::type_convert<DataType>(tmp)
: DataType(0);
}
};
make_ParallelTensorFunctor(
f, randval_b_m_n.mDesc.get_lengths()[0], randval_b_m_n.mDesc.get_lengths()[1])(
std::thread::hardware_concurrency());
}
} // namespace ck_tile

View File

@@ -0,0 +1,74 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace ck_tile {
template <typename RandValOutputDataType>
CK_TILE_HOST void
reference_batched_dropout_randval(HostTensor<RandValOutputDataType>& randval_b_m_n,
index_t batch,
uint64_t drop_seed,
uint64_t drop_offset)
{
const index_t nhead = randval_b_m_n.mDesc.get_lengths()[0];
const index_t real_seqlen_q = randval_b_m_n.mDesc.get_lengths()[1];
const index_t real_seqlen_k = randval_b_m_n.mDesc.get_lengths()[2];
static_assert(std::is_same_v<RandValOutputDataType, uint8_t>);
// BlockDropout generates random numbers by 32x32 tiles. Even when warp gemm 16x16 is used, the
// order of values in the bigger 32x32 tile must be the same because fwd and bwd may use
// different warp gemms (16x16 or 32x32).
// To compute 32x32 tiles, WarpGemmMfmaF16F16F32M32N32K16SwizzleA is used. It is
// WarpGemmAttributeMfmaImplF16F16F32M32N32K8 with SFactor = 2 (swizzling factor).
// Matrix element to register mapping for WarpGemmAttributeMfmaImplF16F16F32M32N32K8:
// C i: (8 * floor(GPR_num / 4) % 32) + 4 * floor(lane / 32) + (GPR_num % 4)
// C j: (lane % 32)
// With SFactor = 2 it becomes:
// C i: (16 * floor(GPR_num / 8) % 32) + 8 * floor(lane / 32) + (GPR_num % 8)
// C j: (lane % 32)
// See ck_tile/ops/fmha/block/block_dropout.hpp for more details.
// The number of Philox 4x32 results required to fill 32x32 tile of 8-bit values
constexpr index_t philox_per_tile = 64;
constexpr index_t warp_gemm_mn = 32;
const index_t rows = integer_divide_ceil(real_seqlen_q, warp_gemm_mn);
const index_t cols = integer_divide_ceil(real_seqlen_k, warp_gemm_mn);
auto f = [&](index_t i_h, index_t row, index_t col) {
uint2 rowcol = make_uint2(row, col);
for(index_t lane = 0; lane < philox_per_tile; lane++)
{
const uint64_t ph_head_offset = drop_offset + (batch * nhead + i_h) * philox_per_tile;
const index_t ph_offset = lane;
philox ph(drop_seed, ph_head_offset + ph_offset);
uint8_t random_uint8_t[16];
ph.get_random_16x8(random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol));
for(auto r = 0; r < 16; r++)
{
index_t i = (16 * (r / 8) % 32) + 8 * (lane / 32) + (r % 8);
index_t j = (lane % 32);
index_t m = row * warp_gemm_mn + i;
index_t n = col * warp_gemm_mn + j;
if(m < real_seqlen_q && n < real_seqlen_k)
{
randval_b_m_n(i_h, m, n) = random_uint8_t[r];
}
}
}
};
make_ParallelTensorFunctor(f, nhead, rows, cols)(std::thread::hardware_concurrency());
}
} // namespace ck_tile

View File

@@ -0,0 +1,64 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace ck_tile {
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity,
typename BinaryElementOp = ck_tile::plus<AccDataType>>
CK_TILE_HOST void reference_batched_elementwise(const HostTensor<ADataType>& a_b_m_n,
const HostTensor<BDataType>& b_b_m_n,
HostTensor<CDataType>& c_b_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const BinaryElementOp& binary_element_op = {})
{
const ck_tile::index_t N = c_b_m_n.mDesc.get_lengths()[2];
const bool broadcast_a_dim_b = (a_b_m_n.get_lengths()[0] == 1);
const bool broadcast_a_dim_m = (a_b_m_n.get_lengths()[1] == 1);
const bool broadcast_a_dim_n = (a_b_m_n.get_lengths()[2] == 1);
const bool broadcast_b_dim_b = (b_b_m_n.get_lengths()[0] == 1);
const bool broadcast_b_dim_m = (b_b_m_n.get_lengths()[1] == 1);
const bool broadcast_b_dim_n = (b_b_m_n.get_lengths()[2] == 1);
auto f = [&](auto batch, auto m) {
for(ck_tile::index_t n = 0; n < N; ++n)
{
AccDataType v_a{};
{
ck_tile::index_t i_b = (broadcast_a_dim_b ? 0 : batch);
ck_tile::index_t i_m = (broadcast_a_dim_m ? 0 : m);
ck_tile::index_t i_n = (broadcast_a_dim_n ? 0 : n);
v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_b_m_n(i_b, i_m, i_n)));
}
AccDataType v_b{};
{
ck_tile::index_t i_b = (broadcast_b_dim_b ? 0 : batch);
ck_tile::index_t i_m = (broadcast_b_dim_m ? 0 : m);
ck_tile::index_t i_n = (broadcast_b_dim_n ? 0 : n);
v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_b_m_n(i_b, i_m, i_n)));
}
c_b_m_n(batch, m, n) = ck_tile::type_convert<CDataType>(binary_element_op(v_a, v_b));
}
};
make_ParallelTensorFunctor(f, c_b_m_n.mDesc.get_lengths()[0], c_b_m_n.mDesc.get_lengths()[1])(
std::thread::hardware_concurrency());
}
} // namespace ck_tile

View File

@@ -0,0 +1,90 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace ck_tile {
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity,
typename ACCElementOp = ck_tile::identity>
CK_TILE_HOST void reference_batched_gemm(const HostTensor<ADataType>& a_b_m_k,
const HostTensor<BDataType>& b_b_n_k,
HostTensor<CDataType>& c_b_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
{
const int N = b_b_n_k.mDesc.get_lengths()[1];
const int K = b_b_n_k.mDesc.get_lengths()[2];
auto f = [&](auto batch, auto m) {
for(int n = 0; n < N; ++n)
{
AccDataType v_acc = 0;
for(int k = 0; k < K; ++k)
{
ADataType v_a = a_element_op(a_b_m_k(batch, m, k));
BDataType v_b = b_element_op(b_b_n_k(batch, n, k));
v_acc += ck_tile::type_convert<AccDataType>(v_a) *
ck_tile::type_convert<AccDataType>(v_b);
}
c_b_m_n(batch, m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
}
};
make_ParallelTensorFunctor(f, c_b_m_n.mDesc.get_lengths()[0], c_b_m_n.mDesc.get_lengths()[1])(
std::thread::hardware_concurrency());
}
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename AElementOp = ck_tile::idx_identity,
typename BElementOp = ck_tile::idx_identity,
typename ACCElementOp = ck_tile::idx_identity>
CK_TILE_HOST void reference_batched_quant_gemm(const HostTensor<ADataType>& a_b_m_k,
const HostTensor<BDataType>& b_b_n_k,
HostTensor<CDataType>& c_b_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
{
const int N = b_b_n_k.mDesc.get_lengths()[1];
const int K = b_b_n_k.mDesc.get_lengths()[2];
auto f = [&](auto batch, auto m) {
for(int n = 0; n < N; ++n)
{
AccDataType v_acc = 0;
for(int k = 0; k < K; ++k)
{
AccDataType v_a = ck_tile::type_convert<AccDataType>(
a_element_op(std::make_tuple(batch, m, k), a_b_m_k(batch, m, k)));
AccDataType v_b = ck_tile::type_convert<AccDataType>(
b_element_op(std::make_tuple(batch, n, k), b_b_n_k(batch, n, k)));
v_acc += v_a * v_b;
}
c_b_m_n(batch, m, n) = ck_tile::type_convert<CDataType>(
acc_element_op(std::make_tuple(batch, m, n), v_acc));
}
};
make_ParallelTensorFunctor(f, c_b_m_n.mDesc.get_lengths()[0], c_b_m_n.mDesc.get_lengths()[1])(
std::thread::hardware_concurrency());
}
} // namespace ck_tile

View File

@@ -0,0 +1,32 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace ck_tile {
template <typename CDataType, typename MaskingType>
CK_TILE_HOST void reference_batched_masking(HostTensor<CDataType>& c_b_m_n, const MaskingType& mask)
{
const int M = c_b_m_n.mDesc.get_lengths()[1];
const int N = c_b_m_n.mDesc.get_lengths()[2];
auto f = [&](auto batch) {
for(int n = 0; n < N; ++n)
{
for(int m = 0; m < M; ++m)
{
if(mask.IsOutOfSinkBound(m, n))
c_b_m_n(batch, m, n) = -ck_tile::numeric<CDataType>::infinity();
}
}
};
make_ParallelTensorFunctor(f,
c_b_m_n.mDesc.get_lengths()[0])(std::thread::hardware_concurrency());
}
} // namespace ck_tile

View File

@@ -0,0 +1,61 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace ck_tile {
template <typename InDataType,
typename ScaleDataType,
typename OutDataType,
typename ComputeDataType>
CK_TILE_HOST HostTensor<OutDataType>
reference_batched_mx_descale(const HostTensor<InDataType>& a_b_m_k,
const HostTensor<ScaleDataType>& scales_b_m_ks,
const std::size_t scale_granularity)
{
const std::size_t B = a_b_m_k.get_length(0);
const std::size_t M = a_b_m_k.get_length(1);
const std::size_t K = a_b_m_k.get_length(2);
HostTensor<ComputeDataType> a_b_m_k_scaled(a_b_m_k.get_lengths());
auto f = [&](auto batch) {
constexpr index_t packed_size = ck_tile::numeric_traits<InDataType>::PackedSize;
for(std::size_t m = 0; m < M; ++m)
{
for(std::size_t k = 0; k < K; k += packed_size)
{
const auto scale = ck_tile::type_convert<ComputeDataType>(
scales_b_m_ks(batch, m, k / scale_granularity));
if constexpr(std::is_same_v<InDataType, pk_fp4_t>)
{
auto a_f4x2 = a_b_m_k(batch, m, k);
auto a_f4_lo = ck_tile::type_convert<ComputeDataType>(
a_f4x2.template unpack<>(number<0>{}));
auto a_f4_hi = ck_tile::type_convert<ComputeDataType>(
a_f4x2.template unpack<>(number<1>{}));
a_b_m_k_scaled(batch, m, k) = a_f4_lo * scale;
a_b_m_k_scaled(batch, m, k + 1) = a_f4_hi * scale;
}
else
{
a_b_m_k_scaled(batch, m, k) =
ck_tile::type_convert<ComputeDataType>(a_b_m_k(batch, m, k)) * scale;
}
}
}
};
make_ParallelTensorFunctor(f, B)(std::thread::hardware_concurrency());
return a_b_m_k_scaled;
}
} // namespace ck_tile

View File

@@ -0,0 +1,73 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <cassert>
#include <thread>
namespace ck_tile {
template <typename DataType, typename ComputeDataType = float>
CK_TILE_HOST void reference_batched_rotary_position_embedding(const HostTensor<DataType>& input_bsd,
const HostTensor<DataType>& cos_sd,
const HostTensor<DataType>& sin_sd,
bool interleaved,
HostTensor<DataType>& output_bsd,
bool use_1_row_sin_cos = false)
{
assert(cos_sd.get_num_of_dimension() == 2 && sin_sd.get_num_of_dimension() == 2);
assert(cos_sd.get_length(0) == sin_sd.get_length(0) &&
cos_sd.get_length(1) == sin_sd.get_length(1));
const index_t rotary_dim = cos_sd.get_length(1) * 2;
assert(static_cast<std::size_t>(rotary_dim) <= input_bsd.get_length(2));
output_bsd.ForEach([&](auto& self, auto i) {
const index_t i_d = i[2];
if(rotary_dim <= i_d)
{
self(i) = input_bsd(i);
return;
}
assert(i_d < rotary_dim);
const index_t i_s = i[1];
const index_t i_s_cos_sin = (use_1_row_sin_cos ? 0 : i_s);
const ComputeDataType cos = type_convert<ComputeDataType>(
interleaved ? cos_sd(i_s_cos_sin, i_d / 2)
: cos_sd(i_s_cos_sin, i_d % cos_sd.get_length(1)));
const ComputeDataType sin = type_convert<ComputeDataType>(
interleaved ? sin_sd(i_s_cos_sin, i_d / 2)
: sin_sd(i_s_cos_sin, i_d % sin_sd.get_length(1)));
const ComputeDataType half_rotated_input = [&] {
const index_t i_b = i[0];
if(interleaved)
{
const bool is_even = (i_d % 2 == 0);
const index_t pos = i_d + (is_even ? 1 : -1);
const ComputeDataType sign = (is_even ? -1 : 1);
return sign * type_convert<ComputeDataType>(input_bsd(i_b, i_s, pos));
}
else
{
const index_t half_rdim = (rotary_dim / 2);
const index_t pos = (i_d + half_rdim) % rotary_dim;
const ComputeDataType sign = (pos < half_rdim ? 1 : -1);
return sign * type_convert<ComputeDataType>(input_bsd(i_b, i_s, pos));
}
}();
ComputeDataType result =
type_convert<ComputeDataType>(input_bsd(i)) * cos + half_rotated_input * sin;
self(i) = type_convert<DataType>(result);
});
}
} // namespace ck_tile

View File

@@ -0,0 +1,71 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace ck_tile {
template <typename ADataType,
typename CompDataType,
typename BDataType,
typename CompElementOp = ck_tile::identity>
CK_TILE_HOST void reference_batched_softmax(
const HostTensor<ADataType>& a_b_m_n,
HostTensor<BDataType>& b_b_m_n,
const CompElementOp& comp_element_op = {},
std::optional<std::reference_wrapper<HostTensor<CompDataType>>> lse_b_m = std::nullopt)
{
const int N = a_b_m_n.mDesc.get_lengths()[2];
auto f = [&](auto batch, auto m) {
CompDataType v_max = -ck_tile::numeric<CompDataType>::infinity();
// max
for(int n = 0; n < N; ++n)
{
const CompDataType v_a = ck_tile::type_convert<CompDataType>(a_b_m_n(batch, m, n));
v_max = v_max < v_a ? v_a : v_max;
}
CompDataType v_exp_sum = 0;
// validate v_max if all the elements within a row are -INF
if(std::isinf(v_max) && v_max < 0)
{
v_max = ck_tile::type_convert<CompDataType>(0.f);
}
// sum
for(int n = 0; n < N; ++n)
{
const CompDataType v_a = ck_tile::type_convert<CompDataType>(a_b_m_n(batch, m, n));
v_exp_sum += ck_tile::exp(v_a - v_max);
}
// if sum is zero(masked), or nan/inf(other computation error), don't do divide
CompDataType inv_sum = (v_exp_sum == 0.f ? 1.f : 1.f / v_exp_sum);
// elementwise
for(int n = 0; n < N; ++n)
{
const CompDataType v_a = ck_tile::type_convert<CompDataType>(a_b_m_n(batch, m, n));
const CompDataType v_b = ck_tile::exp(v_a - v_max) * inv_sum;
b_b_m_n(batch, m, n) = ck_tile::type_convert<BDataType>(comp_element_op(v_b));
}
// lse
if(lse_b_m)
{
lse_b_m->get()(batch, m) = v_max + ck_tile::log(v_exp_sum);
}
};
make_ParallelTensorFunctor(f, b_b_m_n.mDesc.get_lengths()[0], b_b_m_n.mDesc.get_lengths()[1])(
std::thread::hardware_concurrency());
}
} // namespace ck_tile

View File

@@ -0,0 +1,59 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace ck_tile {
template <typename Type>
CK_TILE_HOST void reference_batched_transpose(const HostTensor<Type>& x,
HostTensor<Type>& y,
std::string layout_in = "NCHW",
std::string layout_out = "NHWC")
{
const int N = x.mDesc.get_lengths()[0];
auto f = [&](auto batch) {
if(layout_in == "NCHW" && layout_out == "NHWC")
{
const int C = x.mDesc.get_lengths()[1];
const int H = x.mDesc.get_lengths()[2];
const int W = x.mDesc.get_lengths()[3];
for(int c = 0; c < C; ++c)
{
for(int h = 0; h < H; ++h)
{
for(int w = 0; w < W; ++w)
{
Type v_x = x(batch, c, h, w);
y(batch, h, w, c) = v_x;
}
}
}
}
else if(layout_in == "NHWC" && layout_out == "NCHW")
{
const int H = x.mDesc.get_lengths()[1];
const int W = x.mDesc.get_lengths()[2];
const int C = x.mDesc.get_lengths()[3];
for(int h = 0; h < H; ++h)
{
for(int w = 0; w < W; ++w)
{
for(int c = 0; c < C; ++c)
{
Type v_x = x(batch, h, w, c);
y(batch, c, h, w) = v_x;
}
}
}
}
};
make_ParallelTensorFunctor(f, N)(std::thread::hardware_concurrency());
}
} // namespace ck_tile

View File

@@ -0,0 +1,156 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <algorithm>
#include <cmath>
#include <limits>
#include <vector>
#include "ck_tile/core.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace ck_tile {
template <typename AccT, typename T>
CK_TILE_HOST_DEVICE constexpr AccT to_acc(T value)
{
if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
{
#if CK_TILE_USE_CUSTOM_DATA_TYPE
return static_cast<AccT>(value);
#else
return static_cast<AccT>(
ck_tile::bf16_to_float_raw(ck_tile::bit_cast<ck_tile::bf16_raw_t>(value)));
#endif
}
else
{
return static_cast<AccT>(value);
}
}
// Reference implementation: blocked attention (for sparse attention tests).
template <typename T, typename MaskT, typename AccT = float>
void reference_blocked_attention(
const HostTensor<T>& q, // [B, H, S_q, D]
const HostTensor<T>& k, // [B, H, S_k, D]
const HostTensor<T>& v, // [B, H, S_k, D_v]
const HostTensor<MaskT>& block_relation, // [B, H, Q_blocks, K_blocks]
HostTensor<T>& output, // [B, H, S_q, D_v]
index_t BLKQ,
index_t BLKK,
AccT scale)
{
auto q_lengths = q.get_lengths();
index_t batch = q_lengths[0];
index_t nhead = q_lengths[1];
index_t seqlen_q = q_lengths[2];
index_t hdim = q_lengths[3];
auto v_lengths = v.get_lengths();
index_t seqlen_k = v_lengths[2];
index_t hdim_v = v_lengths[3];
index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ;
index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK;
for(index_t b = 0; b < batch; ++b)
{
for(index_t h = 0; h < nhead; ++h)
{
for(index_t qb = 0; qb < num_q_blocks; ++qb)
{
index_t q_start = qb * BLKQ;
if(q_start >= seqlen_q)
{
continue;
}
index_t q_end = std::min<index_t>(q_start + BLKQ, seqlen_q);
std::vector<index_t> relevant_k_indices;
for(index_t kb = 0; kb < num_k_blocks; ++kb)
{
// Treat block_relation as boolean; >0.5 marks an active block.
if(static_cast<float>(block_relation(b, h, qb, kb)) > 0.5f)
{
relevant_k_indices.push_back(kb);
}
}
if(relevant_k_indices.empty())
{
continue;
}
for(index_t sq = q_start; sq < q_end; ++sq)
{
std::vector<AccT> scores;
AccT max_score = -std::numeric_limits<AccT>::infinity();
for(auto kb : relevant_k_indices)
{
index_t k_start = kb * BLKK;
if(k_start >= seqlen_k)
{
continue;
}
index_t k_end = std::min<index_t>(k_start + BLKK, seqlen_k);
for(index_t sk = k_start; sk < k_end; ++sk)
{
AccT score = 0.0f;
for(index_t d = 0; d < hdim; ++d)
{
score +=
to_acc<AccT>(q(b, h, sq, d)) * to_acc<AccT>(k(b, h, sk, d));
}
score = score * scale;
scores.push_back(score);
max_score = std::max(max_score, score);
}
}
AccT sum_exp = 0.0f;
for(auto& s : scores)
{
s = std::exp(s - max_score);
sum_exp += s;
}
for(auto& s : scores)
{
s /= sum_exp;
}
for(index_t dv = 0; dv < hdim_v; ++dv)
{
AccT out_val = 0.0f;
size_t score_idx = 0;
for(auto kb : relevant_k_indices)
{
index_t k_start = kb * BLKK;
if(k_start >= seqlen_k)
{
continue;
}
index_t k_end = std::min<index_t>(k_start + BLKK, seqlen_k);
for(index_t sk = k_start; sk < k_end; ++sk)
{
out_val += scores[score_idx] * to_acc<AccT>(v(b, h, sk, dv));
score_idx++;
}
}
output(b, h, sq, dv) = static_cast<T>(out_val);
}
}
}
}
}
}
} // namespace ck_tile

View File

@@ -0,0 +1,47 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace ck_tile {
template <typename ADataType, typename BDataType, typename ComputeDataType, typename ElementOp>
CK_TILE_HOST void reference_unary_elementwise(const HostTensor<ADataType>& a,
HostTensor<BDataType>& b,
ElementOp element_op)
{
// TODO: imeplement gpu version reference function
auto f = [&](auto i) {
auto v_a = type_convert<ComputeDataType>(a.mData[i]);
auto v_b = element_op(v_a);
b.mData[i] = ck_tile::type_convert<BDataType>(v_b);
};
make_ParallelTensorFunctor(f, b.get_element_space_size())(std::thread::hardware_concurrency());
}
template <typename ADataType,
typename BDataType,
typename CDataType,
typename ComputeDataType,
typename ElementOp>
CK_TILE_HOST void reference_binary_elementwise(const HostTensor<ADataType>& a,
const HostTensor<BDataType>& b,
HostTensor<CDataType>& c,
ElementOp element_op)
{
// TODO: imeplement gpu version reference function
auto f = [&](auto i) {
auto v_a = type_convert<ComputeDataType>(a.mData[i]);
auto v_b = type_convert<ComputeDataType>(b.mData[i]);
auto v_c = element_op(v_a, v_b);
c.mData[i] = ck_tile::type_convert<CDataType>(v_c);
};
make_ParallelTensorFunctor(f, c.get_element_space_size())(std::thread::hardware_concurrency());
}
} // namespace ck_tile

View File

@@ -0,0 +1,205 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace ck_tile {
// [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices
// each slice map to one expert, and one expert can have multiple slices
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float
// number)
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
// max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6,
// 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4
// -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *,
// c, f, i, o]
//
// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
//
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
// * length is (max_num_tokens_padded + block_size - 1) / block_size
///
// num_tokens_post_padded_ptr : [28]
// num_sorted_tiles_ptr : [7]
template <typename AccDataType, // you only need to explcitly set this one
typename Activation, // ck_tile::element_wise::Gelu
typename ADataType,
typename GDataType,
typename DDataType,
typename ODataType,
typename AScaleDataType,
typename GScaleDataType,
typename DScaleDataType,
typename YSmoothScaleDataType,
typename TopkWeightDataType,
typename IndexDataType>
void reference_fused_moe(
const ck_tile::HostTensor<ADataType>& a_host, // [tokens, hidden_size]
const ck_tile::HostTensor<GDataType>& g_host, // [experts, interme_size_0, hidden_size]
const ck_tile::HostTensor<DDataType>& d_host, // [experts, hidden_size, interme_size_1]
const ck_tile::HostTensor<AScaleDataType>& sa_host, // [tokens, 1],
const ck_tile::HostTensor<GScaleDataType>& sg_host, // [experts, 1, interme_size_0]
const ck_tile::HostTensor<DScaleDataType>& sd_host, // [experts, 1, hidden_size],
const ck_tile::HostTensor<YSmoothScaleDataType>& sy_host, // [experts, 1, interme_size_0]
ck_tile::HostTensor<ODataType>& o_host, // [tokens, hidden_size]
const ck_tile::HostTensor<IndexDataType>& sorted_token_ids_host, // [max_num_tokens_padded]
const ck_tile::HostTensor<TopkWeightDataType>& sorted_weight_host, // [max_num_tokens_padded]
const ck_tile::HostTensor<IndexDataType>&
sorted_expert_ids_host, // [(max_num_tokens_padded + block_size - 1) / block_size]
const ck_tile::HostTensor<IndexDataType>& num_sorted_tiles_host, // [1]
const ck_tile::HostTensor<IndexDataType>&
token_ids_host, // [tokens, topk] --> ugly!!! remove in the future
ck_tile::index_t block_m,
ck_tile::index_t tokens,
ck_tile::index_t experts,
ck_tile::index_t hidden_size,
ck_tile::index_t intermediate_size, // this size is for gate/up/down
ck_tile::index_t topk,
ck_tile::index_t gate_only)
{
assert(sorted_token_ids_host.get_num_of_dimension() == 1);
assert(sorted_weight_host.get_num_of_dimension() == 1);
assert(sorted_expert_ids_host.get_num_of_dimension() == 1);
assert(num_sorted_tiles_host.get_element_size() == 1);
ck_tile::index_t num_sorted_tiles = num_sorted_tiles_host.mData[0] / block_m;
ck_tile::index_t intermediate_size_0 = intermediate_size * (gate_only ? 1 : 2);
ck_tile::index_t intermediate_size_1 = intermediate_size;
ck_tile::HostTensor<AccDataType> out_topk_tokens({tokens, topk, hidden_size});
int max_num_tokens_padded = topk * tokens + experts * block_m - topk;
// assert();
auto f = [&](auto i_flatten) {
ck_tile::index_t i_tile = i_flatten / block_m;
if(i_tile >= num_sorted_tiles)
return;
ck_tile::index_t i_expert = sorted_expert_ids_host.mData[i_tile];
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
ck_tile::index_t i_token = sorted_token_ids_host.mData[i_flatten];
ck_tile::index_t i_topk = i_token >> 24;
i_token &= 0xffffff;
if(i_token >= tokens)
return;
(void)token_ids_host;
#else
// TODO: better remove this in the future, or modify the token_id value
auto get_topk_id = [&](ck_tile::index_t token_id_, ck_tile::index_t expert_id_) {
for(ck_tile::index_t i_ = 0; i_ < topk; i_++)
{
if(token_ids_host(token_id_, i_) == expert_id_)
return i_;
}
throw std::runtime_error("not correct token/expert pair\n");
return -1; // TODO: not correct!!
};
ck_tile::index_t i_token = sorted_token_ids_host.mData[i_flatten];
if(i_token >= tokens)
return;
ck_tile::index_t i_topk = get_topk_id(i_token, i_expert); // TODO: ugly
#endif
auto weight = sorted_weight_host.mData[i_flatten];
ck_tile::HostTensor<AccDataType> acc_0({1, intermediate_size_0});
// first gemm
for(ck_tile::index_t i_n = 0; i_n < intermediate_size_0; i_n++)
{
AccDataType acc = static_cast<AccDataType>(0);
for(ck_tile::index_t i_k = 0; i_k < hidden_size; i_k++)
{
acc += type_convert<AccDataType>(a_host(i_token, i_k)) *
type_convert<AccDataType>(g_host(i_expert, i_n, i_k));
}
acc_0(0, i_n) = acc;
// printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, acc);
}
ck_tile::HostTensor<AccDataType> y({1, intermediate_size_1});
if(gate_only)
{
if(intermediate_size_1 != intermediate_size_0)
throw std::runtime_error(
"intermediate_size not correct, 0:" + std::to_string(intermediate_size_0) +
", 1:" + std::to_string(intermediate_size_1));
for(ck_tile::index_t i_n = 0; i_n < intermediate_size_1; i_n++)
{
Activation{}(y(0, i_n), acc_0(0, i_n));
// printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, y(0, i_n));
}
}
else
{
if(intermediate_size_1 * 2 != intermediate_size_0)
throw std::runtime_error(
"intermediate_size not correct, 0:" + std::to_string(intermediate_size_0) +
", 1:" + std::to_string(intermediate_size_1));
for(ck_tile::index_t i_n = 0; i_n < intermediate_size_1; i_n++)
{
AccDataType tmp;
Activation{}(tmp, acc_0(0, i_n));
y(0, i_n) = tmp * acc_0(0, i_n + intermediate_size_1); // TODO: elementwise mul
}
}
// second gemm, loop along gemm-n
ck_tile::HostTensor<AccDataType> acc_1({1, hidden_size});
for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++)
{
AccDataType acc = static_cast<AccDataType>(0);
for(ck_tile::index_t i_k = 0; i_k < intermediate_size_1; i_k++)
{
acc += y(0, i_k) * type_convert<AccDataType>(d_host(i_expert, i_n, i_k));
}
acc_1(0, i_n) = acc * weight; // multiple weight here
}
for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++)
{
out_topk_tokens(i_token, i_topk, i_n) = acc_1(0, i_n);
}
};
// make_ParallelTensorFunctor(f, max_num_tokens_padded)(std::thread::hardware_concurrency());
make_ParallelTensorFunctor(f, max_num_tokens_padded)(1);
// reduce
auto r = [&](auto i_token) {
for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++)
{
AccDataType acc = type_convert<AccDataType>(0);
for(ck_tile::index_t i_topk = 0; i_topk < topk; i_topk++)
{
acc += out_topk_tokens(i_token, i_topk, i_n);
}
o_host(i_token, i_n) = type_convert<ODataType>(acc);
}
};
make_ParallelTensorFunctor(r, tokens)(std::thread::hardware_concurrency());
(void)num_sorted_tiles_host;
(void)sa_host;
(void)sg_host;
(void)sd_host;
(void)sy_host;
}
} // namespace ck_tile

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,228 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <cinttypes>
#include <cstdlib>
#include <thread>
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace ck_tile {
template <ck_tile::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType>
CK_TILE_HOST void reference_grouped_conv_bwd_data(HostTensor<InDataType>& input,
const HostTensor<WeiDataType>& weight,
const HostTensor<OutDataType>& output,
std::vector<ck_tile::long_index_t> conv_strides,
std::vector<ck_tile::long_index_t> conv_dilations,
std::vector<ck_tile::long_index_t> in_left_pads,
std::vector<ck_tile::long_index_t>)
{
if(!(input.get_num_of_dimension() == NDimSpatial + 3 &&
weight.get_num_of_dimension() == NDimSpatial + 3 &&
output.get_num_of_dimension() == NDimSpatial + 3))
{
printf("%" PRIu64 " %" PRIu64 " %" PRIu64,
input.get_num_of_dimension(),
weight.get_num_of_dimension(),
output.get_num_of_dimension());
throw std::runtime_error("wrong! inconsistent dimension");
}
if constexpr(NDimSpatial == 1)
{
auto func = [&](auto g, auto n, auto c, auto wi) {
std::size_t K = weight.get_lengths()[1];
std::size_t X = weight.get_lengths()[3];
std::size_t Wo = output.get_lengths()[3];
float v_acc = 0;
for(std::size_t x = 0; x < X; ++x)
{
auto w_tmp = static_cast<ck_tile::long_index_t>(wi) +
static_cast<ck_tile::long_index_t>(in_left_pads[0]) -
static_cast<ck_tile::long_index_t>(x * conv_dilations[0]);
if(w_tmp % conv_strides[0] == 0)
{
auto wo = static_cast<ck_tile::long_index_t>(w_tmp) /
static_cast<ck_tile::long_index_t>(conv_strides[0]);
if(wo >= 0 && ck_tile::type_convert<std::size_t>(wo) < Wo)
{
for(std::size_t k = 0; k < K; ++k)
{
OutDataType v_out = output(g, n, k, wo);
WeiDataType v_wei = weight(g, k, c, x);
v_acc += ck_tile::type_convert<float>(v_out) *
ck_tile::type_convert<float>(v_wei);
}
}
}
}
InDataType v_acc_converted = ck_tile::type_convert<InDataType>(v_acc);
input(g, n, c, wi) = v_acc_converted;
};
make_ParallelTensorFunctor(func,
input.get_lengths()[0],
input.get_lengths()[1],
input.get_lengths()[2],
input.get_lengths()[3])(std::thread::hardware_concurrency());
}
else if constexpr(NDimSpatial == 2)
{
auto func = [&](auto g, auto n, auto c, auto hi, auto wi) {
std::size_t K = weight.get_lengths()[1];
std::size_t Y = weight.get_lengths()[3];
std::size_t X = weight.get_lengths()[4];
std::size_t Ho = output.get_lengths()[3];
std::size_t Wo = output.get_lengths()[4];
float v_acc = 0;
for(std::size_t y = 0; y < Y; ++y)
{
auto h_tmp = static_cast<ck_tile::long_index_t>(hi) +
static_cast<ck_tile::long_index_t>(in_left_pads[0]) -
static_cast<ck_tile::long_index_t>(y * conv_dilations[0]);
if(h_tmp % conv_strides[0] == 0)
{
auto ho = static_cast<ck_tile::long_index_t>(h_tmp) /
static_cast<ck_tile::long_index_t>(conv_strides[0]);
if(ho >= 0 && ck_tile::type_convert<std::size_t>(ho) < Ho)
{
for(std::size_t x = 0; x < X; ++x)
{
auto w_tmp = static_cast<ck_tile::long_index_t>(wi) +
static_cast<ck_tile::long_index_t>(in_left_pads[1]) -
static_cast<ck_tile::long_index_t>(x * conv_dilations[1]);
if(w_tmp % conv_strides[1] == 0)
{
auto wo = static_cast<ck_tile::long_index_t>(w_tmp) /
static_cast<ck_tile::long_index_t>(conv_strides[1]);
if(wo >= 0 && ck_tile::type_convert<std::size_t>(wo) < Wo)
{
for(std::size_t k = 0; k < K; ++k)
{
OutDataType v_out = output(g, n, k, ho, wo);
WeiDataType v_wei = weight(g, k, c, y, x);
v_acc += ck_tile::type_convert<float>(v_out) *
ck_tile::type_convert<float>(v_wei);
}
}
}
}
}
}
}
InDataType v_acc_converted = ck_tile::type_convert<InDataType>(v_acc);
input(g, n, c, hi, wi) = v_acc_converted;
};
make_ParallelTensorFunctor(func,
input.get_lengths()[0],
input.get_lengths()[1],
input.get_lengths()[2],
input.get_lengths()[3],
input.get_lengths()[4])(std::thread::hardware_concurrency());
}
else if constexpr(NDimSpatial == 3)
{
auto func = [&](auto g, auto n, auto c, auto di, auto hi, auto wi) {
std::size_t K = weight.get_lengths()[1];
std::size_t Z = weight.get_lengths()[3];
std::size_t Y = weight.get_lengths()[4];
std::size_t X = weight.get_lengths()[5];
std::size_t Do = output.get_lengths()[3];
std::size_t Ho = output.get_lengths()[4];
std::size_t Wo = output.get_lengths()[5];
float v_acc = 0;
for(std::size_t z = 0; z < Z; ++z)
{
auto d_tmp = static_cast<ck_tile::long_index_t>(di) +
static_cast<ck_tile::long_index_t>(in_left_pads[0]) -
static_cast<ck_tile::long_index_t>(z * conv_dilations[0]);
if(d_tmp % conv_strides[0] == 0)
{
auto do_ = static_cast<ck_tile::long_index_t>(d_tmp) /
static_cast<ck_tile::long_index_t>(conv_strides[0]);
if(do_ >= 0 && ck_tile::type_convert<std::size_t>(do_) < Do)
{
for(std::size_t y = 0; y < Y; ++y)
{
auto h_tmp = static_cast<ck_tile::long_index_t>(hi) +
static_cast<ck_tile::long_index_t>(in_left_pads[1]) -
static_cast<ck_tile::long_index_t>(y * conv_dilations[1]);
if(h_tmp % conv_strides[1] == 0)
{
auto ho = static_cast<ck_tile::long_index_t>(h_tmp) /
static_cast<ck_tile::long_index_t>(conv_strides[1]);
if(ho >= 0 && ck_tile::type_convert<std::size_t>(ho) < Ho)
{
for(std::size_t x = 0; x < X; ++x)
{
auto w_tmp =
static_cast<ck_tile::long_index_t>(wi) +
static_cast<ck_tile::long_index_t>(in_left_pads[2]) -
static_cast<ck_tile::long_index_t>(x *
conv_dilations[2]);
if(w_tmp % conv_strides[2] == 0)
{
auto wo =
static_cast<ck_tile::long_index_t>(w_tmp) /
static_cast<ck_tile::long_index_t>(conv_strides[2]);
if(wo >= 0 &&
ck_tile::type_convert<std::size_t>(wo) < Wo)
{
for(std::size_t k = 0; k < K; ++k)
{
OutDataType v_out =
output(g, n, k, do_, ho, wo);
WeiDataType v_wei = weight(g, k, c, z, y, x);
v_acc += ck_tile::type_convert<float>(v_out) *
ck_tile::type_convert<float>(v_wei);
}
}
}
}
}
}
}
}
}
}
InDataType v_acc_converted = ck_tile::type_convert<InDataType>(v_acc);
input(g, n, c, di, hi, wi) = v_acc_converted;
};
make_ParallelTensorFunctor(func,
input.get_lengths()[0],
input.get_lengths()[1],
input.get_lengths()[2],
input.get_lengths()[3],
input.get_lengths()[4],
input.get_lengths()[5])(std::thread::hardware_concurrency());
}
else
{
throw std::runtime_error(
"Ref_conv_bwd_data: number of dimensions must be between 1 and 3.");
}
}
} // namespace ck_tile

View File

@@ -0,0 +1,167 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <cstdlib>
#include <thread>
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace ck_tile {
template <ck_tile::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType>
CK_TILE_HOST void
reference_grouped_conv_bwd_weight(const HostTensor<InDataType>& input,
HostTensor<WeiDataType>& weight,
const HostTensor<OutDataType>& output,
std::vector<ck_tile::long_index_t> conv_strides,
std::vector<ck_tile::long_index_t> conv_dilations,
std::vector<ck_tile::long_index_t> in_left_pads,
std::vector<ck_tile::long_index_t>)
{
if(!(input.get_num_of_dimension() == NDimSpatial + 3 &&
weight.get_num_of_dimension() == NDimSpatial + 3 &&
output.get_num_of_dimension() == NDimSpatial + 3))
{
throw std::runtime_error("wrong! inconsistent dimension");
}
if constexpr(NDimSpatial == 1)
{
auto func = [&](auto g, auto k, auto c, auto x) {
float v_acc = 0;
for(std::size_t n = 0; n < output.get_lengths()[1]; ++n)
{
for(std::size_t wo = 0; wo < output.get_lengths()[3]; ++wo)
{
auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[0]) +
static_cast<ck_tile::long_index_t>(x * conv_dilations[0]) -
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
if(wi >= 0 && ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[3])
{
InDataType v_in = input(g, n, c, wi);
OutDataType v_out = output(g, n, k, wo);
v_acc += ck_tile::type_convert<float>(v_out) *
ck_tile::type_convert<float>(v_in);
}
}
}
OutDataType v_acc_converted = ck_tile::type_convert<WeiDataType>(v_acc);
weight(g, k, c, x) = v_acc_converted;
};
make_ParallelTensorFunctor(func,
weight.get_lengths()[0],
weight.get_lengths()[1],
weight.get_lengths()[2],
weight.get_lengths()[3])(std::thread::hardware_concurrency());
}
else if constexpr(NDimSpatial == 2)
{
auto func = [&](auto g, auto k, auto c, auto y, auto x) {
float v_acc = 0;
for(std::size_t n = 0; n < output.get_lengths()[1]; ++n)
{
for(std::size_t ho = 0; ho < output.get_lengths()[3]; ++ho)
{
auto hi = static_cast<ck_tile::long_index_t>(ho * conv_strides[0]) +
static_cast<ck_tile::long_index_t>(y * conv_dilations[0]) -
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
for(std::size_t wo = 0; wo < output.get_lengths()[4]; ++wo)
{
auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[1]) +
static_cast<ck_tile::long_index_t>(x * conv_dilations[1]) -
static_cast<ck_tile::long_index_t>(in_left_pads[1]);
if(hi >= 0 &&
ck_tile::type_convert<std::size_t>(hi) < input.get_lengths()[3] &&
wi >= 0 &&
ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[4])
{
InDataType v_in = input(g, n, c, hi, wi);
OutDataType v_out = output(g, n, k, ho, wo);
v_acc += ck_tile::type_convert<float>(v_out) *
ck_tile::type_convert<float>(v_in);
}
}
}
}
WeiDataType v_acc_converted = ck_tile::type_convert<WeiDataType>(v_acc);
weight(g, k, c, y, x) = v_acc_converted;
};
make_ParallelTensorFunctor(func,
weight.get_lengths()[0],
weight.get_lengths()[1],
weight.get_lengths()[2],
weight.get_lengths()[3],
weight.get_lengths()[4])(std::thread::hardware_concurrency());
}
else if constexpr(NDimSpatial == 3)
{
auto func = [&](auto g, auto k, auto c, auto z, auto y, auto x) {
float v_acc = 0;
for(std::size_t n = 0; n < output.get_lengths()[1]; ++n)
{
for(std::size_t do_ = 0; do_ < output.get_lengths()[3]; ++do_)
{
auto di = static_cast<ck_tile::long_index_t>(do_ * conv_strides[0]) +
static_cast<ck_tile::long_index_t>(z * conv_dilations[0]) -
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
for(std::size_t ho = 0; ho < output.get_lengths()[4]; ++ho)
{
auto hi = static_cast<ck_tile::long_index_t>(ho * conv_strides[1]) +
static_cast<ck_tile::long_index_t>(y * conv_dilations[1]) -
static_cast<ck_tile::long_index_t>(in_left_pads[1]);
for(std::size_t wo = 0; wo < output.get_lengths()[5]; ++wo)
{
auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[2]) +
static_cast<ck_tile::long_index_t>(x * conv_dilations[2]) -
static_cast<ck_tile::long_index_t>(in_left_pads[2]);
if(di >= 0 &&
ck_tile::type_convert<std::size_t>(di) < input.get_lengths()[3] &&
hi >= 0 &&
ck_tile::type_convert<std::size_t>(hi) < input.get_lengths()[4] &&
wi >= 0 &&
ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[5])
{
InDataType v_in = input(g, n, c, di, hi, wi);
OutDataType v_out = output(g, n, k, do_, ho, wo);
v_acc += ck_tile::type_convert<float>(v_out) *
ck_tile::type_convert<float>(v_in);
}
}
}
}
}
WeiDataType v_acc_converted = ck_tile::type_convert<WeiDataType>(v_acc);
weight(g, k, c, z, y, x) = v_acc_converted;
};
make_ParallelTensorFunctor(func,
weight.get_lengths()[0],
weight.get_lengths()[1],
weight.get_lengths()[2],
weight.get_lengths()[3],
weight.get_lengths()[4],
weight.get_lengths()[5])(std::thread::hardware_concurrency());
}
else
{
throw std::runtime_error(
"Ref_conv_bwd_weight: number of dimensions must be between 1 and 3.");
}
}
} // namespace ck_tile

View File

@@ -0,0 +1,182 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <cstdlib>
#include <thread>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace ck_tile {
template <ck_tile::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename Elfunc = ck_tile::element_wise::PassThrough,
typename Tuple = ck_tile::tuple<>>
CK_TILE_HOST void reference_grouped_conv_fwd(const HostTensor<InDataType>& input,
const HostTensor<WeiDataType>& weight,
HostTensor<OutDataType>& output,
std::vector<ck_tile::long_index_t> conv_strides,
std::vector<ck_tile::long_index_t> conv_dilations,
std::vector<ck_tile::long_index_t> in_left_pads,
std::vector<ck_tile::long_index_t>,
Elfunc elfunc = Elfunc{},
Tuple ds = {})
{
if(!(input.get_num_of_dimension() == NDimSpatial + 3 &&
weight.get_num_of_dimension() == NDimSpatial + 3 &&
output.get_num_of_dimension() == NDimSpatial + 3))
{
throw std::runtime_error("wrong! inconsistent dimension");
}
if constexpr(NDimSpatial == 1)
{
auto func = [&](auto g, auto n, auto k, auto wo) {
float v_acc = 0;
for(std::size_t c = 0; c < weight.get_lengths()[2]; ++c)
{
for(std::size_t x = 0; x < weight.get_lengths()[3]; ++x)
{
auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[0]) +
static_cast<ck_tile::long_index_t>(x * conv_dilations[0]) -
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
if(wi >= 0 && ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[3])
{
InDataType v_in = input(g, n, c, wi);
WeiDataType v_wei = weight(g, k, c, x);
v_acc += ck_tile::type_convert<float>(v_in) *
ck_tile::type_convert<float>(v_wei);
}
}
}
if constexpr(Tuple::size() > 0)
elfunc(v_acc, v_acc, ds.at(ck_tile::number<0>{})(g, n, k, wo));
else
elfunc(v_acc, v_acc);
OutDataType v_acc_out = ck_tile::type_convert<OutDataType>(v_acc);
output(g, n, k, wo) = v_acc_out;
};
make_ParallelTensorFunctor(func,
output.get_lengths()[0],
output.get_lengths()[1],
output.get_lengths()[2],
output.get_lengths()[3])(std::thread::hardware_concurrency());
}
else if constexpr(NDimSpatial == 2)
{
auto func = [&](auto g, auto n, auto k, auto ho, auto wo) {
float v_acc = 0;
for(std::size_t c = 0; c < weight.get_lengths()[2]; ++c)
{
for(std::size_t y = 0; y < weight.get_lengths()[3]; ++y)
{
auto hi = static_cast<ck_tile::long_index_t>(ho * conv_strides[0]) +
static_cast<ck_tile::long_index_t>(y * conv_dilations[0]) -
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
for(std::size_t x = 0; x < weight.get_lengths()[4]; ++x)
{
auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[1]) +
static_cast<ck_tile::long_index_t>(x * conv_dilations[1]) -
static_cast<ck_tile::long_index_t>(in_left_pads[1]);
if(hi >= 0 &&
ck_tile::type_convert<std::size_t>(hi) < input.get_lengths()[3] &&
wi >= 0 &&
ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[4])
{
InDataType v_in = input(g, n, c, hi, wi);
WeiDataType v_wei = weight(g, k, c, y, x);
v_acc += ck_tile::type_convert<float>(v_in) *
ck_tile::type_convert<float>(v_wei);
}
}
}
}
if constexpr(Tuple::size() > 0)
elfunc(v_acc, v_acc, ds.at(ck_tile::number<0>{})(g, n, k, ho, wo));
else
elfunc(v_acc, v_acc);
OutDataType v_acc_out = ck_tile::type_convert<OutDataType>(v_acc);
output(g, n, k, ho, wo) = v_acc_out;
};
make_ParallelTensorFunctor(func,
output.get_lengths()[0],
output.get_lengths()[1],
output.get_lengths()[2],
output.get_lengths()[3],
output.get_lengths()[4])(std::thread::hardware_concurrency());
}
else if constexpr(NDimSpatial == 3)
{
auto func = [&](auto g, auto n, auto k, auto d_o, auto ho, auto wo) {
float v_acc = 0;
for(std::size_t c = 0; c < weight.get_lengths()[2]; ++c)
{
for(std::size_t z = 0; z < weight.get_lengths()[3]; ++z)
{
auto di = static_cast<ck_tile::long_index_t>(d_o * conv_strides[0]) +
static_cast<ck_tile::long_index_t>(z * conv_dilations[0]) -
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
for(std::size_t y = 0; y < weight.get_lengths()[4]; ++y)
{
auto hi = static_cast<ck_tile::long_index_t>(ho * conv_strides[1]) +
static_cast<ck_tile::long_index_t>(y * conv_dilations[1]) -
static_cast<ck_tile::long_index_t>(in_left_pads[1]);
for(std::size_t x = 0; x < weight.get_lengths()[5]; ++x)
{
auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[2]) +
static_cast<ck_tile::long_index_t>(x * conv_dilations[2]) -
static_cast<ck_tile::long_index_t>(in_left_pads[2]);
if(di >= 0 &&
ck_tile::type_convert<std::size_t>(di) < input.get_lengths()[3] &&
hi >= 0 &&
ck_tile::type_convert<std::size_t>(hi) < input.get_lengths()[4] &&
wi >= 0 &&
ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[5])
{
InDataType v_in = input(g, n, c, di, hi, wi);
WeiDataType v_wei = weight(g, k, c, z, y, x);
v_acc += ck_tile::type_convert<float>(v_in) *
ck_tile::type_convert<float>(v_wei);
}
}
}
}
}
if constexpr(Tuple::size() > 0)
elfunc(v_acc, v_acc, ds.at(ck_tile::number<0>{})(g, n, k, d_o, ho, wo));
else
elfunc(v_acc, v_acc);
OutDataType v_acc_out = ck_tile::type_convert<OutDataType>(v_acc);
output(g, n, k, d_o, ho, wo) = v_acc_out;
};
make_ParallelTensorFunctor(func,
output.get_lengths()[0],
output.get_lengths()[1],
output.get_lengths()[2],
output.get_lengths()[3],
output.get_lengths()[4],
output.get_lengths()[5])(std::thread::hardware_concurrency());
}
else
{
throw std::runtime_error("Ref_Conv_fwd: number of dimensions must be between 1 and 3.");
}
}
} // namespace ck_tile

View File

@@ -0,0 +1,133 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace ck_tile {
template <typename InDataType, typename OutDataType, index_t NDimSpatial>
CK_TILE_HOST void reference_im2col(const HostTensor<InDataType>& in_host,
HostTensor<OutDataType>& out_host,
const ck_tile::conv::ConvParam& conv_params)
{
const long_index_t G = in_host.get_lengths()[0];
const long_index_t N = in_host.get_lengths()[1];
const long_index_t C = in_host.get_lengths()[2];
if constexpr(NDimSpatial == 1)
{
const long_index_t Wo = conv_params.output_spatial_lengths_[0];
auto func = [&](auto g, auto n, auto wo) {
long_index_t row = n * Wo + wo;
long_index_t column = 0;
for(long_index_t x = 0; x < conv_params.filter_spatial_lengths_[0]; ++x)
{
auto wi = static_cast<long_index_t>(wo * conv_params.conv_filter_strides_[0]) +
static_cast<long_index_t>(x * conv_params.conv_filter_dilations_[0]) -
static_cast<long_index_t>(conv_params.input_left_pads_[0]);
for(long_index_t c = 0; c < C; ++c)
{
if(wi >= 0 && type_convert<std::size_t>(wi) < in_host.get_lengths()[3])
{
InDataType v_in = in_host(g, n, c, wi);
out_host(g, row, column) = type_convert<OutDataType>(v_in);
}
column++;
}
}
};
make_ParallelTensorFunctor(func, G, N, Wo)(std::thread::hardware_concurrency());
}
else if constexpr(NDimSpatial == 2)
{
const long_index_t Ho = conv_params.output_spatial_lengths_[0];
const long_index_t Wo = conv_params.output_spatial_lengths_[1];
auto func = [&](auto g, auto n, auto ho, auto wo) {
long_index_t row = n * Ho * Wo + ho * Wo + wo;
long_index_t column = 0;
for(long_index_t y = 0; y < conv_params.filter_spatial_lengths_[0]; ++y)
{
auto hi = static_cast<long_index_t>(ho * conv_params.conv_filter_strides_[0]) +
static_cast<long_index_t>(y * conv_params.conv_filter_dilations_[0]) -
static_cast<long_index_t>(conv_params.input_left_pads_[0]);
for(long_index_t x = 0; x < conv_params.filter_spatial_lengths_[1]; ++x)
{
auto wi = static_cast<long_index_t>(wo * conv_params.conv_filter_strides_[1]) +
static_cast<long_index_t>(x * conv_params.conv_filter_dilations_[1]) -
static_cast<long_index_t>(conv_params.input_left_pads_[1]);
for(long_index_t c = 0; c < C; ++c)
{
if(hi >= 0 && type_convert<std::size_t>(hi) < in_host.get_lengths()[3] &&
wi >= 0 && type_convert<std::size_t>(wi) < in_host.get_lengths()[4])
{
InDataType v_in = in_host(g, n, c, hi, wi);
out_host(g, row, column) = type_convert<OutDataType>(v_in);
}
column++;
}
}
}
};
make_ParallelTensorFunctor(func, G, N, Ho, Wo)(std::thread::hardware_concurrency());
}
else if constexpr(NDimSpatial == 3)
{
const long_index_t Do = conv_params.output_spatial_lengths_[0];
const long_index_t Ho = conv_params.output_spatial_lengths_[1];
const long_index_t Wo = conv_params.output_spatial_lengths_[2];
auto func = [&](auto g, auto n, auto d_o, auto ho, auto wo) {
long_index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo;
long_index_t column = 0;
for(long_index_t z = 0; z < conv_params.filter_spatial_lengths_[0]; ++z)
{
auto di = static_cast<long_index_t>(d_o * conv_params.conv_filter_strides_[0]) +
static_cast<long_index_t>(z * conv_params.conv_filter_dilations_[0]) -
static_cast<long_index_t>(conv_params.input_left_pads_[0]);
for(long_index_t y = 0; y < conv_params.filter_spatial_lengths_[1]; ++y)
{
auto hi = static_cast<long_index_t>(ho * conv_params.conv_filter_strides_[1]) +
static_cast<long_index_t>(y * conv_params.conv_filter_dilations_[1]) -
static_cast<long_index_t>(conv_params.input_left_pads_[1]);
for(long_index_t x = 0; x < conv_params.filter_spatial_lengths_[2]; ++x)
{
auto wi =
static_cast<long_index_t>(wo * conv_params.conv_filter_strides_[2]) +
static_cast<long_index_t>(x * conv_params.conv_filter_dilations_[2]) -
static_cast<long_index_t>(conv_params.input_left_pads_[2]);
for(long_index_t c = 0; c < C; ++c)
{
if(di >= 0 &&
type_convert<std::size_t>(di) < in_host.get_lengths()[3] &&
hi >= 0 &&
type_convert<std::size_t>(hi) < in_host.get_lengths()[4] &&
wi >= 0 && type_convert<std::size_t>(wi) < in_host.get_lengths()[5])
{
InDataType v_in = in_host(g, n, c, di, hi, wi);
out_host(g, row, column) = type_convert<OutDataType>(v_in);
}
column++;
}
}
}
}
};
make_ParallelTensorFunctor(func, G, N, Do, Ho, Wo)(std::thread::hardware_concurrency());
}
}
} // namespace ck_tile

View File

@@ -0,0 +1,96 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace ck_tile {
// Note: for simplicity, each functor only care about single M
struct reference_layernorm2d_default_epilogue
{
template <typename OutDataType, typename AccDataType>
void operator()(int m, HostTensor<OutDataType>& o, const HostTensor<AccDataType>& acc)
{
const int N = acc.mDesc.get_lengths()[1];
for(int n = 0; n < N; ++n)
{
o(m, n) = ck_tile::type_convert<OutDataType>(acc(m, n));
}
}
template <typename OutDataType, typename AccDataType>
auto operator()(int m, const HostTensor<AccDataType>& acc)
{
HostTensor<OutDataType> o(acc.get_lengths(), acc.get_strides());
operator()(m, o, acc);
return o;
}
};
template <typename XDataType,
typename GammaDataType,
typename BetaDataType,
typename ComputeDataType,
typename YDataType,
typename MeanDataType,
typename InvStdDataType,
typename Epilogue = reference_layernorm2d_default_epilogue>
void reference_layernorm2d_fwd(const HostTensor<XDataType>& x_m_n,
const HostTensor<GammaDataType>& gamma_n,
const HostTensor<BetaDataType>& beta_n,
HostTensor<YDataType>& y_m_n,
HostTensor<MeanDataType>& mean_m,
HostTensor<InvStdDataType>& invStd_m,
ComputeDataType epsilon,
Epilogue epilogue_functor = {})
{
auto layernorm2d_fwd_func = [&](auto m) {
const int N = x_m_n.mDesc.get_lengths()[1];
int count = 0;
ComputeDataType mean = 0;
ComputeDataType variance = 0;
ComputeDataType divisor = 0;
for(int n = 0; n < N; ++n)
{
++count;
ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
ComputeDataType delta = x - mean;
mean += delta / count;
ComputeDataType delta2 = x - mean;
variance += delta * delta2;
}
// actual variance
variance = variance / count;
divisor = ck_tile::type_convert<ComputeDataType>(1) / ck_tile::sqrt(variance + epsilon);
if constexpr(!std::is_same_v<MeanDataType, ck_tile::null_type>)
mean_m(m) = ck_tile::type_convert<MeanDataType>(mean);
if constexpr(!std::is_same_v<InvStdDataType, ck_tile::null_type>)
invStd_m(m) = ck_tile::type_convert<InvStdDataType>(divisor);
HostTensor<ComputeDataType> acc(x_m_n.get_lengths(), x_m_n.get_strides());
for(int n = 0; n < N; ++n)
{
ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
ComputeDataType gamma = ck_tile::type_convert<ComputeDataType>(gamma_n(n));
ComputeDataType beta = ck_tile::type_convert<ComputeDataType>(beta_n(n));
auto a_ = (x - mean) * divisor;
a_ = a_ * gamma + beta;
acc(m, n) = a_;
}
epilogue_functor(m, y_m_n, acc);
};
make_ParallelTensorFunctor(layernorm2d_fwd_func,
mean_m.mDesc.get_lengths()[0])(std::thread::hardware_concurrency());
}
} // namespace ck_tile

View File

@@ -0,0 +1,318 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <cstdlib>
#include <thread>
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace ck_tile {
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC,
int MoeGemmKind = 0, // 0: gemm1_gate_only, 1: gemm1_gate_up, 2: gemm2, 3:gemm1_split_k
typename ActivationOp = identity>
__global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
const ck_tile::index_t* p_sorted_expert_ids_,
const ck_tile::index_t* p_max_token_id_,
const ADataType* A,
const BDataType* B,
CDataType* C,
const AccDataType* expert_weight_ptr,
ck_tile::index_t Num_tokens,
ck_tile::index_t TokensPerBlock,
ck_tile::index_t TopK,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t strideA,
ck_tile::index_t strideB,
ck_tile::index_t strideC,
index_t scale_granularity_m,
index_t scale_granularity_n,
index_t scale_granularity_k,
float* scale_A_ptr,
float* scale_B_ptr,
float* expert_bias_ptr)
{
constexpr auto is_split_k = MoeGemmKind == 3;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int problem_N = MoeGemmKind == 1 ? N / 2 : N;
int row = idx / problem_N; // Compute row index
int col = idx % problem_N; // Compute column index
index_t gather_token_id = 0;
index_t scatter_token_id = 0;
index_t expert_id = 0;
if(row < p_max_token_id_[0])
{
expert_id = p_sorted_expert_ids_[row / TokensPerBlock];
gather_token_id = p_sorted_token_ids_[row] & 0xff'ffff;
scatter_token_id = p_sorted_token_ids_[row] & 0xff'ffff;
if(gather_token_id >= Num_tokens)
{
return;
}
if(MoeGemmKind == 2)
{
gather_token_id = gather_token_id * TopK + (p_sorted_token_ids_[row] >> 24);
}
else
{
scatter_token_id = scatter_token_id * TopK + (p_sorted_token_ids_[row] >> 24);
}
}
else
{
return;
}
if(row < M)
{
AccDataType acc = 0.0;
AccDataType acc_up = 0.0;
AccDataType acc_temp = 0.0;
AccDataType acc_up_temp = 0.0;
float scale_A = 0;
float scale_B = 0;
float scale_B_up = 0;
index_t scale_A_stride = (M + scale_granularity_m - 1) / scale_granularity_m;
index_t scale_B_stride = (N + scale_granularity_n - 1) / scale_granularity_n;
index_t scale_B_expert_stride = scale_B_stride * K / scale_granularity_k;
for(int k = 0; k < K; ++k)
{
if(k % scale_granularity_k == 0)
{
// update acc
acc += acc_temp * scale_A * scale_B;
acc_up += acc_up_temp * scale_A * scale_B_up;
// reset acc temp
acc_temp = 0.0;
acc_up_temp = 0.0;
// update scale factors
scale_A = scale_A_ptr[(gather_token_id / scale_granularity_m) +
(k / scale_granularity_k) * scale_A_stride];
scale_B =
scale_B_ptr[expert_id * scale_B_expert_stride + col / scale_granularity_n +
(k / scale_granularity_k) * scale_B_stride];
if constexpr(MoeGemmKind == 1)
scale_B_up = scale_B_ptr[expert_id * scale_B_expert_stride +
(col + problem_N) / scale_granularity_n +
(k / scale_granularity_k) * scale_B_stride];
}
constexpr index_t packed_size_a = ck_tile::numeric_traits<ADataType>::PackedSize;
constexpr index_t packed_size_b = ck_tile::numeric_traits<BDataType>::PackedSize;
// Adjust indexing based on matrix layout
int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? gather_token_id * strideA + k
: k * strideA + gather_token_id;
long b_index =
long(expert_id) * N * K +
((std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>) ? col * strideB + k
: k * strideB + col);
long b_index_up;
if constexpr(MoeGemmKind == 1)
b_index_up = long(expert_id) * N * K +
((std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? (col + problem_N) * strideB + k
: k * strideB + col + problem_N);
AccDataType v_a;
AccDataType v_b;
AccDataType v_b_up;
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
if(k % 2 == 1)
v_a = fp32_val.hi;
else
v_a = fp32_val.lo;
}
else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
{
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a]);
if(k % 2 == 1)
v_a = fp32_val.hi;
else
v_a = fp32_val.lo;
}
else
{
v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
if(k % 2 == 1)
v_b = fp32_val.hi;
else
v_b = fp32_val.lo;
if constexpr(MoeGemmKind == 1)
{
const fp32x2_t fp32_val_up =
pk_int4_t_to_fp32x2_t(B[b_index_up / packed_size_b]);
if(k % 2 == 1)
v_b_up = fp32_val_up.hi;
else
v_b_up = fp32_val_up.lo;
}
}
else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
{
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b], 1.0f);
if(k % 2 == 1)
v_b = fp32_val.hi;
else
v_b = fp32_val.lo;
if constexpr(MoeGemmKind == 1)
{
const fp32x2_t fp32_val_up =
pk_fp4_to_fp32x2(B[b_index_up / packed_size_b], 1.0f);
if(k % 2 == 1)
v_b_up = fp32_val_up.hi;
else
v_b_up = fp32_val_up.lo;
}
}
else
{
v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
if constexpr(MoeGemmKind == 1)
v_b_up = ck_tile::type_convert<AccDataType>(B[b_index_up]);
}
acc_temp += v_a * v_b;
if constexpr(MoeGemmKind == 1)
acc_up_temp += v_a * v_b_up;
}
acc += acc_temp * scale_A * scale_B;
acc_up += acc_up_temp * scale_A * scale_B_up;
float bias = 0.f, bias_up = 0.f;
if(expert_bias_ptr != nullptr && !is_split_k)
{
bias = expert_bias_ptr[expert_id * N + col];
if constexpr(MoeGemmKind == 1)
bias_up = expert_bias_ptr[expert_id * N + col + problem_N];
}
int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
? scatter_token_id * strideC + col
: col * strideC + scatter_token_id;
if constexpr(MoeGemmKind < 2)
{
C[c_index] = ck_tile::type_convert<CDataType>(
ActivationOp{}(acc + bias, MoeGemmKind == 1 ? acc_up + bias_up : 1));
}
else
{
// moe gemm2 don't use activation.
auto weight =
is_split_k ? ck_tile::type_convert<AccDataType>(1.0f) : expert_weight_ptr[row];
CDataType res = ck_tile::type_convert<CDataType>((acc + bias) * weight);
thread_buffer<CDataType, 2> add_v = 0;
if(c_index % 2)
{
// result is the second value of fp16 pair.
add_v.template get_as<CDataType>()[1] = res;
}
else
{
// result is the first value of fp16 pair.
add_v.template get_as<CDataType>()[0] = res;
}
// mask last bit to make sure atomicAdd pointer is aligned of DWORD.
atomic_add_g<CDataType, 2>(reinterpret_cast<CDataType*>(C + (c_index & 0xffff'fffe)),
add_v);
}
}
}
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC,
int MoeGemmKind = 0, // 0: gemm1_gate_only, 1: gemm1_gate_up, 2: gemm2, 3:gemm1_split_k
typename ActivationOp = identity>
void reference_moe_gemm_gpu(const index_t* p_sorted_token_ids_,
const index_t* p_sorted_expert_ids_,
const index_t* p_max_token_id_,
const ADataType* a_ptr,
const BDataType* b_ptr,
CDataType* c_ptr,
const AccDataType* expert_weight_ptr,
index_t Num_tokens,
index_t TokensPerBlock,
index_t TopK,
index_t M,
index_t N,
index_t K,
index_t stride_a,
index_t stride_b,
index_t stride_c,
index_t scale_granularity_m,
index_t scale_granularity_n,
index_t scale_granularity_k,
float* scale_A_ptr,
float* scale_B_ptr,
float* exp_bias = nullptr)
{
int problem_N = MoeGemmKind == 1 ? N / 2 : N;
int totalElements = M * problem_N;
int numThreadsPerBlock = 256; // Common choice for threads per block
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
moe_gemm_kernel<ADataType,
BDataType,
AccDataType,
CDataType,
LayoutA,
LayoutB,
LayoutC,
MoeGemmKind,
ActivationOp><<<numBlocks, numThreadsPerBlock>>>(p_sorted_token_ids_,
p_sorted_expert_ids_,
p_max_token_id_,
a_ptr,
b_ptr,
c_ptr,
expert_weight_ptr,
Num_tokens,
TokensPerBlock,
TopK,
M,
N,
K,
stride_a,
stride_b,
stride_c,
scale_granularity_m,
scale_granularity_n,
scale_granularity_k,
scale_A_ptr,
scale_B_ptr,
exp_bias);
return;
}
} // namespace ck_tile

View File

@@ -0,0 +1,121 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace ck_tile {
#define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \
static_cast<uint32_t>(((token_id_) & 0x00ffffff) | (((topk_id_) & 0xff) << 24))
template <typename WeightType, typename IndexType = index_t>
CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
const HostTensor<WeightType>& weights,
const HostTensor<IndexType>& local_expert_mask,
HostTensor<IndexType>& p_sorted_token_ids,
HostTensor<WeightType>& sorted_weight,
HostTensor<IndexType>& sorted_expert_ids,
index_t& unit_cnt,
const index_t experts,
const index_t unit_size,
const index_t tokens,
bool local_expert_masking,
bool skip_experts_with_zero_token = true)
{
// note: if tokens is smaller than topk_ids.mDesc.get_lengths()[0], indicating local_token case
const index_t num_token = tokens; // topk_ids.mDesc.get_lengths()[0];
const index_t topk = topk_ids.mDesc.get_lengths()[1];
// allocate a temp buffer, and fill the value with [number_token|topk]
std::vector<std::vector<IndexType>> expert_tokens(
experts,
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
std::vector<IndexType>(unit_size, MOE_SORTING_MOCK_ID(num_token, topk)));
#else
std::vector<IndexType>(unit_size, num_token));
#endif
std::vector<std::vector<WeightType>> expert_token_weights(
experts, std::vector<WeightType>(unit_size, 0));
// count number of unit-size slices in this expert
std::vector<IndexType> expert_slices(experts, 1);
// count the tokens used in this expert
std::vector<IndexType> expert_slice_idxs(experts, 0);
// TODO: above 2 buffer seems duplicated
for(index_t t = 0; t < num_token; t++)
{
for(index_t k = 0; k < topk; k++)
{
IndexType e = topk_ids(t, k);
WeightType w = weights(t, k);
index_t idx = expert_slice_idxs[e];
if(idx > expert_slices[e] * unit_size - 1)
{
expert_slices[e]++;
index_t new_size = expert_slices[e] * unit_size;
expert_tokens[e].resize(new_size);
expert_token_weights[e].resize(new_size);
for(index_t i = (expert_slices[e] - 1) * unit_size; i < new_size; i++)
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
expert_tokens[e][i] = MOE_SORTING_MOCK_ID(num_token, topk);
#else
expert_tokens[e][i] = num_token;
#endif
expert_token_weights[e][i] = 0;
}
}
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
expert_tokens[e][idx] = MOE_SORTING_MOCK_ID(t, k);
#else
expert_tokens[e][idx] = t;
#endif
expert_token_weights[e][idx] = w;
expert_slice_idxs[e]++;
}
}
IndexType* out_tokens = p_sorted_token_ids.data();
WeightType* out_weights = sorted_weight.data();
IndexType* out_expert_id = sorted_expert_ids.data();
int curr_expert_id = 0;
for(index_t e = 0; e < experts; e++)
{
if(local_expert_masking)
{
if(local_expert_mask(e) == 0)
continue;
}
if(skip_experts_with_zero_token)
{
if(expert_slice_idxs[e] == 0)
{
curr_expert_id++;
continue;
}
}
memcpy(out_tokens, expert_tokens[e].data(), sizeof(index_t) * expert_slices[e] * unit_size);
out_tokens += expert_slices[e] * unit_size;
memcpy(out_weights,
expert_token_weights[e].data(),
sizeof(WeightType) * expert_slices[e] * unit_size);
out_weights += expert_slices[e] * unit_size;
for(index_t s = 0; s < expert_slices[e]; s++)
{
out_expert_id[s] = curr_expert_id;
unit_cnt++;
}
out_expert_id += expert_slices[e];
curr_expert_id++;
}
unit_cnt *= unit_size;
return;
}
#undef MOE_SORTING_MOCK_ID
} // namespace ck_tile

View File

@@ -0,0 +1,76 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
#include <numeric>
#include <functional>
namespace ck_tile {
/*
this will do permute + contiguous like functionality in pytorch
*/
template <typename DataType>
CK_TILE_HOST void
reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::vector<index_t> perm)
{
const auto x_len = x.mDesc.get_lengths();
const auto y_len = y.mDesc.get_lengths();
assert(x_len.size() == y_len.size());
index_t rank = x_len.size();
const auto x_elm = std::accumulate(x_len.begin(), x_len.end(), 1, std::multiplies<index_t>());
const auto y_elm = std::accumulate(y_len.begin(), y_len.end(), 1, std::multiplies<index_t>());
assert(x_elm == y_elm);
(void)y_elm;
auto f = [&](auto i_element) {
std::vector<size_t> y_coord = [&]() {
std::vector<size_t> tmp(rank, 0);
size_t r = i_element;
for(index_t i = rank - 1; i >= 0; i--)
{
tmp[i] = r % y_len[i];
r = r / y_len[i];
}
return tmp;
}();
std::vector<size_t> x_coord = [&]() {
std::vector<size_t> tmp(rank, 0);
for(index_t i = 0; i < rank; i++)
{
tmp[perm[i]] = y_coord[i];
}
return tmp;
}();
// do permute
y(y_coord) = x(x_coord);
};
make_ParallelTensorFunctor(f, x_elm)(std::thread::hardware_concurrency());
}
template <typename DataType>
CK_TILE_HOST auto reference_permute(const HostTensor<DataType>& x, std::vector<index_t> perm)
{
auto x_shape = x.get_lengths();
ck_tile::index_t rank = perm.size();
std::vector<ck_tile::index_t> y_shape = [&]() {
std::vector<ck_tile::index_t> tmp(rank, 0);
for(int i = 0; i < static_cast<int>(rank); i++)
{
tmp[i] = x_shape[perm[i]];
}
return tmp;
}();
HostTensor<DataType> y(y_shape);
reference_permute(x, y, perm);
return y;
}
} // namespace ck_tile

View File

@@ -0,0 +1,198 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/ops/pooling/kernel/pool_kernel.hpp"
#include <thread>
#include <cmath>
namespace ck_tile {
template <typename InDataType,
typename ComputeDataType,
typename OutDataType,
typename IndexDataType,
typename ReduceOp,
typename TensorShape,
typename WindowShape,
bool OutputIndex = false>
CK_TILE_HOST void reference_pool2d(const HostTensor<InDataType>& input,
HostTensor<OutDataType>& output,
HostTensor<IndexDataType>& output_index,
PoolKernelArgs<TensorShape, WindowShape> kargs,
ReduceOp reduce_op)
{
const ck_tile::index_t N = kargs.input_shape.at(ck_tile::number<0>{});
const ck_tile::index_t H = kargs.input_shape.at(ck_tile::number<1>{});
const ck_tile::index_t W = kargs.input_shape.at(ck_tile::number<2>{});
const ck_tile::index_t C = kargs.input_shape.at(ck_tile::number<3>{});
const ck_tile::index_t Ho = kargs.output_shape.at(ck_tile::number<1>{});
const ck_tile::index_t Wo = kargs.output_shape.at(ck_tile::number<2>{});
const ck_tile::index_t Y = kargs.window_lengths.at(ck_tile::number<0>{});
const ck_tile::index_t X = kargs.window_lengths.at(ck_tile::number<1>{});
const ck_tile::index_t Sy = kargs.window_strides.at(ck_tile::number<0>{});
const ck_tile::index_t Sx = kargs.window_strides.at(ck_tile::number<1>{});
const ck_tile::index_t Dy = kargs.window_dilations.at(ck_tile::number<0>{});
const ck_tile::index_t Dx = kargs.window_dilations.at(ck_tile::number<1>{});
const ck_tile::index_t LeftPy = kargs.input_left_pads.at(ck_tile::number<0>{});
const ck_tile::index_t LeftPx = kargs.input_left_pads.at(ck_tile::number<1>{});
// Right padding is handled implicitly by bounds checking
auto f = [&](auto n, auto ho, auto wo, auto c) {
ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
IndexDataType current_index = 0; // Declare outside if constexpr for efficiency
for(ck_tile::index_t y = 0; y < Y; ++y)
{
// Calculate input height index with stride, dilation, and padding
ck_tile::index_t hi = ho * Sy + y * Dy - LeftPy;
for(ck_tile::index_t x = 0; x < X; ++x)
{
// Calculate input width index with stride, dilation, and padding
ck_tile::index_t wi = wo * Sx + x * Dx - LeftPx;
if(hi >= 0 && hi < H && wi >= 0 && wi < W)
{
const ComputeDataType v_in = type_convert<ComputeDataType>(input(n, hi, wi, c));
if constexpr(OutputIndex)
{
IndexDataType flat_index = input.GetOffsetFromMultiIndex(n, hi, wi, c);
bool changed = false;
v_acc = reduce_op(v_acc, v_in, changed);
if(changed)
{
current_index = flat_index;
}
}
else
{
v_acc = reduce_op(v_acc, v_in);
}
}
// For positions outside bounds, we implicitly use identity value
}
}
output(n, ho, wo, c) = ck_tile::type_convert<OutDataType>(v_acc);
if constexpr(OutputIndex)
{
output_index(n, ho, wo, c) = current_index;
}
};
// Parallelize over all output dimensions
make_ParallelTensorFunctor(f, N, Ho, Wo, C)(std::thread::hardware_concurrency());
}
template <typename InDataType,
typename ComputeDataType,
typename OutDataType,
typename IndexDataType,
typename ReduceOp,
typename TensorShape,
typename WindowShape,
bool OutputIndex = false>
CK_TILE_HOST void reference_pool3d(const HostTensor<InDataType>& input,
HostTensor<OutDataType>& output,
HostTensor<IndexDataType>& output_index,
PoolKernelArgs<TensorShape, WindowShape> kargs,
ReduceOp reduce_op)
{
const ck_tile::index_t N = kargs.input_shape.at(ck_tile::number<0>{});
const ck_tile::index_t D = kargs.input_shape.at(ck_tile::number<1>{});
const ck_tile::index_t H = kargs.input_shape.at(ck_tile::number<2>{});
const ck_tile::index_t W = kargs.input_shape.at(ck_tile::number<3>{});
const ck_tile::index_t C = kargs.input_shape.at(ck_tile::number<4>{});
const ck_tile::index_t Do = kargs.output_shape.at(ck_tile::number<1>{});
const ck_tile::index_t Ho = kargs.output_shape.at(ck_tile::number<2>{});
const ck_tile::index_t Wo = kargs.output_shape.at(ck_tile::number<3>{});
const ck_tile::index_t Z = kargs.window_lengths.at(ck_tile::number<0>{});
const ck_tile::index_t Y = kargs.window_lengths.at(ck_tile::number<1>{});
const ck_tile::index_t X = kargs.window_lengths.at(ck_tile::number<2>{});
const ck_tile::index_t Sz = kargs.window_strides.at(ck_tile::number<0>{});
const ck_tile::index_t Sy = kargs.window_strides.at(ck_tile::number<1>{});
const ck_tile::index_t Sx = kargs.window_strides.at(ck_tile::number<2>{});
const ck_tile::index_t Dz = kargs.window_dilations.at(ck_tile::number<0>{});
const ck_tile::index_t Dy = kargs.window_dilations.at(ck_tile::number<1>{});
const ck_tile::index_t Dx = kargs.window_dilations.at(ck_tile::number<2>{});
const ck_tile::index_t LeftPz = kargs.input_left_pads.at(ck_tile::number<0>{});
const ck_tile::index_t LeftPy = kargs.input_left_pads.at(ck_tile::number<1>{});
const ck_tile::index_t LeftPx = kargs.input_left_pads.at(ck_tile::number<2>{});
// Right padding is handled implicitly by bounds checking
auto f = [&](auto n, auto do_, auto ho, auto wo, auto c) {
ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
IndexDataType current_index = 0; // Declare outside if constexpr for efficiency
for(ck_tile::index_t z = 0; z < Z; ++z)
{
// Calculate input depth index with stride, dilation, and padding
ck_tile::index_t di = do_ * Sz + z * Dz - LeftPz;
for(ck_tile::index_t y = 0; y < Y; ++y)
{
// Calculate input height index with stride, dilation, and padding
ck_tile::index_t hi = ho * Sy + y * Dy - LeftPy;
for(ck_tile::index_t x = 0; x < X; ++x)
{
// Calculate input width index with stride, dilation, and padding
ck_tile::index_t wi = wo * Sx + x * Dx - LeftPx;
if(di >= 0 && di < D && hi >= 0 && hi < H && wi >= 0 && wi < W)
{
const ComputeDataType v_in =
type_convert<ComputeDataType>(input(n, di, hi, wi, c));
if constexpr(OutputIndex)
{
IndexDataType flat_index =
input.GetOffsetFromMultiIndex(n, di, hi, wi, c);
bool changed = false;
v_acc = reduce_op(v_acc, v_in, changed);
if(changed)
{
current_index = flat_index;
}
}
else
{
v_acc = reduce_op(v_acc, v_in);
}
}
// For positions outside bounds, we implicitly use identity value
}
}
}
output(n, do_, ho, wo, c) = ck_tile::type_convert<OutDataType>(v_acc);
if constexpr(OutputIndex)
{
output_index(n, do_, ho, wo, c) = current_index;
}
};
// Parallelize over all output dimensions
make_ParallelTensorFunctor(f, N, Do, Ho, Wo, C)(std::thread::hardware_concurrency());
}
} // namespace ck_tile

View File

@@ -0,0 +1,341 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include <thread>
namespace ck_tile {
template <typename XDataType, typename ComputeDataType, typename YDataType, typename ReduceOp>
CK_TILE_HOST void
reference_reduce(const HostTensor<XDataType>& x_m_n, HostTensor<YDataType>& y_m, ReduceOp reduce_op)
{
auto f = [&](auto m) {
const int N = x_m_n.mDesc.get_lengths()[1];
ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
for(int n = 0; n < N; ++n)
{
const ComputeDataType v_a = type_convert<ComputeDataType>(x_m_n(m, n));
v_acc = reduce_op(v_acc, v_a);
}
y_m(m) = ck_tile::type_convert<YDataType>(v_acc);
};
make_ParallelTensorFunctor(f, y_m.mDesc.get_lengths()[0])(std::thread::hardware_concurrency());
}
// Generic reference reduce for arbitrary dimensions
template <
typename XDataType,
typename ComputeDataType,
typename YDataType,
typename ReduceOp,
typename KeptDim, // Expected type: ck_tile::sequence<...> containing dimension indices to keep
typename ReduceDims> // Expected type: ck_tile::sequence<...> containing dimension indices to
// reduce
CK_TILE_HOST void reference_reduce(const HostTensor<XDataType>& x_tensor,
HostTensor<YDataType>& y_tensor,
ReduceOp reduce_op,
KeptDim kept_dim,
ReduceDims reduce_dims)
{
const auto& x_lengths = x_tensor.mDesc.get_lengths();
// Calculate total kept elements (product of all kept dimension lengths)
index_t total_kept_elements = 1;
static_for<0, kept_dim.size(), 1>{}(
[&](auto i) { total_kept_elements *= x_lengths[kept_dim.at(i)]; });
// Calculate total reduce elements (product of all reduce dimension lengths)
index_t total_reduce_elements = 1;
static_for<0, reduce_dims.size(), 1>{}(
[&](auto i) { total_reduce_elements *= x_lengths[reduce_dims.at(i)]; });
auto f = [&](auto linear_kept_idx) {
ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
// Convert linear kept index to multi-dimensional kept indices
std::vector<index_t> kept_indices(kept_dim.size());
index_t temp_kept = linear_kept_idx;
static_for<0, kept_dim.size(), 1>{}([&](auto i) {
constexpr auto dim_idx = kept_dim.size() - 1 - i;
constexpr auto dim = kept_dim.at(dim_idx);
const auto len = x_lengths[dim];
kept_indices[dim_idx] = temp_kept % len;
temp_kept /= len;
});
for(index_t reduce_idx = 0; reduce_idx < total_reduce_elements; ++reduce_idx)
{
// Convert linear reduce index to multi-dimensional reduce indices
std::vector<index_t> reduce_indices(reduce_dims.size());
index_t temp_reduce = reduce_idx;
static_for<0, reduce_dims.size(), 1>{}([&](auto i) {
constexpr auto dim_idx = reduce_dims.size() - 1 - i;
constexpr auto dim = reduce_dims.at(dim_idx);
const auto len = x_lengths[dim];
reduce_indices[dim_idx] = temp_reduce % len;
temp_reduce /= len;
});
// Build full input tensor indices by combining kept and reduce indices
std::vector<std::size_t> full_indices(x_lengths.size(), 0);
static_for<0, kept_dim.size(), 1>{}(
[&](auto i) { full_indices[kept_dim.at(i)] = kept_indices[i]; });
static_for<0, reduce_dims.size(), 1>{}(
[&](auto i) { full_indices[reduce_dims.at(i)] = reduce_indices[i]; });
// Access input tensor element
const auto v_a = type_convert<ComputeDataType>(x_tensor(full_indices));
v_acc = reduce_op(v_acc, v_a);
}
// Calculate output tensor index using kept indices
// The output tensor has the same structure as the kept dimensions
std::vector<std::size_t> y_indices(kept_dim.size());
static_for<0, kept_dim.size(), 1>{}([&](auto i) { y_indices[i] = kept_indices[i]; });
y_tensor(y_indices) = type_convert<YDataType>(v_acc);
};
make_ParallelTensorFunctor(f, total_kept_elements)(std::thread::hardware_concurrency());
}
template <typename XDataType,
typename ComputeDataType,
typename YDataType,
typename YRefTuple,
typename ReduceOps, // Expected type: ck_tile::tuple<...> containing reduce operations
typename KeptDim, // Expected type: ck_tile::sequence<...> containing dimension indices to
// keep
typename ReduceDims, // Expected type: ck_tile::sequence<...> containing dimension indices
// to reduce
typename ElementWiseOps,
typename AccElementWiseOps>
CK_TILE_HOST void reference_multiple_reduce(const HostTensor<XDataType>& x_tensor,
YRefTuple& y_tensor_tuple,
ReduceOps reduce_ops,
KeptDim kept_dim,
ReduceDims reduce_dims,
ElementWiseOps elementwise_ops,
AccElementWiseOps accumulator_ops)
{
const auto& x_lengths = x_tensor.mDesc.get_lengths();
// Calculate total kept elements (product of all kept dimension lengths)
index_t total_kept_elements = 1;
static_for<0, kept_dim.size(), 1>{}(
[&](auto i) { total_kept_elements *= x_lengths[kept_dim.at(i)]; });
// Calculate total reduce elements (product of all reduce dimension lengths)
index_t total_reduce_elements = 1;
static_for<0, reduce_dims.size(), 1>{}(
[&](auto i) { total_reduce_elements *= x_lengths[reduce_dims.at(i)]; });
auto f = [&](auto linear_kept_idx) {
// Initialize accumulators for each reduction operation
auto v_acc_tuple = ck_tile::generate_tuple(
[&](auto i) {
return reduce_ops.template at<i>().template GetIdentityValue<ComputeDataType>();
},
number<reduce_ops.size()>{});
// Convert linear kept index to multi-dimensional kept indices
std::vector<index_t> kept_indices(kept_dim.size());
index_t temp_kept = linear_kept_idx;
static_for<0, kept_dim.size(), 1>{}([&](auto i) {
constexpr auto dim_idx = kept_dim.size() - 1 - i;
constexpr auto dim = kept_dim.at(dim_idx);
const auto len = x_lengths[dim];
kept_indices[dim_idx] = temp_kept % len;
temp_kept /= len;
});
for(index_t reduce_idx = 0; reduce_idx < total_reduce_elements; ++reduce_idx)
{
// Convert linear reduce index to multi-dimensional reduce indices
std::vector<index_t> reduce_indices(reduce_dims.size());
index_t temp_reduce = reduce_idx;
static_for<0, reduce_dims.size(), 1>{}([&](auto i) {
constexpr auto dim_idx = reduce_dims.size() - 1 - i;
constexpr auto dim = reduce_dims.at(dim_idx);
const auto len = x_lengths[dim];
reduce_indices[dim_idx] = temp_reduce % len;
temp_reduce /= len;
});
// Build full input tensor indices by combining kept and reduce indices
std::vector<std::size_t> full_indices(x_lengths.size(), 0);
static_for<0, kept_dim.size(), 1>{}(
[&](auto i) { full_indices[kept_dim.at(i)] = kept_indices[i]; });
static_for<0, reduce_dims.size(), 1>{}(
[&](auto i) { full_indices[reduce_dims.at(i)] = reduce_indices[i]; });
// Access input tensor element
auto v_a = type_convert<ComputeDataType>(x_tensor(full_indices));
// Apply each reduction operation
static_for<0, reduce_ops.size(), 1>{}([&](auto i) {
// Apply element-wise operation before reduction
elementwise_ops.at(i)(v_a, v_a);
v_acc_tuple.template at<i>() =
reduce_ops.template at<i>()(v_acc_tuple.template at<i>(), v_a);
});
}
static_for<0, reduce_ops.size(), 1>{}([&](auto i) {
// Apply accumulator element-wise operation after reduction
accumulator_ops.at(i)(v_acc_tuple.template at<i>(), v_acc_tuple.template at<i>());
});
// Calculate output tensor index using kept indices
// The output tensor has the same structure as the kept dimensions
std::vector<std::size_t> y_indices(kept_dim.size());
static_for<0, kept_dim.size(), 1>{}([&](auto i) { y_indices[i] = kept_indices[i]; });
// Store results for each reduction operation in the output tensor
static_for<0, reduce_ops.size(), 1>{}([&](auto i) {
y_tensor_tuple.template at<i>()(y_indices) =
type_convert<YDataType>(v_acc_tuple.template at<i>());
});
};
make_ParallelTensorFunctor(f, total_kept_elements)(std::thread::hardware_concurrency());
}
template <typename XDataType,
typename ComputeDataType,
typename YDataType,
typename YRefTuple,
typename ReduceOps, // Expected type: ck_tile::tuple<...> containing reduce operations
typename KeptDim, // Expected type: ck_tile::sequence<...> containing dimension indices to
// keep
typename ReduceDims, // Expected type: ck_tile::sequence<...> containing dimension indices
// to reduce
typename ElementWiseOps,
typename AccElementWiseOps,
typename InterBlockReduceOps>
CK_TILE_HOST void reference_multiple_reduce_multiblock(const HostTensor<XDataType>& x_tensor,
YRefTuple& y_tensor_tuple,
ReduceOps reduce_ops,
KeptDim kept_dim,
ReduceDims reduce_dims,
ElementWiseOps elementwise_ops,
AccElementWiseOps accumulator_ops,
InterBlockReduceOps inter_block_reduce_ops,
ck_tile::index_t num_blocks)
{
const auto& x_lengths = x_tensor.mDesc.get_lengths();
// Calculate total kept elements (product of all kept dimension lengths)
index_t total_kept_elements = 1;
static_for<0, kept_dim.size(), 1>{}(
[&](auto i) { total_kept_elements *= x_lengths[kept_dim.at(i)]; });
// Calculate total reduce elements (product of all reduce dimension lengths)
index_t total_reduce_elements = 1;
static_for<0, reduce_dims.size(), 1>{}(
[&](auto i) { total_reduce_elements *= x_lengths[reduce_dims.at(i)]; });
// Initialize output tensors
static_for<0, reduce_ops.size(), 1>{}([&](auto i) {
auto& y_tensor = y_tensor_tuple.template at<i>();
for(auto& val : y_tensor.mData)
{
val = inter_block_reduce_ops.template at<i>().template GetIdentityValue<YDataType>();
}
});
auto f = [&](auto linear_kept_idx) {
// Convert linear kept index to multi-dimensional kept indices
std::vector<index_t> kept_indices(kept_dim.size());
index_t temp_kept = linear_kept_idx;
static_for<0, kept_dim.size(), 1>{}([&](auto i) {
constexpr auto dim_idx = kept_dim.size() - 1 - i;
constexpr auto dim = kept_dim.at(dim_idx);
const auto len = x_lengths[dim];
kept_indices[dim_idx] = temp_kept % len;
temp_kept /= len;
});
// Calculate output tensor index using kept indices
std::vector<std::size_t> y_indices(kept_dim.size());
static_for<0, kept_dim.size(), 1>{}([&](auto i) { y_indices[i] = kept_indices[i]; });
const auto max_element_per_block = (total_reduce_elements + num_blocks - 1) / num_blocks;
for(index_t block_id = 0; block_id < num_blocks; ++block_id)
{
// Initialize accumulators for each reduction operation for the current block
auto v_acc_tuple = ck_tile::generate_tuple(
[&](auto i) {
return reduce_ops.template at<i>().template GetIdentityValue<ComputeDataType>();
},
number<reduce_ops.size()>{});
const index_t element_offset = block_id * max_element_per_block;
const index_t element_end =
std::min(element_offset + max_element_per_block, total_reduce_elements);
for(index_t linear_reduce_idx = element_offset; linear_reduce_idx < element_end;
++linear_reduce_idx)
{
// Convert linear reduce index to multi-dimensional reduce indices
std::vector<index_t> reduce_indices(reduce_dims.size());
index_t temp_reduce = linear_reduce_idx;
static_for<0, reduce_dims.size(), 1>{}([&](auto i) {
constexpr auto dim_idx = reduce_dims.size() - 1 - i;
constexpr auto dim = reduce_dims.at(dim_idx);
const auto len = x_lengths[dim];
reduce_indices[dim_idx] = temp_reduce % len;
temp_reduce /= len;
});
// Build full input tensor indices by combining kept and reduce indices
std::vector<std::size_t> full_indices(x_lengths.size(), 0);
static_for<0, kept_dim.size(), 1>{}(
[&](auto i) { full_indices[kept_dim.at(i)] = kept_indices[i]; });
static_for<0, reduce_dims.size(), 1>{}(
[&](auto i) { full_indices[reduce_dims.at(i)] = reduce_indices[i]; });
// Access input tensor element
const auto v_a_in = type_convert<ComputeDataType>(x_tensor(full_indices));
// Apply each reduction operation
static_for<0, reduce_ops.size(), 1>{}([&](auto i) {
auto v_a = v_a_in;
// Apply element-wise operation before reduction
elementwise_ops.at(i)(v_a, v_a);
v_acc_tuple.template at<i>() =
reduce_ops.template at<i>()(v_acc_tuple.template at<i>(), v_a);
});
}
static_for<0, reduce_ops.size(), 1>{}([&](auto i) {
// Apply accumulator element-wise operation after reduction
accumulator_ops.at(i)(v_acc_tuple.template at<i>(), v_acc_tuple.template at<i>());
// Update the output tensor with the partial result from this block
auto& y_tensor = y_tensor_tuple.template at<i>();
auto& y_val = y_tensor(y_indices);
y_val = inter_block_reduce_ops.template at<i>()(
y_val, type_convert<YDataType>(v_acc_tuple.template at<i>()));
});
}
};
make_ParallelTensorFunctor(f, total_kept_elements)(std::thread::hardware_concurrency());
}
} // namespace ck_tile

View File

@@ -0,0 +1,114 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp"
namespace ck_tile {
// Note: for simplicity, each functor only care about single M
struct reference_rmsnorm2d_default_epilogue
{
template <typename OutDataType, typename AccDataType>
void operator()(int m, HostTensor<OutDataType>& o, const HostTensor<AccDataType>& acc)
{
const int N = acc.mDesc.get_lengths()[1];
for(int n = 0; n < N; ++n)
{
o(m, n) = ck_tile::type_convert<OutDataType>(acc(m, n));
}
}
template <typename OutDataType, typename AccDataType>
auto operator()(int m, const HostTensor<AccDataType>& acc)
{
HostTensor<OutDataType> o(acc.get_lengths(), acc.get_strides());
operator()(m, o, acc);
return o;
}
};
template <typename XDataType,
typename GammaDataType,
typename ComputeDataType,
typename YDataType,
typename InvRmsDataType,
typename UnquantYDataType,
typename Epilogue = reference_rmsnorm2d_default_epilogue>
void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
const HostTensor<GammaDataType>& gamma_n,
HostTensor<YDataType>& y_m_n,
HostTensor<InvRmsDataType>& invRms_m,
HostTensor<UnquantYDataType>& unquant_y_m_n,
ComputeDataType epsilon,
Epilogue epilogue_functor = {},
const int use_model_sensitive_rmsnorm =
static_cast<int>(Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL))
{
auto rmsnorm2d_fwd_func = [&](auto m) {
const int N = x_m_n.mDesc.get_lengths()[1];
ComputeDataType mean_square = 0;
ComputeDataType divisor = 0;
for(int n = 0; n < N; ++n)
{
ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
mean_square += x * x;
}
mean_square = mean_square / N;
divisor = ck_tile::type_convert<ComputeDataType>(1) / ck_tile::sqrt(mean_square + epsilon);
if constexpr(!std::is_same_v<InvRmsDataType, ck_tile::null_type>)
invRms_m(m) = ck_tile::type_convert<InvRmsDataType>(divisor);
HostTensor<ComputeDataType> acc(x_m_n.get_lengths(), x_m_n.get_strides());
for(int n = 0; n < N; ++n)
{
ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
ComputeDataType gamma = ck_tile::type_convert<ComputeDataType>(gamma_n(n));
if(use_model_sensitive_rmsnorm ==
static_cast<int>(
Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL)) // 0: for no specific model
{
acc(m, n) = x * divisor * gamma;
}
else if(use_model_sensitive_rmsnorm ==
static_cast<int>(Rmsnorm2dSensitiveEnum::T5_MODEL_LIKE)) // 1: for T5-like model
{
if constexpr(std::is_same_v<XDataType, ck_tile::bf16_t>)
{
const auto tmp0 = float_to_bf16<bf16_rounding_mode::standard>(x * divisor);
const auto tmp1 = float_to_bf16<bf16_rounding_mode::standard>(
type_convert<ComputeDataType>(tmp0) * gamma);
const auto rmsn_ = type_convert<ComputeDataType>(tmp1);
acc(m, n) = rmsn_;
}
else
{
const auto tmp = type_convert<XDataType>(x * divisor);
const auto rmsn_ = type_convert<ComputeDataType>(tmp) * gamma;
acc(m, n) = rmsn_;
}
}
}
if constexpr(!std::is_same_v<UnquantYDataType, ck_tile::null_type>)
{
epilogue_functor(m, unquant_y_m_n, y_m_n, acc);
}
else
{
epilogue_functor(m, y_m_n, acc);
}
};
make_ParallelTensorFunctor(rmsnorm2d_fwd_func, invRms_m.mDesc.get_lengths()[0])(
std::thread::hardware_concurrency());
}
} // namespace ck_tile

View File

@@ -0,0 +1,33 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace ck_tile {
template <typename XDataType, typename ScaleDataType, typename QXDataType>
CK_TILE_HOST void reference_rowwise_quantization2d(const HostTensor<XDataType>& x_m_n,
const HostTensor<ScaleDataType>& scale_m,
HostTensor<QXDataType>& qx_m_n)
{
auto f = [&](auto m) {
const int N = x_m_n.mDesc.get_lengths()[1];
for(int n = 0; n < N; ++n)
{
auto v_x = x_m_n(m, n);
// scale = amax / 127 for int8
auto v_scale = type_convert<XDataType>(scale_m(m));
auto v_qx = v_x / v_scale;
qx_m_n(m, n) = type_convert<QXDataType>(saturates<QXDataType>{}(v_qx));
}
};
make_ParallelTensorFunctor(f,
scale_m.mDesc.get_lengths()[0])(std::thread::hardware_concurrency());
}
} // namespace ck_tile

View File

@@ -0,0 +1,89 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace ck_tile {
template <typename InputType, typename ComputeType, typename OutputType = ComputeType>
CK_TILE_HOST void
reference_softmax(const HostTensor<InputType>& x, HostTensor<OutputType>& y, index_t dim = -1)
{
index_t rank = x.get_num_of_dimension();
assert(static_cast<std::size_t>(rank) == y.get_num_of_dimension());
assert(dim == -1 || dim < rank);
index_t target_dim = dim == -1 ? (rank - 1) : dim;
index_t softmax_len = x.get_length(target_dim);
index_t n_parallel = x.get_element_size() / softmax_len;
auto x_len = x.get_lengths();
auto f = [&](auto i_element) {
std::vector<size_t> coord = [&]() {
std::vector<size_t> t_(rank, 0);
size_t r = i_element;
for(index_t i = rank - 1; i >= 0; i--)
{
if(i == target_dim)
continue;
t_[i] = r % x_len[i];
r = r / x_len[i];
}
return t_;
}();
ComputeType v_max = -ck_tile::numeric<ComputeType>::infinity();
// compute max
for(auto idx = 0; idx < softmax_len; idx++)
{
auto c_ = coord;
c_[target_dim] = idx;
const ComputeType v_x = ck_tile::type_convert<ComputeType>(x(c_));
v_max = v_max < v_x ? v_x : v_max;
}
ComputeType v_exp_sum = static_cast<ComputeType>(0);
// sum
for(auto idx = 0; idx < softmax_len; idx++)
{
auto c_ = coord;
c_[target_dim] = idx;
const ComputeType v_x = ck_tile::type_convert<ComputeType>(x(c_));
v_exp_sum += ck_tile::exp(v_x - v_max);
}
// elementwise
for(auto idx = 0; idx < softmax_len; idx++)
{
auto c_ = coord;
c_[target_dim] = idx;
const ComputeType v_x = ck_tile::type_convert<ComputeType>(x(c_));
auto out = ck_tile::exp(v_x - v_max) / v_exp_sum;
y(c_) = ck_tile::type_convert<OutputType>(out);
}
};
make_ParallelTensorFunctor(f, n_parallel)(std::thread::hardware_concurrency());
}
template <typename InputType, typename ComputeType, typename OutputType = ComputeType>
CK_TILE_HOST auto reference_softmax(const HostTensor<InputType>& x, index_t dim = -1)
{
HostTensor<OutputType> y(x.get_lengths(), x.get_strides());
reference_softmax<InputType, ComputeType, OutputType>(x, y, dim);
return y;
}
} // namespace ck_tile

View File

@@ -0,0 +1,125 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
#include <numeric>
#include <functional>
#include <utility>
#include <algorithm>
namespace ck_tile {
/*
similiar to torch.topk()
x (Tensor) the input tensor.
k (int) the k in “top-k”
dim (int, optional) the dimension to sort along
largest (bool, optional) largest or smallest elements
sorted (bool, optional) elements in sorted order or not
output:
y_values
y_indices
https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/TopKImpl.h
*/
template <typename DataType, typename IndexType = index_t>
CK_TILE_HOST void reference_topk(const HostTensor<DataType>& x,
HostTensor<DataType>& y_values,
HostTensor<IndexType>& y_indices,
index_t k,
index_t dim = -1,
bool largest = true,
bool sorted = true)
{
// rank must be the same
index_t rank = x.get_num_of_dimension();
assert(static_cast<std::size_t>(rank) == y_values.get_num_of_dimension());
assert(static_cast<size_t>(rank) == y_indices.get_num_of_dimension());
assert(dim == -1 || dim < rank);
index_t topk_dim = dim == -1 ? (rank - 1) : dim;
index_t topk_src_len = x.get_length(topk_dim);
auto x_len = x.get_lengths();
assert(k <= topk_src_len);
assert(static_cast<size_t>(k) == y_values.get_length(topk_dim) &&
static_cast<size_t>(k) == y_indices.get_length(topk_dim));
index_t n_parallel = x.get_element_size() / topk_src_len;
// clang-format off
auto f = [&](auto i_element) {
std::vector<size_t> topk_coord = [&](){
std::vector<size_t> t_(rank, 0);
size_t r = i_element;
for(index_t i = rank - 1; i >= 0; i--) {
if(i == topk_dim) continue; // topk dim should be zero
t_[i] = r % x_len[i]; r = r / x_len[i];
}
return t_;
}();
using elem_t = std::pair<DataType, IndexType>;
std::vector<elem_t> q = [&](){
std::vector<elem_t> t_(topk_src_len);
for(index_t i = 0; i < topk_src_len; i++) {
auto c_ = topk_coord; c_[topk_dim] = i;
t_[i].first = x(c_); t_[i].second = i;
}
return t_;
}();
// run topk
if(largest) {
std::nth_element(q.begin(), q.begin() + k - 1, q.end(),
[](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first > rhs.first; });
if(sorted) {
std::sort(q.begin(), q.begin() + k - 1,
[](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first > rhs.first; });
}
} else {
std::nth_element(q.begin(), q.begin() + k - 1, q.end(),
[](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first < rhs.first; });
if(sorted) {
std::sort(q.begin(), q.begin() + k - 1,
[](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first < rhs.first; });
}
}
// write out
for(index_t i = 0; i < k; i++) {
auto c_ = topk_coord; c_[topk_dim] = i;
y_values(c_) = q[i].first; y_indices(c_) = q[i].second;
}
};
// clang-format on
make_ParallelTensorFunctor(f, n_parallel)(std::thread::hardware_concurrency());
}
// TODO: if using this method, the return tensor would be dense(no stride)
template <typename DataType, typename IndexType = index_t>
CK_TILE_HOST auto reference_topk(const HostTensor<DataType>& x,
index_t k,
index_t dim = -1,
bool largest = true,
bool sorted = true)
{
auto lens = x.get_lengths();
index_t target_dim = (dim == -1) ? (lens.size() - 1) : dim;
assert(target_dim < lens.size());
assert(k <= lens[target_dim]);
lens[target_dim] = k;
HostTensor<DataType> y_values(lens);
HostTensor<IndexType> y_indices(lens);
reference_topk<DataType, IndexType>(x, y_values, y_indices, k, dim, largest, sorted);
return ck_tile::make_tuple(y_values, y_indices);
}
} // namespace ck_tile

View File

@@ -0,0 +1,33 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace ck_tile {
template <typename ADataType, typename BDataType>
void reference_transpose_elementwise(const HostTensor<ADataType>& a, HostTensor<BDataType>& b)
{
ck_tile::index_t M = static_cast<ck_tile::index_t>(a.mDesc.get_lengths()[0]);
ck_tile::index_t N = static_cast<ck_tile::index_t>(a.mDesc.get_lengths()[1]);
// Ensure the b tensor is sized correctly for N x M
if(static_cast<ck_tile::index_t>(b.mDesc.get_lengths()[0]) != N ||
static_cast<ck_tile::index_t>(b.mDesc.get_lengths()[1]) != M)
{
throw std::runtime_error("Output tensor b has incorrect dimensions for transpose.");
}
auto f = [&](auto i, auto j) {
auto v_a = a(i, j);
b(j, i) = ck_tile::type_convert<BDataType>(v_a);
};
make_ParallelTensorFunctor(f, M, N)(std::thread::hardware_concurrency());
}
} // namespace ck_tile