mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
Merge remote-tracking branch 'origin/ginolu/add_wgmfma_dispatcher' into mtgu/cktile_mxfp4_flatmm_dev
This commit is contained in:
102
include/ck_tile/ops/elementwise/binary_elementwise_operation.hpp
Normal file
102
include/ck_tile/ops/elementwise/binary_elementwise_operation.hpp
Normal file
@@ -0,0 +1,102 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
namespace element_wise {
|
||||
|
||||
struct Add
|
||||
{
|
||||
template <typename Y, typename X0, typename X1>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<float>(float& y, const float& x0, const float& x1) const
|
||||
{
|
||||
y = x0 + x1;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<double>(double& y, const double& x0, const double& x1) const
|
||||
{
|
||||
y = x0 + x1;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<float>(float& y, const float& x0, const half_t& x1) const
|
||||
{
|
||||
y = x0 + type_convert<half_t>(x1);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t>(half_t& y, const float& x0, const float& x1) const
|
||||
{
|
||||
y = type_convert<half_t>(x0 + x1);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
|
||||
{
|
||||
y = type_convert<half_t>(x0) + x1;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
|
||||
{
|
||||
y = x0 + x1;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<float>(float& y, const float& x0, const bf16_t& x1) const
|
||||
{
|
||||
const float x1_tmp = type_convert<float>(x1);
|
||||
y = x0 + x1_tmp;
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bf16_t>(bf16_t& y, const bf16_t& x0, const bf16_t& x1) const
|
||||
{
|
||||
const float x1_tmp = type_convert<float>(x0);
|
||||
const float x2_tmp = type_convert<float>(x1);
|
||||
const float y_tmp = x1_tmp + x2_tmp;
|
||||
y = type_convert<bf16_t>(y_tmp);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bf16_t>(bf16_t& y, const float& x0, const bf16_t& x1) const
|
||||
{
|
||||
const float x2_tmp = type_convert<float>(x1);
|
||||
const float y_tmp = x0 + x2_tmp;
|
||||
y = type_convert<bf16_t>(y_tmp);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bf16_t>(bf16_t& y, const float& x0, const float& x1) const
|
||||
{
|
||||
const float y_tmp = x0 + x1;
|
||||
y = type_convert<bf16_t>(y_tmp);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
|
||||
{
|
||||
y = x0 + x1;
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace element_wise
|
||||
} // namespace ck_tile
|
||||
125
include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp
Normal file
125
include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp
Normal file
@@ -0,0 +1,125 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_default_policy.hpp"
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_, typename Policy_>
|
||||
struct ElementWiseKernel
|
||||
{
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
|
||||
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
|
||||
using ElementWiseOperation = ck_tile::remove_cvref_t<typename Problem::ElementWiseOperation>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::BlockShape::kBlockSize;
|
||||
|
||||
template <typename... XDataType, typename Dims>
|
||||
CK_TILE_DEVICE void operator()(const Dims lens,
|
||||
const Dims input_strides,
|
||||
const Dims output_strides,
|
||||
const tuple<XDataType...>& input_tensors,
|
||||
YDataType* p_y) const
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
// Setup block-level coordinates and transforms
|
||||
const index_t iM = get_block_id() * S::kBlockM;
|
||||
const auto merge_transform = make_merge_transform(lens);
|
||||
|
||||
// Load all input tiles into registers.
|
||||
// The lambda structure here is intended to minimize the lifetime
|
||||
// of intermediate objects (views, windows) used for loading.
|
||||
const auto x_tiles = ck_tile::generate_tuple(
|
||||
[&](auto i) {
|
||||
const auto tensor_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
input_tensors.get(i), lens, input_strides, number<S::kVectorM>{}, number<1>{});
|
||||
|
||||
const auto transformed_tensor = pad_tensor_view(
|
||||
transform_tensor_view(tensor_view,
|
||||
ck_tile::make_tuple(merge_transform),
|
||||
ck_tile::make_tuple(make_index_sequence<Dims::size()>{}),
|
||||
ck_tile::make_tuple(sequence<0>{})),
|
||||
ck_tile::make_tuple(number<S::kBlockM>{}),
|
||||
sequence<Problem::kPad>{});
|
||||
|
||||
const auto x_window =
|
||||
make_tile_window(transformed_tensor,
|
||||
ck_tile::make_tuple(number<S::kBlockM>{}),
|
||||
{iM},
|
||||
Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
|
||||
return load_tile(x_window);
|
||||
},
|
||||
number<sizeof...(XDataType)>{});
|
||||
|
||||
// Setup output tile in registers.
|
||||
const auto& x_tile0 = x_tiles.get(number<0>{});
|
||||
auto y_tile = make_static_distributed_tensor<YDataType>(x_tile0.get_tile_distribution());
|
||||
|
||||
// Perform element-wise computation.
|
||||
const auto spans = x_tile0.get_distributed_spans();
|
||||
sweep_tile_span(spans[number<0>{}], [&](auto idx) {
|
||||
const auto tile_idx = make_tuple(idx);
|
||||
apply(
|
||||
[&](auto&&... tiles) {
|
||||
ElementWiseOperation{}(y_tile(tile_idx),
|
||||
type_convert<ComputeDataType>(tiles[tile_idx])...);
|
||||
},
|
||||
x_tiles);
|
||||
});
|
||||
|
||||
// Setup output window and store the result tile.
|
||||
const auto y_m_n = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_y, lens, output_strides, number<S::kVectorM>{});
|
||||
|
||||
const auto transformed_y_m_n = pad_tensor_view(
|
||||
transform_tensor_view(y_m_n,
|
||||
ck_tile::make_tuple(merge_transform),
|
||||
ck_tile::make_tuple(make_index_sequence<Dims::size()>{}),
|
||||
ck_tile::make_tuple(sequence<0>{})),
|
||||
ck_tile::make_tuple(number<S::kBlockM>{}),
|
||||
sequence<Problem::kPad>{});
|
||||
|
||||
auto y_window = make_tile_window(transformed_y_m_n,
|
||||
make_tuple(number<S::kBlockM>{}),
|
||||
{iM},
|
||||
y_tile.get_tile_distribution());
|
||||
|
||||
store_tile(y_window, cast_tile<YDataType>(y_tile));
|
||||
}
|
||||
|
||||
template <typename... Ints>
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const ck_tile::tuple<Ints...>& input_sizes)
|
||||
{
|
||||
int total_elements = 1;
|
||||
const auto kVectorM = Problem_::BlockShape::kVectorM;
|
||||
|
||||
apply([&](auto&&... args) { ((total_elements *= args), ...); }, input_sizes);
|
||||
|
||||
if((total_elements % kVectorM) != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Conditions not met: total number of input elements (",
|
||||
total_elements,
|
||||
") should be multiple of the vectorization size (",
|
||||
kVectorM,
|
||||
")");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,29 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
struct ElementWiseDefaultPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>, // Replicate
|
||||
tuple<sequence<S::kRepeatM,
|
||||
S::kWarpPerBlockM,
|
||||
S::kThreadPerWarpM,
|
||||
S::kVectorM>>, // Hierarchical
|
||||
tuple<sequence<1>, sequence<1>>, // Parallel
|
||||
tuple<sequence<1>, sequence<2>>, // Parallel
|
||||
sequence<1, 1>, // Yield
|
||||
sequence<0, 3>>{} // Yield
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,26 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename XDataType_,
|
||||
typename ComputeDataType_,
|
||||
typename YDataType_,
|
||||
typename BlockShape_,
|
||||
typename ElementWiseOperation_,
|
||||
bool kPad_ = true>
|
||||
struct ElementWisePipelineProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using YDataType = remove_cvref_t<YDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
using ElementWiseOperation = remove_cvref_t<ElementWiseOperation_>;
|
||||
static constexpr bool kPad = kPad_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,30 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BlockWarps, typename BlockTile, typename WarpTile, typename ComputeDataType>
|
||||
struct ElementWiseShape
|
||||
{
|
||||
static constexpr index_t kBlockM = BlockTile::at(number<0>{});
|
||||
|
||||
static constexpr index_t kWarpM = WarpTile::at(number<0>{});
|
||||
|
||||
static constexpr index_t kVectorM =
|
||||
min(static_cast<index_t>(16 / sizeof(ComputeDataType)), kWarpM / get_warp_size());
|
||||
|
||||
static constexpr index_t kWarpPerBlockM = BlockWarps::at(number<0>{});
|
||||
|
||||
static constexpr index_t kThreadPerWarpM = get_warp_size();
|
||||
|
||||
static constexpr index_t kRepeatM = kBlockM / (kWarpPerBlockM * kVectorM * kThreadPerWarpM);
|
||||
|
||||
static constexpr index_t kBlockSize =
|
||||
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -110,6 +110,86 @@ CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q)
|
||||
return res;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE fp8x8_t amd_assembly_i4_to_fp8x8(int a)
|
||||
{
|
||||
uint32_t src = static_cast<uint32_t>(a), src_hi;
|
||||
uint32_t fp8x4_lo, fp8x4_hi;
|
||||
float tmp_0, tmp_1;
|
||||
|
||||
asm volatile("v_lshrrev_b32 %[v_hi_src], 4, %[v_src]\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_3\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_3\n"
|
||||
"v_cvt_pk_fp8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n"
|
||||
|
||||
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_2\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_2\n"
|
||||
"v_cvt_pk_fp8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0]\n"
|
||||
|
||||
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_1\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_1\n"
|
||||
"v_cvt_pk_fp8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n"
|
||||
|
||||
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src]\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src]\n"
|
||||
"v_cvt_pk_fp8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0]\n"
|
||||
: [v_tmp_0] "+v"(tmp_0),
|
||||
[v_tmp_1] "+v"(tmp_1),
|
||||
[v_hi_src] "+v"(src_hi),
|
||||
[v_dst_lo] "+v"(fp8x4_lo),
|
||||
[v_dst_hi] "+v"(fp8x4_hi),
|
||||
[v_src] "+v"(src)
|
||||
:);
|
||||
|
||||
return bit_cast<fp8x8_t>(((static_cast<uint64_t>(fp8x4_hi) << 32) | fp8x4_lo));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE float amd_assembly_fp8_to_fp32(uint32_t src)
|
||||
{
|
||||
float res;
|
||||
asm volatile("v_cvt_f32_fp8 %0, %1, src0_sel:BYTE_0" : "=v"(res) : "v"(src));
|
||||
return res;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE float amd_assembly_bf8_to_fp32(uint32_t src)
|
||||
{
|
||||
float res;
|
||||
asm volatile("v_cvt_f32_bf8 %0, %1, src0_sel:BYTE_0" : "=v"(res) : "v"(src));
|
||||
return res;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE bf8x8_t amd_assembly_i4_to_bf8x8(int a)
|
||||
{
|
||||
uint32_t src = static_cast<uint32_t>(a), src_hi;
|
||||
uint32_t bf8x4_lo, bf8x4_hi;
|
||||
float tmp_0, tmp_1;
|
||||
|
||||
asm volatile("v_lshrrev_b32 %[v_hi_src], 4, %[v_src]\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_3\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_3\n"
|
||||
"v_cvt_pk_bf8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n"
|
||||
|
||||
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_2\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_2\n"
|
||||
"v_cvt_pk_bf8_f32 %[v_dst_hi], %[v_tmp_1], %[v_tmp_0]\n"
|
||||
|
||||
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_1\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src], src0_sel:BYTE_1\n"
|
||||
"v_cvt_pk_bf8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0], op_sel:[0, 0, 1]\n"
|
||||
|
||||
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src]\n"
|
||||
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_hi_src]\n"
|
||||
"v_cvt_pk_bf8_f32 %[v_dst_lo], %[v_tmp_1], %[v_tmp_0]\n"
|
||||
: [v_tmp_0] "+v"(tmp_0),
|
||||
[v_tmp_1] "+v"(tmp_1),
|
||||
[v_hi_src] "+v"(src_hi),
|
||||
[v_dst_lo] "+v"(bf8x4_lo),
|
||||
[v_dst_hi] "+v"(bf8x4_hi),
|
||||
[v_src] "+v"(src)
|
||||
:);
|
||||
|
||||
return bit_cast<bf8x8_t>(((static_cast<uint64_t>(bf8x4_hi) << 32) | bf8x4_lo));
|
||||
}
|
||||
|
||||
struct PassThroughPack8
|
||||
{
|
||||
template <typename Y, typename X>
|
||||
@@ -126,6 +206,16 @@ struct PassThroughPack8
|
||||
y.lo = i4_to_bhalf4(bit_cast<int>(x));
|
||||
y.hi = i4_to_bhalf4(bit_cast<int>(x) >> 16);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(fp8x8_t& y, const pk_int4x4_t& x) const
|
||||
{
|
||||
y = amd_assembly_i4_to_fp8x8(bit_cast<int>(x));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(bf8x8_t& y, const pk_int4x4_t& x) const
|
||||
{
|
||||
y = amd_assembly_i4_to_bf8x8(bit_cast<int>(x));
|
||||
}
|
||||
constexpr const static bool is_pack8_invocable = true;
|
||||
};
|
||||
|
||||
@@ -172,223 +262,70 @@ struct PassThroughPack2
|
||||
|
||||
struct PassThrough
|
||||
{
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
|
||||
template <class T>
|
||||
using raw_t = std::remove_cv_t<std::remove_reference_t<T>>;
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<double, double>(double& y, const double& x) const
|
||||
template <class Y, class X>
|
||||
CK_TILE_HOST_DEVICE void operator()(Y&& y, const X& x) const
|
||||
{
|
||||
y = x;
|
||||
/* Only do the assignment when
|
||||
- y is an *l-value* and
|
||||
- y is *not* const */
|
||||
if constexpr(std::is_lvalue_reference_v<Y&&> && !std::is_const_v<raw_t<Y>>)
|
||||
{
|
||||
y = ck_tile::type_convert<raw_t<Y>>(x);
|
||||
}
|
||||
/* otherwise (r-value or const) → do nothing */
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<float, double>(float& y, const double& x) const
|
||||
template <typename E, typename C, typename... Ds>
|
||||
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void
|
||||
{
|
||||
y = type_convert<float>(x);
|
||||
}
|
||||
// Suppress unused parameter warning for ds
|
||||
((void)ds, ...);
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<double, float>(double& y, const float& x) const
|
||||
{
|
||||
y = type_convert<double>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<float, float>(float& y, const float& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::fp16_t, ck_tile::fp16_t>(ck_tile::fp16_t& y, const ck_tile::fp16_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp16_t, float>(ck_tile::fp16_t& y,
|
||||
const float& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::fp16_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::bf16_t, ck_tile::bf16_t>(ck_tile::bf16_t& y, const ck_tile::bf16_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<int32_t, int32_t>(int32_t& y, const int32_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<ck_tile::bf16_t, float>(ck_tile::bf16_t& y,
|
||||
const float& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::bf16_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::bf16_t>(float& y,
|
||||
const ck_tile::bf16_t& x) const
|
||||
{
|
||||
y = type_convert<float>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::bf16_t, ck_tile::fp16_t>(ck_tile::bf16_t& y, const ck_tile::fp16_t& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::bf16_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::fp16_t>(float& y,
|
||||
const ck_tile::fp16_t& x) const
|
||||
{
|
||||
y = type_convert<float>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp16_t, int8_t>(ck_tile::fp16_t& y,
|
||||
const int8_t& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::fp16_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<ck_tile::bf16_t, int8_t>(ck_tile::bf16_t& y,
|
||||
const int8_t& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::bf16_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<uint8_t, uint8_t>(uint8_t& y, const uint8_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<int8_t, int32_t>(int8_t& y, const int32_t& x) const
|
||||
{
|
||||
y = type_convert<int8_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<int32_t, int8_t>(int32_t& y, const int8_t& x) const
|
||||
{
|
||||
y = type_convert<int32_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<int8_t, float>(int8_t& y, const float& x) const
|
||||
{
|
||||
y = type_convert<int8_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<float, int8_t>(float& y, const int8_t& x) const
|
||||
{
|
||||
y = type_convert<float>(x);
|
||||
}
|
||||
|
||||
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<int4_t, int4_t>(int4_t& y, const int4_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<int4_t, int>(int4_t& y, const int& x) const
|
||||
{
|
||||
y = type_convert<int4_t>(x);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::fp8_t, ck_tile::fp8_t>(ck_tile::fp8_t& y, const ck_tile::fp8_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::fp8_t>(float& y,
|
||||
const ck_tile::fp8_t& x) const
|
||||
{
|
||||
y = type_convert<float>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp8_t, float>(ck_tile::fp8_t& y,
|
||||
const float& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::fp8_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::fp16_t, ck_tile::fp8_t>(ck_tile::fp16_t& y, const ck_tile::fp8_t& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::fp16_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::fp8_t, ck_tile::fp16_t>(ck_tile::fp8_t& y, const ck_tile::fp16_t& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::fp8_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::bf8_t, ck_tile::bf8_t>(ck_tile::bf8_t& y, const ck_tile::bf8_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::bf8_t>(float& y,
|
||||
const ck_tile::bf8_t& x) const
|
||||
{
|
||||
y = type_convert<float>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<ck_tile::bf8_t, float>(ck_tile::bf8_t& y,
|
||||
const float& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::bf8_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::fp16_t, ck_tile::bf8_t>(ck_tile::fp16_t& y, const ck_tile::bf8_t& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::fp16_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::bf8_t, ck_tile::fp16_t>(ck_tile::bf8_t& y, const ck_tile::fp16_t& x) const
|
||||
{
|
||||
y = ck_tile::type_convert<ck_tile::bf8_t>(x);
|
||||
// Just assign e with c
|
||||
if constexpr(std::is_same_v<E, C>)
|
||||
{
|
||||
e = c;
|
||||
}
|
||||
else
|
||||
{
|
||||
e = ck_tile::type_convert<E>(c);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct MultiDMultiply
|
||||
{
|
||||
template <typename E, typename C, typename... Ds>
|
||||
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void
|
||||
{
|
||||
// Start with the base value c
|
||||
float result = ck_tile::type_convert<float>(c);
|
||||
|
||||
// Multiply by each D parameter using fold expression
|
||||
((result *= ck_tile::type_convert<float>(ds)), ...);
|
||||
|
||||
e = ck_tile::type_convert<E>(result);
|
||||
}
|
||||
};
|
||||
|
||||
struct MultiDAdd
|
||||
{
|
||||
template <typename E, typename C, typename... Ds>
|
||||
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void
|
||||
{
|
||||
// Start with the base value c
|
||||
float result = ck_tile::type_convert<float>(c);
|
||||
|
||||
// Add by each D parameter using fold expression
|
||||
((result += ck_tile::type_convert<float>(ds)), ...);
|
||||
|
||||
e = ck_tile::type_convert<E>(result);
|
||||
}
|
||||
};
|
||||
|
||||
#if 0
|
||||
struct UnaryConvert
|
||||
{
|
||||
template <typename Y, typename X>
|
||||
@@ -398,6 +335,7 @@ struct UnaryConvert
|
||||
}
|
||||
};
|
||||
|
||||
#if 0
|
||||
struct ConvertBF16RTN
|
||||
{
|
||||
// convert to bf16 using round to nearest (rtn)
|
||||
@@ -534,14 +472,14 @@ struct UnaryDivide
|
||||
|
||||
struct UnarySquare
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
|
||||
{
|
||||
static_assert(std::is_same_v<T, float> || std::is_same_v<T, ck_tile::fp16_t> ||
|
||||
std::is_same_v<T, double> || std::is_same_v<T, int32_t> ||
|
||||
std::is_same_v<T, int8_t>
|
||||
static_assert(std::is_same_v<X, float> || std::is_same_v<X, ck_tile::fp16_t> ||
|
||||
std::is_same_v<X, double> || std::is_same_v<X, int32_t> ||
|
||||
std::is_same_v<X, int8_t>
|
||||
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
|| std::is_same_v<T, int4_t>
|
||||
|| std::is_same_v<X, int4_t>
|
||||
#endif
|
||||
,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
Reference in New Issue
Block a user