mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
Merge remote-tracking branch 'origin/ginolu/add_wgmfma_dispatcher' into mtgu/cktile_mxfp4_flatmm_dev
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
@@ -11,6 +11,196 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType,
|
||||
typename QDataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
uint32_t QuantGroupSize,
|
||||
bool aquant,
|
||||
typename AElementOp = ck_tile::identity,
|
||||
typename BElementOp = ck_tile::identity,
|
||||
typename ACCElementOp = ck_tile::identity>
|
||||
CK_TILE_HOST void reference_gemm_quant(const HostTensor<ADataType>& a_m_k,
|
||||
const HostTensor<QDataType>& q,
|
||||
const HostTensor<BDataType>& b_k_n,
|
||||
HostTensor<CDataType>& c_m_n,
|
||||
const AElementOp& a_element_op = {},
|
||||
const BElementOp& b_element_op = {},
|
||||
const ACCElementOp& acc_element_op = {})
|
||||
{
|
||||
const std::size_t M = a_m_k.get_length(0);
|
||||
const std::size_t N = b_k_n.get_length(1);
|
||||
const std::size_t K = a_m_k.get_length(1);
|
||||
|
||||
auto f_mn = [&](auto m, auto n) {
|
||||
AccDataType v_acc = 0, v_block_acc = 0;
|
||||
|
||||
static_assert(std::is_same_v<ADataType, pk_int4_t> || std::is_same_v<ADataType, fp8_t> ||
|
||||
std::is_same_v<ADataType, bf8_t>);
|
||||
static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t> ||
|
||||
std::is_same_v<BDataType, pk_int4_t>);
|
||||
static_assert(std::is_same_v<AccDataType, float>);
|
||||
static_assert(std::is_same_v<CDataType, float> ||
|
||||
std::is_same_v<CDataType, ck_tile::half_t>);
|
||||
for(std::size_t k = 0; k < K; ++k)
|
||||
{
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val);
|
||||
if(k % 2 == 1)
|
||||
v_a = fp32_val.hi;
|
||||
else
|
||||
v_a = fp32_val.lo;
|
||||
}
|
||||
else
|
||||
{
|
||||
v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
|
||||
}
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val);
|
||||
if(k % 2 == 1)
|
||||
v_b = fp32_val.hi;
|
||||
else
|
||||
v_b = fp32_val.lo;
|
||||
}
|
||||
else if constexpr(std::is_same_v<BDataType, fp8_t>)
|
||||
{
|
||||
v_b = fp8_to_float_raw(b_element_op(b_k_n(k, n)));
|
||||
}
|
||||
else
|
||||
{
|
||||
v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
|
||||
}
|
||||
v_block_acc += v_a * v_b;
|
||||
|
||||
// Apply group dequant scale
|
||||
if((k + 1) % QuantGroupSize == 0)
|
||||
{
|
||||
float scale = 0.f;
|
||||
index_t outer_dim = (aquant) ? m : k / QuantGroupSize;
|
||||
index_t inner_dim = (aquant) ? k / QuantGroupSize : n;
|
||||
|
||||
if constexpr(std::is_same_v<QDataType, float>)
|
||||
{
|
||||
scale = q(outer_dim, inner_dim);
|
||||
}
|
||||
else if constexpr(std::is_same_v<QDataType, ck_tile::fp8_t>)
|
||||
{
|
||||
scale = fp8_to_float_raw(q(outer_dim, inner_dim));
|
||||
}
|
||||
else if constexpr(std::is_same_v<QDataType, ck_tile::bf8_t>)
|
||||
{
|
||||
scale = bf8_to_float_raw(q(outer_dim, inner_dim));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unexpected Q datatype.");
|
||||
}
|
||||
v_block_acc *= scale;
|
||||
v_acc += v_block_acc;
|
||||
v_block_acc = 0;
|
||||
}
|
||||
}
|
||||
|
||||
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename BQDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename AElementOp = ck_tile::identity,
|
||||
typename BElementOp = ck_tile::identity,
|
||||
typename ACCElementOp = ck_tile::identity>
|
||||
CK_TILE_HOST void reference_gemm_rowcol_quant(const HostTensor<ADataType>& a_m_k,
|
||||
const HostTensor<AQDataType>& aq_m_1,
|
||||
const HostTensor<BDataType>& b_k_n,
|
||||
const HostTensor<BQDataType>& bq_1_n,
|
||||
HostTensor<CDataType>& c_m_n,
|
||||
const AElementOp& a_element_op = {},
|
||||
const BElementOp& b_element_op = {},
|
||||
const ACCElementOp& acc_element_op = {})
|
||||
{
|
||||
static_assert(std::is_same_v<ADataType, fp8_t> || std::is_same_v<ADataType, bf8_t>);
|
||||
static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t>);
|
||||
static_assert(std::is_same_v<AccDataType, float>);
|
||||
static_assert(std::is_same_v<CDataType, float> || std::is_same_v<CDataType, ck_tile::half_t>);
|
||||
static_assert(std::is_same_v<AQDataType, float> && std::is_same_v<BQDataType, float>);
|
||||
const std::size_t M = a_m_k.get_length(0);
|
||||
const std::size_t N = b_k_n.get_length(1);
|
||||
const std::size_t K = a_m_k.get_length(1);
|
||||
|
||||
auto f_mn = [&](auto m, auto n) {
|
||||
// Init accumulator
|
||||
AccDataType v_acc = 0;
|
||||
// Get row scale for A and column scale for B
|
||||
float a_scale = aq_m_1(m, 0);
|
||||
float b_scale = bq_1_n(0, n);
|
||||
|
||||
// Compute the dot product
|
||||
for(std::size_t k = 0; k < K; ++k)
|
||||
{
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
|
||||
// Process A data
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val);
|
||||
if(k % 2 == 1)
|
||||
v_a = fp32_val.hi;
|
||||
else
|
||||
v_a = fp32_val.lo;
|
||||
}
|
||||
else
|
||||
{
|
||||
v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
|
||||
}
|
||||
|
||||
// Process B data
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val);
|
||||
if(k % 2 == 1)
|
||||
v_b = fp32_val.hi;
|
||||
else
|
||||
v_b = fp32_val.lo;
|
||||
}
|
||||
else if constexpr(std::is_same_v<BDataType, fp8_t>)
|
||||
{
|
||||
v_b = fp8_to_float_raw(b_element_op(b_k_n(k, n)));
|
||||
}
|
||||
else
|
||||
{
|
||||
v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
|
||||
}
|
||||
|
||||
v_acc += v_a * v_b;
|
||||
}
|
||||
|
||||
v_acc = v_acc * a_scale * b_scale;
|
||||
|
||||
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
|
||||
@@ -0,0 +1,227 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <thread>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType>
|
||||
CK_TILE_HOST void reference_grouped_conv_bwd_data(HostTensor<InDataType>& input,
|
||||
const HostTensor<WeiDataType>& weight,
|
||||
const HostTensor<OutDataType>& output,
|
||||
std::vector<ck_tile::long_index_t> conv_strides,
|
||||
std::vector<ck_tile::long_index_t> conv_dilations,
|
||||
std::vector<ck_tile::long_index_t> in_left_pads,
|
||||
std::vector<ck_tile::long_index_t>)
|
||||
{
|
||||
if(!(input.get_num_of_dimension() == NDimSpatial + 3 &&
|
||||
weight.get_num_of_dimension() == NDimSpatial + 3 &&
|
||||
output.get_num_of_dimension() == NDimSpatial + 3))
|
||||
{
|
||||
|
||||
printf("%lu %lu %lu",
|
||||
input.get_num_of_dimension(),
|
||||
weight.get_num_of_dimension(),
|
||||
output.get_num_of_dimension());
|
||||
|
||||
throw std::runtime_error("wrong! inconsistent dimension");
|
||||
}
|
||||
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
auto func = [&](auto g, auto n, auto c, auto wi) {
|
||||
std::size_t K = weight.get_lengths()[1];
|
||||
std::size_t X = weight.get_lengths()[3];
|
||||
|
||||
std::size_t Wo = output.get_lengths()[3];
|
||||
float v_acc = 0;
|
||||
|
||||
for(std::size_t x = 0; x < X; ++x)
|
||||
{
|
||||
auto w_tmp = static_cast<ck_tile::long_index_t>(wi) +
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[0]) -
|
||||
static_cast<ck_tile::long_index_t>(x * conv_dilations[0]);
|
||||
|
||||
if(w_tmp % conv_strides[0] == 0)
|
||||
{
|
||||
auto wo = static_cast<ck_tile::long_index_t>(w_tmp) /
|
||||
static_cast<ck_tile::long_index_t>(conv_strides[0]);
|
||||
|
||||
if(wo >= 0 && ck_tile::type_convert<std::size_t>(wo) < Wo)
|
||||
{
|
||||
for(std::size_t k = 0; k < K; ++k)
|
||||
{
|
||||
OutDataType v_out = output(g, n, k, wo);
|
||||
WeiDataType v_wei = weight(g, k, c, x);
|
||||
v_acc += ck_tile::type_convert<float>(v_out) *
|
||||
ck_tile::type_convert<float>(v_wei);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
InDataType v_acc_converted = ck_tile::type_convert<InDataType>(v_acc);
|
||||
input(g, n, c, wi) = v_acc_converted;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func,
|
||||
input.get_lengths()[0],
|
||||
input.get_lengths()[1],
|
||||
input.get_lengths()[2],
|
||||
input.get_lengths()[3])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
auto func = [&](auto g, auto n, auto c, auto hi, auto wi) {
|
||||
std::size_t K = weight.get_lengths()[1];
|
||||
std::size_t Y = weight.get_lengths()[3];
|
||||
std::size_t X = weight.get_lengths()[4];
|
||||
|
||||
std::size_t Ho = output.get_lengths()[3];
|
||||
std::size_t Wo = output.get_lengths()[4];
|
||||
|
||||
float v_acc = 0;
|
||||
|
||||
for(std::size_t y = 0; y < Y; ++y)
|
||||
{
|
||||
auto h_tmp = static_cast<ck_tile::long_index_t>(hi) +
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[0]) -
|
||||
static_cast<ck_tile::long_index_t>(y * conv_dilations[0]);
|
||||
if(h_tmp % conv_strides[0] == 0)
|
||||
{
|
||||
auto ho = static_cast<ck_tile::long_index_t>(h_tmp) /
|
||||
static_cast<ck_tile::long_index_t>(conv_strides[0]);
|
||||
if(ho >= 0 && ck_tile::type_convert<std::size_t>(ho) < Ho)
|
||||
{
|
||||
for(std::size_t x = 0; x < X; ++x)
|
||||
{
|
||||
auto w_tmp = static_cast<ck_tile::long_index_t>(wi) +
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[1]) -
|
||||
static_cast<ck_tile::long_index_t>(x * conv_dilations[1]);
|
||||
if(w_tmp % conv_strides[1] == 0)
|
||||
{
|
||||
auto wo = static_cast<ck_tile::long_index_t>(w_tmp) /
|
||||
static_cast<ck_tile::long_index_t>(conv_strides[1]);
|
||||
|
||||
if(wo >= 0 && ck_tile::type_convert<std::size_t>(wo) < Wo)
|
||||
{
|
||||
for(std::size_t k = 0; k < K; ++k)
|
||||
{
|
||||
OutDataType v_out = output(g, n, k, ho, wo);
|
||||
WeiDataType v_wei = weight(g, k, c, y, x);
|
||||
v_acc += ck_tile::type_convert<float>(v_out) *
|
||||
ck_tile::type_convert<float>(v_wei);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
InDataType v_acc_converted = ck_tile::type_convert<InDataType>(v_acc);
|
||||
input(g, n, c, hi, wi) = v_acc_converted;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func,
|
||||
input.get_lengths()[0],
|
||||
input.get_lengths()[1],
|
||||
input.get_lengths()[2],
|
||||
input.get_lengths()[3],
|
||||
input.get_lengths()[4])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
auto func = [&](auto g, auto n, auto c, auto di, auto hi, auto wi) {
|
||||
std::size_t K = weight.get_lengths()[1];
|
||||
std::size_t Z = weight.get_lengths()[3];
|
||||
std::size_t Y = weight.get_lengths()[4];
|
||||
std::size_t X = weight.get_lengths()[5];
|
||||
|
||||
std::size_t Do = output.get_lengths()[3];
|
||||
std::size_t Ho = output.get_lengths()[4];
|
||||
std::size_t Wo = output.get_lengths()[5];
|
||||
|
||||
float v_acc = 0;
|
||||
|
||||
for(std::size_t z = 0; z < Z; ++z)
|
||||
{
|
||||
auto d_tmp = static_cast<ck_tile::long_index_t>(di) +
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[0]) -
|
||||
static_cast<ck_tile::long_index_t>(z * conv_dilations[0]);
|
||||
if(d_tmp % conv_strides[0] == 0)
|
||||
{
|
||||
auto do_ = static_cast<ck_tile::long_index_t>(d_tmp) /
|
||||
static_cast<ck_tile::long_index_t>(conv_strides[0]);
|
||||
if(do_ >= 0 && ck_tile::type_convert<std::size_t>(do_) < Do)
|
||||
{
|
||||
for(std::size_t y = 0; y < Y; ++y)
|
||||
{
|
||||
auto h_tmp = static_cast<ck_tile::long_index_t>(hi) +
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[1]) -
|
||||
static_cast<ck_tile::long_index_t>(y * conv_dilations[1]);
|
||||
if(h_tmp % conv_strides[1] == 0)
|
||||
{
|
||||
auto ho = static_cast<ck_tile::long_index_t>(h_tmp) /
|
||||
static_cast<ck_tile::long_index_t>(conv_strides[1]);
|
||||
if(ho >= 0 && ck_tile::type_convert<std::size_t>(ho) < Ho)
|
||||
{
|
||||
for(std::size_t x = 0; x < X; ++x)
|
||||
{
|
||||
auto w_tmp =
|
||||
static_cast<ck_tile::long_index_t>(wi) +
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[2]) -
|
||||
static_cast<ck_tile::long_index_t>(x *
|
||||
conv_dilations[2]);
|
||||
|
||||
if(w_tmp % conv_strides[2] == 0)
|
||||
{
|
||||
auto wo =
|
||||
static_cast<ck_tile::long_index_t>(w_tmp) /
|
||||
static_cast<ck_tile::long_index_t>(conv_strides[2]);
|
||||
if(wo >= 0 &&
|
||||
ck_tile::type_convert<std::size_t>(wo) < Wo)
|
||||
{
|
||||
for(std::size_t k = 0; k < K; ++k)
|
||||
{
|
||||
OutDataType v_out =
|
||||
output(g, n, k, do_, ho, wo);
|
||||
WeiDataType v_wei = weight(g, k, c, z, y, x);
|
||||
v_acc += ck_tile::type_convert<float>(v_out) *
|
||||
ck_tile::type_convert<float>(v_wei);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
InDataType v_acc_converted = ck_tile::type_convert<InDataType>(v_acc);
|
||||
input(g, n, c, di, hi, wi) = v_acc_converted;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func,
|
||||
input.get_lengths()[0],
|
||||
input.get_lengths()[1],
|
||||
input.get_lengths()[2],
|
||||
input.get_lengths()[3],
|
||||
input.get_lengths()[4],
|
||||
input.get_lengths()[5])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Ref_conv_bwd_data: number of dimensions must be between 1 and 3.");
|
||||
}
|
||||
}
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,167 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <thread>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType>
|
||||
CK_TILE_HOST void
|
||||
reference_grouped_conv_bwd_weight(const HostTensor<InDataType>& input,
|
||||
HostTensor<WeiDataType>& weight,
|
||||
const HostTensor<OutDataType>& output,
|
||||
std::vector<ck_tile::long_index_t> conv_strides,
|
||||
std::vector<ck_tile::long_index_t> conv_dilations,
|
||||
std::vector<ck_tile::long_index_t> in_left_pads,
|
||||
std::vector<ck_tile::long_index_t>)
|
||||
{
|
||||
if(!(input.get_num_of_dimension() == NDimSpatial + 3 &&
|
||||
weight.get_num_of_dimension() == NDimSpatial + 3 &&
|
||||
output.get_num_of_dimension() == NDimSpatial + 3))
|
||||
{
|
||||
throw std::runtime_error("wrong! inconsistent dimension");
|
||||
}
|
||||
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
auto func = [&](auto g, auto k, auto c, auto x) {
|
||||
float v_acc = 0;
|
||||
|
||||
for(std::size_t n = 0; n < output.get_lengths()[1]; ++n)
|
||||
{
|
||||
for(std::size_t wo = 0; wo < output.get_lengths()[3]; ++wo)
|
||||
{
|
||||
auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[0]) +
|
||||
static_cast<ck_tile::long_index_t>(x * conv_dilations[0]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
|
||||
|
||||
if(wi >= 0 && ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[3])
|
||||
{
|
||||
InDataType v_in = input(g, n, c, wi);
|
||||
OutDataType v_out = output(g, n, k, wo);
|
||||
v_acc += ck_tile::type_convert<float>(v_out) *
|
||||
ck_tile::type_convert<float>(v_in);
|
||||
}
|
||||
}
|
||||
}
|
||||
OutDataType v_acc_converted = ck_tile::type_convert<WeiDataType>(v_acc);
|
||||
weight(g, k, c, x) = v_acc_converted;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func,
|
||||
weight.get_lengths()[0],
|
||||
weight.get_lengths()[1],
|
||||
weight.get_lengths()[2],
|
||||
weight.get_lengths()[3])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
auto func = [&](auto g, auto k, auto c, auto y, auto x) {
|
||||
float v_acc = 0;
|
||||
|
||||
for(std::size_t n = 0; n < output.get_lengths()[1]; ++n)
|
||||
{
|
||||
for(std::size_t ho = 0; ho < output.get_lengths()[3]; ++ho)
|
||||
{
|
||||
auto hi = static_cast<ck_tile::long_index_t>(ho * conv_strides[0]) +
|
||||
static_cast<ck_tile::long_index_t>(y * conv_dilations[0]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
|
||||
|
||||
for(std::size_t wo = 0; wo < output.get_lengths()[4]; ++wo)
|
||||
{
|
||||
auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[1]) +
|
||||
static_cast<ck_tile::long_index_t>(x * conv_dilations[1]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[1]);
|
||||
|
||||
if(hi >= 0 &&
|
||||
ck_tile::type_convert<std::size_t>(hi) < input.get_lengths()[3] &&
|
||||
wi >= 0 &&
|
||||
ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[4])
|
||||
{
|
||||
InDataType v_in = input(g, n, c, hi, wi);
|
||||
OutDataType v_out = output(g, n, k, ho, wo);
|
||||
|
||||
v_acc += ck_tile::type_convert<float>(v_out) *
|
||||
ck_tile::type_convert<float>(v_in);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
WeiDataType v_acc_converted = ck_tile::type_convert<WeiDataType>(v_acc);
|
||||
weight(g, k, c, y, x) = v_acc_converted;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func,
|
||||
weight.get_lengths()[0],
|
||||
weight.get_lengths()[1],
|
||||
weight.get_lengths()[2],
|
||||
weight.get_lengths()[3],
|
||||
weight.get_lengths()[4])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
auto func = [&](auto g, auto k, auto c, auto z, auto y, auto x) {
|
||||
float v_acc = 0;
|
||||
|
||||
for(std::size_t n = 0; n < output.get_lengths()[1]; ++n)
|
||||
{
|
||||
for(std::size_t do_ = 0; do_ < output.get_lengths()[3]; ++do_)
|
||||
{
|
||||
auto di = static_cast<ck_tile::long_index_t>(do_ * conv_strides[0]) +
|
||||
static_cast<ck_tile::long_index_t>(z * conv_dilations[0]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
|
||||
for(std::size_t ho = 0; ho < output.get_lengths()[4]; ++ho)
|
||||
{
|
||||
auto hi = static_cast<ck_tile::long_index_t>(ho * conv_strides[1]) +
|
||||
static_cast<ck_tile::long_index_t>(y * conv_dilations[1]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[1]);
|
||||
for(std::size_t wo = 0; wo < output.get_lengths()[5]; ++wo)
|
||||
{
|
||||
auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[2]) +
|
||||
static_cast<ck_tile::long_index_t>(x * conv_dilations[2]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[2]);
|
||||
if(di >= 0 &&
|
||||
ck_tile::type_convert<std::size_t>(di) < input.get_lengths()[3] &&
|
||||
hi >= 0 &&
|
||||
ck_tile::type_convert<std::size_t>(hi) < input.get_lengths()[4] &&
|
||||
wi >= 0 &&
|
||||
ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[5])
|
||||
{
|
||||
InDataType v_in = input(g, n, c, di, hi, wi);
|
||||
OutDataType v_out = output(g, n, k, do_, ho, wo);
|
||||
|
||||
v_acc += ck_tile::type_convert<float>(v_out) *
|
||||
ck_tile::type_convert<float>(v_in);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
WeiDataType v_acc_converted = ck_tile::type_convert<WeiDataType>(v_acc);
|
||||
weight(g, k, c, z, y, x) = v_acc_converted;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func,
|
||||
weight.get_lengths()[0],
|
||||
weight.get_lengths()[1],
|
||||
weight.get_lengths()[2],
|
||||
weight.get_lengths()[3],
|
||||
weight.get_lengths()[4],
|
||||
weight.get_lengths()[5])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Ref_conv_bwd_weight: number of dimensions must be between 1 and 3.");
|
||||
}
|
||||
}
|
||||
} // namespace ck_tile
|
||||
@@ -9,7 +9,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
#define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \
|
||||
static_cast<uint32_t>(((token_id_)&0x00ffffff) | (((topk_id_)&0xff) << 24))
|
||||
static_cast<uint32_t>(((token_id_) & 0x00ffffff) | (((topk_id_) & 0xff) << 24))
|
||||
|
||||
template <typename WeightType, typename IndexType = index_t>
|
||||
CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
|
||||
|
||||
@@ -30,4 +30,82 @@ reference_reduce(const HostTensor<XDataType>& x_m_n, HostTensor<YDataType>& y_m,
|
||||
|
||||
make_ParallelTensorFunctor(f, y_m.mDesc.get_lengths()[0])(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
// Generic reference reduce for arbitrary dimensions
|
||||
template <
|
||||
typename XDataType,
|
||||
typename ComputeDataType,
|
||||
typename YDataType,
|
||||
typename ReduceOp,
|
||||
typename KeptDim, // Expected type: ck_tile::sequence<...> containing dimension indices to keep
|
||||
typename ReduceDims> // Expected type: ck_tile::sequence<...> containing dimension indices to
|
||||
// reduce
|
||||
CK_TILE_HOST void reference_reduce(const HostTensor<XDataType>& x_tensor,
|
||||
HostTensor<YDataType>& y_tensor,
|
||||
ReduceOp reduce_op,
|
||||
KeptDim kept_dim,
|
||||
ReduceDims reduce_dims)
|
||||
{
|
||||
const auto& x_lengths = x_tensor.mDesc.get_lengths();
|
||||
|
||||
// Calculate total kept elements (product of all kept dimension lengths)
|
||||
index_t total_kept_elements = 1;
|
||||
static_for<0, kept_dim.size(), 1>{}(
|
||||
[&](auto i) { total_kept_elements *= x_lengths[kept_dim.at(i)]; });
|
||||
|
||||
// Calculate total reduce elements (product of all reduce dimension lengths)
|
||||
index_t total_reduce_elements = 1;
|
||||
static_for<0, reduce_dims.size(), 1>{}(
|
||||
[&](auto i) { total_reduce_elements *= x_lengths[reduce_dims.at(i)]; });
|
||||
|
||||
auto f = [&](auto linear_kept_idx) {
|
||||
ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
|
||||
|
||||
// Convert linear kept index to multi-dimensional kept indices
|
||||
std::vector<index_t> kept_indices(kept_dim.size());
|
||||
index_t temp_kept = linear_kept_idx;
|
||||
static_for<0, kept_dim.size(), 1>{}([&](auto i) {
|
||||
constexpr auto dim_idx = kept_dim.size() - 1 - i;
|
||||
constexpr auto dim = kept_dim.at(dim_idx);
|
||||
const auto len = x_lengths[dim];
|
||||
kept_indices[dim_idx] = temp_kept % len;
|
||||
temp_kept /= len;
|
||||
});
|
||||
|
||||
for(index_t reduce_idx = 0; reduce_idx < total_reduce_elements; ++reduce_idx)
|
||||
{
|
||||
// Convert linear reduce index to multi-dimensional reduce indices
|
||||
std::vector<index_t> reduce_indices(reduce_dims.size());
|
||||
index_t temp_reduce = reduce_idx;
|
||||
static_for<0, reduce_dims.size(), 1>{}([&](auto i) {
|
||||
constexpr auto dim_idx = reduce_dims.size() - 1 - i;
|
||||
constexpr auto dim = reduce_dims.at(dim_idx);
|
||||
const auto len = x_lengths[dim];
|
||||
reduce_indices[dim_idx] = temp_reduce % len;
|
||||
temp_reduce /= len;
|
||||
});
|
||||
|
||||
// Build full input tensor indices by combining kept and reduce indices
|
||||
std::vector<std::size_t> full_indices(x_lengths.size(), 0);
|
||||
static_for<0, kept_dim.size(), 1>{}(
|
||||
[&](auto i) { full_indices[kept_dim.at(i)] = kept_indices[i]; });
|
||||
static_for<0, reduce_dims.size(), 1>{}(
|
||||
[&](auto i) { full_indices[reduce_dims.at(i)] = reduce_indices[i]; });
|
||||
|
||||
// Access input tensor element
|
||||
const auto v_a = type_convert<ComputeDataType>(x_tensor(full_indices));
|
||||
|
||||
v_acc = reduce_op(v_acc, v_a);
|
||||
}
|
||||
|
||||
// Calculate output tensor index using kept indices
|
||||
// The output tensor has the same structure as the kept dimensions
|
||||
std::vector<std::size_t> y_indices(kept_dim.size());
|
||||
static_for<0, kept_dim.size(), 1>{}([&](auto i) { y_indices[i] = kept_indices[i]; });
|
||||
|
||||
y_tensor(y_indices) = type_convert<YDataType>(v_acc);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f, total_kept_elements)(std::thread::hardware_concurrency());
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -14,7 +14,7 @@ CK_TILE_HOST void
|
||||
reference_softmax(const HostTensor<InputType>& x, HostTensor<OutputType>& y, index_t dim = -1)
|
||||
{
|
||||
index_t rank = x.get_num_of_dimension();
|
||||
assert(rank == y.get_num_of_dimension());
|
||||
assert(static_cast<std::size_t>(rank) == y.get_num_of_dimension());
|
||||
assert(dim == -1 || dim < rank);
|
||||
|
||||
index_t target_dim = dim == -1 ? (rank - 1) : dim;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -38,8 +38,8 @@ CK_TILE_HOST void reference_topk(const HostTensor<DataType>& x,
|
||||
{
|
||||
// rank must be the same
|
||||
index_t rank = x.get_num_of_dimension();
|
||||
assert(rank == y_values.get_num_of_dimension());
|
||||
assert(rank == y_indices.get_num_of_dimension());
|
||||
assert(static_cast<std::size_t>(rank) == y_values.get_num_of_dimension());
|
||||
assert(static_cast<size_t>(rank) == y_indices.get_num_of_dimension());
|
||||
assert(dim == -1 || dim < rank);
|
||||
|
||||
index_t topk_dim = dim == -1 ? (rank - 1) : dim;
|
||||
@@ -47,7 +47,8 @@ CK_TILE_HOST void reference_topk(const HostTensor<DataType>& x,
|
||||
auto x_len = x.get_lengths();
|
||||
|
||||
assert(k <= topk_src_len);
|
||||
assert(k == y_values.get_length(topk_dim) && k == y_indices.get_length(topk_dim));
|
||||
assert(static_cast<size_t>(k) == y_values.get_length(topk_dim) &&
|
||||
static_cast<size_t>(k) == y_indices.get_length(topk_dim));
|
||||
|
||||
index_t n_parallel = x.get_element_size() / topk_src_len;
|
||||
|
||||
|
||||
33
include/ck_tile/host/reference/reference_transpose.hpp
Normal file
33
include/ck_tile/host/reference/reference_transpose.hpp
Normal file
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include <thread>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType, typename BDataType>
|
||||
void reference_transpose_elementwise(const HostTensor<ADataType>& a, HostTensor<BDataType>& b)
|
||||
{
|
||||
ck_tile::index_t M = static_cast<ck_tile::index_t>(a.mDesc.get_lengths()[0]);
|
||||
ck_tile::index_t N = static_cast<ck_tile::index_t>(a.mDesc.get_lengths()[1]);
|
||||
|
||||
// Ensure the b tensor is sized correctly for N x M
|
||||
if(static_cast<ck_tile::index_t>(b.mDesc.get_lengths()[0]) != N ||
|
||||
static_cast<ck_tile::index_t>(b.mDesc.get_lengths()[1]) != M)
|
||||
{
|
||||
throw std::runtime_error("Output tensor b has incorrect dimensions for transpose.");
|
||||
}
|
||||
|
||||
auto f = [&](auto i, auto j) {
|
||||
auto v_a = a(i, j);
|
||||
b(j, i) = ck_tile::type_convert<BDataType>(v_a);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f, M, N)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user