mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
fix and merge
This commit is contained in:
@@ -9,6 +9,8 @@
|
||||
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
|
||||
#include "ck_tile/core/algorithm/static_encoding_pattern.hpp"
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp"
|
||||
#include "ck_tile/core/arch/amd_transpose_load_encoding.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
|
||||
#include "ck_tile/core/arch/utility.hpp"
|
||||
@@ -38,6 +40,7 @@
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include "ck_tile/core/tensor/buffer_view.hpp"
|
||||
#include "ck_tile/core/tensor/load_tile.hpp"
|
||||
#include "ck_tile/core/tensor/load_tile_transpose.hpp"
|
||||
#include "ck_tile/core/tensor/null_tensor.hpp"
|
||||
#include "ck_tile/core/tensor/null_tile_window.hpp"
|
||||
#include "ck_tile/core/tensor/shuffle_tile.hpp"
|
||||
@@ -55,6 +58,7 @@
|
||||
#include "ck_tile/core/tensor/tile_elementwise.hpp"
|
||||
#include "ck_tile/core/tensor/tile_scatter_gather.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window_base.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window_linear.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window_utils.hpp"
|
||||
#include "ck_tile/core/tensor/transpose_tile.hpp"
|
||||
|
||||
@@ -56,19 +56,24 @@ template <index_t BlockSize,
|
||||
index_t YPerTile,
|
||||
index_t XPerTile,
|
||||
index_t VecSize,
|
||||
tile_distribution_pattern DistributionPattern>
|
||||
tile_distribution_pattern DistributionPattern,
|
||||
index_t NumWaveGroups = 1>
|
||||
struct TileDistributionEncodingPattern2D : public TileDistributionEncodingPattern
|
||||
{
|
||||
};
|
||||
|
||||
// Thread raked
|
||||
template <index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t VecSize>
|
||||
template <index_t BlockSize,
|
||||
index_t YPerTile,
|
||||
index_t XPerTile,
|
||||
index_t VecSize,
|
||||
index_t NumWaveGroups>
|
||||
struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
tile_distribution_pattern::thread_raked>
|
||||
: public TileDistributionEncodingPattern
|
||||
tile_distribution_pattern::thread_raked,
|
||||
NumWaveGroups> : public TileDistributionEncodingPattern
|
||||
{
|
||||
|
||||
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
|
||||
@@ -83,45 +88,76 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
static constexpr index_t Y1 = warp_size / X0;
|
||||
static_assert(X0 * Y1 == warp_size, "X0 * Y1 must cover whole wavefront!");
|
||||
|
||||
static constexpr index_t Y0 = num_warps;
|
||||
static constexpr index_t Y0 = num_warps / NumWaveGroups;
|
||||
// YPerWarp = YPerTile / Y0;
|
||||
// Y2 = YPerWarp / Y1;
|
||||
static constexpr index_t Y2 = YPerTile / (Y1 * Y0); // # of iters within wavefront
|
||||
|
||||
static_assert(X0 * Y1 * Y0 == BlockSize, "X0 * warp_ys * Y0 must cover whole workgroup!");
|
||||
static_assert(X0 * Y1 * Y0 * NumWaveGroups == BlockSize,
|
||||
"X0 * warp_ys * Y0 must cover whole workgroup!");
|
||||
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>>{});
|
||||
if constexpr(NumWaveGroups != 1)
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<Y0>,
|
||||
tuple<sequence<Y1, Y2>, sequence<X0, X1>>,
|
||||
tuple<sequence<0>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<0, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>>{});
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 2>>{});
|
||||
if constexpr(NumWaveGroups != 1)
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<Y0>,
|
||||
tuple<sequence<X0, X1>, sequence<Y1, Y2>>,
|
||||
tuple<sequence<0>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<0, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 2>>{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Warp raked
|
||||
template <index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t VecSize>
|
||||
template <index_t BlockSize,
|
||||
index_t YPerTile,
|
||||
index_t XPerTile,
|
||||
index_t VecSize,
|
||||
index_t NumWaveGroups>
|
||||
struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
tile_distribution_pattern::warp_raked>
|
||||
: public TileDistributionEncodingPattern
|
||||
tile_distribution_pattern::warp_raked,
|
||||
NumWaveGroups> : public TileDistributionEncodingPattern
|
||||
{
|
||||
|
||||
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
|
||||
@@ -164,13 +200,17 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
};
|
||||
|
||||
// Block raked
|
||||
template <index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t VecSize>
|
||||
template <index_t BlockSize,
|
||||
index_t YPerTile,
|
||||
index_t XPerTile,
|
||||
index_t VecSize,
|
||||
index_t NumWaveGroups>
|
||||
struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
tile_distribution_pattern::block_raked>
|
||||
: public TileDistributionEncodingPattern
|
||||
tile_distribution_pattern::block_raked,
|
||||
NumWaveGroups> : public TileDistributionEncodingPattern
|
||||
{
|
||||
|
||||
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
|
||||
|
||||
@@ -1438,8 +1438,10 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
|
||||
static_assert(
|
||||
(std::is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, fp16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, bf16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, fp16_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) ||
|
||||
(std::is_same<T, bf16_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) ||
|
||||
(std::is_same<T, int32_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
@@ -1562,6 +1564,54 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
|
||||
|
||||
return bit_cast<rtn_type>(tmp);
|
||||
}
|
||||
else if constexpr(N == 16)
|
||||
{
|
||||
thread_buffer<float, 8> tmp;
|
||||
|
||||
tmp.template get_as<fp32x4_t>()(number<0>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
tmp.template get_as<fp32x4_t>()(number<1>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(float),
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
return bit_cast<rtn_type>(tmp);
|
||||
}
|
||||
else if constexpr(N == 32)
|
||||
{
|
||||
thread_buffer<float, 16> tmp;
|
||||
|
||||
tmp.template get_as<fp32x4_t>()(number<0>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
tmp.template get_as<fp32x4_t>()(number<1>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(float),
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
tmp.template get_as<fp32x4_t>()(number<2>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 8 * sizeof(float),
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
tmp.template get_as<fp32x4_t>()(number<3>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 12 * sizeof(float),
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
return bit_cast<rtn_type>(tmp);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, bf16_t>::value) // bf16
|
||||
{
|
||||
@@ -1598,6 +1648,54 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
|
||||
|
||||
return bit_cast<rtn_type>(tmp);
|
||||
}
|
||||
else if constexpr(N == 16)
|
||||
{
|
||||
thread_buffer<float, 8> tmp;
|
||||
|
||||
tmp.template get_as<fp32x4_t>()(number<0>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
tmp.template get_as<fp32x4_t>()(number<1>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(float),
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
return bit_cast<rtn_type>(tmp);
|
||||
}
|
||||
else if constexpr(N == 32)
|
||||
{
|
||||
thread_buffer<float, 16> tmp;
|
||||
|
||||
tmp.template get_as<fp32x4_t>()(number<0>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
tmp.template get_as<fp32x4_t>()(number<1>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(float),
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
tmp.template get_as<fp32x4_t>()(number<2>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 8 * sizeof(float),
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
tmp.template get_as<fp32x4_t>()(number<3>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 12 * sizeof(float),
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
return bit_cast<rtn_type>(tmp);
|
||||
}
|
||||
}
|
||||
else // other datatype
|
||||
{
|
||||
@@ -2698,6 +2796,44 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <typename T, index_t N, address_space_enum BufferAddressSpace>
|
||||
__device__ auto amd_transpose_load_to_vgpr(const T* in_ptr)
|
||||
{
|
||||
|
||||
static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32),
|
||||
"We need to have the compatible compiler version to build this instruction");
|
||||
if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::half_t>)
|
||||
{
|
||||
typedef __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16 llvm_fp16x4_t;
|
||||
__attribute__((address_space(3))) llvm_fp16x4_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>(
|
||||
reinterpret_cast<uintptr_t>(in_ptr));
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4f16(lds_ptr));
|
||||
}
|
||||
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::bf16_t>)
|
||||
{
|
||||
typedef __attribute__((__vector_size__(4 * sizeof(__bf16)))) __bf16 llvm_bf16x4_t;
|
||||
__attribute__((address_space(3))) llvm_bf16x4_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>(
|
||||
reinterpret_cast<uintptr_t>(in_ptr));
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr));
|
||||
}
|
||||
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t>)
|
||||
{
|
||||
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_fp8x8_t;
|
||||
__attribute__((address_space(3))) llvm_fp8x8_t* lds_ptr =
|
||||
reinterpret_cast<__attribute__((address_space(3))) llvm_fp8x8_t*>(
|
||||
reinterpret_cast<uintptr_t>(in_ptr));
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "not implemented");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
#endif // !CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
|
||||
|
||||
2597
include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp
Normal file
2597
include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp
Normal file
File diff suppressed because it is too large
Load Diff
86
include/ck_tile/core/arch/amd_transpose_load_encoding.hpp
Normal file
86
include/ck_tile/core/arch/amd_transpose_load_encoding.hpp
Normal file
@@ -0,0 +1,86 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// this generate wave level tile distribution
|
||||
template <typename T, typename = void>
|
||||
struct LaneGroupTransposeTraits;
|
||||
|
||||
template <typename T>
|
||||
struct LaneGroupTransposeTraits<T, std::enable_if_t<sizeof(T) == 2>>
|
||||
{
|
||||
// before transpose, 4x16
|
||||
static constexpr index_t ksecondDim = 4;
|
||||
static constexpr index_t kleadDim = 16;
|
||||
// after transpose, 16x4
|
||||
static constexpr index_t ksecondDimT = 16;
|
||||
static constexpr index_t kleadDimT = 4;
|
||||
template <index_t kOuterDistDim0,
|
||||
index_t kOuterDistDim1,
|
||||
index_t kInnerDistDim0,
|
||||
index_t kInnerDistDim1>
|
||||
using TileDistribution =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<kOuterDistDim0, kOuterDistDim1, 4>,
|
||||
sequence<kInnerDistDim0, kInnerDistDim1, 4, 4>>,
|
||||
tuple<sequence<1, 2, 1, 2>>,
|
||||
tuple<sequence<0, 0, 2, 2>>,
|
||||
sequence<2, 1, 2>,
|
||||
sequence<1, 1, 3>>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct LaneGroupTransposeTraits<T, std::enable_if_t<sizeof(T) == 1>>
|
||||
{
|
||||
static constexpr index_t ksecondDim = 8;
|
||||
static constexpr index_t kleadDim = 16;
|
||||
|
||||
static constexpr index_t ksecondDimT = 16;
|
||||
static constexpr index_t kleadDimT = 8;
|
||||
|
||||
template <index_t kOuterDistDim0,
|
||||
index_t kOuterDistDim1,
|
||||
index_t kInnerDistDim0,
|
||||
index_t kInnerDistDim1>
|
||||
using TileDistribution =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<kOuterDistDim0, kOuterDistDim1, 8>,
|
||||
sequence<kInnerDistDim0, kInnerDistDim1, 2, 8>>,
|
||||
tuple<sequence<1, 2, 1, 2>>,
|
||||
tuple<sequence<0, 0, 2, 2>>,
|
||||
sequence<2, 1, 2>,
|
||||
sequence<1, 1, 3>>;
|
||||
};
|
||||
|
||||
/*
|
||||
* @brief This function is used to generate the transposed distribution encoding
|
||||
* for the given data type and distribution dimensions.
|
||||
*
|
||||
* @tparam T The data type of the elements in the tensor.
|
||||
* @tparam kOuterDistDim0 The outer distribution dimension 0, which is outer dimension for stride.
|
||||
* @tparam kOuterDistDim1 The outer distribution dimension 1, which is inner dimension for stride.
|
||||
* @tparam kInnerDistDim0 The inner distribution dimension 0, which is outer dimension for
|
||||
* consecutive.
|
||||
* @tparam kInnerDistDim1 The inner distribution dimension 1, which is inner dimension for
|
||||
* consecutive.
|
||||
*/
|
||||
template <typename T,
|
||||
index_t kOuterDistDim0,
|
||||
index_t kOuterDistDim1,
|
||||
index_t kInnerDistDim0,
|
||||
index_t kInnerDistDim1>
|
||||
CK_TILE_DEVICE constexpr auto make_transposed_distr_encode()
|
||||
{
|
||||
using xdllevel_dstr_encoding = typename LaneGroupTransposeTraits<T>::
|
||||
template TileDistribution<kOuterDistDim0, kOuterDistDim1, kInnerDistDim0, kInnerDistDim1>;
|
||||
return xdllevel_dstr_encoding{};
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -50,8 +50,11 @@ enum struct memory_operation_enum : std::uint16_t
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
|
||||
{
|
||||
// warpSize is defined by HIP
|
||||
return warpSize;
|
||||
#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
|
||||
return 64;
|
||||
#else
|
||||
return 32;
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE index_t get_grid_size() { return gridDim.x; }
|
||||
|
||||
@@ -35,7 +35,7 @@ CK_TILE_DEVICE T warp_shuffle_up(const T& v_local, uint32_t lane_delta)
|
||||
#elif 1
|
||||
static_assert(sizeof(T) == sizeof(int32_t), "wrong!");
|
||||
|
||||
const uint32_t wrap_around_lane_delta = warpSize - lane_delta;
|
||||
const uint32_t wrap_around_lane_delta = get_warp_size() - lane_delta;
|
||||
|
||||
const int32_t v_remote_tmp = __builtin_amdgcn_ds_bpermute(
|
||||
(__lane_id() << 2) + (wrap_around_lane_delta << 2), bit_cast<int32_t>(v_local));
|
||||
|
||||
@@ -223,6 +223,10 @@
|
||||
#define CK_TILE_FMHA_FWD_FAST_EXP2 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN
|
||||
#define CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_BUFFER_LOAD_RAW_BF16_WA
|
||||
#define CK_TILE_BUFFER_LOAD_RAW_BF16_WA 1
|
||||
#endif
|
||||
@@ -236,17 +240,17 @@
|
||||
#define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID 1
|
||||
#endif
|
||||
|
||||
#ifndef __HIP_DEVICE_COMPILE__ // for host code
|
||||
#ifdef CK_TILE_USE_OCP_FP8
|
||||
#ifndef CK_TILE_USE_OCP_FP8
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
#if defined(__gfx950__) || defined(__gfx12__)
|
||||
#define CK_TILE_USE_OCP_FP8 1
|
||||
#else
|
||||
#define CK_TILE_USE_OCP_FP8 0
|
||||
#endif
|
||||
#elif defined(__gfx950__) || defined(__gfx12__) // for GPU code
|
||||
#define CK_TILE_USE_OCP_FP8 1
|
||||
#else // for GPU code
|
||||
#else
|
||||
#define CK_TILE_USE_OCP_FP8 0
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
|
||||
#if __clang_major__ == 20
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
using index_t = int32_t;
|
||||
using int32_t = int32_t;
|
||||
using long_index_t = int64_t;
|
||||
using int8_t = int8_t;
|
||||
|
||||
|
||||
@@ -5,7 +5,11 @@
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#if __clang_major__ == 20
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp"
|
||||
#else
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
|
||||
#endif
|
||||
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
@@ -14,6 +18,7 @@
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/utility/ignore.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -129,6 +134,28 @@ struct buffer_view<address_space_enum::generic,
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
In the generic address space, we do not support the transpose instruction in the buffer view.
|
||||
Will report compilation error when developer wants to use it.
|
||||
*/
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto transpose_get(index_t i,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
static_assert(false, "Error: transpose load not supported in global memory space.");
|
||||
ignore = i;
|
||||
ignore = linear_offset;
|
||||
ignore = is_valid_element;
|
||||
return;
|
||||
}
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <memory_operation_enum Op,
|
||||
typename X,
|
||||
@@ -355,6 +382,28 @@ struct buffer_view<address_space_enum::global,
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
In the global memory address space, we do not support the transpose instruction in the buffer
|
||||
view. Will report compilation error when developer wants to use it.
|
||||
*/
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto transpose_get(index_t i,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
static_assert(false, "Error: transpose load not supported in global memory space.");
|
||||
ignore = i;
|
||||
ignore = linear_offset;
|
||||
ignore = is_valid_element;
|
||||
return;
|
||||
}
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
@@ -849,6 +898,47 @@ struct buffer_view<address_space_enum::lds,
|
||||
smem_load<sizeof(X)>{}(dst, v_offset * sizeof(T), i_offset * sizeof(T));
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto transpose_get([[maybe_unused]] index_t i,
|
||||
[[maybe_unused]] index_t linear_offset,
|
||||
bool is_valid_element) const
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
|
||||
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
|
||||
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X should contain multiple T");
|
||||
|
||||
if(is_valid_element)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
constexpr address_space_enum addr_space = get_address_space();
|
||||
return amd_transpose_load_to_vgpr<remove_cvref_t<T>, t_per_x, addr_space>(
|
||||
p_data_ + i + linear_offset);
|
||||
#else
|
||||
return X{numeric<remove_cvref_t<T>>::zero()};
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(InvalidElementUseNumericalZeroValue)
|
||||
{
|
||||
return X{numeric<remove_cvref_t<T>>::zero()};
|
||||
}
|
||||
else
|
||||
{
|
||||
return X{invalid_element_value_};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <memory_operation_enum Op,
|
||||
typename X,
|
||||
@@ -920,6 +1010,15 @@ struct buffer_view<address_space_enum::lds,
|
||||
std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8x16_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
|
||||
// int8 on thread buffer
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 8>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 4>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 2>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 1>>) ||
|
||||
// ext_vector_type for pk_int4 must use int8_t as type
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 1>>) ||
|
||||
@@ -942,6 +1041,8 @@ struct buffer_view<address_space_enum::lds,
|
||||
|
||||
if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 1>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 1>>))
|
||||
{
|
||||
@@ -952,6 +1053,8 @@ struct buffer_view<address_space_enum::lds,
|
||||
}
|
||||
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x2_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 2>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 2>>))
|
||||
{
|
||||
@@ -962,6 +1065,8 @@ struct buffer_view<address_space_enum::lds,
|
||||
}
|
||||
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 4>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>))
|
||||
{
|
||||
@@ -972,6 +1077,8 @@ struct buffer_view<address_space_enum::lds,
|
||||
}
|
||||
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 8>>) ||
|
||||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
|
||||
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>))
|
||||
{
|
||||
|
||||
362
include/ck_tile/core/tensor/load_tile_transpose.hpp
Normal file
362
include/ck_tile/core/tensor/load_tile_transpose.hpp
Normal file
@@ -0,0 +1,362 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
|
||||
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/container/thread_buffer.hpp"
|
||||
#include "ck_tile/core/container/statically_indexed_array.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace util {
|
||||
template <typename Suffix, typename Sequence>
|
||||
struct is_sequence_suffix
|
||||
{
|
||||
static constexpr bool size_check = (Suffix::size() <= Sequence::size());
|
||||
|
||||
static constexpr index_t start_pos = Sequence::size() - Suffix::size();
|
||||
using extract_indices = typename arithmetic_sequence_gen<start_pos, Sequence::size(), 1>::type;
|
||||
|
||||
static constexpr bool value =
|
||||
size_check && (Suffix{} == decltype(Sequence::extract(extract_indices{})){});
|
||||
};
|
||||
|
||||
template <index_t... Xs>
|
||||
struct is_sequence_suffix<sequence<>, sequence<Xs...>>
|
||||
{
|
||||
static constexpr bool value = true;
|
||||
};
|
||||
|
||||
template <typename Suffix, typename Sequence>
|
||||
constexpr bool is_sequence_suffix_v = is_sequence_suffix<Suffix, Sequence>::value;
|
||||
|
||||
} // namespace util
|
||||
|
||||
// Default policy: Retains original 2D transpose behavior
|
||||
template <typename DataType>
|
||||
struct DefaultTranspose
|
||||
{
|
||||
struct Quad16
|
||||
{
|
||||
using InputEncoding = tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<4>, sequence<4, 4>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
|
||||
using OutputEncoding = tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<16>, sequence<4>>,
|
||||
tuple<sequence<1>>,
|
||||
tuple<sequence<0>>,
|
||||
sequence<2>,
|
||||
sequence<0>>;
|
||||
};
|
||||
|
||||
struct Quad8
|
||||
{
|
||||
using InputEncoding = tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<8>, sequence<2, 8>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
|
||||
using OutputEncoding = tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<16>, sequence<8>>,
|
||||
tuple<sequence<1>>,
|
||||
tuple<sequence<0>>,
|
||||
sequence<2>,
|
||||
sequence<0>>;
|
||||
};
|
||||
|
||||
// Select based on data size
|
||||
using QuadInputEncoding = std::conditional_t<sizeof(DataType) == 2,
|
||||
typename Quad16::InputEncoding,
|
||||
typename Quad8::InputEncoding>;
|
||||
|
||||
using QuadOutputEncoding = std::conditional_t<sizeof(DataType) == 2,
|
||||
typename Quad16::OutputEncoding,
|
||||
typename Quad8::OutputEncoding>;
|
||||
|
||||
// Always swap last two dimensions
|
||||
static constexpr auto transpose_dims = sequence<1, 0>{};
|
||||
|
||||
// Programmable: Element grouping function
|
||||
static constexpr auto group_func = [](auto idx) {
|
||||
return idx; // Identity mapping
|
||||
};
|
||||
|
||||
template <typename InDstrEncode>
|
||||
struct ValidationTraits
|
||||
{
|
||||
static constexpr auto input_hs_lengthss = InDstrEncode::hs_lengthss_;
|
||||
static constexpr auto quad_hs_lengthss = QuadInputEncoding::hs_lengthss_;
|
||||
// 1. Must be 2D tensor
|
||||
static constexpr bool dims_valid = (InDstrEncode::NDimX == 2);
|
||||
// 2. Quad pattern must be suffix of input pattern
|
||||
static constexpr bool suffix_valid_dim0 =
|
||||
util::is_sequence_suffix_v<decltype(quad_hs_lengthss.template get<0>()),
|
||||
decltype(input_hs_lengthss.template get<0>())>;
|
||||
static constexpr bool suffix_valid_dim1 =
|
||||
util::is_sequence_suffix_v<decltype(quad_hs_lengthss.template get<1>()),
|
||||
decltype(input_hs_lengthss.template get<1>())>;
|
||||
|
||||
// 3. PS→RHS mapping constraints
|
||||
static constexpr auto input_ps_to_rhss_major = InDstrEncode::ps_to_rhss_major_;
|
||||
static constexpr auto input_ps_to_rhss_minor = InDstrEncode::ps_to_rhss_minor_;
|
||||
|
||||
static constexpr index_t ndimp_outer = input_ps_to_rhss_major.size() - 1;
|
||||
static constexpr index_t ndimp_inner =
|
||||
input_ps_to_rhss_major[number<ndimp_outer>{}].size() - 1;
|
||||
|
||||
static constexpr bool ps_mapping_valid =
|
||||
(input_ps_to_rhss_major[number<ndimp_outer>{}][number<ndimp_inner>{}] == 2) &&
|
||||
(input_ps_to_rhss_minor[number<ndimp_outer>{}][number<ndimp_inner>{}] ==
|
||||
input_hs_lengthss[number<1>{}].size() - 2) &&
|
||||
(input_ps_to_rhss_major[number<ndimp_outer>{}][number<ndimp_inner - 1>{}] == 1) &&
|
||||
(input_ps_to_rhss_minor[number<ndimp_outer>{}][number<ndimp_inner - 1>{}] ==
|
||||
input_hs_lengthss[number<0>{}].size() - 1);
|
||||
|
||||
// 4. YS→RHS mapping constraints
|
||||
static constexpr auto input_ys_to_rhs_major = InDstrEncode::ys_to_rhs_major_;
|
||||
static constexpr auto input_ys_to_rhs_minor = InDstrEncode::ys_to_rhs_minor_;
|
||||
|
||||
static constexpr bool ys_mapping_valid =
|
||||
(input_ys_to_rhs_major.back() == 2) &&
|
||||
(input_ys_to_rhs_minor.back() == input_hs_lengthss[number<1>{}].size() - 1) &&
|
||||
(input_ys_to_rhs_major[input_ys_to_rhs_major.size() - 2] == 1) &&
|
||||
(input_ys_to_rhs_minor[input_ys_to_rhs_minor.size() - 2] ==
|
||||
input_hs_lengthss[number<0>{}].size() - 2);
|
||||
|
||||
static constexpr bool value = dims_valid && suffix_valid_dim0 && suffix_valid_dim1 &&
|
||||
ps_mapping_valid && ys_mapping_valid;
|
||||
};
|
||||
};
|
||||
template <typename TileDistribution_, typename DataType_, typename Policy>
|
||||
struct TransposeTileDistrChecker
|
||||
{
|
||||
using InDstrEncode = typename remove_cvref_t<TileDistribution_>::DstrEncode;
|
||||
|
||||
using Validator = typename Policy::template ValidationTraits<InDstrEncode>;
|
||||
|
||||
static constexpr bool distr_encoding_valid = Validator::value;
|
||||
};
|
||||
|
||||
// this is used to generate the transposed output tile distribution encoding
|
||||
// based on the input tile distribution encoding
|
||||
template <typename TileDistribution_,
|
||||
typename DataType_,
|
||||
typename Policy = DefaultTranspose<DataType_>>
|
||||
struct OutputTileDistributionTraits
|
||||
{
|
||||
using InDstrEncode = typename remove_cvref_t<TileDistribution_>::DstrEncode;
|
||||
static constexpr auto input_hs_lengthss = InDstrEncode::hs_lengthss_;
|
||||
static constexpr auto quad_input_hs_lengthss = Policy::QuadInputEncoding::hs_lengthss_;
|
||||
static constexpr auto quad_output_hs_lengthss = Policy::QuadOutputEncoding::hs_lengthss_;
|
||||
|
||||
static constexpr auto input_ps_to_rhss_major = InDstrEncode::ps_to_rhss_major_;
|
||||
static constexpr auto input_ps_to_rhss_minor = InDstrEncode::ps_to_rhss_minor_;
|
||||
static constexpr auto input_ys_to_rhs_major = InDstrEncode::ys_to_rhs_major_;
|
||||
static constexpr auto input_ys_to_rhs_minor = InDstrEncode::ys_to_rhs_minor_;
|
||||
|
||||
static constexpr auto quad_ps_to_rhss_major = Policy::QuadInputEncoding::ps_to_rhss_major_;
|
||||
static constexpr auto quad_ps_to_rhss_minor = Policy::QuadInputEncoding::ps_to_rhss_minor_;
|
||||
|
||||
// for transpose load
|
||||
// append the reversed quad output hs lengths to the input hs lengthss after removing
|
||||
// the quad_input_hs_lengthss
|
||||
// then reverse the whole sequence to get the dst_out_hs_lengthss
|
||||
static constexpr auto reversed_quad_output_hs_lengthss = tuple_reverse(quad_output_hs_lengthss);
|
||||
|
||||
static constexpr auto full_out_hs_lengthss = generate_tuple(
|
||||
[](auto i) {
|
||||
return input_hs_lengthss[i]
|
||||
.extract(typename arithmetic_sequence_gen<0,
|
||||
input_hs_lengthss[i].size() -
|
||||
quad_input_hs_lengthss[i].size(),
|
||||
1>::type{})
|
||||
.push_back(reversed_quad_output_hs_lengthss[i]);
|
||||
},
|
||||
number<InDstrEncode::NDimX>{});
|
||||
|
||||
static constexpr auto dst_out_hs_lengthss = tuple_reverse(full_out_hs_lengthss);
|
||||
|
||||
// for PS→RHS mapping(both major and minor), we need to modify the last element of the major
|
||||
// sequence
|
||||
static constexpr auto modified_ps_to_rhss_major = generate_tuple(
|
||||
[](auto i) {
|
||||
if constexpr(i == input_ps_to_rhss_major.size() - 1)
|
||||
{
|
||||
constexpr auto current_size = input_ps_to_rhss_major[i].size();
|
||||
constexpr auto reduce_size = quad_ps_to_rhss_major[number<0>{}].size();
|
||||
constexpr auto reduced_ps_to_rhss_major = input_ps_to_rhss_major[i].extract(
|
||||
typename arithmetic_sequence_gen<0, current_size - reduce_size, 1>::type{});
|
||||
return reduced_ps_to_rhss_major.push_back(number<2>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
// For all other sequences, keep them unchanged
|
||||
return input_ps_to_rhss_major[i];
|
||||
}
|
||||
},
|
||||
number<input_ps_to_rhss_major.size()>{});
|
||||
|
||||
static constexpr auto minor_last_index =
|
||||
full_out_hs_lengthss[number<InDstrEncode::NDimX - 1>{}].size() - 1;
|
||||
static constexpr auto major_last_index = full_out_hs_lengthss[number<0>{}].size() - 1;
|
||||
|
||||
static constexpr auto dst_ps_to_rhss_minor = generate_tuple(
|
||||
[](auto i) {
|
||||
if constexpr(i == input_ps_to_rhss_minor.size() - 1)
|
||||
{
|
||||
constexpr auto current_size = input_ps_to_rhss_minor[i].size();
|
||||
constexpr auto reduce_size = quad_ps_to_rhss_minor[number<0>{}].size();
|
||||
constexpr auto reduced_ps_to_rhss_minor = input_ps_to_rhss_minor[i].extract(
|
||||
typename arithmetic_sequence_gen<0, current_size - reduce_size, 1>::type{});
|
||||
return reduced_ps_to_rhss_minor.push_back(number<minor_last_index>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
// For all other sequences, keep them unchanged
|
||||
return input_ps_to_rhss_minor[i];
|
||||
}
|
||||
},
|
||||
number<input_ps_to_rhss_minor.size()>{});
|
||||
|
||||
// for major because of dst_out_hs_lengthss is reversed, this index also need to be reversed
|
||||
static constexpr auto swap_one_and_two = [](const index_t idx) {
|
||||
return (idx == 1) ? 2 : (idx == 2) ? 1 : idx;
|
||||
};
|
||||
static constexpr auto dst_ps_to_rhss_major = generate_tuple(
|
||||
[](auto i) { return modified_ps_to_rhss_major[i].transform(swap_one_and_two); },
|
||||
number<modified_ps_to_rhss_major.size()>{});
|
||||
|
||||
static constexpr auto modified_input_ys_to_rhs_major =
|
||||
input_ys_to_rhs_major.pop_back().push_back(number<1>{});
|
||||
|
||||
static constexpr auto dst_ys_to_rhs_major = generate_sequence_v2(
|
||||
[](auto i) { return number<swap_one_and_two(modified_input_ys_to_rhs_major[i])>{}; },
|
||||
number<modified_input_ys_to_rhs_major.size()>{});
|
||||
|
||||
static constexpr auto dst_ys_to_rhs_minor =
|
||||
input_ys_to_rhs_minor.pop_back().push_back(number<major_last_index>{});
|
||||
|
||||
using OutDstrEncode = tile_distribution_encoding<typename InDstrEncode::RsLengths,
|
||||
remove_cvref_t<decltype(dst_out_hs_lengthss)>,
|
||||
remove_cvref_t<decltype(dst_ps_to_rhss_major)>,
|
||||
remove_cvref_t<decltype(dst_ps_to_rhss_minor)>,
|
||||
remove_cvref_t<decltype(dst_ys_to_rhs_major)>,
|
||||
remove_cvref_t<decltype(dst_ys_to_rhs_minor)>>;
|
||||
};
|
||||
|
||||
template <typename InnerEncode,
|
||||
index_t kLeadIterPerWarp,
|
||||
index_t kSecondIterPerWarp,
|
||||
index_t kLeadNumWarps,
|
||||
index_t kSecondNumWarps>
|
||||
CK_TILE_HOST_DEVICE constexpr auto InputTileDistributionEncoding()
|
||||
{
|
||||
constexpr auto block_outer_dst_encoding =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<kSecondIterPerWarp, kSecondNumWarps>,
|
||||
sequence<kLeadIterPerWarp, kLeadNumWarps>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<2, 1>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto blk_distr_encode =
|
||||
detail::make_embed_tile_distribution_encoding(block_outer_dst_encoding, InnerEncode{});
|
||||
|
||||
return blk_distr_encode;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief transpose loads tile from a tensor and returns the resulting tensor with a new
|
||||
* (transposed) tile distribution. use SFINAE to ensure the tile distribution encoding is valid.
|
||||
*
|
||||
* This function is intended for use with statically distributed tensor tiles, where the input
|
||||
* and output tile distributions differ due to the transpose operation. It ensures that the
|
||||
* element space size and vector length remain consistent between the input and output
|
||||
* distributions.
|
||||
*
|
||||
* @tparam BottomTensorView_ The type of the bottom tensor view.
|
||||
* @tparam WindowLengths_ The type representing the window lengths.
|
||||
* @tparam TileDistribution_ The type representing the tile distribution.
|
||||
* @tparam NumCoord The number of coordinates (dimensions).
|
||||
* @tparam Policy The transpose policy to use (defaults to DefaultTranspose).
|
||||
* the last is SFINAE to ensure the tile distribution encoding is valid.
|
||||
*
|
||||
* @param tile_window The tile window with static distribution to load and transpose.
|
||||
*
|
||||
* @return A statically distributed tensor containing the transposed tile data.
|
||||
*
|
||||
* @note
|
||||
* - The function uses compile-time checks to ensure the input and output tile distributions
|
||||
* are compatible in terms of element space size and vector length.
|
||||
* - The transpose operation is performed according to the specified Policy.
|
||||
*/
|
||||
template <
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
typename Policy = DefaultTranspose<typename BottomTensorView_::DataType>,
|
||||
typename = std::enable_if_t<TransposeTileDistrChecker<TileDistribution_,
|
||||
typename BottomTensorView_::DataType,
|
||||
Policy>::distr_encoding_valid,
|
||||
Policy>>
|
||||
CK_TILE_DEVICE auto
|
||||
load_tile_transpose(const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window)
|
||||
{
|
||||
using OutTileDstrEncode =
|
||||
typename OutputTileDistributionTraits<TileDistribution_,
|
||||
typename BottomTensorView_::DataType>::OutDstrEncode;
|
||||
auto out_tensor = make_static_distributed_tensor<typename BottomTensorView_::DataType>(
|
||||
make_static_tile_distribution(OutTileDstrEncode{}));
|
||||
auto trans_tensor = tile_window.template load_transpose<Policy>();
|
||||
constexpr auto input_distr = TileDistribution_{};
|
||||
constexpr auto output_distr = make_static_tile_distribution(OutTileDstrEncode{});
|
||||
|
||||
constexpr auto y_in_desc = input_distr.get_ys_to_d_descriptor();
|
||||
constexpr auto y_out_desc = output_distr.get_ys_to_d_descriptor();
|
||||
|
||||
constexpr index_t NDimYIn = input_distr.get_num_of_dimension_y();
|
||||
constexpr index_t NDimYOut = output_distr.get_num_of_dimension_y();
|
||||
|
||||
constexpr auto y_in_lengths = to_sequence(y_in_desc.get_lengths());
|
||||
constexpr auto y_out_lengths = to_sequence(y_out_desc.get_lengths());
|
||||
|
||||
constexpr auto y_in_element_space_size = y_in_desc.get_element_space_size();
|
||||
constexpr auto y_out_element_space_size = y_out_desc.get_element_space_size();
|
||||
static_assert(y_in_element_space_size == y_out_element_space_size,
|
||||
"the element space size is not the same!");
|
||||
static_assert(y_in_lengths[NDimYIn - 1] == y_out_lengths[NDimYOut - 1],
|
||||
"the vector length is not the same!");
|
||||
constexpr index_t vecLoadSize = y_in_lengths[NDimYIn - 1];
|
||||
constexpr index_t num_of_access =
|
||||
reduce_on_sequence(y_in_lengths, multiplies{}, number<1>{}) / vecLoadSize;
|
||||
|
||||
using DataVec = array<typename BottomTensorView_::DataType, vecLoadSize>;
|
||||
static_for<0, num_of_access, 1>{}([&](auto iAccess) {
|
||||
out_tensor.get_thread_buffer().template set_as<DataVec>(
|
||||
number<iAccess>{},
|
||||
trans_tensor.get_thread_buffer().template get_as<DataVec>(number<iAccess>{}));
|
||||
});
|
||||
|
||||
return out_tensor;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -129,7 +129,10 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
|
||||
// set output vectors
|
||||
static_for<0, num_vec_out, 1>{}([&](auto i) {
|
||||
constexpr auto idx_y_out_tmp = generate_array(
|
||||
[&](auto ii) { return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii]; },
|
||||
[&](auto ii) {
|
||||
return ii == y_dim_vec_in ? static_cast<index_t>(idx_y_start[ii]) + i
|
||||
: static_cast<index_t>(idx_y_start[ii]);
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr auto idx_y_out =
|
||||
|
||||
@@ -253,6 +253,33 @@ struct tensor_view
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
|
||||
get_transpose_vectorized_elements(const TensorCoord& coord, index_t linear_offset) const
|
||||
{
|
||||
return buf_.template transpose_get<X>(
|
||||
coord.get_offset(),
|
||||
linear_offset,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord));
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
|
||||
get_transpose_vectorized_elements(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element // flag
|
||||
) const
|
||||
{
|
||||
return buf_.template transpose_get<X>(coord.get_offset(), linear_offset, is_valid_element);
|
||||
}
|
||||
// X is vector of DataType.
|
||||
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
|
||||
template <typename X,
|
||||
|
||||
@@ -59,6 +59,38 @@ CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc& in_element_func,
|
||||
return out_dstr_tensor;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Template function that "unpacks" a tuple and applies an element-wise operation.
|
||||
*
|
||||
* @param in_element_func Function to apply element-wise.
|
||||
* @param t Any container containing elements to process, with known size and
|
||||
* tuple-like semantic.
|
||||
* @return Calls tile_elementwise_inout with unpacked tuple elements.
|
||||
*/
|
||||
template <typename InElementFunc, typename Tuple, size_t... I>
|
||||
CK_TILE_DEVICE auto tile_elementwise_inout_unpack(const InElementFunc& in_element_func,
|
||||
const Tuple& t,
|
||||
std::index_sequence<I...>)
|
||||
{
|
||||
return tile_elementwise_inout(in_element_func, t[number<I>{}]...);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Template function that "unpacks" a tuple and applies an element-wise operation.
|
||||
*
|
||||
* @param in_element_func Function to apply element-wise.
|
||||
* @param t Any container containing elements to process, with known size and
|
||||
* tuple-like semantic.
|
||||
* @return Calls the overloaded function, passing an index sequence.
|
||||
*/
|
||||
template <typename InElementFunc, typename Tuple>
|
||||
CK_TILE_DEVICE auto tile_elementwise_inout_unpack(const InElementFunc& in_element_func,
|
||||
const Tuple& t)
|
||||
{
|
||||
static constexpr auto size = Tuple::size();
|
||||
return tile_elementwise_inout_unpack(in_element_func, t, std::make_index_sequence<size>{});
|
||||
}
|
||||
|
||||
template <typename DstrTensors, typename T>
|
||||
CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, const T& value)
|
||||
{
|
||||
|
||||
@@ -33,6 +33,7 @@ template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
typename StaticPageIndexArray_,
|
||||
typename StaticValidArray_,
|
||||
index_t HsGatherDim = 0,
|
||||
index_t NumCoord = 1,
|
||||
index_t YsGatherDim = 0>
|
||||
@@ -42,6 +43,7 @@ struct tile_scatter_gather
|
||||
using WindowLengths = remove_cvref_t<WindowLengths_>;
|
||||
using TileDstr = remove_cvref_t<StaticTileDistribution_>;
|
||||
using PageIdxArray = remove_cvref_t<StaticPageIndexArray_>;
|
||||
using ValidArray = remove_cvref_t<StaticValidArray_>;
|
||||
using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
|
||||
using BottomTensorDesc = typename BottomTensorView::TensorDesc;
|
||||
|
||||
@@ -152,12 +154,14 @@ struct tile_scatter_gather
|
||||
const WindowLengths& window_lengths,
|
||||
const BottomTensorIndex& window_origin,
|
||||
const TileDstr& tile_distribution,
|
||||
const PageIdxArray& page_idx)
|
||||
const PageIdxArray& page_idx,
|
||||
const ValidArray& valids)
|
||||
: bottom_tensor_view_{bottom_tensor_view},
|
||||
window_lengths_{window_lengths},
|
||||
window_origin_{window_origin},
|
||||
tile_dstr_{tile_distribution},
|
||||
page_idx_{page_idx},
|
||||
valids_{valids},
|
||||
pre_computed_coords_{}
|
||||
{
|
||||
#if 0 // debug
|
||||
@@ -336,12 +340,25 @@ struct tile_scatter_gather
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
|
||||
const auto page_offset = page_idx_[idx_gather];
|
||||
|
||||
// read from bottom tensor
|
||||
const vector_t vec_value =
|
||||
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
page_offset,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
const vector_t vec_value = [&]() {
|
||||
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
|
||||
{
|
||||
return get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
page_offset,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
page_offset,
|
||||
valids_[idx_gather],
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
}();
|
||||
#if 1
|
||||
// write into distributed tensor
|
||||
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
|
||||
@@ -451,9 +468,23 @@ struct tile_scatter_gather
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
|
||||
const auto page_offset = page_idx_[idx_gather];
|
||||
|
||||
// read from bottom tensor
|
||||
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
|
||||
smem, bottom_tensor_thread_coord, page_offset, 0, pre_nop_);
|
||||
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
|
||||
{
|
||||
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
|
||||
smem, bottom_tensor_thread_coord, page_offset, 0, pre_nop_);
|
||||
}
|
||||
else
|
||||
{
|
||||
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
|
||||
smem,
|
||||
bottom_tensor_thread_coord,
|
||||
page_offset,
|
||||
valids_[idx_gather],
|
||||
0,
|
||||
pre_nop_);
|
||||
}
|
||||
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
@@ -529,11 +560,24 @@ struct tile_scatter_gather
|
||||
// const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
|
||||
|
||||
// write into bottom tensor
|
||||
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
page_offset,
|
||||
vec_value,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
|
||||
{
|
||||
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
page_offset,
|
||||
vec_value,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
page_offset,
|
||||
valids_[idx_gather],
|
||||
vec_value,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
// printf("coord_offset:%d, scatter_offset:%d \n",
|
||||
// bottom_tensor_thread_coord.get_offset(), offset); move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
@@ -570,14 +614,23 @@ struct tile_scatter_gather
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx)
|
||||
{
|
||||
page_idx_ = new_idx;
|
||||
CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx) { page_idx_ = new_idx; }
|
||||
|
||||
// static_for<0, 2, 1>{}([&](auto k0) {
|
||||
// printf("update tid %d %d \n", threadIdx.x, page_idx_[k0]);
|
||||
// });
|
||||
CK_TILE_DEVICE void update_valids(const ValidArray& new_valids)
|
||||
{
|
||||
if constexpr(std::is_same_v<ValidArray, std::nullptr_t> == false)
|
||||
{
|
||||
valids_ = new_valids;
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void update_page_idx_and_valids(const PageIdxArray& new_idx,
|
||||
const ValidArray& new_valids)
|
||||
{
|
||||
update_page_idx(new_idx);
|
||||
update_valids(new_valids);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
|
||||
{
|
||||
window_origin_ = new_window_origin;
|
||||
@@ -657,6 +710,7 @@ struct tile_scatter_gather
|
||||
TileDstr tile_dstr_;
|
||||
|
||||
PageIdxArray page_idx_;
|
||||
ValidArray valids_;
|
||||
|
||||
// this contains:
|
||||
// per-thread coordinate for window adaptor
|
||||
@@ -684,9 +738,10 @@ make_tile_scatter_gather(const TensorView_& tensor_view,
|
||||
remove_cvref_t<WindowLengths_>,
|
||||
remove_cvref_t<StaticTileDistribution_>,
|
||||
remove_cvref_t<StaticPageIndexArray_>,
|
||||
std::nullptr_t,
|
||||
HsGatherDim,
|
||||
NumCoord>{
|
||||
tensor_view, window_lengths, origin, tile_distribution, page_idx};
|
||||
tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr};
|
||||
}
|
||||
|
||||
template <typename TensorView,
|
||||
@@ -728,4 +783,76 @@ CK_TILE_DEVICE constexpr auto make_tile_scatter_gather(
|
||||
number<HsGatherDim>{});
|
||||
}
|
||||
|
||||
template <typename TensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
typename StaticPageIndexArray_,
|
||||
typename StaticValidArray_,
|
||||
index_t HsGatherDim = 0,
|
||||
index_t NumCoord = 1>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
make_tile_scatter_gather(const TensorView_& tensor_view,
|
||||
const WindowLengths_& window_lengths,
|
||||
const multi_index<TensorView_::get_num_of_dimension()>& origin,
|
||||
const StaticTileDistribution_& tile_distribution,
|
||||
const StaticPageIndexArray_& page_idx,
|
||||
const StaticValidArray_& valids,
|
||||
number<HsGatherDim> = {},
|
||||
number<NumCoord> = {})
|
||||
{
|
||||
return tile_scatter_gather<remove_cvref_t<TensorView_>,
|
||||
remove_cvref_t<WindowLengths_>,
|
||||
remove_cvref_t<StaticTileDistribution_>,
|
||||
remove_cvref_t<StaticPageIndexArray_>,
|
||||
remove_cvref_t<StaticValidArray_>,
|
||||
HsGatherDim,
|
||||
NumCoord>{
|
||||
tensor_view, window_lengths, origin, tile_distribution, page_idx, valids};
|
||||
}
|
||||
|
||||
template <typename TensorView,
|
||||
typename WindowLengths,
|
||||
typename StaticTileDistribution,
|
||||
typename StaticPageIndexArray,
|
||||
typename StaticValidArray,
|
||||
index_t HsGatherDim>
|
||||
CK_TILE_DEVICE constexpr auto make_tile_scatter_gather(
|
||||
const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
|
||||
const multi_index<TensorView::get_num_of_dimension()>& origin,
|
||||
const StaticTileDistribution& tile_distribution,
|
||||
const StaticPageIndexArray& page_idx,
|
||||
const StaticValidArray& valids,
|
||||
number<HsGatherDim> = {})
|
||||
{
|
||||
return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
|
||||
tile_window.get_window_lengths(),
|
||||
origin,
|
||||
tile_distribution,
|
||||
page_idx,
|
||||
valids,
|
||||
number<HsGatherDim>{});
|
||||
}
|
||||
|
||||
template <typename TensorView,
|
||||
typename WindowLengths,
|
||||
typename StaticTileDistribution,
|
||||
typename StaticPageIndexArray,
|
||||
typename StaticValidArray,
|
||||
index_t HsGatherDim>
|
||||
CK_TILE_DEVICE constexpr auto make_tile_scatter_gather(
|
||||
const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
|
||||
const StaticTileDistribution& tile_distribution,
|
||||
const StaticPageIndexArray& page_idx,
|
||||
const StaticValidArray& valids,
|
||||
number<HsGatherDim> = {})
|
||||
{
|
||||
return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
|
||||
tile_window.get_window_lengths(),
|
||||
tile_window.get_window_origin(),
|
||||
tile_distribution,
|
||||
page_idx,
|
||||
valids,
|
||||
number<HsGatherDim>{});
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window_base.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
@@ -34,166 +35,60 @@ template <typename BottomTensorView_,
|
||||
typename StaticTileDistribution_,
|
||||
index_t NumCoord>
|
||||
struct tile_window_with_static_distribution
|
||||
: public tile_window_with_tile_dstr_base<
|
||||
tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
StaticTileDistribution_,
|
||||
NumCoord>,
|
||||
BottomTensorView_,
|
||||
WindowLengths_,
|
||||
StaticTileDistribution_>
|
||||
{
|
||||
using BottomTensorView = remove_reference_t<BottomTensorView_>;
|
||||
using WindowLengths = remove_cvref_t<WindowLengths_>;
|
||||
using TileDstr = remove_cvref_t<StaticTileDistribution_>;
|
||||
|
||||
using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
|
||||
using BottomTensorDesc = typename BottomTensorView::TensorDesc;
|
||||
|
||||
using DataType = remove_cvref_t<typename BottomTensorView::DataType>;
|
||||
|
||||
static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension();
|
||||
static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
|
||||
|
||||
static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p();
|
||||
static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y();
|
||||
using Base = tile_window_with_tile_dstr_base<
|
||||
tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
StaticTileDistribution_,
|
||||
NumCoord>,
|
||||
BottomTensorView_,
|
||||
WindowLengths_,
|
||||
StaticTileDistribution_>;
|
||||
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static_assert(NumCoord == 1);
|
||||
|
||||
// TODO: check WindowLengths and StaticTileDistribution are consistent
|
||||
|
||||
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
|
||||
"wrong! lengths should be static");
|
||||
static_assert(TileDstr::is_static(), "wrong!");
|
||||
|
||||
static_assert(NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
|
||||
"wrong! inconsistent # of diemsnions");
|
||||
|
||||
using AdaptorTopIndex = array<index_t, NDimWindowAdaptorTop>;
|
||||
using BottomTensorIndex = array<index_t, NDimBottomTensor>;
|
||||
|
||||
using WindowAdaptorCoord =
|
||||
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{}));
|
||||
|
||||
using BottomTensorCoord =
|
||||
decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{}));
|
||||
|
||||
struct load_store_traits
|
||||
{
|
||||
private:
|
||||
static constexpr auto get_vector_dim_y_scalar_per_vector()
|
||||
{
|
||||
const auto [ys_vector_lengths, ys_vector_strides] =
|
||||
tile_window_with_static_distribution::
|
||||
get_window_adaptor_ys_safe_vector_length_strides();
|
||||
|
||||
index_t VectorDimY_ = 0;
|
||||
index_t ScalarPerVector_ = 1;
|
||||
|
||||
for(index_t i = 0; i < NDimY; ++i)
|
||||
{
|
||||
if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
|
||||
{
|
||||
ScalarPerVector_ = ys_vector_lengths[i];
|
||||
VectorDimY_ = i;
|
||||
}
|
||||
}
|
||||
|
||||
return make_tuple(VectorDimY_, ScalarPerVector_);
|
||||
}
|
||||
|
||||
public:
|
||||
static constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
|
||||
static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
|
||||
static constexpr index_t ScalarPerVector =
|
||||
get_vector_dim_y_scalar_per_vector().template at<1>();
|
||||
|
||||
// using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
|
||||
// using vector_t = typename vector_type_t::type;
|
||||
using vector_t = thread_buffer<DataType, ScalarPerVector / PackedSize>;
|
||||
|
||||
private:
|
||||
static constexpr auto scalars_per_access_ = [] {
|
||||
constexpr auto scalars_per_access_arr = generate_array(
|
||||
[&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number<NDimY>{});
|
||||
|
||||
/// TODO: add non-automatic storage argument support to macro TO_SEQUENCE()
|
||||
constexpr auto NDimY_ = NDimY;
|
||||
|
||||
return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
|
||||
}();
|
||||
|
||||
static constexpr auto get_space_filling_curve()
|
||||
{
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
constexpr auto thread_tensor_lengths_ys =
|
||||
to_sequence(tile_dstr.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
// FIXME: need logic to judge dim access order
|
||||
using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
|
||||
|
||||
return space_filling_curve<decltype(thread_tensor_lengths_ys),
|
||||
DimAccessOrder,
|
||||
decltype(scalars_per_access_)>{};
|
||||
}
|
||||
|
||||
public:
|
||||
using SFC_Ys = decltype(get_space_filling_curve());
|
||||
|
||||
static constexpr index_t NumAccess = SFC_Ys::get_num_of_access();
|
||||
|
||||
static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0");
|
||||
static_assert(NumAccess % NumCoord == 0, "wrong! # of access is not divisible by NumCoord");
|
||||
};
|
||||
|
||||
static constexpr index_t NumAccessPerCoord = load_store_traits::NumAccess / NumCoord;
|
||||
static_assert(Base::Traits::NumAccess % NumCoord == 0,
|
||||
"wrong! # of access is not divisible by NumCoord");
|
||||
static constexpr index_t NumAccessPerCoord = Base::Traits::NumAccess / NumCoord;
|
||||
|
||||
CK_TILE_DEVICE constexpr tile_window_with_static_distribution() = default;
|
||||
|
||||
CK_TILE_DEVICE constexpr tile_window_with_static_distribution(
|
||||
const BottomTensorView& bottom_tensor_view,
|
||||
const WindowLengths& window_lengths,
|
||||
const BottomTensorIndex& window_origin,
|
||||
const TileDstr& tile_distribution)
|
||||
: bottom_tensor_view_{bottom_tensor_view},
|
||||
window_lengths_{window_lengths},
|
||||
window_origin_{window_origin},
|
||||
tile_dstr_{tile_distribution},
|
||||
pre_computed_coords_{}
|
||||
const typename Base::BottomTensorView& bottom_tensor_view,
|
||||
const typename Base::WindowLengths& window_lengths,
|
||||
const typename Base::BottomTensorIndex& window_origin,
|
||||
const typename Base::TileDstr& tile_distribution)
|
||||
: pre_computed_coords_{}
|
||||
{
|
||||
#if 0 // debug
|
||||
// TODO: this use more register for FA, but less register for GEMM
|
||||
// need investigation
|
||||
// only support warp-tile and block-tile
|
||||
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
|
||||
|
||||
WindowAdaptorCoord window_adaptor_thread_coord_tmp;
|
||||
|
||||
if constexpr(NDimP == 1)
|
||||
{
|
||||
window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
|
||||
tile_distribution.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
|
||||
}
|
||||
else if constexpr(NDimP == 2)
|
||||
{
|
||||
window_adaptor_thread_coord_tmp =
|
||||
make_tensor_adaptor_coordinate(tile_distribution.get_ps_ys_to_xs_adaptor(),
|
||||
AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
|
||||
}
|
||||
#else
|
||||
// TODO: this use less register for FA, but more register for GEMM
|
||||
// need investigation
|
||||
this->window_origin_ = window_origin;
|
||||
this->window_lengths_ = window_lengths;
|
||||
this->bottom_tensor_view_ = bottom_tensor_view;
|
||||
this->tile_dstr_ = tile_distribution;
|
||||
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
|
||||
tile_distribution.get_ps_ys_to_xs_adaptor(),
|
||||
container_concat(detail::get_partition_index(tile_distribution),
|
||||
array<index_t, NDimY>{0}));
|
||||
#endif
|
||||
array<index_t, Base::NDimY>{0}));
|
||||
|
||||
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
|
||||
typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
|
||||
window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
|
||||
|
||||
const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
|
||||
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
|
||||
bottom_tensor_view.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
|
||||
|
||||
// pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
|
||||
// future load/store() calls (might allocate more registers)
|
||||
using Traits = load_store_traits;
|
||||
using Traits = typename Base::Traits;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
@@ -204,9 +99,10 @@ struct tile_window_with_static_distribution
|
||||
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
|
||||
pre_computed_coords_(iCoord) =
|
||||
@@ -214,95 +110,12 @@ struct tile_window_with_static_distribution
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; }
|
||||
|
||||
CK_TILE_DEVICE static constexpr bool has_static_tile_distribution()
|
||||
{
|
||||
return TileDstr::is_static();
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr void
|
||||
set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
|
||||
{
|
||||
bottom_tensor_view_.buf_.p_data_ = data;
|
||||
}
|
||||
|
||||
// move thread's window adaptor coordinate and bottom tensor coordinate
|
||||
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
|
||||
template <typename ATopIndex>
|
||||
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
WindowAdaptorCoord& window_adaptor_thread_coord,
|
||||
BottomTensorCoord& bottom_tensor_thread_coord,
|
||||
const ATopIndex& idx_diff_adaptor_top) const
|
||||
{
|
||||
array<index_t, NDimBottomTensor> idx_diff_adaptor_bottom;
|
||||
|
||||
move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
|
||||
window_adaptor_thread_coord,
|
||||
idx_diff_adaptor_top,
|
||||
idx_diff_adaptor_bottom);
|
||||
|
||||
move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
|
||||
bottom_tensor_thread_coord,
|
||||
idx_diff_adaptor_bottom);
|
||||
}
|
||||
|
||||
// return vector dimension among [y0, y1, ...]
|
||||
CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides()
|
||||
{
|
||||
// bottom tensor top dimension vector lengths and strides
|
||||
const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] =
|
||||
BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
|
||||
|
||||
// window vector lengths/strides
|
||||
const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths;
|
||||
const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides;
|
||||
|
||||
// window adaptor [p0, p1, ..., y0, y1, ...]
|
||||
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_lengths{
|
||||
-1};
|
||||
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_strides{
|
||||
-1};
|
||||
|
||||
constexpr auto window_adaptor_bottom_dims =
|
||||
WindowAdaptor::get_bottom_dimension_hidden_ids();
|
||||
|
||||
set_container_subset(window_adaptor_vector_lengths,
|
||||
window_adaptor_bottom_dims,
|
||||
window_adaptor_bottom_dim_vector_lengths);
|
||||
set_container_subset(window_adaptor_vector_strides,
|
||||
window_adaptor_bottom_dims,
|
||||
window_adaptor_bottom_dim_vector_strides);
|
||||
|
||||
const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
|
||||
WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
|
||||
window_adaptor_vector_lengths, window_adaptor_vector_strides);
|
||||
|
||||
// [y0, y1, ...]
|
||||
constexpr auto y_dims = typename arithmetic_sequence_gen<TileDstr::get_num_of_dimension_p(),
|
||||
NDimWindowAdaptorTop,
|
||||
1>::type{};
|
||||
|
||||
return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims),
|
||||
get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_num_of_access() const { return load_store_traits::NumAccess; }
|
||||
|
||||
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load(number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
|
||||
load(dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
|
||||
return dst_tensor;
|
||||
}
|
||||
@@ -314,11 +127,11 @@ struct tile_window_with_static_distribution
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
using Traits = typename Base::Traits;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
@@ -334,9 +147,8 @@ struct tile_window_with_static_distribution
|
||||
|
||||
// read from bottom tensor
|
||||
const vector_t vec_value =
|
||||
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
|
||||
this->get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
|
||||
#if 1
|
||||
// write into distributed tensor
|
||||
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
@@ -344,33 +156,26 @@ struct tile_window_with_static_distribution
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
number<Base::NDimY>{});
|
||||
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
Traits::PackedSize;
|
||||
|
||||
dst_tensor.get_thread_buffer().template at<d>() =
|
||||
vec_value.template get_as<DataType>()[j / Traits::PackedSize];
|
||||
vec_value
|
||||
.template get_as<typename Base::DataType>()[j / Traits::PackedSize];
|
||||
});
|
||||
#else
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
|
||||
static_assert(d % Traits::ScalarPerVector == 0);
|
||||
|
||||
dst_tensor.get_thread_buffer().template get_as<vector_t>()(
|
||||
number<d / Traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
|
||||
#endif
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
@@ -386,22 +191,16 @@ struct tile_window_with_static_distribution
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
|
||||
// using vector_type_t = typename Traits::vector_type_t;
|
||||
using Traits = typename Base::Traits;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
static constexpr index_t YElementSize =
|
||||
TileDstr{}.get_ys_to_d_descriptor().get_element_space_size();
|
||||
typename Base::TileDstr{}.get_ys_to_d_descriptor().get_element_space_size();
|
||||
static_assert(YElementSize % (Traits::PackedSize * Traits::ScalarPerVector) == 0);
|
||||
using vectorized_tbuf =
|
||||
array<vector_t, YElementSize / (Traits::PackedSize * Traits::ScalarPerVector)>;
|
||||
// StaticBuffer<address_space_enum::vgpr,
|
||||
// vector_t,
|
||||
// YElementSize / Traits::ScalarPerVector,
|
||||
// true>;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
|
||||
auto& dst_vec_tbuf = reinterpret_cast<vectorized_tbuf&>(dst_tensor.get_thread_buffer());
|
||||
|
||||
@@ -427,7 +226,7 @@ struct tile_window_with_static_distribution
|
||||
Traits::PackedSize;
|
||||
static_assert(d % Traits::ScalarPerVector == 0);
|
||||
|
||||
get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
|
||||
this->get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
|
||||
dst_vec_tbuf.template at<d / Traits::ScalarPerVector>(),
|
||||
bottom_tensor_thread_coord,
|
||||
0 /**/,
|
||||
@@ -444,10 +243,10 @@ struct tile_window_with_static_distribution
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
@@ -492,9 +291,8 @@ struct tile_window_with_static_distribution
|
||||
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
|
||||
m0_set_with_memory(m0_init_value); // This should be wave independent
|
||||
|
||||
using Traits = load_store_traits;
|
||||
using Traits = typename Base::Traits;
|
||||
|
||||
// using vector_type_t = typename Traits::vector_type_t;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
@@ -516,7 +314,7 @@ struct tile_window_with_static_distribution
|
||||
}();
|
||||
|
||||
// read from bottom tensor
|
||||
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
|
||||
this->get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
|
||||
smem, bottom_tensor_thread_coord, 0, pre_nop_);
|
||||
|
||||
// move thread coordinate
|
||||
@@ -525,10 +323,10 @@ struct tile_window_with_static_distribution
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
|
||||
m0_inc_with_memory(size_per_issue);
|
||||
@@ -548,7 +346,7 @@ struct tile_window_with_static_distribution
|
||||
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
|
||||
using LdsDataType = typename LdsTileWindow::DataType;
|
||||
#if defined(__gfx950__)
|
||||
using Traits = load_store_traits;
|
||||
using Traits = typename Base::Traits;
|
||||
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
@@ -570,7 +368,7 @@ struct tile_window_with_static_distribution
|
||||
lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ +
|
||||
lds_coord.get_offset();
|
||||
// write into bottom tensor
|
||||
get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
|
||||
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
|
||||
smem, bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
|
||||
|
||||
// move thread coordinate
|
||||
@@ -579,7 +377,7 @@ struct tile_window_with_static_distribution
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
@@ -610,7 +408,7 @@ struct tile_window_with_static_distribution
|
||||
|
||||
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
|
||||
|
||||
using Traits = load_store_traits;
|
||||
using Traits = typename Base::Traits;
|
||||
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
@@ -629,7 +427,7 @@ struct tile_window_with_static_distribution
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// read from bottom tensor
|
||||
get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
|
||||
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
|
||||
smem, bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
|
||||
|
||||
// move thread coordinate
|
||||
@@ -638,10 +436,10 @@ struct tile_window_with_static_distribution
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
|
||||
smem += size_per_issue; // Note we manually increase the per-issue
|
||||
@@ -655,18 +453,94 @@ struct tile_window_with_static_distribution
|
||||
#endif
|
||||
}
|
||||
|
||||
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
template <typename Policy, index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_transpose() const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
|
||||
this->template load_transpose<Policy>(
|
||||
dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
|
||||
return dst_tensor;
|
||||
}
|
||||
|
||||
// using vector_type_t = typename Traits::vector_type_t;
|
||||
template <typename Policy,
|
||||
typename DistributedTensor,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_transpose(DistributedTensor& dst_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using Traits = typename Base::Traits;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
|
||||
constexpr auto group_func = Policy::group_func;
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
/// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
|
||||
// read from bottom tensor
|
||||
const vector_t vec_value =
|
||||
this->get_bottom_tensor_view()
|
||||
.template get_transpose_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord, 0);
|
||||
// write into distributed tensor
|
||||
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
constexpr auto orig_idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<Base::NDimY>{});
|
||||
|
||||
constexpr auto grouped_idx_ys = group_func(orig_idx_ys);
|
||||
|
||||
constexpr index_t linear_distributed_index =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(grouped_idx_ys);
|
||||
|
||||
dst_tensor.get_thread_buffer().template at<linear_distributed_index>() =
|
||||
vec_value.template get_as<typename Base::DataType>()[j];
|
||||
});
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void store(const static_distributed_tensor<typename Base::DataType,
|
||||
typename Base::TileDstr>& dstr_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using Traits = typename Base::Traits;
|
||||
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
@@ -689,20 +563,20 @@ struct tile_window_with_static_distribution
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
number<Base::NDimY>{});
|
||||
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
Traits::PackedSize;
|
||||
|
||||
vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
|
||||
vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
// const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
|
||||
|
||||
// write into bottom tensor
|
||||
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
|
||||
this->get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
0,
|
||||
vec_value,
|
||||
@@ -714,10 +588,10 @@ struct tile_window_with_static_distribution
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
@@ -725,15 +599,17 @@ struct tile_window_with_static_distribution
|
||||
}
|
||||
|
||||
template <index_t i_access_unsupport_ = -1>
|
||||
CK_TILE_DEVICE void store_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
|
||||
number<i_access_unsupport_> = {}) const
|
||||
CK_TILE_DEVICE void
|
||||
store_raw(const static_distributed_tensor<typename Base::DataType, typename Base::TileDstr>&
|
||||
dstr_tensor,
|
||||
number<i_access_unsupport_> = {}) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
using Traits = typename Base::Traits;
|
||||
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
static constexpr bool oob_conditional_check = true;
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
@@ -756,16 +632,16 @@ struct tile_window_with_static_distribution
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
number<Base::NDimY>{});
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
Traits::PackedSize;
|
||||
vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
|
||||
vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
// write into bottom tensor
|
||||
get_bottom_tensor_view()
|
||||
this->get_bottom_tensor_view()
|
||||
.template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
|
||||
bottom_tensor_thread_coord, 0, vec_value);
|
||||
|
||||
@@ -775,10 +651,10 @@ struct tile_window_with_static_distribution
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
@@ -786,16 +662,18 @@ struct tile_window_with_static_distribution
|
||||
}
|
||||
|
||||
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void update(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
CK_TILE_DEVICE void
|
||||
update(const static_distributed_tensor<typename Base::DataType, typename Base::TileDstr>&
|
||||
dstr_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
using Traits = typename Base::Traits;
|
||||
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
@@ -818,18 +696,18 @@ struct tile_window_with_static_distribution
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
number<Base::NDimY>{});
|
||||
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
Traits::PackedSize;
|
||||
|
||||
vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
|
||||
vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
// write into bottom tensor
|
||||
get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
|
||||
this->get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
0,
|
||||
vec_value,
|
||||
@@ -841,10 +719,10 @@ struct tile_window_with_static_distribution
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
@@ -852,17 +730,19 @@ struct tile_window_with_static_distribution
|
||||
}
|
||||
|
||||
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true, bool pre_nop>
|
||||
CK_TILE_DEVICE void update_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {}) const
|
||||
CK_TILE_DEVICE void
|
||||
update_raw(const static_distributed_tensor<typename Base::DataType, typename Base::TileDstr>&
|
||||
dstr_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
using Traits = typename Base::Traits;
|
||||
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
@@ -885,18 +765,18 @@ struct tile_window_with_static_distribution
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
number<Base::NDimY>{});
|
||||
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
Traits::PackedSize;
|
||||
|
||||
vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
|
||||
vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
// write into bottom tensor
|
||||
get_bottom_tensor_view().template update_vectorized_elements_raw<vector_t>(
|
||||
this->get_bottom_tensor_view().template update_vectorized_elements_raw<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
0,
|
||||
vec_value,
|
||||
@@ -909,70 +789,44 @@ struct tile_window_with_static_distribution
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// move thread's botom tensor coordiante
|
||||
// [x0', x1', ... ] ==> [offset]
|
||||
// also move window-origin
|
||||
CK_TILE_DEVICE void move(const BottomTensorIndex& step)
|
||||
// Custom move behavior
|
||||
CK_TILE_DEVICE void move_extended(const typename Base::BottomTensorIndex& step)
|
||||
{
|
||||
window_origin_ += step;
|
||||
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
|
||||
move_tensor_coordinate(this->bottom_tensor_view_.get_tensor_descriptor(),
|
||||
pre_computed_coords_(iCoord)(I1),
|
||||
step);
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
|
||||
CK_TILE_DEVICE void set_window_origin_extended(const typename Base::BottomTensorIndex&)
|
||||
{
|
||||
window_origin_ = new_window_origin;
|
||||
|
||||
#if 0 // debug
|
||||
// TODO: this use more register for FA, but less register for GEMM
|
||||
// need investigation
|
||||
// only support warp-tile and block-tile
|
||||
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
|
||||
|
||||
WindowAdaptorCoord window_adaptor_thread_coord_tmp;
|
||||
|
||||
if constexpr(NDimP == 1)
|
||||
{
|
||||
window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
|
||||
tile_dstr_.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
|
||||
}
|
||||
else if constexpr(NDimP == 2)
|
||||
{
|
||||
window_adaptor_thread_coord_tmp =
|
||||
make_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
|
||||
AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
|
||||
}
|
||||
#else
|
||||
// TODO: this use less register for FA, but more register for GEMM
|
||||
// need investigation
|
||||
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
|
||||
tile_dstr_.get_ps_ys_to_xs_adaptor(),
|
||||
container_concat(detail::get_partition_index(tile_dstr_), array<index_t, NDimY>{0}));
|
||||
#endif
|
||||
this->tile_dstr_.get_ps_ys_to_xs_adaptor(),
|
||||
container_concat(detail::get_partition_index(this->tile_dstr_),
|
||||
array<index_t, Base::NDimY>{0}));
|
||||
|
||||
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
|
||||
window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
|
||||
typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
|
||||
this->window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
|
||||
|
||||
const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
|
||||
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
|
||||
this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
|
||||
|
||||
// pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
|
||||
// future load/store() calls (might allocate more registers)
|
||||
using Traits = load_store_traits;
|
||||
using Traits = typename Base::Traits;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
@@ -983,9 +837,10 @@ struct tile_window_with_static_distribution
|
||||
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
|
||||
pre_computed_coords_(iCoord) =
|
||||
@@ -993,27 +848,11 @@ struct tile_window_with_static_distribution
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void init_raw() { bottom_tensor_view_.init_raw(); }
|
||||
|
||||
// this is the bottom tensor view
|
||||
// [x0', x1', ...] ==> [offset]
|
||||
BottomTensorView bottom_tensor_view_;
|
||||
|
||||
//
|
||||
WindowLengths window_lengths_;
|
||||
|
||||
// origin ([x0', x1', ...]) of window on bottom tensor
|
||||
BottomTensorIndex window_origin_;
|
||||
|
||||
// Tile tensor distribution, which contains:
|
||||
// 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
|
||||
// 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
|
||||
TileDstr tile_dstr_;
|
||||
|
||||
// this contains:
|
||||
// per-thread coordinate for window adaptor
|
||||
// per-thread coordinate for bottom tensor
|
||||
array<tuple<WindowAdaptorCoord, BottomTensorCoord>, NumCoord> pre_computed_coords_;
|
||||
array<tuple<typename Base::WindowAdaptorCoord, typename Base::BottomTensorCoord>, NumCoord>
|
||||
pre_computed_coords_;
|
||||
};
|
||||
|
||||
// TODO: use strategy
|
||||
@@ -1083,62 +922,26 @@ CK_TILE_DEVICE void move_tile_window(
|
||||
*/
|
||||
template <typename BottomTensorView_, typename WindowLengths_>
|
||||
struct tile_window_with_static_lengths
|
||||
: public tile_window_base<tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>,
|
||||
BottomTensorView_,
|
||||
WindowLengths_>
|
||||
{
|
||||
using BottomTensorView = remove_reference_t<BottomTensorView_>;
|
||||
using WindowLengths = remove_cvref_t<WindowLengths_>;
|
||||
using BottomTensorDesc = typename BottomTensorView::TensorDesc;
|
||||
using DataType = typename BottomTensorView::DataType;
|
||||
|
||||
static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
|
||||
|
||||
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
|
||||
"wrong! lengths should be static");
|
||||
|
||||
using BottomTensorIndex = array<index_t, NDimBottomTensor>;
|
||||
using Base =
|
||||
tile_window_base<tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>,
|
||||
BottomTensorView_,
|
||||
WindowLengths_>;
|
||||
|
||||
CK_TILE_DEVICE constexpr tile_window_with_static_lengths() = default;
|
||||
|
||||
CK_TILE_DEVICE constexpr tile_window_with_static_lengths(
|
||||
const BottomTensorView& bottom_tensor_view,
|
||||
const WindowLengths& window_lengths,
|
||||
const BottomTensorIndex& window_origin)
|
||||
: bottom_tensor_view_{bottom_tensor_view},
|
||||
window_lengths_{window_lengths},
|
||||
window_origin_{window_origin}
|
||||
const typename Base::BottomTensorView& bottom_tensor_view,
|
||||
const typename Base::WindowLengths& window_lengths,
|
||||
const typename Base::BottomTensorIndex& window_origin)
|
||||
{
|
||||
this->window_origin_ = window_origin;
|
||||
this->window_lengths_ = window_lengths;
|
||||
this->bottom_tensor_view_ = bottom_tensor_view;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
|
||||
|
||||
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
|
||||
{
|
||||
window_origin_ = new_window_origin;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr void
|
||||
set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
|
||||
{
|
||||
bottom_tensor_view_.buf_.p_data_ = data;
|
||||
}
|
||||
|
||||
// move window-origin
|
||||
CK_TILE_DEVICE void move(const BottomTensorIndex& step) { window_origin_ += step; }
|
||||
|
||||
// this is the bottom tensor view
|
||||
// [x0', x1', ...] ==> [offset]
|
||||
BottomTensorView bottom_tensor_view_;
|
||||
|
||||
//
|
||||
WindowLengths window_lengths_;
|
||||
|
||||
// origin ([x0', x1', ...]) of window on bottom tensor
|
||||
BottomTensorIndex window_origin_;
|
||||
};
|
||||
|
||||
template <typename TensorView_, typename WindowLengths_>
|
||||
|
||||
256
include/ck_tile/core/tensor/tile_window_base.hpp
Normal file
256
include/ck_tile/core/tensor/tile_window_base.hpp
Normal file
@@ -0,0 +1,256 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/utility.hpp"
|
||||
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/**
|
||||
* @brief This class provides description of tile windowed view on the device memory.
|
||||
*
|
||||
* @note This class does not provide any functions to read or modify device memory.
|
||||
*
|
||||
* @tparam BottomTensorView_ Class describing & holding device tensor memory.
|
||||
* @tparam WindowLengths_ Spatial sizes of windowed view on tensor.
|
||||
*/
|
||||
template <typename TileWindowType_, typename BottomTensorView_, typename WindowLengths_>
|
||||
struct tile_window_base
|
||||
{
|
||||
|
||||
using BottomTensorView = remove_reference_t<BottomTensorView_>;
|
||||
using WindowLengths = remove_cvref_t<WindowLengths_>;
|
||||
using BottomTensorDesc = typename BottomTensorView::TensorDesc;
|
||||
using DataType = remove_cvref_t<typename BottomTensorView::DataType>;
|
||||
|
||||
static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
|
||||
|
||||
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
|
||||
"wrong! lengths should be static");
|
||||
|
||||
using BottomTensorIndex = array<index_t, NDimBottomTensor>;
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
|
||||
CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
|
||||
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
|
||||
CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; }
|
||||
|
||||
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
|
||||
{
|
||||
window_origin_ = new_window_origin;
|
||||
|
||||
// Delegate to child if it implements extra logic
|
||||
static_cast<TileWindowType_*>(this)->set_window_origin_extended(new_window_origin);
|
||||
}
|
||||
// Default no-op; can be overridden in child
|
||||
CK_TILE_DEVICE void set_window_origin_extended(const BottomTensorIndex&) {}
|
||||
|
||||
CK_TILE_DEVICE constexpr void
|
||||
set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
|
||||
{
|
||||
bottom_tensor_view_.buf_.p_data_ = data;
|
||||
}
|
||||
|
||||
// move window-origin
|
||||
CK_TILE_DEVICE void move(const BottomTensorIndex& step)
|
||||
{
|
||||
window_origin_ += step;
|
||||
|
||||
// Delegate to child if it implements extra movement logic
|
||||
static_cast<TileWindowType_*>(this)->move_extended(step);
|
||||
}
|
||||
|
||||
// Default no-op; can be overridden in child
|
||||
CK_TILE_DEVICE void move_extended(const BottomTensorIndex&) {}
|
||||
|
||||
// origin ([x0', x1', ...]) of window on bottom tensor
|
||||
BottomTensorIndex window_origin_;
|
||||
|
||||
WindowLengths window_lengths_;
|
||||
|
||||
// this is the bottom tensor view
|
||||
// [x0', x1', ...] ==> [offset]
|
||||
BottomTensorView bottom_tensor_view_;
|
||||
};
|
||||
|
||||
template <typename TileWindowType_,
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_>
|
||||
struct tile_window_with_tile_dstr_base
|
||||
: public tile_window_base<TileWindowType_, BottomTensorView_, WindowLengths_>
|
||||
{
|
||||
using TileDstr = remove_cvref_t<StaticTileDistribution_>;
|
||||
using TileWindowBase = tile_window_base<TileWindowType_, BottomTensorView_, WindowLengths_>;
|
||||
|
||||
using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
|
||||
|
||||
static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension();
|
||||
|
||||
static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p();
|
||||
static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y();
|
||||
|
||||
using AdaptorTopIndex = array<index_t, NDimWindowAdaptorTop>;
|
||||
// using BottomTensorIndex = array<index_t, TileWindowBase::NDimBottomTensor>;
|
||||
|
||||
using WindowAdaptorCoord =
|
||||
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{}));
|
||||
|
||||
using BottomTensorCoord = decltype(make_tensor_coordinate(
|
||||
typename TileWindowBase::BottomTensorDesc{}, typename TileWindowBase::BottomTensorIndex{}));
|
||||
|
||||
static_assert(TileDstr::is_static(), "wrong!");
|
||||
static_assert(TileWindowBase::NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
|
||||
"wrong! inconsistent # of diemsnions");
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; }
|
||||
CK_TILE_HOST_DEVICE void init_raw() { this->bottom_tensor_view_.init_raw(); }
|
||||
|
||||
CK_TILE_DEVICE static constexpr bool has_static_tile_distribution()
|
||||
{
|
||||
return TileDstr::is_static();
|
||||
}
|
||||
|
||||
// move thread's window adaptor coordinate and bottom tensor coordinate
|
||||
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
|
||||
template <typename ATopIndex>
|
||||
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
WindowAdaptorCoord& window_adaptor_thread_coord,
|
||||
BottomTensorCoord& bottom_tensor_thread_coord,
|
||||
const ATopIndex& idx_diff_adaptor_top) const
|
||||
{
|
||||
array<index_t, TileWindowBase::NDimBottomTensor> idx_diff_adaptor_bottom;
|
||||
|
||||
move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
|
||||
window_adaptor_thread_coord,
|
||||
idx_diff_adaptor_top,
|
||||
idx_diff_adaptor_bottom);
|
||||
|
||||
move_tensor_coordinate(this->bottom_tensor_view_.get_tensor_descriptor(),
|
||||
bottom_tensor_thread_coord,
|
||||
idx_diff_adaptor_bottom);
|
||||
}
|
||||
|
||||
struct Traits
|
||||
{
|
||||
public:
|
||||
static constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<typename TileWindowBase::DataType>>::PackedSize;
|
||||
|
||||
static constexpr auto get_vector_dim_y_scalar_per_vector()
|
||||
{
|
||||
const auto [ys_vector_lengths, ys_vector_strides] =
|
||||
tile_window_with_tile_dstr_base::get_window_adaptor_ys_safe_vector_length_strides();
|
||||
|
||||
index_t VectorDimY_ = 0;
|
||||
index_t ScalarPerVector_ = 1;
|
||||
|
||||
for(index_t i = 0; i < NDimY; ++i)
|
||||
{
|
||||
if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
|
||||
{
|
||||
ScalarPerVector_ = ys_vector_lengths[i];
|
||||
VectorDimY_ = i;
|
||||
}
|
||||
}
|
||||
|
||||
return make_tuple(VectorDimY_, ScalarPerVector_);
|
||||
}
|
||||
|
||||
static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
|
||||
static constexpr index_t ScalarPerVector =
|
||||
get_vector_dim_y_scalar_per_vector().template at<1>();
|
||||
using vector_t =
|
||||
thread_buffer<typename TileWindowBase::DataType, ScalarPerVector / PackedSize>;
|
||||
|
||||
static constexpr auto scalars_per_access_ = [] {
|
||||
constexpr auto scalars_per_access_arr = generate_array(
|
||||
[&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number<NDimY>{});
|
||||
|
||||
/// TODO: add non-automatic storage argument support to macro TO_SEQUENCE()
|
||||
constexpr auto NDimY_ = NDimY;
|
||||
|
||||
return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
|
||||
}();
|
||||
|
||||
static constexpr auto get_space_filling_curve()
|
||||
{
|
||||
constexpr auto thread_tensor_lengths_ys =
|
||||
to_sequence(TileDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
// FIXME: need logic to judge dim access order
|
||||
using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
|
||||
|
||||
return space_filling_curve<decltype(thread_tensor_lengths_ys),
|
||||
DimAccessOrder,
|
||||
decltype(scalars_per_access_),
|
||||
false /*!!! no snaked curve! */>{};
|
||||
}
|
||||
|
||||
using SFC_Ys = decltype(get_space_filling_curve());
|
||||
|
||||
static constexpr index_t NumAccess = SFC_Ys::get_num_of_access();
|
||||
|
||||
static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0");
|
||||
};
|
||||
|
||||
// return vector dimension among [y0, y1, ...]
|
||||
CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides()
|
||||
{
|
||||
// bottom tensor top dimension vector lengths and strides
|
||||
const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] =
|
||||
TileWindowBase::BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
|
||||
|
||||
// window vector lengths/strides
|
||||
const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths;
|
||||
const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides;
|
||||
|
||||
// window adaptor [p0, p1, ..., y0, y1, ...]
|
||||
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_lengths{
|
||||
-1};
|
||||
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_strides{
|
||||
-1};
|
||||
|
||||
constexpr auto window_adaptor_bottom_dims =
|
||||
WindowAdaptor::get_bottom_dimension_hidden_ids();
|
||||
|
||||
set_container_subset(window_adaptor_vector_lengths,
|
||||
window_adaptor_bottom_dims,
|
||||
window_adaptor_bottom_dim_vector_lengths);
|
||||
set_container_subset(window_adaptor_vector_strides,
|
||||
window_adaptor_bottom_dims,
|
||||
window_adaptor_bottom_dim_vector_strides);
|
||||
|
||||
const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
|
||||
WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
|
||||
window_adaptor_vector_lengths, window_adaptor_vector_strides);
|
||||
|
||||
// [y0, y1, ...]
|
||||
constexpr auto y_dims = typename arithmetic_sequence_gen<TileDstr::get_num_of_dimension_p(),
|
||||
NDimWindowAdaptorTop,
|
||||
1>::type{};
|
||||
|
||||
return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims),
|
||||
get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_num_of_access() const { return Traits::NumAccess; }
|
||||
// Tile tensor distribution, which contains:
|
||||
// 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
|
||||
// 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
|
||||
TileDstr tile_dstr_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -13,6 +13,7 @@
|
||||
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window_base.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
@@ -37,171 +38,48 @@ namespace ck_tile {
|
||||
// TODO: if using this struct, better use load_raw()/store_raw(), can control
|
||||
// the the immediate offset on the fly
|
||||
// space-filing-curve is non-snaked here!
|
||||
//
|
||||
// This struct inherits from tile_window_with_tile_dstr_base, which is an intermediary base class
|
||||
// with the ultimate parent class being tile_window_base.
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
typename LinearBottomDims_>
|
||||
struct tile_window_linear
|
||||
: public tile_window_with_tile_dstr_base<tile_window_linear<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
StaticTileDistribution_,
|
||||
LinearBottomDims_>,
|
||||
BottomTensorView_,
|
||||
WindowLengths_,
|
||||
StaticTileDistribution_>
|
||||
{
|
||||
using Base = tile_window_with_tile_dstr_base<tile_window_linear<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
StaticTileDistribution_,
|
||||
LinearBottomDims_>,
|
||||
BottomTensorView_,
|
||||
WindowLengths_,
|
||||
StaticTileDistribution_>;
|
||||
|
||||
using BottomTensorView = remove_reference_t<BottomTensorView_>;
|
||||
using WindowLengths = remove_cvref_t<WindowLengths_>;
|
||||
using TileDstr = remove_cvref_t<StaticTileDistribution_>;
|
||||
|
||||
using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
|
||||
using BottomTensorDesc = typename BottomTensorView::TensorDesc;
|
||||
|
||||
using DataType = remove_cvref_t<typename BottomTensorView::DataType>;
|
||||
using LinearBottomDims = remove_cvref_t<LinearBottomDims_>;
|
||||
|
||||
static_assert(LinearBottomDims::size() == BottomTensorView::get_num_of_dimension());
|
||||
|
||||
static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension();
|
||||
static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
|
||||
|
||||
static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p();
|
||||
static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y();
|
||||
static_assert(LinearBottomDims::size() == Base::BottomTensorView::get_num_of_dimension());
|
||||
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
|
||||
// TODO: check WindowLengths and StaticTileDistribution are consistent
|
||||
|
||||
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
|
||||
"wrong! lengths should be static");
|
||||
static_assert(TileDstr::is_static(), "wrong!");
|
||||
|
||||
static_assert(NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
|
||||
"wrong! inconsistent # of diemsnions");
|
||||
|
||||
using AdaptorTopIndex = array<index_t, NDimWindowAdaptorTop>;
|
||||
using BottomTensorIndex = array<index_t, NDimBottomTensor>;
|
||||
|
||||
using WindowAdaptorCoord =
|
||||
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{}));
|
||||
|
||||
using BottomTensorCoord =
|
||||
decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{}));
|
||||
|
||||
struct traits
|
||||
{
|
||||
private:
|
||||
// return vector dimension among [y0, y1, ...]
|
||||
CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides()
|
||||
{
|
||||
// bottom tensor top dimension vector lengths and strides
|
||||
const auto [bottom_tensor_top_dim_vector_lengths,
|
||||
bottom_tensor_top_dim_vector_strides] =
|
||||
BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
|
||||
|
||||
// window vector lengths/strides
|
||||
const auto window_adaptor_bottom_dim_vector_lengths =
|
||||
bottom_tensor_top_dim_vector_lengths;
|
||||
const auto window_adaptor_bottom_dim_vector_strides =
|
||||
bottom_tensor_top_dim_vector_strides;
|
||||
|
||||
// window adaptor [p0, p1, ..., y0, y1, ...]
|
||||
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()>
|
||||
window_adaptor_vector_lengths{-1};
|
||||
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()>
|
||||
window_adaptor_vector_strides{-1};
|
||||
|
||||
constexpr auto window_adaptor_bottom_dims =
|
||||
WindowAdaptor::get_bottom_dimension_hidden_ids();
|
||||
|
||||
set_container_subset(window_adaptor_vector_lengths,
|
||||
window_adaptor_bottom_dims,
|
||||
window_adaptor_bottom_dim_vector_lengths);
|
||||
set_container_subset(window_adaptor_vector_strides,
|
||||
window_adaptor_bottom_dims,
|
||||
window_adaptor_bottom_dim_vector_strides);
|
||||
|
||||
const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
|
||||
WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
|
||||
window_adaptor_vector_lengths, window_adaptor_vector_strides);
|
||||
|
||||
// [y0, y1, ...]
|
||||
constexpr auto y_dims =
|
||||
typename arithmetic_sequence_gen<TileDstr::get_num_of_dimension_p(),
|
||||
NDimWindowAdaptorTop,
|
||||
1>::type{};
|
||||
|
||||
return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims),
|
||||
get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
|
||||
}
|
||||
|
||||
static constexpr auto get_vector_dim_y_scalar_per_vector()
|
||||
{
|
||||
const auto [ys_vector_lengths, ys_vector_strides] =
|
||||
get_window_adaptor_ys_safe_vector_length_strides();
|
||||
|
||||
index_t VectorDimY_ = 0;
|
||||
index_t ScalarPerVector_ = 1;
|
||||
|
||||
for(index_t i = 0; i < NDimY; ++i)
|
||||
{
|
||||
if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
|
||||
{
|
||||
ScalarPerVector_ = ys_vector_lengths[i];
|
||||
VectorDimY_ = i;
|
||||
}
|
||||
}
|
||||
|
||||
return make_tuple(VectorDimY_, ScalarPerVector_);
|
||||
}
|
||||
|
||||
public:
|
||||
static constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
|
||||
static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
|
||||
static constexpr index_t ScalarPerVector =
|
||||
get_vector_dim_y_scalar_per_vector().template at<1>();
|
||||
|
||||
using vector_t = thread_buffer<DataType, ScalarPerVector / PackedSize>;
|
||||
|
||||
private:
|
||||
static constexpr auto scalars_per_access_ = [] {
|
||||
constexpr auto scalars_per_access_arr = generate_array(
|
||||
[&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number<NDimY>{});
|
||||
|
||||
/// TODO: add non-automatic storage argument support to macro TO_SEQUENCE()
|
||||
constexpr auto NDimY_ = NDimY;
|
||||
|
||||
return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
|
||||
}();
|
||||
|
||||
static constexpr auto get_space_filling_curve()
|
||||
{
|
||||
constexpr auto thread_tensor_lengths_ys =
|
||||
to_sequence(TileDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
// FIXME: need logic to judge dim access order
|
||||
using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
|
||||
|
||||
return space_filling_curve<decltype(thread_tensor_lengths_ys),
|
||||
DimAccessOrder,
|
||||
decltype(scalars_per_access_),
|
||||
false /*!!! no snaked curve! */>{};
|
||||
}
|
||||
|
||||
public:
|
||||
using SFC_Ys = decltype(get_space_filling_curve());
|
||||
|
||||
static constexpr index_t NumAccess = SFC_Ys::get_num_of_access();
|
||||
|
||||
static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0");
|
||||
|
||||
private:
|
||||
static constexpr auto get_num_non_linear_access()
|
||||
{
|
||||
constexpr auto sfc_access_lens = SFC_Ys::access_lengths;
|
||||
using ys_to_rhs_major =
|
||||
typename decltype(TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor;
|
||||
constexpr auto sfc_access_lens = Base::Traits::SFC_Ys::access_lengths;
|
||||
using ys_to_rhs_major = typename decltype(
|
||||
typename Base::TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor;
|
||||
|
||||
constexpr auto non_linear = [&]() {
|
||||
index_t cnt = 1;
|
||||
static_for<0, NDimY, 1>{}([&](auto i_dim_y) {
|
||||
static_for<0, Base::NDimY, 1>{}([&](auto i_dim_y) {
|
||||
constexpr auto rhs_major = ys_to_rhs_major{}[i_dim_y];
|
||||
constexpr auto target_h_dim = number<rhs_major - 1>{}; // no r dim here!
|
||||
if constexpr(LinearBottomDims{}[target_h_dim] == 0)
|
||||
@@ -230,20 +108,20 @@ struct tile_window_linear
|
||||
// -> prefixsum : seqneuce<0, 2, 4, 6, 8>
|
||||
static constexpr auto get_non_linear_access_map()
|
||||
{
|
||||
constexpr auto sfc_access_lens = SFC_Ys::access_lengths;
|
||||
using ys_to_rhs_major =
|
||||
typename decltype(TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor;
|
||||
constexpr auto sfc_access_lens = Base::Traits::SFC_Ys::access_lengths;
|
||||
using ys_to_rhs_major = typename decltype(
|
||||
typename Base::TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor;
|
||||
constexpr auto non_linear_map = [&]() {
|
||||
array<index_t, NumAccess> m_{0};
|
||||
array<index_t, Base::Traits::NumAccess> m_{0};
|
||||
index_t cumulative_len_ = 1;
|
||||
index_t cumulative_non_linear_len_ = 1;
|
||||
static_for<0, NDimY, 1>{}([&](auto i_y) {
|
||||
constexpr auto i_dim_y = number<NDimY - i_y - 1>{}; // from right to left
|
||||
static_for<0, Base::NDimY, 1>{}([&](auto i_y) {
|
||||
constexpr auto i_dim_y = number<Base::NDimY - i_y - 1>{}; // from right to left
|
||||
constexpr auto rhs_major = ys_to_rhs_major{}[i_dim_y];
|
||||
constexpr auto target_h_dim = number<rhs_major - 1>{}; // no r dim here!
|
||||
constexpr auto is_linear_dim = LinearBottomDims{}[target_h_dim];
|
||||
|
||||
array<index_t, NumAccess> current_m_{0};
|
||||
array<index_t, Base::Traits::NumAccess> current_m_{0};
|
||||
constexpr auto current_len_ = sfc_access_lens[i_dim_y];
|
||||
|
||||
// copy cumulative length as current pattern
|
||||
@@ -266,13 +144,12 @@ struct tile_window_linear
|
||||
return m_;
|
||||
}();
|
||||
|
||||
return TO_SEQUENCE(non_linear_map, NumAccess);
|
||||
return TO_SEQUENCE(non_linear_map, Base::Traits::NumAccess);
|
||||
}
|
||||
|
||||
static constexpr auto get_non_linear_access_histogram()
|
||||
{
|
||||
constexpr auto m_ = get_non_linear_access_map();
|
||||
// m_.foo();
|
||||
|
||||
constexpr auto r_ =
|
||||
typename arithmetic_sequence_gen<0, get_num_non_linear_access() + 1, 1>::type{};
|
||||
@@ -296,7 +173,7 @@ struct tile_window_linear
|
||||
using AccessPrefixSum_NonLinear = decltype(get_non_linear_access_histogram_prefix_sum());
|
||||
};
|
||||
|
||||
static constexpr index_t NumAccess = traits::NumAccess;
|
||||
static constexpr index_t NumAccess = Base::Traits::NumAccess;
|
||||
static constexpr index_t NumAccess_NonLinear = traits::NumAccess_NonLinear;
|
||||
using AccessMap_NonLinear = typename traits::AccessMap_NonLinear;
|
||||
using AccessHistogram_NonLinear = typename traits::AccessHistogram_NonLinear;
|
||||
@@ -304,31 +181,34 @@ struct tile_window_linear
|
||||
|
||||
CK_TILE_DEVICE constexpr tile_window_linear() = default;
|
||||
|
||||
CK_TILE_DEVICE constexpr tile_window_linear(const BottomTensorView& bottom_tensor_view,
|
||||
const WindowLengths& window_lengths,
|
||||
const BottomTensorIndex& window_origin,
|
||||
const TileDstr& tile_distribution)
|
||||
: bottom_tensor_view_{bottom_tensor_view},
|
||||
window_lengths_{window_lengths},
|
||||
window_origin_{window_origin},
|
||||
tile_dstr_{tile_distribution},
|
||||
CK_TILE_DEVICE constexpr tile_window_linear(
|
||||
const typename Base::BottomTensorView& bottom_tensor_view,
|
||||
const typename Base::WindowLengths& window_lengths,
|
||||
const typename Base::BottomTensorIndex& window_origin,
|
||||
const typename Base::TileDstr& tile_distribution)
|
||||
:
|
||||
cached_coords_{},
|
||||
cached_window_adaptor_coords_{},
|
||||
cached_flags_{}
|
||||
{
|
||||
this->bottom_tensor_view_ = bottom_tensor_view;
|
||||
this->window_lengths_ = window_lengths;
|
||||
this->window_origin_ = window_origin;
|
||||
this->tile_dstr_ = tile_distribution;
|
||||
auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
|
||||
tile_distribution.get_ps_ys_to_xs_adaptor(),
|
||||
container_concat(make_tuple(get_warp_id(), get_lane_id()),
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimY>{})));
|
||||
container_concat(
|
||||
make_tuple(get_warp_id(), get_lane_id()),
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimY>{})));
|
||||
|
||||
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
|
||||
typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
|
||||
window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
|
||||
|
||||
auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
|
||||
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
|
||||
this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
|
||||
|
||||
// future load/store() calls (might allocate more registers)
|
||||
using SFC_Ys = typename traits::SFC_Ys;
|
||||
using SFC_Ys = typename Base::Traits::SFC_Ys;
|
||||
|
||||
static_for<0, NumAccess, 1>{}([&](auto i_access) {
|
||||
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[i_access]>{};
|
||||
@@ -345,16 +225,16 @@ struct tile_window_linear
|
||||
// cached flag is independent from non-linear-coord
|
||||
// but need be updated in move_tile, with proper dims
|
||||
cached_flags_(i_access) = coordinate_has_valid_offset_assuming_top_index_is_valid(
|
||||
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_coord_tmp);
|
||||
this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_coord_tmp);
|
||||
|
||||
if constexpr(i_access != (NumAccess - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(i_access); // tuple of number
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord_tmp,
|
||||
bottom_tensor_thread_coord_tmp,
|
||||
idx_diff_ps_ys);
|
||||
@@ -362,54 +242,13 @@ struct tile_window_linear
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; }
|
||||
|
||||
CK_TILE_DEVICE static constexpr bool has_static_tile_distribution()
|
||||
{
|
||||
return TileDstr::is_static();
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr void
|
||||
set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
|
||||
{
|
||||
bottom_tensor_view_.buf_.p_data_ = data;
|
||||
}
|
||||
|
||||
// move thread's window adaptor coordinate and bottom tensor coordinate
|
||||
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
|
||||
template <typename ATopIndex>
|
||||
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
WindowAdaptorCoord& window_adaptor_thread_coord,
|
||||
BottomTensorCoord& bottom_tensor_thread_coord,
|
||||
const ATopIndex& idx_diff_adaptor_top) const
|
||||
{
|
||||
array<index_t, NDimBottomTensor> idx_diff_adaptor_bottom;
|
||||
|
||||
move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
|
||||
window_adaptor_thread_coord,
|
||||
idx_diff_adaptor_top,
|
||||
idx_diff_adaptor_bottom);
|
||||
|
||||
move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
|
||||
bottom_tensor_thread_coord,
|
||||
idx_diff_adaptor_bottom);
|
||||
}
|
||||
|
||||
template <index_t i_access>
|
||||
CK_TILE_DEVICE static constexpr auto get_bottom_linear_coordinate(number<i_access>)
|
||||
{
|
||||
using SFC_Ys = typename traits::SFC_Ys;
|
||||
using SFC_Ys = typename Base::Traits::SFC_Ys;
|
||||
constexpr auto idx_ys = SFC_Ys::get_index(number<i_access>{});
|
||||
using ys_to_rhs_major =
|
||||
typename decltype(TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor;
|
||||
using ys_to_rhs_major = typename decltype(
|
||||
typename Base::TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor;
|
||||
|
||||
constexpr auto modified_idx_ys = generate_tuple(
|
||||
[&](auto i_dim_y) {
|
||||
@@ -424,9 +263,9 @@ struct tile_window_linear
|
||||
return number<idx_ys[i_dim_y]>{};
|
||||
}
|
||||
},
|
||||
number<NDimY>{});
|
||||
number<Base::NDimY>{});
|
||||
|
||||
constexpr auto adaptor_ = TileDstr{}.get_ps_ys_to_xs_adaptor();
|
||||
constexpr auto adaptor_ = typename Base::TileDstr{}.get_ps_ys_to_xs_adaptor();
|
||||
constexpr auto idx_ =
|
||||
container_concat(make_tuple(number<0>{}, number<0>{}), modified_idx_ys);
|
||||
|
||||
@@ -443,8 +282,8 @@ struct tile_window_linear
|
||||
{
|
||||
// this case usually is a LDS window, everything is known at compile tile.
|
||||
// we directly use BottomTensorView transform to compute the offset, in case padding
|
||||
auto bottom_tensor_coord =
|
||||
make_tensor_coordinate(BottomTensorView{}.get_tensor_descriptor(), linear_coord);
|
||||
auto bottom_tensor_coord = make_tensor_coordinate(
|
||||
typename Base::BottomTensorView{}.get_tensor_descriptor(), linear_coord);
|
||||
return bottom_tensor_coord.get_offset();
|
||||
}
|
||||
else
|
||||
@@ -455,7 +294,7 @@ struct tile_window_linear
|
||||
// since that would introduce runtime length (so can't use linear offset)
|
||||
constexpr index_t linear_offset = [&]() {
|
||||
constexpr auto x_idx_ = linear_coord;
|
||||
constexpr auto x_len_ = TileDstr{}.get_lengths();
|
||||
constexpr auto x_len_ = typename Base::TileDstr{}.get_lengths();
|
||||
static_assert(x_idx_.size() == x_len_.size());
|
||||
constexpr index_t x_dims_ = x_idx_.size();
|
||||
index_t cu_stride_ = 1;
|
||||
@@ -471,17 +310,15 @@ struct tile_window_linear
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_num_of_access() const { return traits::NumAccess; }
|
||||
|
||||
template <index_t i_access = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load(number<i_access> = {}, bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using vector_t = typename traits::vector_t;
|
||||
using SFC_Ys = typename traits::SFC_Ys;
|
||||
using vector_t = typename Base::Traits::vector_t;
|
||||
using SFC_Ys = typename Base::Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
|
||||
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
|
||||
auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
|
||||
|
||||
auto issue = [&](auto i_access_) {
|
||||
constexpr auto IAccess = number<i_access_>{};
|
||||
@@ -494,35 +331,30 @@ struct tile_window_linear
|
||||
|
||||
// read from bottom tensor
|
||||
const vector_t vec_value =
|
||||
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
|
||||
this->get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
linear_offset,
|
||||
bottom_tensor_flag,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
#if 1
|
||||
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess);
|
||||
// write into distributed tensor
|
||||
static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) {
|
||||
static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj];
|
||||
return jj == Base::Traits::VectorDimY ? (idx_diff_ys[jj] + j)
|
||||
: idx_diff_ys[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
number<Base::NDimY>{});
|
||||
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
traits::PackedSize;
|
||||
Base::Traits::PackedSize;
|
||||
|
||||
dst_tensor.get_thread_buffer().template at<d>() =
|
||||
vec_value.template get_as<DataType>()[j / traits::PackedSize];
|
||||
vec_value
|
||||
.template get_as<typename Base::DataType>()[j / Base::Traits::PackedSize];
|
||||
});
|
||||
#else
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
|
||||
static_assert(d % traits::ScalarPerVector == 0);
|
||||
|
||||
dst_tensor.get_thread_buffer().template get_as<vector_t>()(
|
||||
number<d / traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
|
||||
#endif
|
||||
};
|
||||
|
||||
WINDOW_DISPATCH_ISSUE();
|
||||
@@ -535,10 +367,10 @@ struct tile_window_linear
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using vector_t = typename traits::vector_t;
|
||||
using SFC_Ys = typename traits::SFC_Ys;
|
||||
using vector_t = typename Base::Traits::vector_t;
|
||||
using SFC_Ys = typename Base::Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
|
||||
// auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
|
||||
|
||||
@@ -553,35 +385,29 @@ struct tile_window_linear
|
||||
|
||||
// read from bottom tensor
|
||||
const vector_t vec_value =
|
||||
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
|
||||
this->get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
linear_offset,
|
||||
bottom_tensor_flag,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
#if 1
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess);
|
||||
// write into distributed tensor
|
||||
static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) {
|
||||
static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj];
|
||||
return jj == Base::Traits::VectorDimY ? (idx_diff_ys[jj] + j)
|
||||
: idx_diff_ys[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
number<Base::NDimY>{});
|
||||
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
traits::PackedSize;
|
||||
Base::Traits::PackedSize;
|
||||
|
||||
dst_tensor.get_thread_buffer().template at<d>() =
|
||||
vec_value.template get_as<DataType>()[j / traits::PackedSize];
|
||||
vec_value
|
||||
.template get_as<typename Base::DataType>()[j / Base::Traits::PackedSize];
|
||||
});
|
||||
#else
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
|
||||
static_assert(d % traits::ScalarPerVector == 0);
|
||||
|
||||
dst_tensor.get_thread_buffer().template get_as<vector_t>()(
|
||||
number<d / traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
|
||||
#endif
|
||||
};
|
||||
|
||||
WINDOW_DISPATCH_ISSUE();
|
||||
@@ -598,15 +424,17 @@ struct tile_window_linear
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
using vector_t = typename traits::vector_t;
|
||||
using SFC_Ys = typename traits::SFC_Ys;
|
||||
using vector_t = typename Base::Traits::vector_t;
|
||||
using SFC_Ys = typename Base::Traits::SFC_Ys;
|
||||
static constexpr index_t YElementSize =
|
||||
TileDstr{}.get_ys_to_d_descriptor().get_element_space_size();
|
||||
static_assert(YElementSize % (traits::PackedSize * traits::ScalarPerVector) == 0);
|
||||
typename Base::TileDstr{}.get_ys_to_d_descriptor().get_element_space_size();
|
||||
static_assert(YElementSize % (Base::Traits::PackedSize * Base::Traits::ScalarPerVector) ==
|
||||
0);
|
||||
using vectorized_tbuf =
|
||||
array<vector_t, YElementSize / (traits::PackedSize * traits::ScalarPerVector)>;
|
||||
array<vector_t,
|
||||
YElementSize / (Base::Traits::PackedSize * Base::Traits::ScalarPerVector)>;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
|
||||
auto& dst_vec_tbuf = reinterpret_cast<vectorized_tbuf&>(dst_tensor.get_thread_buffer());
|
||||
|
||||
@@ -614,7 +442,7 @@ struct tile_window_linear
|
||||
constexpr auto IAccess = number<i_access_>{};
|
||||
constexpr auto pre_nop_ = [&]() {
|
||||
if constexpr(pre_nop && i_access_ == 0 &&
|
||||
BottomTensorView::buffer_view::get_address_space() ==
|
||||
Base::BottomTensorView::buffer_view::get_address_space() ==
|
||||
address_space_enum::global)
|
||||
return bool_constant<true>{};
|
||||
else
|
||||
@@ -630,11 +458,11 @@ struct tile_window_linear
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start) /
|
||||
traits::PackedSize;
|
||||
static_assert(d % traits::ScalarPerVector == 0);
|
||||
Base::Traits::PackedSize;
|
||||
static_assert(d % Base::Traits::ScalarPerVector == 0);
|
||||
|
||||
get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
|
||||
dst_vec_tbuf.template at<d / traits::ScalarPerVector>(),
|
||||
this->get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
|
||||
dst_vec_tbuf.template at<d / Base::Traits::ScalarPerVector>(),
|
||||
bottom_tensor_thread_coord,
|
||||
linear_offset /**/,
|
||||
bottom_tensor_flag,
|
||||
@@ -665,7 +493,7 @@ struct tile_window_linear
|
||||
// currently we only support everything is non linear dim
|
||||
// actually it's not performant if we have linear dim(e.g. fast changing)
|
||||
static_assert(NumAccess_NonLinear == NumAccess);
|
||||
static_assert(BottomTensorView::buffer_view::get_address_space() ==
|
||||
static_assert(Base::BottomTensorView::buffer_view::get_address_space() ==
|
||||
address_space_enum::global);
|
||||
|
||||
// issues * warps * lanes
|
||||
@@ -691,7 +519,7 @@ struct tile_window_linear
|
||||
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
|
||||
m0_set_with_memory(m0_init_value); // This should be wave independent
|
||||
|
||||
using vector_t = typename traits::vector_t;
|
||||
using vector_t = typename Base::Traits::vector_t;
|
||||
|
||||
LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
|
||||
|
||||
@@ -710,7 +538,7 @@ struct tile_window_linear
|
||||
auto bottom_tensor_flag = cached_flags_[IAccess]; // get this flag anyway
|
||||
|
||||
// read from bottom tensor
|
||||
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
|
||||
this->get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
|
||||
smem, bottom_tensor_thread_coord, 0, bottom_tensor_flag, pre_nop_);
|
||||
|
||||
// move thread coordinate
|
||||
@@ -735,7 +563,7 @@ struct tile_window_linear
|
||||
// currently we only support everything is non linear dim
|
||||
// actually it's not performant if we have linear dim(e.g. fast changing)
|
||||
static_assert(NumAccess_NonLinear == NumAccess);
|
||||
static_assert(BottomTensorView::buffer_view::get_address_space() ==
|
||||
static_assert(Base::BottomTensorView::buffer_view::get_address_space() ==
|
||||
address_space_enum::global);
|
||||
|
||||
#if defined(__gfx950__)
|
||||
@@ -758,7 +586,7 @@ struct tile_window_linear
|
||||
lds_coord.get_offset();
|
||||
|
||||
// read from bottom tensor
|
||||
get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
|
||||
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
|
||||
smem,
|
||||
bottom_tensor_thread_coord,
|
||||
0,
|
||||
@@ -800,7 +628,7 @@ struct tile_window_linear
|
||||
auto bottom_tensor_flag = cached_flags_[IAccess];
|
||||
|
||||
// read from bottom tensor
|
||||
get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
|
||||
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
|
||||
smem,
|
||||
bottom_tensor_thread_coord,
|
||||
0,
|
||||
@@ -817,16 +645,71 @@ struct tile_window_linear
|
||||
WINDOW_DISPATCH_ISSUE();
|
||||
}
|
||||
|
||||
template <typename Policy, index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_transpose() const
|
||||
{
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
|
||||
this->template load_transpose_linear<Policy>(
|
||||
dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
|
||||
return dst_tensor;
|
||||
}
|
||||
|
||||
template <typename Policy,
|
||||
typename DistributedTensor,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_transpose_linear(DistributedTensor& dst_tensor,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using vector_t = typename traits::vector_t;
|
||||
using SFC_Ys = typename traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
|
||||
constexpr auto group_func = Policy::group_func;
|
||||
|
||||
auto issue = [&](auto i_access_) {
|
||||
constexpr auto IAccess = number<i_access_>{};
|
||||
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
|
||||
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
|
||||
auto bottom_tensor_flag = cached_flags_[IAccess];
|
||||
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
|
||||
|
||||
// read from bottom tensor
|
||||
const vector_t vec_value =
|
||||
this->get_bottom_tensor_view().template get_transpose_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord, 0);
|
||||
// write into distributed tensor
|
||||
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
|
||||
},
|
||||
number<Base::NDimY>{});
|
||||
|
||||
constexpr index_t linear_distributed_index =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
|
||||
dst_tensor.get_thread_buffer().template at<linear_distributed_index>() =
|
||||
vec_value.template get_as<typename Base::DataType>()[j];
|
||||
});
|
||||
};
|
||||
WINDOW_DISPATCH_ISSUE();
|
||||
}
|
||||
|
||||
template <index_t i_access = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
|
||||
CK_TILE_DEVICE void store(const static_distributed_tensor<typename Base::DataType,
|
||||
typename Base::TileDstr>& dstr_tensor,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
|
||||
using vector_t = typename traits::vector_t;
|
||||
using SFC_Ys = typename traits::SFC_Ys;
|
||||
using vector_t = typename Base::Traits::vector_t;
|
||||
using SFC_Ys = typename Base::Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
auto issue = [&](auto i_access_) {
|
||||
@@ -841,22 +724,23 @@ struct tile_window_linear
|
||||
// read from distributed tensor
|
||||
vector_t vec_value;
|
||||
|
||||
static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) {
|
||||
static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
|
||||
return jj == Base::Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
number<Base::NDimY>{});
|
||||
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
traits::PackedSize;
|
||||
Base::Traits::PackedSize;
|
||||
|
||||
vec_value.template get_as<DataType>()(j / traits::PackedSize) =
|
||||
vec_value.template get_as<typename Base::DataType>()(j / Base::Traits::PackedSize) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
// write into bottom tensor
|
||||
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
|
||||
this->get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
linear_offset,
|
||||
bottom_tensor_flag,
|
||||
@@ -868,13 +752,15 @@ struct tile_window_linear
|
||||
}
|
||||
|
||||
template <index_t i_access = -1>
|
||||
CK_TILE_DEVICE void store_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
|
||||
number<i_access> = {}) const
|
||||
CK_TILE_DEVICE void
|
||||
store_raw(const static_distributed_tensor<typename Base::DataType, typename Base::TileDstr>&
|
||||
dstr_tensor,
|
||||
number<i_access> = {}) const
|
||||
{
|
||||
using vector_t = typename traits::vector_t;
|
||||
using SFC_Ys = typename traits::SFC_Ys;
|
||||
using vector_t = typename Base::Traits::vector_t;
|
||||
using SFC_Ys = typename Base::Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
static constexpr bool oob_conditional_check = true;
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
@@ -890,20 +776,21 @@ struct tile_window_linear
|
||||
|
||||
// read from distributed tensor
|
||||
vector_t vec_value;
|
||||
static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) {
|
||||
static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
|
||||
return jj == Base::Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
number<Base::NDimY>{});
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
traits::PackedSize;
|
||||
vec_value.template get_as<DataType>()(j / traits::PackedSize) =
|
||||
Base::Traits::PackedSize;
|
||||
vec_value.template get_as<typename Base::DataType>()(j / Base::Traits::PackedSize) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
// write into bottom tensor
|
||||
get_bottom_tensor_view()
|
||||
this->get_bottom_tensor_view()
|
||||
.template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
|
||||
bottom_tensor_thread_coord, linear_offset, bottom_tensor_flag, vec_value);
|
||||
};
|
||||
@@ -912,15 +799,17 @@ struct tile_window_linear
|
||||
}
|
||||
|
||||
template <index_t i_access = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void update(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
CK_TILE_DEVICE void
|
||||
update(const static_distributed_tensor<typename Base::DataType, typename Base::TileDstr>&
|
||||
dstr_tensor,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
|
||||
using vector_t = typename traits::vector_t;
|
||||
using SFC_Ys = typename traits::SFC_Ys;
|
||||
using vector_t = typename Base::Traits::vector_t;
|
||||
using SFC_Ys = typename Base::Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
auto issue = [&](auto i_access_) {
|
||||
@@ -936,22 +825,23 @@ struct tile_window_linear
|
||||
// read from distributed tensor
|
||||
vector_t vec_value;
|
||||
|
||||
static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) {
|
||||
static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
|
||||
return jj == Base::Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
number<Base::NDimY>{});
|
||||
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
traits::PackedSize;
|
||||
Base::Traits::PackedSize;
|
||||
|
||||
vec_value.template get_as<DataType>()(j / traits::PackedSize) =
|
||||
vec_value.template get_as<typename Base::DataType>()(j / Base::Traits::PackedSize) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
// write into bottom tensor
|
||||
get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
|
||||
this->get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
linear_offset,
|
||||
bottom_tensor_flag,
|
||||
@@ -963,16 +853,18 @@ struct tile_window_linear
|
||||
}
|
||||
|
||||
template <index_t i_access = -1, bool oob_conditional_check = true, bool pre_nop = false>
|
||||
CK_TILE_DEVICE void update_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {}) const
|
||||
CK_TILE_DEVICE void
|
||||
update_raw(const static_distributed_tensor<typename Base::DataType, typename Base::TileDstr>&
|
||||
dstr_tensor,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
|
||||
using vector_t = typename traits::vector_t;
|
||||
using SFC_Ys = typename traits::SFC_Ys;
|
||||
using vector_t = typename Base::Traits::vector_t;
|
||||
using SFC_Ys = typename Base::Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
auto issue = [&](auto i_access_) {
|
||||
@@ -988,22 +880,23 @@ struct tile_window_linear
|
||||
// read from distributed tensor
|
||||
vector_t vec_value;
|
||||
|
||||
static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) {
|
||||
static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
|
||||
return jj == Base::Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
number<Base::NDimY>{});
|
||||
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
traits::PackedSize;
|
||||
Base::Traits::PackedSize;
|
||||
|
||||
vec_value.template get_as<DataType>()(j / traits::PackedSize) =
|
||||
vec_value.template get_as<typename Base::DataType>()(j / Base::Traits::PackedSize) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
// write into bottom tensor
|
||||
get_bottom_tensor_view().template update_vectorized_elements_raw<vector_t>(
|
||||
this->get_bottom_tensor_view().template update_vectorized_elements_raw<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
linear_offset,
|
||||
bottom_tensor_flag,
|
||||
@@ -1014,14 +907,10 @@ struct tile_window_linear
|
||||
|
||||
WINDOW_DISPATCH_ISSUE();
|
||||
}
|
||||
|
||||
// move thread's botom tensor coordiante
|
||||
// [x0', x1', ... ] ==> [offset]
|
||||
// also move window-origin
|
||||
CK_TILE_DEVICE void move(const BottomTensorIndex& step)
|
||||
// *_extended() functions acts like a virtual function with a default implementation exisiting
|
||||
// in the base class
|
||||
CK_TILE_DEVICE void move_extended(const typename Base::BottomTensorIndex& step)
|
||||
{
|
||||
window_origin_ += step;
|
||||
|
||||
static_for<0, NumAccess, 1>{}([&](auto i_access) {
|
||||
constexpr auto IAccess = number<i_access>{};
|
||||
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[i_access]>{};
|
||||
@@ -1030,7 +919,7 @@ struct tile_window_linear
|
||||
|
||||
if constexpr(need_update_non_linear_coord)
|
||||
{
|
||||
move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
|
||||
move_tensor_coordinate(this->bottom_tensor_view_.get_tensor_descriptor(),
|
||||
cached_coords_(non_linear_id),
|
||||
step);
|
||||
}
|
||||
@@ -1039,30 +928,29 @@ struct tile_window_linear
|
||||
auto tmp_coords = cached_coords_[non_linear_id];
|
||||
constexpr auto linear_coord = get_bottom_linear_coordinate(IAccess);
|
||||
move_tensor_coordinate(
|
||||
bottom_tensor_view_.get_tensor_descriptor(), tmp_coords, linear_coord);
|
||||
this->bottom_tensor_view_.get_tensor_descriptor(), tmp_coords, linear_coord);
|
||||
|
||||
cached_flags_(IAccess) = coordinate_has_valid_offset_assuming_top_index_is_valid(
|
||||
bottom_tensor_view_.get_tensor_descriptor(), tmp_coords);
|
||||
this->bottom_tensor_view_.get_tensor_descriptor(), tmp_coords);
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
|
||||
CK_TILE_DEVICE void set_window_origin_extended(const typename Base::BottomTensorIndex&)
|
||||
{
|
||||
window_origin_ = new_window_origin;
|
||||
|
||||
auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
|
||||
TileDstr{}.get_ps_ys_to_xs_adaptor(),
|
||||
container_concat(make_tuple(get_warp_id(), get_lane_id()),
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimY>{})));
|
||||
typename Base::TileDstr{}.get_ps_ys_to_xs_adaptor(),
|
||||
container_concat(
|
||||
make_tuple(get_warp_id(), get_lane_id()),
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimY>{})));
|
||||
|
||||
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
|
||||
window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
|
||||
typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
|
||||
this->window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
|
||||
|
||||
auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
|
||||
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
|
||||
this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
|
||||
|
||||
// future load/store() calls (might allocate more registers)
|
||||
using SFC_Ys = typename traits::SFC_Ys;
|
||||
using SFC_Ys = typename Base::Traits::SFC_Ys;
|
||||
|
||||
static_for<0, NumAccess, 1>{}([&](auto i_access) {
|
||||
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[i_access]>{};
|
||||
@@ -1072,16 +960,17 @@ struct tile_window_linear
|
||||
if constexpr(need_save_non_linear_coord)
|
||||
{
|
||||
cached_coords_(non_linear_id) = bottom_tensor_thread_coord_tmp;
|
||||
cached_window_adaptor_coords_(non_linear_id) = window_adaptor_thread_coord_tmp;
|
||||
}
|
||||
|
||||
if constexpr(i_access != (NumAccess - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(i_access); // tuple of number
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord_tmp,
|
||||
bottom_tensor_thread_coord_tmp,
|
||||
idx_diff_ps_ys);
|
||||
@@ -1089,28 +978,10 @@ struct tile_window_linear
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void init_raw() { bottom_tensor_view_.init_raw(); }
|
||||
|
||||
// this is the bottom tensor view
|
||||
// [x0', x1', ...] ==> [offset]
|
||||
BottomTensorView bottom_tensor_view_;
|
||||
|
||||
//
|
||||
WindowLengths window_lengths_;
|
||||
|
||||
// origin ([x0', x1', ...]) of window on bottom tensor
|
||||
BottomTensorIndex window_origin_;
|
||||
|
||||
// Tile tensor distribution, which contains:
|
||||
// 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
|
||||
// 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
|
||||
TileDstr tile_dstr_;
|
||||
|
||||
// this contains:
|
||||
array<BottomTensorCoord, traits::NumAccess_NonLinear> cached_coords_;
|
||||
// added for gfx950
|
||||
array<WindowAdaptorCoord, traits::NumAccess_NonLinear> cached_window_adaptor_coords_;
|
||||
array<bool, traits::NumAccess> cached_flags_;
|
||||
array<typename Base::BottomTensorCoord, traits::NumAccess_NonLinear> cached_coords_;
|
||||
array<typename Base::WindowAdaptorCoord, traits::NumAccess_NonLinear> cached_window_adaptor_coords_;
|
||||
array<bool, Base::Traits::NumAccess> cached_flags_;
|
||||
};
|
||||
|
||||
#undef WINDOW_DISPATCH_ISSUE
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
#include <stdint.h>
|
||||
|
||||
@@ -127,4 +128,44 @@ struct is_any_of<CompareTo, FirstType, Rest...>
|
||||
{
|
||||
};
|
||||
|
||||
// Helper to check if a type is a specialization of a given template
|
||||
template <typename Test, template <typename...> class RefTemplate>
|
||||
struct is_specialization_of : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <template <typename...> class RefTemplate, typename... Args>
|
||||
struct is_specialization_of<RefTemplate<Args...>, RefTemplate> : std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
// Helper to get a tuple element or default type
|
||||
namespace detail {
|
||||
|
||||
template <bool IsWithinBounds, std::size_t Idx, typename Tuple, typename DefaultType>
|
||||
struct tuple_element_or_default_dispatch
|
||||
{
|
||||
using type = DefaultType;
|
||||
};
|
||||
|
||||
template <std::size_t Idx, typename Tuple, typename DefaultType>
|
||||
struct tuple_element_or_default_dispatch<true, Idx, Tuple, DefaultType>
|
||||
{
|
||||
using type = std::tuple_element_t<Idx, Tuple>;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename Tuple_, std::size_t Idx, typename DefaultType>
|
||||
struct tuple_element_or_default
|
||||
{
|
||||
using Tuple = remove_cvref_t<Tuple_>;
|
||||
static constexpr bool is_within_bounds = Idx < std::tuple_size_v<Tuple>;
|
||||
using type = typename detail::
|
||||
tuple_element_or_default_dispatch<is_within_bounds, Idx, Tuple, DefaultType>::type;
|
||||
};
|
||||
template <typename Tuple_, std::size_t Idx, typename DefaultType>
|
||||
using tuple_element_or_default_t =
|
||||
typename tuple_element_or_default<Tuple_, Idx, DefaultType>::type;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -9,7 +9,9 @@
|
||||
#include "ck_tile/host/convolution_host_tensor_descriptor_helper.hpp"
|
||||
#include "ck_tile/host/convolution_parameter.hpp"
|
||||
#include "ck_tile/host/device_memory.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "ck_tile/host/fill.hpp"
|
||||
#include "ck_tile/host/flush_icache.hpp"
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include "ck_tile/host/joinable_thread.hpp"
|
||||
@@ -25,6 +27,7 @@
|
||||
#include "ck_tile/host/reference/reference_elementwise.hpp"
|
||||
#include "ck_tile/host/reference/reference_fused_moe.hpp"
|
||||
#include "ck_tile/host/reference/reference_gemm.hpp"
|
||||
#include "ck_tile/host/reference/reference_grouped_conv_fwd.hpp"
|
||||
#include "ck_tile/host/reference/reference_im2col.hpp"
|
||||
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
|
||||
#include "ck_tile/host/reference/reference_moe_sorting.hpp"
|
||||
@@ -34,5 +37,7 @@
|
||||
#include "ck_tile/host/reference/reference_rowwise_quantization2d.hpp"
|
||||
#include "ck_tile/host/reference/reference_softmax.hpp"
|
||||
#include "ck_tile/host/reference/reference_topk.hpp"
|
||||
#include "ck_tile/host/rotating_buffers.hpp"
|
||||
#include "ck_tile/host/stream_config.hpp"
|
||||
#include "ck_tile/host/stream_utils.hpp"
|
||||
#include "ck_tile/host/timer.hpp"
|
||||
|
||||
@@ -18,16 +18,36 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/** @brief 8-bit floating point type */
|
||||
using F8 = ck_tile::fp8_t;
|
||||
/** @brief 8-bit brain floating point type */
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
/** @brief 16-bit floating point (half precision) type */
|
||||
using F16 = ck_tile::half_t;
|
||||
/** @brief 16-bit brain floating point type */
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
/** @brief 32-bit floating point (single precision) type */
|
||||
using F32 = float;
|
||||
/** @brief 8-bit signed integer type */
|
||||
using I8 = int8_t;
|
||||
/** @brief 32-bit signed integer type */
|
||||
using I32 = int32_t;
|
||||
|
||||
/**
|
||||
* @brief Calculate relative error threshold for numerical comparisons
|
||||
*
|
||||
* Calculates the relative error threshold based on the mantissa bits and characteristics
|
||||
* of the data types involved in the computation.
|
||||
*
|
||||
* @tparam ComputeDataType Type used for computation
|
||||
* @tparam OutDataType Type used for output
|
||||
* @tparam AccDataType Type used for accumulation (defaults to ComputeDataType)
|
||||
* @param number_of_accumulations Number of accumulation operations performed
|
||||
* @return Relative error threshold based on data type characteristics
|
||||
*/
|
||||
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
|
||||
double get_relative_threshold(const int number_of_accumulations = 1)
|
||||
CK_TILE_HOST double get_relative_threshold(const int number_of_accumulations = 1)
|
||||
{
|
||||
using F8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using F16 = ck_tile::half_t;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
using F32 = float;
|
||||
using I8 = int8_t;
|
||||
using I32 = int32_t;
|
||||
|
||||
static_assert(
|
||||
is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
|
||||
@@ -72,16 +92,23 @@ double get_relative_threshold(const int number_of_accumulations = 1)
|
||||
return std::max(acc_error, midway_error);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculate absolute error threshold for numerical comparisons
|
||||
*
|
||||
* Calculates the absolute error threshold based on the maximum possible value and
|
||||
* the characteristics of the data types involved in the computation.
|
||||
*
|
||||
* @tparam ComputeDataType Type used for computation
|
||||
* @tparam OutDataType Type used for output
|
||||
* @tparam AccDataType Type used for accumulation (defaults to ComputeDataType)
|
||||
* @param max_possible_num Maximum possible value in the computation
|
||||
* @param number_of_accumulations Number of accumulation operations performed
|
||||
* @return Absolute error threshold based on data type characteristics and maximum value
|
||||
*/
|
||||
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
|
||||
double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1)
|
||||
CK_TILE_HOST double get_absolute_threshold(const double max_possible_num,
|
||||
const int number_of_accumulations = 1)
|
||||
{
|
||||
using F8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using F16 = ck_tile::half_t;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
using F32 = float;
|
||||
using I8 = int8_t;
|
||||
using I32 = int32_t;
|
||||
|
||||
static_assert(
|
||||
is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
|
||||
@@ -128,6 +155,16 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
|
||||
return std::max(acc_error, midway_error);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Stream operator overload for vector output
|
||||
*
|
||||
* Provides a formatted string representation of a vector, useful for debugging and logging.
|
||||
*
|
||||
* @tparam T Type of vector elements
|
||||
* @param os Output stream
|
||||
* @param v Vector to output
|
||||
* @return Reference to the output stream
|
||||
*/
|
||||
template <typename T>
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
|
||||
{
|
||||
@@ -145,6 +182,66 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
|
||||
return os << "]";
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Check for size mismatch between output and reference ranges
|
||||
*
|
||||
* Verifies that the output and reference ranges are the same size.
|
||||
*
|
||||
* @tparam Range Type of output range
|
||||
* @tparam RefRange Type of reference range
|
||||
* @param out Output range to check
|
||||
* @param ref Reference range to check against
|
||||
* @param msg Error message to display if sizes mismatch
|
||||
* @return True if sizes mismatch, false otherwise
|
||||
*/
|
||||
template <typename Range, typename RefRange>
|
||||
CK_TILE_HOST bool check_size_mismatch(const Range& out,
|
||||
const RefRange& ref,
|
||||
const std::string& msg = "Error: Incorrect results!")
|
||||
{
|
||||
if(out.size() != ref.size())
|
||||
{
|
||||
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
|
||||
<< std::endl;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Report error statistics for numerical comparisons
|
||||
*
|
||||
* Outputs statistics about numerical comparison errors including count and maximum error.
|
||||
*
|
||||
* @param err_count Number of errors found
|
||||
* @param max_err Maximum error value encountered
|
||||
* @param total_size Total number of elements compared
|
||||
*/
|
||||
CK_TILE_HOST void report_error_stats(int err_count, double max_err, std::size_t total_size)
|
||||
{
|
||||
const float error_percent =
|
||||
static_cast<float>(err_count) / static_cast<float>(total_size) * 100.f;
|
||||
std::cerr << "max err: " << max_err;
|
||||
std::cerr << ", number of errors: " << err_count;
|
||||
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Check errors between floating point ranges using the specified tolerances.
|
||||
*
|
||||
* Compares two ranges of floating point values within specified relative and absolute tolerances.
|
||||
* This overload handles standard floating point types except half precision floating point.
|
||||
*
|
||||
* @tparam Range Type of output range
|
||||
* @tparam RefRange Type of reference range
|
||||
* @param out Output range to check
|
||||
* @param ref Reference range to check against
|
||||
* @param msg Error message to display if check fails
|
||||
* @param rtol Relative tolerance
|
||||
* @param atol Absolute tolerance
|
||||
* @param allow_infinity_ref Whether to allow infinity in reference values
|
||||
* @return True if check passes, false otherwise
|
||||
*/
|
||||
template <typename Range, typename RefRange>
|
||||
typename std::enable_if<
|
||||
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
||||
@@ -158,12 +255,9 @@ check_err(const Range& out,
|
||||
double atol = 3e-6,
|
||||
bool allow_infinity_ref = false)
|
||||
{
|
||||
if(out.size() != ref.size())
|
||||
{
|
||||
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
|
||||
<< std::endl;
|
||||
|
||||
if(check_size_mismatch(out, ref, msg))
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto is_infinity_error = [=](auto o, auto r) {
|
||||
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
|
||||
@@ -196,15 +290,27 @@ check_err(const Range& out,
|
||||
}
|
||||
if(!res)
|
||||
{
|
||||
const float error_percent =
|
||||
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
|
||||
std::cerr << "max err: " << max_err;
|
||||
std::cerr << ", number of errors: " << err_count;
|
||||
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
|
||||
report_error_stats(err_count, max_err, ref.size());
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Check errors between floating point ranges using the specified tolerances
|
||||
*
|
||||
* Compares two ranges of brain floating point values within specified relative and absolute
|
||||
* tolerances.
|
||||
*
|
||||
* @tparam Range Type of output range
|
||||
* @tparam RefRange Type of reference range
|
||||
* @param out Output range to check
|
||||
* @param ref Reference range to check against
|
||||
* @param msg Error message to display if check fails
|
||||
* @param rtol Relative tolerance
|
||||
* @param atol Absolute tolerance
|
||||
* @param allow_infinity_ref Whether to allow infinity in reference values
|
||||
* @return True if check passes, false otherwise
|
||||
*/
|
||||
template <typename Range, typename RefRange>
|
||||
typename std::enable_if<
|
||||
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
||||
@@ -217,12 +323,8 @@ check_err(const Range& out,
|
||||
double atol = 1e-3,
|
||||
bool allow_infinity_ref = false)
|
||||
{
|
||||
if(out.size() != ref.size())
|
||||
{
|
||||
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
|
||||
<< std::endl;
|
||||
if(check_size_mismatch(out, ref, msg))
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto is_infinity_error = [=](auto o, auto r) {
|
||||
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
|
||||
@@ -256,15 +358,28 @@ check_err(const Range& out,
|
||||
}
|
||||
if(!res)
|
||||
{
|
||||
const float error_percent =
|
||||
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
|
||||
std::cerr << "max err: " << max_err;
|
||||
std::cerr << ", number of errors: " << err_count;
|
||||
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
|
||||
report_error_stats(err_count, max_err, ref.size());
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Check errors between half precision floating point ranges
|
||||
*
|
||||
* Compares two ranges of half precision floating point values within specified tolerances.
|
||||
* This specialization handles the specific requirements and characteristics of half precision
|
||||
* floating point comparisons.
|
||||
*
|
||||
* @tparam Range Type of output range
|
||||
* @tparam RefRange Type of reference range
|
||||
* @param out Output range to check
|
||||
* @param ref Reference range to check against
|
||||
* @param msg Error message to display if check fails
|
||||
* @param rtol Relative tolerance
|
||||
* @param atol Absolute tolerance
|
||||
* @param allow_infinity_ref Whether to allow infinity in reference values
|
||||
* @return True if check passes, false otherwise
|
||||
*/
|
||||
template <typename Range, typename RefRange>
|
||||
typename std::enable_if<
|
||||
std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
||||
@@ -277,12 +392,8 @@ check_err(const Range& out,
|
||||
double atol = 1e-3,
|
||||
bool allow_infinity_ref = false)
|
||||
{
|
||||
if(out.size() != ref.size())
|
||||
{
|
||||
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
|
||||
<< std::endl;
|
||||
if(check_size_mismatch(out, ref, msg))
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto is_infinity_error = [=](auto o, auto r) {
|
||||
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
|
||||
@@ -315,15 +426,26 @@ check_err(const Range& out,
|
||||
}
|
||||
if(!res)
|
||||
{
|
||||
const float error_percent =
|
||||
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
|
||||
std::cerr << "max err: " << max_err;
|
||||
std::cerr << ", number of errors: " << err_count;
|
||||
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
|
||||
report_error_stats(err_count, max_err, ref.size());
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Check errors between integer ranges
|
||||
*
|
||||
* Compares two ranges of integer values with an absolute tolerance.
|
||||
* This specialization handles integer types and optionally int4_t when the
|
||||
* experimental bit int extension is enabled.
|
||||
*
|
||||
* @tparam Range Type of output range
|
||||
* @tparam RefRange Type of reference range
|
||||
* @param out Output range to check
|
||||
* @param ref Reference range to check against
|
||||
* @param msg Error message to display if check fails
|
||||
* @param atol Absolute tolerance
|
||||
* @return True if check passes, false otherwise
|
||||
*/
|
||||
template <typename Range, typename RefRange>
|
||||
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
||||
std::is_integral_v<ranges::range_value_t<Range>> &&
|
||||
@@ -339,12 +461,8 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
|
||||
double = 0,
|
||||
double atol = 0)
|
||||
{
|
||||
if(out.size() != ref.size())
|
||||
{
|
||||
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
|
||||
<< std::endl;
|
||||
if(check_size_mismatch(out, ref, msg))
|
||||
return false;
|
||||
}
|
||||
|
||||
bool res{true};
|
||||
int err_count = 0;
|
||||
@@ -370,15 +488,28 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
|
||||
}
|
||||
if(!res)
|
||||
{
|
||||
const float error_percent =
|
||||
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
|
||||
std::cerr << "max err: " << max_err;
|
||||
std::cerr << ", number of errors: " << err_count;
|
||||
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
|
||||
report_error_stats(err_count, static_cast<double>(max_err), ref.size());
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Check errors between FP8 ranges
|
||||
*
|
||||
* Specialized comparison for 8-bit floating point values that takes into account
|
||||
* the unique characteristics and limitations of FP8 arithmetic, including
|
||||
* rounding point distances and special handling of infinity values.
|
||||
*
|
||||
* @tparam Range Type of output range
|
||||
* @tparam RefRange Type of reference range
|
||||
* @param out Output range to check
|
||||
* @param ref Reference range to check against
|
||||
* @param msg Error message to display if check fails
|
||||
* @param max_rounding_point_distance Maximum allowed distance between rounding points
|
||||
* @param atol Absolute tolerance
|
||||
* @param allow_infinity_ref Whether to allow infinity in reference values
|
||||
* @return True if check passes, false otherwise
|
||||
*/
|
||||
template <typename Range, typename RefRange>
|
||||
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
||||
std::is_same_v<ranges::range_value_t<Range>, fp8_t>),
|
||||
@@ -390,12 +521,8 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
|
||||
double atol = 1e-1,
|
||||
bool allow_infinity_ref = false)
|
||||
{
|
||||
if(out.size() != ref.size())
|
||||
{
|
||||
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
|
||||
<< std::endl;
|
||||
if(check_size_mismatch(out, ref, msg))
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto is_infinity_error = [=](auto o, auto r) {
|
||||
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
|
||||
@@ -447,15 +574,27 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
|
||||
}
|
||||
if(!res)
|
||||
{
|
||||
const float error_percent =
|
||||
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
|
||||
std::cerr << "max err: " << max_err;
|
||||
std::cerr << ", number of errors: " << err_count;
|
||||
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
|
||||
report_error_stats(err_count, max_err, ref.size());
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Check errors between BF8 ranges
|
||||
*
|
||||
* Specialized comparison for 8-bit brain floating point values that considers
|
||||
* the specific numerical properties and error characteristics of the BF8 format.
|
||||
*
|
||||
* @tparam Range Type of output range
|
||||
* @tparam RefRange Type of reference range
|
||||
* @param out Output range to check
|
||||
* @param ref Reference range to check against
|
||||
* @param msg Error message to display if check fails
|
||||
* @param rtol Relative tolerance
|
||||
* @param atol Absolute tolerance
|
||||
* @param allow_infinity_ref Whether to allow infinity in reference values
|
||||
* @return True if check passes, false otherwise
|
||||
*/
|
||||
template <typename Range, typename RefRange>
|
||||
std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_value_t<RefRange>> &&
|
||||
std::is_same_v<ranges::range_value_t<Range>, bf8_t>),
|
||||
@@ -467,12 +606,8 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
|
||||
double atol = 1e-3,
|
||||
bool allow_infinity_ref = false)
|
||||
{
|
||||
if(out.size() != ref.size())
|
||||
{
|
||||
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
|
||||
<< std::endl;
|
||||
if(check_size_mismatch(out, ref, msg))
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto is_infinity_error = [=](auto o, auto r) {
|
||||
const bool either_not_finite = !std::isfinite(o) || !std::isfinite(r);
|
||||
@@ -505,11 +640,7 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
|
||||
}
|
||||
if(!res)
|
||||
{
|
||||
const float error_percent =
|
||||
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
|
||||
std::cerr << "max err: " << max_err;
|
||||
std::cerr << ", number of errors: " << err_count;
|
||||
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
|
||||
report_error_stats(err_count, max_err, ref.size());
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -20,10 +20,35 @@ __global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size)
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Container for storing data in GPU device memory
|
||||
* @brief Manages device memory allocation and host-device data transfers
|
||||
*
|
||||
* DeviceMem encapsulates GPU memory management operations using HIP runtime API.
|
||||
* It provides functionality for allocating device memory, transferring data between
|
||||
* host and device, and performing basic memory operations.
|
||||
*
|
||||
* Key features:
|
||||
* - Automatic memory allocation and deallocation
|
||||
* - Host-to-device and device-to-host data transfers
|
||||
* - Memory initialization operations
|
||||
* - Integration with HostTensor for simplified data handling
|
||||
*
|
||||
* Usage example:
|
||||
* ```
|
||||
* // Allocate device memory
|
||||
* BHostTensor<float> AHostData({256});
|
||||
* DeviceMem d_mem(BHostData.get_element_space_size_in_bytes());
|
||||
*
|
||||
* // Transfer data to device
|
||||
* HostTensor<float> AHostTensor({256});
|
||||
* d_mem.ToDevice(AHostData.data());
|
||||
*
|
||||
* // Retrieve data from device
|
||||
* HostTensor<float> ResultHostTensor({256});
|
||||
* d_mem.FromDevice(ResultHostTensor.data());
|
||||
* ```
|
||||
*/
|
||||
struct DeviceMem
|
||||
|
||||
{
|
||||
DeviceMem() : mpDeviceBuf(nullptr), mMemSize(0) {}
|
||||
DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
|
||||
@@ -163,8 +188,8 @@ struct DeviceMem
|
||||
}
|
||||
}
|
||||
|
||||
void* mpDeviceBuf;
|
||||
std::size_t mMemSize;
|
||||
void* mpDeviceBuf; ///< pointer to device buffer
|
||||
std::size_t mMemSize; ///< size of device buffer in bytes
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
56
include/ck_tile/host/device_prop.hpp
Normal file
56
include/ck_tile/host/device_prop.hpp
Normal file
@@ -0,0 +1,56 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
constexpr unsigned int fnv1a_hash(std::string_view str, unsigned int h = 2166136261u)
|
||||
{
|
||||
return str.empty() ? h
|
||||
: fnv1a_hash(str.substr(1),
|
||||
(h ^ static_cast<unsigned char>(str.front())) * 16777619u);
|
||||
}
|
||||
inline std::string get_device_name()
|
||||
{
|
||||
hipDeviceProp_t props{};
|
||||
int device;
|
||||
auto status = hipGetDevice(&device);
|
||||
if(status != hipSuccess)
|
||||
{
|
||||
return std::string();
|
||||
}
|
||||
status = hipGetDeviceProperties(&props, device);
|
||||
if(status != hipSuccess)
|
||||
{
|
||||
return std::string();
|
||||
}
|
||||
const std::string raw_name(props.gcnArchName);
|
||||
const auto name = raw_name.substr(0, raw_name.find(':')); // str.substr(0, npos) returns str.
|
||||
switch(fnv1a_hash(name))
|
||||
{
|
||||
// https://github.com/ROCm/MIOpen/blob/8498875aef84878e04c1eabefdf6571514891086/src/target_properties.cpp#L40
|
||||
case fnv1a_hash("Ellesmere"):
|
||||
case fnv1a_hash("Baffin"):
|
||||
case fnv1a_hash("RacerX"):
|
||||
case fnv1a_hash("Polaris10"):
|
||||
case fnv1a_hash("Polaris11"):
|
||||
case fnv1a_hash("Tonga"):
|
||||
case fnv1a_hash("Fiji"):
|
||||
case fnv1a_hash("gfx800"):
|
||||
case fnv1a_hash("gfx802"):
|
||||
case fnv1a_hash("gfx804"): return "gfx803";
|
||||
case fnv1a_hash("Vega10"):
|
||||
case fnv1a_hash("gfx901"): return "gfx900";
|
||||
case fnv1a_hash("10.3.0 Sienna_Cichlid 18"): return "gfx1030";
|
||||
default: return name;
|
||||
}
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
#endif
|
||||
@@ -17,13 +17,31 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/**
|
||||
* @brief Functor for filling a range with randomly generated values from a uniform distribution.
|
||||
*
|
||||
* This struct provides functionality to fill iterators or ranges with random values
|
||||
* generated from a uniform distribution. It supports both single-threaded and
|
||||
* multi-threaded operation.
|
||||
*
|
||||
* @tparam T The target type for the generated values.
|
||||
*
|
||||
* @note The multi-threaded implementation is not guaranteed to provide perfectly
|
||||
* distributed values across threads.
|
||||
*
|
||||
* @example
|
||||
*
|
||||
* // Direct usage without creating a separate variable:
|
||||
* ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_host_tensor);
|
||||
*/
|
||||
template <typename T>
|
||||
struct FillUniformDistribution
|
||||
{
|
||||
float a_{-5.f};
|
||||
float b_{5.f};
|
||||
std::optional<uint32_t> seed_{11939};
|
||||
// ATTENTION: threaded does not guarantee the distribution between thread
|
||||
// ATTENTION: Whether to use multi-threading (note: not guaranteed to be perfectly distributed
|
||||
// across threads).
|
||||
bool threaded = false;
|
||||
|
||||
template <typename ForwardIter>
|
||||
|
||||
30
include/ck_tile/host/flush_icache.hpp
Normal file
30
include/ck_tile/host/flush_icache.hpp
Normal file
@@ -0,0 +1,30 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace ck_tile {
|
||||
static __global__ void flush_cache()
|
||||
{
|
||||
asm __volatile__("s_icache_inv \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t" ::
|
||||
:);
|
||||
}
|
||||
} // namespace ck_tile
|
||||
@@ -85,6 +85,19 @@ CK_TILE_HOST auto construct_f_unpack_args(F, T args)
|
||||
return construct_f_unpack_args_impl<F>(args, std::make_index_sequence<N>{});
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Descriptor for tensors in host memory.
|
||||
*
|
||||
* HostTensorDescriptor manages the shape (dimensions) and memory layout (strides)
|
||||
* of a tensor in host memory. It provides functionality to:
|
||||
* - Store tensor dimensions and strides
|
||||
* - Calculate default strides for contiguous memory layout
|
||||
* - Convert multi-dimensional indices to linear memory offsets
|
||||
* - Query tensor metadata (dimensions, element counts, etc.)
|
||||
*
|
||||
* The class supports both automatic stride calculation for contiguous memory layout
|
||||
* and custom strides for more complex memory patterns.
|
||||
*/
|
||||
struct HostTensorDescriptor
|
||||
{
|
||||
HostTensorDescriptor() = default;
|
||||
@@ -138,12 +151,35 @@ struct HostTensorDescriptor
|
||||
}
|
||||
|
||||
std::size_t get_num_of_dimension() const { return mLens.size(); }
|
||||
/**
|
||||
* @brief Calculates the total number of elements in the tensor.
|
||||
*
|
||||
* Computes the product of all dimension lengths to determine the
|
||||
* total element count in the tensor.
|
||||
*
|
||||
* @pre The lengths array (mLens) and strides array (mStrides) must have
|
||||
* the same size.
|
||||
*
|
||||
* @return The total number of elements in the tensor.
|
||||
*/
|
||||
std::size_t get_element_size() const
|
||||
{
|
||||
assert(mLens.size() == mStrides.size());
|
||||
return std::accumulate(
|
||||
mLens.begin(), mLens.end(), std::size_t{1}, std::multiplies<std::size_t>());
|
||||
}
|
||||
/**
|
||||
* @brief Calculates the total element space required for the tensor in memory.
|
||||
*
|
||||
* This method computes the minimum size of contiguous memory needed to store
|
||||
* all elements of the tensor, taking into account the tensor's dimensions and
|
||||
* strides. The calculation is based on the formula: 1 + max((length_i - 1) * stride_i)
|
||||
* across all dimensions.
|
||||
*
|
||||
* Dimensions with length 0 are skipped in this calculation.
|
||||
*
|
||||
* @return The size of the tensor's element space (number of elements).
|
||||
*/
|
||||
std::size_t get_element_space_size() const
|
||||
{
|
||||
std::size_t space = 1;
|
||||
@@ -165,6 +201,18 @@ struct HostTensorDescriptor
|
||||
|
||||
const std::vector<std::size_t>& get_strides() const { return mStrides; }
|
||||
|
||||
/**
|
||||
* @brief Calculates the linear offset from multi-dimensional indices.
|
||||
*
|
||||
* Converts a set of N-dimensional indices into a single linear offset by computing
|
||||
* the inner product of the indices with the tensor's strides.
|
||||
*
|
||||
* @tparam Is Parameter pack of index types (should be convertible to std::size_t)
|
||||
* @param is Variable number of indices, one for each dimension of the tensor
|
||||
* @return std::size_t Linear offset corresponding to the given multi-dimensional indices
|
||||
*
|
||||
* @pre The number of indices must match the number of dimensions in the tensor
|
||||
*/
|
||||
template <typename... Is>
|
||||
std::size_t GetOffsetFromMultiIndex(Is... is) const
|
||||
{
|
||||
@@ -173,7 +221,16 @@ struct HostTensorDescriptor
|
||||
return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
|
||||
}
|
||||
|
||||
std::size_t GetOffsetFromMultiIndex(std::vector<std::size_t> iss) const
|
||||
/**
|
||||
* @brief Calculates the linear memory offset from a multi-dimensional index
|
||||
*
|
||||
* Computes the linear offset by performing an inner product between the provided
|
||||
* multi-dimensional indices and the tensor's strides.
|
||||
*
|
||||
* @param iss Vector containing the multi-dimensional indices
|
||||
* @return The calculated linear offset as a size_t
|
||||
*/
|
||||
std::size_t GetOffsetFromMultiIndex(const std::vector<std::size_t>& iss) const
|
||||
{
|
||||
return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
|
||||
}
|
||||
@@ -194,8 +251,8 @@ struct HostTensorDescriptor
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::size_t> mLens;
|
||||
std::vector<std::size_t> mStrides;
|
||||
std::vector<std::size_t> mLens; ///< Lengths of each dimension
|
||||
std::vector<std::size_t> mStrides; ///< Strides for each dimension
|
||||
};
|
||||
|
||||
template <typename New2Old>
|
||||
@@ -483,9 +540,12 @@ struct HostTensor
|
||||
return mData[GetOffsetFromMultiIndex(is...)];
|
||||
}
|
||||
|
||||
T& operator()(std::vector<std::size_t> idx) { return mData[GetOffsetFromMultiIndex(idx)]; }
|
||||
T& operator()(const std::vector<std::size_t>& idx)
|
||||
{
|
||||
return mData[GetOffsetFromMultiIndex(idx)];
|
||||
}
|
||||
|
||||
const T& operator()(std::vector<std::size_t> idx) const
|
||||
const T& operator()(const std::vector<std::size_t>& idx) const
|
||||
{
|
||||
return mData[GetOffsetFromMultiIndex(idx)];
|
||||
}
|
||||
@@ -662,6 +722,8 @@ struct HostTensor
|
||||
file << type_convert<float>(itm) << std::endl;
|
||||
else if(dtype == "int")
|
||||
file << type_convert<int>(itm) << std::endl;
|
||||
else if(dtype == "int8_t")
|
||||
file << static_cast<int>(type_convert<ck_tile::int8_t>(itm)) << std::endl;
|
||||
else
|
||||
// TODO: we didn't implement operator<< for all custom
|
||||
// data types, here fall back to float in case compile error
|
||||
@@ -681,6 +743,24 @@ struct HostTensor
|
||||
Data mData;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Creates a host tensor descriptor with specified dimensions and layout
|
||||
*
|
||||
* Constructs a HostTensorDescriptor with appropriate strides based on whether the tensor
|
||||
* layout is row-major or column-major. This is determined via the compile-time template
|
||||
* parameter `is_row_major`.
|
||||
*
|
||||
* @tparam is_row_major Compile-time flag indicating if the layout is row-major (true) or
|
||||
* column-major (false)
|
||||
*
|
||||
* @param row Number of rows in the tensor
|
||||
* @param col Number of columns in the tensor
|
||||
* @param stride Stride between adjacent rows (for row-major) or columns (for column-major)
|
||||
*
|
||||
* @return HostTensorDescriptor with shape {row, col} and strides:
|
||||
* - For row-major: {stride, 1}
|
||||
* - For column-major: {1, stride}
|
||||
*/
|
||||
template <bool is_row_major>
|
||||
auto host_tensor_descriptor(std::size_t row,
|
||||
std::size_t col,
|
||||
@@ -698,6 +778,7 @@ auto host_tensor_descriptor(std::size_t row,
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
}
|
||||
|
||||
template <bool is_row_major>
|
||||
auto get_default_stride(std::size_t row,
|
||||
std::size_t col,
|
||||
@@ -718,5 +799,4 @@ auto get_default_stride(std::size_t row,
|
||||
else
|
||||
return stride;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -11,6 +11,13 @@
|
||||
#include <cstddef>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
#define LOW_CU_PROCESSORS 80
|
||||
#define HIGH_CU_PROCESSORS 228
|
||||
#define OPTIMAL_LATENCY_LOW_CU_PROCESSORS 0.005
|
||||
#define OPTIMAL_LATENCY_HIGH_CU_PROCESSORS 0.0015
|
||||
#define OPTIMAL_LATENCY_SAFE_MARGIN 0.01
|
||||
|
||||
template <int MaxThreadPerBlock, int MinBlockPerCu, typename Kernel, typename... Args>
|
||||
#if CK_TILE_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(MaxThreadPerBlock, MinBlockPerCu)
|
||||
@@ -81,6 +88,8 @@ CK_TILE_HOST void launch_and_check(const stream_config& sc, Callables&&... calla
|
||||
template <typename... Callables>
|
||||
CK_TILE_HOST float launch_kernel(const stream_config& s, Callables&&... callables)
|
||||
{
|
||||
static_assert(sizeof...(callables) > 0, "At least one callable is required!");
|
||||
|
||||
if(!s.time_kernel_)
|
||||
{
|
||||
launch_and_check(s, std::forward<Callables>(callables)...);
|
||||
@@ -88,7 +97,7 @@ CK_TILE_HOST float launch_kernel(const stream_config& s, Callables&&... callable
|
||||
}
|
||||
|
||||
auto time_launches = [&](auto timer) {
|
||||
// warmup
|
||||
// Warmup
|
||||
for(int i = 0; i < s.cold_niters_; i++)
|
||||
{
|
||||
launch_and_check(s, std::forward<Callables>(callables)...);
|
||||
@@ -114,4 +123,53 @@ CK_TILE_HOST float launch_kernel(const stream_config& s, Callables&&... callable
|
||||
}
|
||||
}
|
||||
|
||||
template <typename PreprocessFunc, typename... Callables>
|
||||
CK_TILE_HOST float launch_kernel_preprocess(const stream_config& s,
|
||||
PreprocessFunc preprocess,
|
||||
Callables&&... callables)
|
||||
{
|
||||
static_assert(sizeof...(callables) > 0, "At least one callable is required!");
|
||||
|
||||
if(!s.time_kernel_)
|
||||
{
|
||||
preprocess();
|
||||
launch_and_check(s, std::forward<Callables>(callables)...);
|
||||
return 0;
|
||||
}
|
||||
|
||||
auto time_launches = [&](auto timer) {
|
||||
// Warmup
|
||||
for(int i = 0; i < s.cold_niters_; i++)
|
||||
{
|
||||
launch_and_check(s, std::forward<Callables>(callables)...);
|
||||
}
|
||||
|
||||
timer.start(s.stream_id_);
|
||||
for(int i = 0; i < s.nrepeat_; i++)
|
||||
{
|
||||
preprocess();
|
||||
launch_and_check(s, std::forward<Callables>(callables)...);
|
||||
}
|
||||
timer.stop(s.stream_id_);
|
||||
|
||||
hipDeviceProp_t deviceProps;
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&deviceProps, 0));
|
||||
|
||||
float preprocess_offset = (deviceProps.multiProcessorCount >= HIGH_CU_PROCESSORS)
|
||||
? OPTIMAL_LATENCY_HIGH_CU_PROCESSORS
|
||||
: (deviceProps.multiProcessorCount == LOW_CU_PROCESSORS)
|
||||
? OPTIMAL_LATENCY_LOW_CU_PROCESSORS
|
||||
: OPTIMAL_LATENCY_SAFE_MARGIN;
|
||||
return (timer.duration() - preprocess_offset * s.nrepeat_) / s.nrepeat_;
|
||||
};
|
||||
|
||||
if(s.is_gpu_timer_)
|
||||
{
|
||||
return time_launches(gpu_timer{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return time_launches(cpu_timer{});
|
||||
}
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -71,6 +71,58 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ACCElementOp,
|
||||
typename DDataType = remove_cvref_t<std::tuple_element_t<0, DsDataType>>>
|
||||
CK_TILE_HOST void
|
||||
reference_gemm_multiple_d(const HostTensor<ADataType>& a_m_k,
|
||||
const HostTensor<BDataType>& b_k_n,
|
||||
const std::array<HostTensor<DDataType>, DsDataType::size()>& ds_m_n,
|
||||
HostTensor<CDataType>& c_m_n,
|
||||
const ACCElementOp& acc_element_op = {})
|
||||
{
|
||||
const std::size_t M = a_m_k.get_length(0);
|
||||
const std::size_t N = b_k_n.get_length(1);
|
||||
const std::size_t K = a_m_k.get_length(1);
|
||||
|
||||
auto f_mk_kn_mn = [&](auto m, auto n) {
|
||||
AccDataType v_acc = 0;
|
||||
for(std::size_t k = 0; k < K; ++k)
|
||||
{
|
||||
ADataType v_a = a_m_k(m, k);
|
||||
BDataType v_b = b_k_n(k, n);
|
||||
v_acc +=
|
||||
ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
|
||||
}
|
||||
|
||||
CDataType v_c = 0;
|
||||
if constexpr(DsDataType::size() == 0)
|
||||
{
|
||||
acc_element_op(v_c, ck_tile::type_convert<float>(v_acc));
|
||||
}
|
||||
else if constexpr(DsDataType::size() == 1)
|
||||
{
|
||||
acc_element_op(v_c,
|
||||
ck_tile::type_convert<float>(v_acc),
|
||||
ck_tile::type_convert<float>(ds_m_n[0](m, n)));
|
||||
}
|
||||
else if constexpr(DsDataType::size() == 2)
|
||||
{
|
||||
acc_element_op(v_c,
|
||||
ck_tile::type_convert<float>(v_acc),
|
||||
ck_tile::type_convert<float>(ds_m_n[0](m, n)),
|
||||
ck_tile::type_convert<float>(ds_m_n[1](m, n)));
|
||||
}
|
||||
c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_c);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
|
||||
165
include/ck_tile/host/reference/reference_grouped_conv_fwd.hpp
Normal file
165
include/ck_tile/host/reference/reference_grouped_conv_fwd.hpp
Normal file
@@ -0,0 +1,165 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <thread>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType>
|
||||
CK_TILE_HOST void reference_grouped_conv_fwd(const HostTensor<InDataType>& input,
|
||||
const HostTensor<WeiDataType>& weight,
|
||||
HostTensor<OutDataType>& output,
|
||||
std::vector<ck_tile::long_index_t> conv_strides,
|
||||
std::vector<ck_tile::long_index_t> conv_dilations,
|
||||
std::vector<ck_tile::long_index_t> in_left_pads,
|
||||
std::vector<ck_tile::long_index_t>)
|
||||
{
|
||||
if(!(input.get_num_of_dimension() == NDimSpatial + 3 &&
|
||||
weight.get_num_of_dimension() == NDimSpatial + 3 &&
|
||||
output.get_num_of_dimension() == NDimSpatial + 3))
|
||||
{
|
||||
throw std::runtime_error("wrong! inconsistent dimension");
|
||||
}
|
||||
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
auto func = [&](auto g, auto n, auto k, auto wo) {
|
||||
float v_acc = 0;
|
||||
|
||||
for(std::size_t c = 0; c < weight.get_lengths()[2]; ++c)
|
||||
{
|
||||
for(std::size_t x = 0; x < weight.get_lengths()[3]; ++x)
|
||||
{
|
||||
auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[0]) +
|
||||
static_cast<ck_tile::long_index_t>(x * conv_dilations[0]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
|
||||
|
||||
if(wi >= 0 && ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[3])
|
||||
{
|
||||
InDataType v_in = input(g, n, c, wi);
|
||||
WeiDataType v_wei = weight(g, k, c, x);
|
||||
v_acc += ck_tile::type_convert<float>(v_in) *
|
||||
ck_tile::type_convert<float>(v_wei);
|
||||
}
|
||||
}
|
||||
}
|
||||
OutDataType v_acc_converted = ck_tile::type_convert<OutDataType>(v_acc);
|
||||
output(g, n, k, wo) = v_acc_converted;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func,
|
||||
output.get_lengths()[0],
|
||||
output.get_lengths()[1],
|
||||
output.get_lengths()[2],
|
||||
output.get_lengths()[3])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
auto func = [&](auto g, auto n, auto k, auto ho, auto wo) {
|
||||
float v_acc = 0;
|
||||
|
||||
for(std::size_t c = 0; c < weight.get_lengths()[2]; ++c)
|
||||
{
|
||||
for(std::size_t y = 0; y < weight.get_lengths()[3]; ++y)
|
||||
{
|
||||
auto hi = static_cast<ck_tile::long_index_t>(ho * conv_strides[0]) +
|
||||
static_cast<ck_tile::long_index_t>(y * conv_dilations[0]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
|
||||
|
||||
for(std::size_t x = 0; x < weight.get_lengths()[4]; ++x)
|
||||
{
|
||||
auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[1]) +
|
||||
static_cast<ck_tile::long_index_t>(x * conv_dilations[1]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[1]);
|
||||
|
||||
if(hi >= 0 &&
|
||||
ck_tile::type_convert<std::size_t>(hi) < input.get_lengths()[3] &&
|
||||
wi >= 0 &&
|
||||
ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[4])
|
||||
{
|
||||
InDataType v_in = input(g, n, c, hi, wi);
|
||||
WeiDataType v_wei = weight(g, k, c, y, x);
|
||||
|
||||
v_acc += ck_tile::type_convert<float>(v_in) *
|
||||
ck_tile::type_convert<float>(v_wei);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
OutDataType v_acc_converted = ck_tile::type_convert<OutDataType>(v_acc);
|
||||
output(g, n, k, ho, wo) = v_acc_converted;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func,
|
||||
output.get_lengths()[0],
|
||||
output.get_lengths()[1],
|
||||
output.get_lengths()[2],
|
||||
output.get_lengths()[3],
|
||||
output.get_lengths()[4])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
auto func = [&](auto g, auto n, auto k, auto d_o, auto ho, auto wo) {
|
||||
float v_acc = 0;
|
||||
|
||||
for(std::size_t c = 0; c < weight.get_lengths()[2]; ++c)
|
||||
{
|
||||
for(std::size_t z = 0; z < weight.get_lengths()[3]; ++z)
|
||||
{
|
||||
auto di = static_cast<ck_tile::long_index_t>(d_o * conv_strides[0]) +
|
||||
static_cast<ck_tile::long_index_t>(z * conv_dilations[0]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
|
||||
for(std::size_t y = 0; y < weight.get_lengths()[4]; ++y)
|
||||
{
|
||||
auto hi = static_cast<ck_tile::long_index_t>(ho * conv_strides[1]) +
|
||||
static_cast<ck_tile::long_index_t>(y * conv_dilations[1]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[1]);
|
||||
for(std::size_t x = 0; x < weight.get_lengths()[5]; ++x)
|
||||
{
|
||||
auto wi = static_cast<ck_tile::long_index_t>(wo * conv_strides[2]) +
|
||||
static_cast<ck_tile::long_index_t>(x * conv_dilations[2]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[2]);
|
||||
if(di >= 0 &&
|
||||
ck_tile::type_convert<std::size_t>(di) < input.get_lengths()[3] &&
|
||||
hi >= 0 &&
|
||||
ck_tile::type_convert<std::size_t>(hi) < input.get_lengths()[4] &&
|
||||
wi >= 0 &&
|
||||
ck_tile::type_convert<std::size_t>(wi) < input.get_lengths()[5])
|
||||
{
|
||||
InDataType v_in = input(g, n, c, di, hi, wi);
|
||||
WeiDataType v_wei = weight(g, k, c, z, y, x);
|
||||
|
||||
v_acc += ck_tile::type_convert<float>(v_in) *
|
||||
ck_tile::type_convert<float>(v_wei);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
OutDataType v_acc_converted = ck_tile::type_convert<OutDataType>(v_acc);
|
||||
output(g, n, k, d_o, ho, wo) = v_acc_converted;
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func,
|
||||
output.get_lengths()[0],
|
||||
output.get_lengths()[1],
|
||||
output.get_lengths()[2],
|
||||
output.get_lengths()[3],
|
||||
output.get_lengths()[4],
|
||||
output.get_lengths()[5])(std::thread::hardware_concurrency());
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Ref_Conv_fwd: number of dimensions must be between 1 and 3.");
|
||||
}
|
||||
}
|
||||
} // namespace ck_tile
|
||||
@@ -21,10 +21,12 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
|
||||
index_t& unit_cnt,
|
||||
const index_t experts,
|
||||
const index_t unit_size,
|
||||
const index_t tokens,
|
||||
bool local_expert_masking,
|
||||
bool skip_experts_with_zero_token = true)
|
||||
{
|
||||
const index_t num_token = topk_ids.mDesc.get_lengths()[0];
|
||||
// note: if tokens is smaller than topk_ids.mDesc.get_lengths()[0], indicating local_token case
|
||||
const index_t num_token = tokens; // topk_ids.mDesc.get_lengths()[0];
|
||||
const index_t topk = topk_ids.mDesc.get_lengths()[1];
|
||||
// allocate a temp buffer, and fill the value with [number_token|topk]
|
||||
std::vector<std::vector<IndexType>> expert_tokens(
|
||||
|
||||
102
include/ck_tile/host/rotating_buffers.hpp
Normal file
102
include/ck_tile/host/rotating_buffers.hpp
Normal file
@@ -0,0 +1,102 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType, typename BDataType>
|
||||
struct RotatingMemWrapper
|
||||
{
|
||||
RotatingMemWrapper() = delete;
|
||||
RotatingMemWrapper(const void* a_ptr_,
|
||||
const void* b_ptr_,
|
||||
std::size_t rotating_count_,
|
||||
std::size_t size_a_,
|
||||
std::size_t size_b_)
|
||||
: a_ptr(a_ptr_),
|
||||
b_ptr(b_ptr_),
|
||||
rotating_count(rotating_count_),
|
||||
size_a(size_a_),
|
||||
size_b(size_b_)
|
||||
{
|
||||
p_a_grids.push_back(a_ptr);
|
||||
p_b_grids.push_back(b_ptr);
|
||||
for(size_t i = 1; i < rotating_count; i++)
|
||||
{
|
||||
{
|
||||
void* pADeviceBuf;
|
||||
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&pADeviceBuf), size_a_));
|
||||
HIP_CHECK_ERROR(hipMemcpy(static_cast<void*>(pADeviceBuf),
|
||||
const_cast<void*>(p_a_grids[0]),
|
||||
size_a_,
|
||||
hipMemcpyDeviceToDevice));
|
||||
p_a_grids.push_back(pADeviceBuf);
|
||||
}
|
||||
|
||||
{
|
||||
void* pBDeviceBuf;
|
||||
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_b_));
|
||||
HIP_CHECK_ERROR(hipMemcpy(static_cast<void*>(pBDeviceBuf),
|
||||
const_cast<void*>(p_b_grids[0]),
|
||||
size_b_,
|
||||
hipMemcpyDeviceToDevice));
|
||||
p_b_grids.push_back(pBDeviceBuf);
|
||||
}
|
||||
}
|
||||
}
|
||||
void Next()
|
||||
{
|
||||
if(rotating_count > 1)
|
||||
{
|
||||
std::size_t idx = iter++ % rotating_count;
|
||||
a_ptr = p_a_grids[idx];
|
||||
b_ptr = p_b_grids[idx];
|
||||
}
|
||||
}
|
||||
void Print()
|
||||
{
|
||||
std::cout << "RotatingMemWrapper: { size_a: " << size_a << ", size_b: " << size_b
|
||||
<< ", rotating_count: " << rotating_count << "}" << std::endl;
|
||||
}
|
||||
~RotatingMemWrapper() noexcept
|
||||
{
|
||||
if(rotating_count > 1)
|
||||
{
|
||||
// restore ptr
|
||||
a_ptr = p_a_grids[0];
|
||||
b_ptr = p_b_grids[0];
|
||||
|
||||
// free device mem
|
||||
for(size_t i = 1; i < rotating_count; i++)
|
||||
{
|
||||
ck_tile::hip_check_error(hipFree(const_cast<void*>(p_a_grids[i])));
|
||||
ck_tile::hip_check_error(hipFree(const_cast<void*>(p_b_grids[i])));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
const void* a_ptr;
|
||||
const void* b_ptr;
|
||||
std::size_t iter = 0;
|
||||
std::size_t rotating_count = 1;
|
||||
std::size_t size_a = 0;
|
||||
std::size_t size_b = 0;
|
||||
std::vector<const void*> p_a_grids;
|
||||
std::vector<const void*> p_b_grids;
|
||||
};
|
||||
inline void flush_icache()
|
||||
{
|
||||
hipDeviceProp_t deviceProps;
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&deviceProps, 0));
|
||||
int32_t gpu_block3 = deviceProps.multiProcessorCount * 60;
|
||||
|
||||
ck_tile::flush_cache<<<dim3(gpu_block3), dim3(64), 0, nullptr>>>();
|
||||
HIP_CHECK_ERROR(hipGetLastError());
|
||||
}
|
||||
} // namespace ck_tile
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -30,5 +30,7 @@ struct stream_config
|
||||
int cold_niters_ = 3;
|
||||
int nrepeat_ = 10;
|
||||
bool is_gpu_timer_ = true; // keep compatible
|
||||
bool flush_cache_ = false;
|
||||
int rotating_count_ = 1;
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
45
include/ck_tile/host/stream_utils.hpp
Normal file
45
include/ck_tile/host/stream_utils.hpp
Normal file
@@ -0,0 +1,45 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime_api.h>
|
||||
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/host/stream_config.hpp"
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
static inline index_t get_available_compute_units(const stream_config& s)
|
||||
{
|
||||
constexpr static uint32_t MAX_MASK_DWORDS = 64;
|
||||
|
||||
// assume at most 64*32 = 2048 CUs
|
||||
uint32_t cu_mask[MAX_MASK_DWORDS]{};
|
||||
|
||||
auto count_set_bits = [](uint32_t dword) {
|
||||
index_t count = 0;
|
||||
while(dword != 0)
|
||||
{
|
||||
if(dword & 0x1)
|
||||
{
|
||||
count++;
|
||||
}
|
||||
dword = dword >> 1;
|
||||
}
|
||||
return count;
|
||||
};
|
||||
|
||||
HIP_CHECK_ERROR(hipExtStreamGetCUMask(s.stream_id_, MAX_MASK_DWORDS, &cu_mask[0]));
|
||||
|
||||
index_t num_cu = 0;
|
||||
for(uint32_t i = 0; i < MAX_MASK_DWORDS; i++)
|
||||
{
|
||||
num_cu += count_set_bits(cu_mask[i]);
|
||||
}
|
||||
|
||||
return num_cu;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
0
include/ck_tile/ops/common/utils.hpp
Executable file → Normal file
0
include/ck_tile/ops/common/utils.hpp
Executable file → Normal file
@@ -1479,5 +1479,6 @@ struct FastNumericArrayConverter<uint8_t, ck_tile::fp16_t, N>
|
||||
CK_TILE_DEVICE OutputArray operator()(InputArray const& Input) { return convert(Input); }
|
||||
};
|
||||
#endif
|
||||
|
||||
} // namespace element_wise
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -11,36 +11,52 @@ namespace ck_tile {
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename DsDataType_,
|
||||
typename AccDataType_,
|
||||
typename ODataType_,
|
||||
typename CLayout_,
|
||||
typename DsLayout_,
|
||||
typename ELayout_,
|
||||
typename CDElementwise_,
|
||||
index_t kBlockSize_,
|
||||
index_t kM_,
|
||||
index_t kN_,
|
||||
index_t kMWave_,
|
||||
index_t kNWave_,
|
||||
index_t kMPerXdl_,
|
||||
index_t kNPerXdl_,
|
||||
index_t kKPerXdl_,
|
||||
index_t MWave_,
|
||||
index_t NWave_,
|
||||
index_t MPerXdl_,
|
||||
index_t NPerXdl_,
|
||||
index_t KPerXdl_,
|
||||
bool isCTransposed_,
|
||||
memory_operation_enum MemoryOperation_>
|
||||
memory_operation_enum MemoryOperation_,
|
||||
index_t kNumWaveGroups_ = 1,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeC_ = 1>
|
||||
struct CShuffleEpilogueProblem
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using CLayout = remove_cvref_t<CLayout_>;
|
||||
using DsDataType = remove_cvref_t<DsDataType_>;
|
||||
using DsLayout = remove_cvref_t<DsLayout_>;
|
||||
using ELayout = remove_cvref_t<ELayout_>;
|
||||
using CDElementwise = remove_cvref_t<CDElementwise_>;
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t kMPerBlock = kM_;
|
||||
static constexpr index_t kNPerBlock = kN_;
|
||||
static constexpr index_t kMWave = kMWave_;
|
||||
static constexpr index_t kNWave = kNWave_;
|
||||
static constexpr index_t kMPerXdl = kMPerXdl_;
|
||||
static constexpr index_t kNPerXdl = kNPerXdl_;
|
||||
static constexpr index_t kKPerXdl = kKPerXdl_;
|
||||
static constexpr index_t MWave = MWave_;
|
||||
static constexpr index_t NWave = NWave_;
|
||||
static constexpr index_t MPerXdl = MPerXdl_;
|
||||
static constexpr index_t NPerXdl = NPerXdl_;
|
||||
static constexpr index_t KPerXdl = KPerXdl_;
|
||||
static constexpr index_t isCTransposed = isCTransposed_;
|
||||
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
|
||||
static constexpr bool FixedVectorSize = FixedVectorSize_;
|
||||
static constexpr index_t VectorSizeC = VectorSizeC_;
|
||||
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
static_assert(NumDTensor == DsLayout::size(),
|
||||
"The size of DsDataType and DsLayout should be the same");
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
@@ -51,34 +67,31 @@ struct CShuffleEpilogue
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using DsDataType = remove_cvref_t<typename Problem::DsDataType>;
|
||||
using DsLayout = remove_cvref_t<typename Problem::DsLayout>;
|
||||
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
|
||||
using BTypeToUse =
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
using ELayout = remove_cvref_t<typename Problem::ELayout>;
|
||||
using CDElementwise = remove_cvref_t<typename Problem::CDElementwise>;
|
||||
static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t kMPerBlock = Problem::kMPerBlock;
|
||||
static constexpr index_t kNPerBlock = Problem::kNPerBlock;
|
||||
static constexpr index_t kMWave = Problem::kMWave;
|
||||
static constexpr index_t kNWave = Problem::kNWave;
|
||||
static constexpr index_t kMPerXdl = Problem::kMPerXdl;
|
||||
static constexpr index_t kNPerXdl = Problem::kNPerXdl;
|
||||
static constexpr index_t kKPerXdl = Problem::kKPerXdl;
|
||||
static constexpr index_t MWave = Problem::MWave;
|
||||
static constexpr index_t NWave = Problem::NWave;
|
||||
static constexpr index_t MPerXdl = Problem::MPerXdl;
|
||||
static constexpr index_t NPerXdl = Problem::NPerXdl;
|
||||
static constexpr index_t KPerXdl = Problem::KPerXdl;
|
||||
static constexpr index_t isCTransposed = Problem::isCTransposed;
|
||||
static constexpr index_t kMPerIteration = kMPerXdl * kMWave;
|
||||
static constexpr index_t kNPerIteration = kNPerXdl * kNWave;
|
||||
|
||||
using WG = WarpGemmMfmaDispatcher<ADataType,
|
||||
BTypeToUse,
|
||||
AccDataType,
|
||||
kMPerXdl,
|
||||
kNPerXdl,
|
||||
kKPerXdl,
|
||||
isCTransposed>;
|
||||
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
|
||||
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
|
||||
static constexpr index_t MPerIteration = MPerXdl * MWave;
|
||||
static constexpr index_t NPerIteration = NPerXdl * NWave;
|
||||
static constexpr index_t NumDTensor = Problem::NumDTensor;
|
||||
|
||||
static_assert(NumDTensor == DsLayout::size(),
|
||||
"The size of DsDataType and DsLayout should be the same");
|
||||
/**
|
||||
* @brief Get the vector store size for C tensor.
|
||||
*
|
||||
@@ -89,96 +102,242 @@ struct CShuffleEpilogue
|
||||
*
|
||||
* @return The vector store size for C tensor.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeC()
|
||||
{
|
||||
constexpr index_t MaxVectorStoreSize = 16;
|
||||
return MaxVectorStoreSize / sizeof(ODataType);
|
||||
if constexpr(FixedVectorSize)
|
||||
{
|
||||
return VectorSizeC;
|
||||
}
|
||||
constexpr index_t max_vector_size = 16;
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return std::min(static_cast<int>(NPerIteration),
|
||||
static_cast<int>(max_vector_size / sizeof(ODataType)));
|
||||
}
|
||||
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return std::min(static_cast<int>(MPerIteration),
|
||||
static_cast<int>(max_vector_size / sizeof(ODataType)));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported ELayout!");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the vector store size for Di tensor.
|
||||
*
|
||||
* @return The vector store size for Di tensor.
|
||||
*/
|
||||
template <index_t I>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeD(number<I> index)
|
||||
{
|
||||
constexpr index_t max_vector_size = 16;
|
||||
using DiDataType = remove_cvref_t<std::tuple_element_t<index.value, DsDataType>>;
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return std::min(static_cast<int>(NPerIteration),
|
||||
static_cast<int>(max_vector_size / sizeof(DiDataType)));
|
||||
}
|
||||
else if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return std::min(static_cast<int>(MPerIteration),
|
||||
static_cast<int>(max_vector_size / sizeof(DiDataType)));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported DLayout!");
|
||||
}
|
||||
return max_vector_size / sizeof(DiDataType);
|
||||
}
|
||||
/**
|
||||
* @brief Shuffle tile configuration parameters
|
||||
*
|
||||
* @details These parameters control the number of XDL tiles processed per wave in each shuffle
|
||||
* iteration:
|
||||
* - NumMXdlPerWavePerShuffle: Number of XDL tiles in M dimension processed per wave
|
||||
* - NumNXdlPerWavePerShuffle: Number of XDL tiles in N dimension processed per wave
|
||||
*/
|
||||
static constexpr auto shuffle_tile_tuple = [] {
|
||||
constexpr index_t elem_per_thread = MPerXdl * NPerXdl / get_warp_size();
|
||||
if constexpr(elem_per_thread >= GetVectorSizeC())
|
||||
{
|
||||
return std::make_tuple(1, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t num_xdl_shuffles = GetVectorSizeC() / elem_per_thread;
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
static_assert((kMPerBlock % (MPerXdl * MWave) == 0) &&
|
||||
(kMPerBlock % num_xdl_shuffles == 0),
|
||||
"kMPerBlock must be divisible by MPerXdl*MWave and "
|
||||
"num_xdl_shuffles for CShuffleEpilogue");
|
||||
return std::make_tuple(min(num_xdl_shuffles, kMPerBlock / (MPerXdl * MWave)), 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert((kNPerBlock % (NPerXdl * NWave) == 0) &&
|
||||
(kNPerBlock % num_xdl_shuffles == 0),
|
||||
"kNPerBlock must be divisible by NPerXdl*NWave and "
|
||||
"num_xdl_shuffles for CShuffleEpilogue");
|
||||
return std::make_tuple(1, min(num_xdl_shuffles, kNPerBlock / (NPerXdl * NWave)));
|
||||
}
|
||||
}
|
||||
}();
|
||||
static constexpr index_t NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple);
|
||||
static constexpr index_t NumNXdlPerWavePerShuffle = std::get<1>(shuffle_tile_tuple);
|
||||
|
||||
static constexpr auto MNPerIterationShuffle = [] {
|
||||
constexpr index_t m_val = MPerXdl * MWave * NumMXdlPerWavePerShuffle;
|
||||
constexpr index_t n_val = NPerXdl * NWave * NumNXdlPerWavePerShuffle;
|
||||
if constexpr(kMPerBlock % m_val != 0 || kNPerBlock % n_val != 0)
|
||||
return std::make_tuple(MPerXdl * MWave, NPerXdl * NWave);
|
||||
else
|
||||
return std::make_tuple(m_val, n_val);
|
||||
}();
|
||||
static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle);
|
||||
static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle);
|
||||
|
||||
using WG = WarpGemmMfmaDispatcher<ADataType,
|
||||
BTypeToUse,
|
||||
AccDataType,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
KPerXdl,
|
||||
isCTransposed>;
|
||||
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor()
|
||||
{
|
||||
// N is contiguous dimension
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(number<kMWave * kMPerXdl>{}, number<kNWave * kNPerXdl>{}),
|
||||
make_tuple(number<kNWave * kNPerXdl>{}, number<1>{}));
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
make_tuple(number<NPerIterationShuffle>{}, number<1>{}));
|
||||
}
|
||||
// M is contiguous dimension
|
||||
else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(number<kMWave * kMPerXdl>{}, number<kNWave * kNPerXdl>{}),
|
||||
make_tuple(number<1>{}, number<kMWave * kMPerXdl>{}));
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
make_tuple(number<1>{}, number<MPerIterationShuffle>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported CLayout!");
|
||||
static_assert(false, "Unsupported ELayout!");
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeLdsDistributionEncode()
|
||||
{
|
||||
constexpr auto block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
|
||||
sequence<NumNXdlPerWavePerShuffle, NWave>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding(
|
||||
block_outer_dstr_encoding, typename CWarpDstr::DstrEncode{});
|
||||
|
||||
return block_dstr_encoding;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return kMWave * kNWave * kMPerXdl * kNPerXdl * sizeof(ODataType);
|
||||
return MPerIterationShuffle * NPerIterationShuffle * sizeof(ODataType);
|
||||
}
|
||||
|
||||
template <typename ODramWindow, typename OAccTile>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(ODramWindow& out_dram_window, const OAccTile& o_acc_tile, void* p_smem)
|
||||
template <typename ODramWindow, typename OAccTile, typename DsDramWindows>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& ds_dram_windows,
|
||||
void* p_smem)
|
||||
{
|
||||
constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
|
||||
|
||||
const index_t iMWarp = get_warp_id() / kNWave;
|
||||
const index_t iNWarp = get_warp_id() - iMWarp * kNWave;
|
||||
auto lds_tile = make_static_distributed_tensor<AccDataType>(LdsTileDistr);
|
||||
|
||||
constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
|
||||
auto o_lds_block = make_tensor_view<address_space_enum::lds>(
|
||||
static_cast<ODataType*>(p_smem), lds_block_desc);
|
||||
auto in_lds_window =
|
||||
make_tile_window(o_lds_block,
|
||||
make_tuple(number<kMPerXdl>{}, number<kNPerXdl>{}),
|
||||
{number<kMPerXdl>{} * iMWarp, number<kNPerXdl>{} * iNWarp});
|
||||
auto out_lds_window =
|
||||
make_tile_window(o_lds_block,
|
||||
make_tuple(number<kMWave * kMPerXdl>{}, number<kNWave * kNPerXdl>{}),
|
||||
{0, 0});
|
||||
|
||||
auto in_lds_window = make_tile_window(
|
||||
o_lds_block,
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
{0, 0},
|
||||
LdsTileDistr);
|
||||
|
||||
auto out_lds_window = make_tile_window(
|
||||
o_lds_block,
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
{0, 0});
|
||||
|
||||
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
|
||||
sequence<0, 1>,
|
||||
sequence<kMPerXdl * kMWave, kNPerXdl * kNWave>>;
|
||||
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
|
||||
constexpr index_t num_access = SFC::get_num_of_access();
|
||||
|
||||
static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
|
||||
"Currently, the CShuffle Epilogue only supports the Row Major Output layout");
|
||||
|
||||
using TileEncodingPattern =
|
||||
TileDistributionEncodingPattern2D<kBlockSize,
|
||||
kMPerIteration,
|
||||
kNPerIteration,
|
||||
MPerIterationShuffle,
|
||||
NPerIterationShuffle,
|
||||
GetVectorSizeC(),
|
||||
tile_distribution_pattern::thread_raked>;
|
||||
tile_distribution_pattern::thread_raked,
|
||||
Problem::kNumWaveGroups>;
|
||||
constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
|
||||
auto d_dram_windows = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
CWarpTensor c_warp_in_tensor;
|
||||
static_for<0, num_access, 1>{}([&](auto iAccess) {
|
||||
block_sync_lds();
|
||||
constexpr auto idx_y_start = SFC::get_index(iAccess);
|
||||
|
||||
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (kMPerXdl * kMWave)>{};
|
||||
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (kNPerXdl * kNWave)>{};
|
||||
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
|
||||
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
|
||||
|
||||
c_warp_in_tensor.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(
|
||||
sequence<mIter * NumMXdlPerWavePerShuffle, nIter * NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_lengths));
|
||||
|
||||
const auto c_warp_in_tensor_casted = cast_tile<ODataType>(c_warp_in_tensor);
|
||||
const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile);
|
||||
|
||||
block_sync_lds();
|
||||
store_tile(in_lds_window, c_warp_in_tensor_casted);
|
||||
store_tile(in_lds_window, c_warptile_in_tensor_casted);
|
||||
block_sync_lds();
|
||||
|
||||
const auto c_out_tensor =
|
||||
load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
|
||||
auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
|
||||
|
||||
const auto ds_tensor = generate_tuple(
|
||||
[&](auto idx) { return load_tile(d_dram_windows[idx]); }, number<NumDTensor>{});
|
||||
|
||||
const auto c_ds_tiles = concat_tuple_of_reference(
|
||||
tie(c_out_tensor, c_out_tensor),
|
||||
generate_tie(
|
||||
[&](auto idx) -> const auto& { return ds_tensor[idx]; }, number<NumDTensor>{}));
|
||||
|
||||
tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles);
|
||||
|
||||
if constexpr(MemoryOperation == memory_operation_enum::set)
|
||||
{
|
||||
@@ -191,7 +350,13 @@ struct CShuffleEpilogue
|
||||
if constexpr(iAccess != num_access - 1)
|
||||
{
|
||||
constexpr auto step = SFC::get_forward_step(iAccess);
|
||||
|
||||
move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto idx) {
|
||||
move_tile_window(d_dram_windows[idx],
|
||||
{step.at(number<0>{}), step.at(number<1>{})});
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -75,7 +75,6 @@ struct Default2DEpilogue
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr)
|
||||
{
|
||||
|
||||
// TODO: this is ugly
|
||||
if constexpr(UseRawStore && (kPadM || kPadN))
|
||||
{
|
||||
@@ -101,6 +100,15 @@ struct Default2DEpilogue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ODramWindowTmp, typename OAccTile, typename DsDramWindows>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& /* unused */,
|
||||
void* = nullptr)
|
||||
{
|
||||
return operator()<ODramWindowTmp, OAccTile>(o_dram_window_tmp, o_acc_tile);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
@@ -114,6 +122,8 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
|
||||
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
|
||||
using BTypeToUse =
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
|
||||
using DsDataType = ck_tile::tuple<>;
|
||||
using DsLayout = ck_tile::tuple<>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
static constexpr index_t kMPerXdl = Problem::kMPerXdl;
|
||||
static constexpr index_t kNPerXdl = Problem::kNPerXdl;
|
||||
@@ -149,7 +159,9 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
|
||||
else
|
||||
{
|
||||
// In this case each thread has just a single item in Ndim
|
||||
return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
|
||||
return (WG::WarpGemmAttribute::Impl::kCNLane *
|
||||
WG::WarpGemmAttribute::Impl::kBNBlock) /
|
||||
WG::kN;
|
||||
}
|
||||
}
|
||||
// M is contiguous dimension
|
||||
@@ -158,7 +170,9 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
|
||||
if constexpr(isCTransposed)
|
||||
{
|
||||
// In this case each thread has just a single item in Mdim
|
||||
return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
|
||||
return (WG::WarpGemmAttribute::Impl::kCNLane *
|
||||
WG::WarpGemmAttribute::Impl::kAMBlock) /
|
||||
WG::kN;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -177,6 +191,8 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
|
||||
static_assert(false, "Unsupported CLayout!");
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeD() { return 1; }
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -95,7 +95,7 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
|
||||
// constexpr index_t Block_M = Problem::BlockShape::Block_M0;
|
||||
// constexpr index_t Block_K = Problem::BlockShape::Block_K0;
|
||||
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
|
||||
constexpr index_t warpSize = ck_tile::get_warp_size();
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
// constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
|
||||
|
||||
constexpr index_t KPack_ = 8; // GetSmemKPack_A<Problem>(); // LDS
|
||||
@@ -104,11 +104,11 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
|
||||
|
||||
static_assert(Block_K % KVector == 0);
|
||||
constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
|
||||
if constexpr(LanesPerK >= warpSize)
|
||||
if constexpr(LanesPerK >= WarpSize)
|
||||
{
|
||||
// need multiple waves to load K
|
||||
static_assert(LanesPerK % warpSize == 0);
|
||||
constexpr index_t wavesPerK = LanesPerK / warpSize;
|
||||
static_assert(LanesPerK % WarpSize == 0);
|
||||
constexpr index_t wavesPerK = LanesPerK / WarpSize;
|
||||
if constexpr(wavesPerK > NumWarps)
|
||||
{
|
||||
// TODO: need multiple issues along K to load all data
|
||||
@@ -121,11 +121,11 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
|
||||
make_tuple(number<NumIssues>{}, // m0
|
||||
number<wavesPerM>{}, // m1
|
||||
number<wavesPerK>{}, // k0
|
||||
number<warpSize>{}, // k1
|
||||
number<WarpSize>{}, // k1
|
||||
number<KVector>{}), // k2
|
||||
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
|
||||
number<wavesPerK*(warpSize * KVector + KPad)>{}, // m1
|
||||
number<warpSize * KVector + KPad>{}, // k0
|
||||
make_tuple(number<NumWarps*(WarpSize * KVector + KPad)>{}, // m0
|
||||
number<wavesPerK*(WarpSize * KVector + KPad)>{}, // m1
|
||||
number<WarpSize * KVector + KPad>{}, // k0
|
||||
number<KVector>{}, // k1
|
||||
number<1>{}), // k2
|
||||
number<KVector>{}, // lds store vector(actually no explicit store)
|
||||
@@ -136,7 +136,7 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<NumIssues>{}),
|
||||
make_merge_transform(make_tuple(number<wavesPerM>{}, number<wavesPerK>{})),
|
||||
make_merge_transform(make_tuple(number<warpSize>{}, number<KVector>{}))),
|
||||
make_merge_transform(make_tuple(number<WarpSize>{}, number<KVector>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
|
||||
@@ -146,8 +146,8 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
|
||||
else
|
||||
{
|
||||
// lanes within a wave load different M but same K
|
||||
static_assert(warpSize % LanesPerK == 0);
|
||||
constexpr index_t LaneGroups = warpSize / LanesPerK; // along m
|
||||
static_assert(WarpSize % LanesPerK == 0);
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK; // along m
|
||||
constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps);
|
||||
|
||||
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
@@ -156,9 +156,9 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
|
||||
number<NumWarps>{}, // m2
|
||||
number<LanesPerK>{}, // k0
|
||||
number<KVector>{}), // k1
|
||||
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
|
||||
make_tuple(number<NumWarps*(WarpSize * KVector + KPad)>{}, // m0
|
||||
number<Block_K>{}, // m1
|
||||
number<warpSize * KVector + KPad>{}, // m2
|
||||
number<WarpSize * KVector + KPad>{}, // m2
|
||||
number<KVector>{}, // k0
|
||||
number<1>{}), // k1
|
||||
number<KVector>{}, // lds store vector(actually no explicit store)
|
||||
|
||||
@@ -447,6 +447,7 @@ struct FlatmmKernel
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_flat_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
const auto& c_block_tile = FlatmmPipeline{}.template operator()(
|
||||
a_block_window, b_flat_block_window, num_loop, smem_ptr);
|
||||
|
||||
@@ -454,7 +455,7 @@ struct FlatmmKernel
|
||||
auto& c_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, smem_ptr);
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(FlatmmKernelArgs kargs) const
|
||||
|
||||
@@ -75,7 +75,6 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
#if defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8) || defined(USING_MFMA_32x32x16)
|
||||
constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
@@ -91,64 +90,68 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
constexpr index_t A_Buffer_Load_Inst_Num = kMPerBlock * kKPerBlock / BlockSize / KPerLoad;
|
||||
constexpr index_t A_LDS_Read_Inst_Num = MIterPerWarp * KIterPerWarp;
|
||||
constexpr index_t B_Buffer_Load_Inst_Num = NIterPerWarp * KIterPerWarp;
|
||||
#endif
|
||||
#if defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8)
|
||||
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_LDS_Read_Inst_Num - A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA
|
||||
});
|
||||
static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
|
||||
});
|
||||
|
||||
#elif defined(USING_MFMA_32x32x16)
|
||||
static_for<0,
|
||||
A_LDS_Read_Inst_Num / 2 - A_Buffer_Load_Inst_Num - B_Buffer_Load_Inst_Num,
|
||||
1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_LDS_Read_Inst_Num / 2, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
|
||||
#endif
|
||||
if constexpr(WG::kM == 16 && WG::kN == 16)
|
||||
{
|
||||
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_LDS_Read_Inst_Num - A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA
|
||||
});
|
||||
static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
|
||||
});
|
||||
}
|
||||
else if constexpr(WG::kM == 32 && WG::kN == 32 &&
|
||||
(A_LDS_Read_Inst_Num / 2 >
|
||||
A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num))
|
||||
{
|
||||
static_for<0,
|
||||
A_LDS_Read_Inst_Num / 2 - A_Buffer_Load_Inst_Num - B_Buffer_Load_Inst_Num,
|
||||
1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_LDS_Read_Inst_Num / 2, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp, typename AElementFunction>
|
||||
|
||||
@@ -19,55 +19,61 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
#if defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8)
|
||||
/*reduce transform layers,compare with old ck*/
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>();
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
|
||||
if constexpr(MPerXdl == 16 && NPerXdl == 16)
|
||||
{
|
||||
/*reduce transform layers,compare with old ck*/
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>();
|
||||
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_xor_transform(make_tuple(number<MPerBlock>{}, number<KPerBlock / KPack>{})),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(make_pass_through_transform(number<MPerBlock>{}),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock / KPack>{})),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
return a_lds_block_desc;
|
||||
#elif defined(USING_MFMA_32x32x16)
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t kKPack = GetSmemPackA<Problem>();
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(make_pass_through_transform(number<MPerBlock>{}),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kMPerBlock>{}, number<kKPack>{}),
|
||||
make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t kKPack = GetSmemPackA<Problem>();
|
||||
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kMPerBlock),
|
||||
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kMPerBlock>{}, number<kKPack>{}),
|
||||
make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
return a_lds_block_desc;
|
||||
#endif
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kMPerBlock),
|
||||
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
/*xor*/
|
||||
#if 0
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
@@ -112,7 +118,7 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return a_lds_block_desc;
|
||||
return a_lds_block_desc;
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -138,6 +144,21 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
return Problem::VectorLoadSize / sizeof(typename Problem::ADataType);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
if constexpr(TileShape::WarpTile::at(TileShape::idxN) == 32)
|
||||
{
|
||||
return TileShape::WarpTile::at(TileShape::idxK) / 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(TileShape::WarpTile::at(TileShape::idxN) == 16);
|
||||
return TileShape::WarpTile::at(TileShape::idxK) / 4;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
@@ -189,7 +210,7 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = 16 / sizeof(ADataType);
|
||||
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
// coalesce reading for each blocks
|
||||
@@ -232,19 +253,17 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
|
||||
{
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
|
||||
constexpr index_t KBPerLoad =
|
||||
Problem::VectorLoadSize / sizeof(BDataType); // dwordx4 load B elem cnt
|
||||
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
|
||||
constexpr index_t KBPerLoad = GetKBPerLoad<Problem>();
|
||||
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
|
||||
constexpr index_t KWavePerBlk = 1;
|
||||
constexpr index_t KRepeat = 1;
|
||||
static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
|
||||
|
||||
constexpr index_t NBPerLoad = 1;
|
||||
constexpr index_t NThdPerWave = 1;
|
||||
|
||||
@@ -15,7 +15,36 @@
|
||||
#define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_ATTENTION_USE_SOFTSIGN_ASM
|
||||
#define CK_TILE_ATTENTION_USE_SOFTSIGN_ASM 0
|
||||
#endif
|
||||
|
||||
namespace ck_tile {
|
||||
namespace internal {
|
||||
__device__ inline float
|
||||
exp2_soft_sign_impl(float softmax_scale, float logits, float logits_soft_cap_rcp)
|
||||
{
|
||||
#if(defined(__gfx90a__) || defined(__gfx94__)) && \
|
||||
(CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN && \
|
||||
CK_TILE_ATTENTION_USE_SOFTSIGN_ASM)
|
||||
/// NOTICE: Make sure softmax_scale is stored in SGPR
|
||||
float result, numerator, denominator;
|
||||
asm volatile(
|
||||
"v_mul_f32_e32 %[denominator], %[logits], %[logits_soft_cap_rcp]\n"
|
||||
"v_add_f32_e64 %[denominator], |%[denominator]|, 1.0\n"
|
||||
"v_rcp_f32_e32 %[denominator], %[denominator]\n"
|
||||
"v_mul_f32_e32 %[numerator], %[softmax_scale], %[logits]\n"
|
||||
"v_mul_f32_e32 %[result], %[numerator], %[denominator]"
|
||||
: [numerator] "=&v"(numerator), [denominator] "=&v"(denominator), [result] "=v"(result)
|
||||
: [softmax_scale] "s"(softmax_scale),
|
||||
[logits] "v"(logits),
|
||||
[logits_soft_cap_rcp] "v"(logits_soft_cap_rcp));
|
||||
return result;
|
||||
#else
|
||||
return softmax_scale * logits * rcp<float>(1.f + abs(logits * logits_soft_cap_rcp));
|
||||
#endif
|
||||
}
|
||||
} // namespace internal
|
||||
|
||||
template <typename ImplMask>
|
||||
struct StandardAttentionParams
|
||||
@@ -169,8 +198,8 @@ struct LogitsSoftCap
|
||||
return params.logits_soft_cap *
|
||||
tanh_fast<float>(type_convert<float>(logits) * params.logits_soft_cap_rcp);
|
||||
#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
|
||||
return params.sm_scale * type_convert<float>(logits) *
|
||||
rcp<float>(1.f + abs(type_convert<float>(logits) * params.logits_soft_cap_rcp));
|
||||
return internal::exp2_soft_sign_impl(
|
||||
params.sm_scale, type_convert<float>(logits), params.logits_soft_cap_rcp);
|
||||
#endif
|
||||
}
|
||||
else
|
||||
@@ -239,9 +268,8 @@ struct ComposedAttention
|
||||
return params.logits_soft_cap *
|
||||
tanh_fast<float>(type_convert<float>(logits) * params.logits_soft_cap_rcp);
|
||||
#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN
|
||||
return params.sm_scale * type_convert<float>(logits) *
|
||||
rcp<float>(1.f +
|
||||
abs(type_convert<float>(logits) * params.logits_soft_cap_rcp));
|
||||
return internal::exp2_soft_sign_impl(
|
||||
params.sm_scale, type_convert<float>(logits), params.logits_soft_cap_rcp);
|
||||
#endif
|
||||
}
|
||||
else
|
||||
|
||||
@@ -316,56 +316,56 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
|
||||
template <bool Cond = !kIsGroupMode>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargsImpl(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
int32_t num_total_pages,
|
||||
const void* kv_indptr,
|
||||
const void* kv_page_indices,
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
int32_t num_total_pages,
|
||||
const void* kv_indptr,
|
||||
const void* kv_page_indices,
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
const void* kv_last_page_lens,
|
||||
ck_tile::index_t page_block_size,
|
||||
#endif
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
float scale_o,
|
||||
float logits_soft_cap,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_bias,
|
||||
ck_tile::index_t stride_randval,
|
||||
ck_tile::index_t stride_o,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t batch_stride_bias,
|
||||
ck_tile::index_t batch_stride_randval,
|
||||
ck_tile::index_t batch_stride_lse,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset)
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
float scale_o,
|
||||
float logits_soft_cap,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_bias,
|
||||
ck_tile::index_t stride_randval,
|
||||
ck_tile::index_t stride_o,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t batch_stride_bias,
|
||||
ck_tile::index_t batch_stride_randval,
|
||||
ck_tile::index_t batch_stride_lse,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -468,51 +468,51 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
|
||||
template <bool Cond = kIsGroupMode>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargsImpl(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
const void* seqstart_q_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
int32_t num_total_pages,
|
||||
const void* kv_indptr,
|
||||
const void* kv_page_indices,
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* rand_val_ptr,
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
const void* seqstart_q_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
ck_tile::index_t nhead_ratio_qk,
|
||||
int32_t num_total_pages,
|
||||
const void* kv_indptr,
|
||||
const void* kv_page_indices,
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
const void* kv_last_page_lens,
|
||||
ck_tile::index_t page_block_size,
|
||||
#endif
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
float scale_o,
|
||||
float logits_soft_cap,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_bias,
|
||||
ck_tile::index_t stride_randval,
|
||||
ck_tile::index_t stride_o,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset)
|
||||
float scale_s,
|
||||
float scale_p,
|
||||
float scale_o,
|
||||
float logits_soft_cap,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_v,
|
||||
ck_tile::index_t stride_bias,
|
||||
ck_tile::index_t stride_randval,
|
||||
ck_tile::index_t stride_o,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_randval,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -651,8 +651,15 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
};
|
||||
|
||||
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
|
||||
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
// assume that num_tile_n1 is always 1
|
||||
return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -672,7 +679,15 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
|
||||
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
|
||||
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
// assume that num_tile_n1 is always 1
|
||||
return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -53,6 +53,8 @@ struct FmhaFwdKernel
|
||||
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
|
||||
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
|
||||
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
|
||||
static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
|
||||
|
||||
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
||||
static constexpr bool kHasMask = FmhaMask::IsMasking;
|
||||
@@ -257,6 +259,11 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t batch_stride_randval = 0;
|
||||
};
|
||||
|
||||
struct FmhaFwdSkipMinSeqlenQKargs
|
||||
{
|
||||
ck_tile::index_t min_seqlen_q = 0;
|
||||
};
|
||||
|
||||
struct FmhaFwdBatchModeKargs
|
||||
: FmhaFwdCommonKargs,
|
||||
std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
|
||||
@@ -287,7 +294,8 @@ struct FmhaFwdKernel
|
||||
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
|
||||
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
|
||||
std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>,
|
||||
std::conditional_t<kSkipMinSeqlenQ, FmhaFwdSkipMinSeqlenQKargs, FmhaFwdEmptyKargs<6>>
|
||||
{
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seqstart_k_ptr;
|
||||
@@ -664,6 +672,7 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
ck_tile::index_t min_seqlen_q,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
@@ -698,6 +707,7 @@ struct FmhaFwdKernel
|
||||
{}, // placeholder for fp8_static_quant args
|
||||
{}, // placeholder for dropout
|
||||
{}, // placeholder for logits_soft_cap
|
||||
{}, // placeholder for min_seqlen_q
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
|
||||
@@ -753,6 +763,10 @@ struct FmhaFwdKernel
|
||||
{
|
||||
kargs.init_logits_soft_cap(logits_soft_cap);
|
||||
}
|
||||
if constexpr(kSkipMinSeqlenQ)
|
||||
{
|
||||
kargs.min_seqlen_q = min_seqlen_q;
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
@@ -794,6 +808,7 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
ck_tile::index_t min_seqlen_q,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
|
||||
@@ -833,6 +848,7 @@ struct FmhaFwdKernel
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
mask_type,
|
||||
min_seqlen_q,
|
||||
p_drop,
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
|
||||
@@ -875,6 +891,7 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
ck_tile::index_t min_seqlen_q,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<const void*, const void*>& drop_seed_offset)
|
||||
@@ -914,6 +931,7 @@ struct FmhaFwdKernel
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
mask_type,
|
||||
min_seqlen_q,
|
||||
p_drop,
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
|
||||
@@ -969,7 +987,15 @@ struct FmhaFwdKernel
|
||||
|
||||
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
|
||||
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
// assume that num_tile_n1 is always 1
|
||||
return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -989,7 +1015,15 @@ struct FmhaFwdKernel
|
||||
|
||||
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
|
||||
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
// assume that num_tile_n1 is always 1
|
||||
return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1053,6 +1087,14 @@ struct FmhaFwdKernel
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
|
||||
if constexpr(kSkipMinSeqlenQ)
|
||||
{
|
||||
if(kargs.seqlen_q <= kargs.min_seqlen_q)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// # of required blocks is different in each groups, terminate unnecessary blocks
|
||||
// earlier
|
||||
if(kargs.seqlen_q <= i_m0)
|
||||
|
||||
@@ -561,7 +561,16 @@ struct FmhaFwdSplitKVKernel
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
// assume that num_tile_n1 is always 1
|
||||
return ck_tile::make_tuple(
|
||||
(gridDim.x / kargs.num_splits) - 1 - i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
|
||||
}
|
||||
}
|
||||
|
||||
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
|
||||
@@ -6,8 +6,9 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -498,6 +499,16 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
#else
|
||||
for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
|
||||
{
|
||||
#if(defined(__gfx90a__) || defined(__gfx94__)) && \
|
||||
(CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN && \
|
||||
CK_TILE_ATTENTION_USE_SOFTSIGN_ASM)
|
||||
// Avoid data hazard if v_mfma is followed by inline asm consumer
|
||||
// instructions. In this case, compiler won't add s_nop for us
|
||||
if(i == s_acc.thread_buf_.size() / 2)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
#endif
|
||||
apply_logits_transform(s_acc.thread_buf_[i]);
|
||||
}
|
||||
#endif
|
||||
@@ -691,12 +702,19 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
}
|
||||
|
||||
const auto p = [&]() {
|
||||
#if CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN
|
||||
// For fp32 to fp16,
|
||||
// impl::cast_tile_pk_fp16_fp32 would cause precision issue,
|
||||
// since it uses __builtin_amdgcn_cvt_pkrtz, which is round to zero.
|
||||
return cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
#else
|
||||
if constexpr(std::is_same_v<PDataType, fp16_t>)
|
||||
return impl::cast_tile_pk_fp16_fp32<PDataType>(
|
||||
tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
else
|
||||
return cast_tile<PDataType>(
|
||||
tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
#endif
|
||||
}();
|
||||
|
||||
// STAGE 3, KV gemm
|
||||
|
||||
@@ -53,6 +53,7 @@ struct BlockFmhaPipelineProblem
|
||||
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
|
||||
static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap;
|
||||
static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ;
|
||||
static constexpr auto BiasEnum = Traits::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Traits::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Traits::kHasDropout;
|
||||
|
||||
@@ -653,12 +653,19 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
}
|
||||
|
||||
const auto p = [&]() {
|
||||
#if CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN
|
||||
// For fp32 to fp16,
|
||||
// impl::cast_tile_pk_fp16_fp32 would cause precision issue,
|
||||
// since it uses __builtin_amdgcn_cvt_pkrtz, which is round to zero.
|
||||
return cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
#else
|
||||
if constexpr(std::is_same_v<PDataType, fp16_t>)
|
||||
return impl::cast_tile_pk_fp16_fp32<PDataType>(
|
||||
tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
else
|
||||
return cast_tile<PDataType>(
|
||||
tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
#endif
|
||||
}();
|
||||
|
||||
// STAGE 3, KV gemm
|
||||
|
||||
@@ -28,6 +28,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
using AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant>;
|
||||
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
|
||||
@@ -54,6 +55,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
@@ -127,7 +129,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
typename SAccElementFunction,
|
||||
typename PComputeElementFunction,
|
||||
typename OAccElementFunction,
|
||||
typename PositionEncoding>
|
||||
typename PositionEncoding,
|
||||
typename AttentionVariantParams,
|
||||
typename BlockIndices>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
|
||||
const QElementFunction& q_element_func,
|
||||
@@ -146,6 +150,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
const AttentionVariant& /* unused */,
|
||||
const AttentionVariantParams& /* unused */,
|
||||
const BlockIndices& /* unused */,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
{
|
||||
@@ -890,7 +897,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename PositionEncoding>
|
||||
typename PositionEncoding,
|
||||
typename AttentionVariantParams,
|
||||
typename BlockIndices>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
@@ -901,6 +910,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
const AttentionVariant& variant,
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
{
|
||||
@@ -921,6 +933,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
mask,
|
||||
position_encoding,
|
||||
scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -448,19 +448,19 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
|
||||
constexpr index_t warpSize = ck_tile::get_warp_size();
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
|
||||
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
|
||||
constexpr index_t kPad = KPack;
|
||||
|
||||
static_assert(warpSize * KVector >= kKPerBlock &&
|
||||
warpSize * KVector % kKPerBlock == 0);
|
||||
static_assert(WarpSize * KVector >= kKPerBlock &&
|
||||
WarpSize * KVector % kKPerBlock == 0);
|
||||
constexpr index_t LanesPerK = kKPerBlock / KVector;
|
||||
constexpr index_t LaneGroups = warpSize / LanesPerK;
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK;
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
|
||||
return NumIssues * NumWarps * (warpSize * KVector + kPad);
|
||||
return NumIssues * NumWarps * (WarpSize * KVector + kPad);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -516,18 +516,18 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
|
||||
constexpr index_t warpSize = ck_tile::get_warp_size();
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
|
||||
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
|
||||
constexpr index_t kPad =
|
||||
KPack; // for async-copy, this pad is between warps. Optimize this for lds_read speed
|
||||
|
||||
static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0);
|
||||
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
|
||||
constexpr index_t LanesPerK =
|
||||
kKPerBlock / KVector; // how many lane (within a wave) to load K
|
||||
constexpr index_t LaneGroups =
|
||||
warpSize /
|
||||
WarpSize /
|
||||
LanesPerK; // how many groups (within a wave), they may load different N, but same K
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
|
||||
@@ -538,9 +538,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
number<NumWarps>{}, // n2
|
||||
number<LanesPerK>{}, // k0
|
||||
number<KVector>{}), // k1
|
||||
make_tuple(number<NumWarps*(warpSize * KVector + kPad)>{},
|
||||
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
|
||||
number<kKPerBlock>{},
|
||||
number<warpSize * KVector + kPad>{},
|
||||
number<WarpSize * KVector + kPad>{},
|
||||
number<KVector>{},
|
||||
number<1>{}),
|
||||
number<IBuf * GetSingleSmemElementSpaceSize<Problem>()>{},
|
||||
@@ -569,18 +569,18 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
|
||||
constexpr index_t warpSize = ck_tile::get_warp_size();
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
|
||||
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
|
||||
constexpr index_t kPad = KPack; // for async-copy, this pad is between warps
|
||||
|
||||
static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0);
|
||||
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
|
||||
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
|
||||
constexpr index_t LaneGroups = warpSize / LanesPerK; // within a wave
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
|
||||
// constexpr index_t SingleKSize = NumIssues * NumWarps * (warpSize * KVector + kPad);
|
||||
// constexpr index_t SingleKSize = NumIssues * NumWarps * (WarpSize * KVector + kPad);
|
||||
// constexpr index_t SingleVSize =
|
||||
// MakeVLdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
constexpr index_t BufferSize =
|
||||
@@ -594,8 +594,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
number<kKPerBlock / KPack>{}, // k0
|
||||
number<KPack>{}), // k1
|
||||
make_tuple(number<BufferSize>{},
|
||||
number<NumWarps*(warpSize * KVector + kPad)>{},
|
||||
number<warpSize * KVector + kPad>{},
|
||||
number<NumWarps*(WarpSize * KVector + kPad)>{},
|
||||
number<WarpSize * KVector + kPad>{},
|
||||
number<kKPerBlock>{},
|
||||
number<KPack>{},
|
||||
number<1>{}),
|
||||
@@ -746,13 +746,13 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
|
||||
constexpr index_t warpSize = ck_tile::get_warp_size();
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
|
||||
|
||||
static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0);
|
||||
static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
|
||||
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
|
||||
constexpr index_t LaneGroups = warpSize / LanesPerK; // within a wave
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
|
||||
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
|
||||
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
|
||||
|
||||
@@ -787,12 +787,29 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
constexpr index_t N0 = kNPerBlock / N1; // P
|
||||
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
|
||||
constexpr index_t K3 = total_pixels / N1;
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
static_assert(kKPack % K3 == 0);
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
constexpr index_t K3 = total_pixels / N1;
|
||||
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
|
||||
if constexpr(get_warp_size() % (K2 * N0) == 0)
|
||||
if constexpr(total_pixels % N1 != 0 || kKPack % K3 != 0) // if K2 or K3 is not divisible
|
||||
{
|
||||
constexpr index_t kNPack = 32;
|
||||
static_assert(kNPerBlock % kNPack == 0);
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N2 = 2;
|
||||
constexpr index_t N1_m = kNPack / N2;
|
||||
constexpr index_t N0_m = kNPerBlock / kNPack;
|
||||
constexpr index_t K1 = get_warp_size() / N1_m;
|
||||
constexpr index_t K2_m = kKPerBlock / K1;
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<N0_m, N1_m, N2>, sequence<K0, K1, K2_m>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>, // K0, K1 N0
|
||||
tuple<sequence<0>, sequence<1, 1>>,
|
||||
sequence<1, 2, 1>, // N0 K2 N2
|
||||
sequence<0, 2, 2>>{});
|
||||
}
|
||||
else if constexpr(get_warp_size() % (kKPack / K3 * N0) == 0)
|
||||
{
|
||||
constexpr index_t K1 = get_warp_size() / (K2 * N0);
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size();
|
||||
@@ -860,12 +877,28 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
|
||||
constexpr index_t K3 = total_pixels / N1;
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
static_assert(kKPack % K3 == 0);
|
||||
constexpr index_t K3 = total_pixels / N1;
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
|
||||
if constexpr(get_warp_size() % (K2 * N0) == 0)
|
||||
if constexpr(total_pixels % N1 != 0 || kKPack % K3 != 0) // if K2 or K3 is not divisible
|
||||
{
|
||||
constexpr index_t kNPack = 32;
|
||||
static_assert(kNPerBlock % kNPack == 0);
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N2 = 2;
|
||||
constexpr index_t N1_m = kNPack / N2;
|
||||
constexpr index_t N0_m = kNPerBlock / kNPack;
|
||||
constexpr index_t K1 = get_warp_size() / N1_m;
|
||||
constexpr index_t K2_m = kKPerBlock / K1;
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0_m, N1_m, N2>, sequence<K0, K1, K2_m>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>, // K0, K1 N0
|
||||
tuple<sequence<0>, sequence<1, 1>>,
|
||||
sequence<1, 1, 2>, // N0 K2 <-> N2
|
||||
sequence<0, 2, 2>>{});
|
||||
}
|
||||
else if constexpr(get_warp_size() % (kKPack / K3 * N0) == 0)
|
||||
{
|
||||
constexpr index_t K1 = get_warp_size() / (K2 * N0);
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size();
|
||||
|
||||
@@ -19,7 +19,8 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kStoreLSE_,
|
||||
bool kHasDropout_,
|
||||
bool kDoFp8StaticQuant_,
|
||||
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
|
||||
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
|
||||
bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */>
|
||||
struct TileFmhaTraits
|
||||
{
|
||||
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
|
||||
@@ -33,6 +34,7 @@ struct TileFmhaTraits
|
||||
static constexpr bool kHasDropout = kHasDropout_;
|
||||
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
|
||||
};
|
||||
|
||||
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
|
||||
@@ -101,7 +101,7 @@ struct FusedMoeGemmShape
|
||||
static constexpr index_t Repeat_N1 = Block_N1 / ThreadPerBlock_N1;
|
||||
static constexpr index_t Repeat_K1 = Block_K1 / ThreadPerBlock_K1;
|
||||
|
||||
static constexpr index_t BlockSize = warpSize * NumWarps;
|
||||
static constexpr index_t BlockSize = get_warp_size() * NumWarps;
|
||||
|
||||
// some assert
|
||||
static_assert(Block_M0 == Block_M1);
|
||||
|
||||
@@ -127,37 +127,21 @@ CK_TILE_HOST constexpr auto moe_sorting_get_smem_row_col(int tokens_, int num_ex
|
||||
constexpr index_t cumsum_bufs = 2; // 1 for cumsum, 1 for cnt
|
||||
// at lease 2 lines, one for sub_token unroll, one for cumsum
|
||||
// should be enough
|
||||
if ((total_ / target_occupancy_) < ((cumsum_bufs+sub_unroll) * smem_cols)) {
|
||||
if ((total_ / 1) < ((cumsum_bufs+sub_unroll) * smem_cols))
|
||||
throw std::runtime_error("too many num_experts, can't allocate smem");
|
||||
target_occupancy_ = 1;
|
||||
}
|
||||
|
||||
int r = total_ / target_occupancy_ / smem_cols;
|
||||
|
||||
// Note: at lease allocate cumsum_bufs + sub_unroll as num-row. Otherwise, fallback to mp kernel
|
||||
if(r < (cumsum_bufs + sub_unroll))
|
||||
return cumsum_bufs;
|
||||
|
||||
// round to sub_unroll multipl
|
||||
int r_for_sub_token = r - cumsum_bufs;
|
||||
r_for_sub_token = min(r_for_sub_token, tokens_);
|
||||
r_for_sub_token = (r_for_sub_token + sub_unroll - 1) / sub_unroll * sub_unroll;
|
||||
r_for_sub_token = max(r_for_sub_token, 1);
|
||||
r_for_sub_token = r_for_sub_token / sub_unroll * sub_unroll;
|
||||
int r_token_min = (tokens_ + sub_unroll - 1) / sub_unroll * sub_unroll;
|
||||
r_for_sub_token = min(r_for_sub_token, r_token_min);
|
||||
|
||||
if(r_for_sub_token > 1)
|
||||
{
|
||||
int r_unroll_ = r_for_sub_token / sub_unroll;
|
||||
|
||||
|
||||
// round to 1x/2x/4x/8x number of sub_unroll
|
||||
int clz_ = __builtin_clz(r_unroll_); // 0b1:31 0b2:30, 0b3:30, 0b4:29
|
||||
int mask_ = (1 << (31 - clz_)) - 1;
|
||||
|
||||
|
||||
mask_ = mask_ > 0b111 ? 0b111 : mask_; //clamp to 8x at most
|
||||
mask_ = ~mask_;
|
||||
|
||||
r_for_sub_token = (r_unroll_ & mask_) * sub_unroll;
|
||||
}
|
||||
|
||||
// final check
|
||||
if( (r_for_sub_token + cumsum_bufs * smem_cols * target_occupancy_ ) >= total_ ) {
|
||||
// final check, but usually should not happen
|
||||
if( ((r_for_sub_token + cumsum_bufs) * smem_cols * target_occupancy_ ) > total_ ) {
|
||||
throw std::runtime_error("can't run this kernel, request LDS over size");
|
||||
}
|
||||
|
||||
@@ -167,6 +151,7 @@ CK_TILE_HOST constexpr auto moe_sorting_get_smem_row_col(int tokens_, int num_ex
|
||||
return ck_tile::make_tuple(smem_rows, smem_cols);
|
||||
}
|
||||
|
||||
// if return 0 or negative, means LDS is not enough
|
||||
CK_TILE_HOST index_t moe_sorting_get_sub_token(int tokens_, int num_experts_)
|
||||
{
|
||||
auto [r_, c_] = moe_sorting_get_smem_row_col(tokens_, num_experts_);
|
||||
@@ -180,7 +165,8 @@ struct MoeSortingHostArgs
|
||||
const void* p_topk_ids; // [token, topk]
|
||||
const void* p_weights; // [token, topk]
|
||||
|
||||
const void* p_local_expert_mask;
|
||||
const void* p_local_expert_mask; // [experts]
|
||||
const void* p_local_tokens; // [1] if not nullptr, tokens read from here
|
||||
|
||||
void* p_sorted_token_ids;
|
||||
void* p_sorted_weights;
|
||||
@@ -192,7 +178,7 @@ struct MoeSortingHostArgs
|
||||
void* p_ws; // size is moe_sorting_get_workspace_size()
|
||||
// if return zero, then could be nullptr
|
||||
// must be cleard before use
|
||||
index_t tokens;
|
||||
index_t tokens; // if p_local_tokens is not nullptr, this indicate the max possible tokens used for ws/LDS calculation
|
||||
index_t unit_size; // this is the M_a of fused-moe kernel
|
||||
index_t num_experts;
|
||||
index_t topk;
|
||||
@@ -216,6 +202,7 @@ struct MoeSortingKernel
|
||||
const void* p_topk_ids;
|
||||
const void* p_weights;
|
||||
const void* p_local_expert_mask;
|
||||
const void* p_local_tokens; // [1] if not nullptr, tokens read from here
|
||||
void* p_sorted_token_ids;
|
||||
void* p_sorted_weights;
|
||||
void* p_sorted_expert_ids;
|
||||
@@ -268,6 +255,7 @@ struct MoeSortingKernel
|
||||
k.p_topk_ids = h.p_topk_ids;
|
||||
k.p_weights = h.p_weights;
|
||||
k.p_local_expert_mask = h.p_local_expert_mask;
|
||||
k.p_local_tokens = h.p_local_tokens;
|
||||
k.p_sorted_token_ids = h.p_sorted_token_ids;
|
||||
k.p_sorted_weights = h.p_sorted_weights;
|
||||
k.p_sorted_expert_ids = h.p_sorted_expert_ids;
|
||||
@@ -278,9 +266,13 @@ struct MoeSortingKernel
|
||||
k.moe_buf_bytes = h.moe_buf_bytes;
|
||||
|
||||
const auto blocks = BlockSize(h);
|
||||
// NOTE: tokens could from p_local_tokens, so here this variable is useless
|
||||
// hence moe_align_block_size_kernel() will not behavior properly if we have dynamic tokens
|
||||
// (indeed we can deprecate moe_align_block_size_kernel)
|
||||
k.tokens_per_thread = integer_divide_ceil(h.tokens * h.topk, blocks.x);
|
||||
k.unit_size_mdiv = mdiv{static_cast<uint32_t>(h.unit_size)};
|
||||
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
|
||||
// NOTE: tokens could from p_local_tokens, so here the LDS will be bigger than expected (but works)
|
||||
k.smem_rows = [&](){
|
||||
auto [r_, c_] = moe_sorting_get_smem_row_col(h.tokens, h.num_experts);
|
||||
(void) c_;
|
||||
@@ -396,7 +388,7 @@ struct MoeSortingKernel
|
||||
}
|
||||
|
||||
// reduce single pixel within a wave
|
||||
template <typename T, typename F, index_t wave_size_ = warpSize>
|
||||
template <typename T, typename F, index_t wave_size_ = get_warp_size()>
|
||||
__device__ static constexpr T wave_reduce(T local, F reduce_f, number<wave_size_> = {})
|
||||
{
|
||||
// constexpr int wave_size = 64;
|
||||
@@ -633,7 +625,7 @@ struct MoeSortingKernel
|
||||
{
|
||||
const index_t prefill_token = topk_mdiv.div(numel);
|
||||
// TODO: only support expert-tile like 8, 16, 32
|
||||
static constexpr index_t experts_per_wave = warpSize / Problem::ExpertTile;
|
||||
static constexpr index_t experts_per_wave = get_warp_size() / Problem::ExpertTile;
|
||||
{
|
||||
index_t eid = tid / experts_per_wave;
|
||||
index_t expert_offset = cumsum[eid] +
|
||||
@@ -701,7 +693,7 @@ struct MoeSortingKernel
|
||||
void* smem) const
|
||||
{
|
||||
const index_t tid = static_cast<index_t>(threadIdx.x);
|
||||
const index_t wid = __builtin_amdgcn_readfirstlane(tid / warpSize);
|
||||
const index_t wid = __builtin_amdgcn_readfirstlane(tid / get_warp_size());
|
||||
const index_t lid = __lane_id();
|
||||
constexpr index_t block_size = 256; // blockDim.x;
|
||||
const index_t sub_tokens = smem_rows - 2; // sub_tokens_mdiv.divisor;
|
||||
@@ -806,7 +798,7 @@ struct MoeSortingKernel
|
||||
// NOTE: under this block can never use __syncthreads!
|
||||
int i_e_ = 0;
|
||||
int local_cumsum_ = 0;
|
||||
for(; i_e_ < num_experts; i_e_ += warpSize)
|
||||
for(; i_e_ < num_experts; i_e_ += get_warp_size())
|
||||
{
|
||||
int pre_cumsum_ = smem_cumsum(lid == 0 ? i_e_ : 0);
|
||||
int local_cnt = smem_cumsum(i_e_ + lid + 1);
|
||||
@@ -851,7 +843,7 @@ struct MoeSortingKernel
|
||||
// cumsum padded in case local cumsum is zero, but
|
||||
// pre_sumsum has value, which will result int
|
||||
// zero local cumsum(but we want at least padded)
|
||||
wave_cumsum<int, warpSize>(local_cumsum_);
|
||||
wave_cumsum<int, get_warp_size()>(local_cumsum_);
|
||||
|
||||
if((i_e_ + lid) < num_experts)
|
||||
smem_cumsum(i_e_ + lid + 1) = local_cumsum_;
|
||||
@@ -859,7 +851,7 @@ struct MoeSortingKernel
|
||||
if constexpr(Problem::LocalExpertMasking)
|
||||
{
|
||||
local_masking += pre_cumsum_masking;
|
||||
wave_cumsum<int, warpSize>(local_masking);
|
||||
wave_cumsum<int, get_warp_size()>(local_masking);
|
||||
if((i_e_ + lid) < num_experts)
|
||||
smem_cumdup(i_e_ + lid + 1) = local_masking;
|
||||
}
|
||||
@@ -869,7 +861,7 @@ struct MoeSortingKernel
|
||||
// than 0(which is not we want)
|
||||
__builtin_amdgcn_s_waitcnt(0xc07f);
|
||||
}
|
||||
if((lid + i_e_ - warpSize) == (num_experts - 1))
|
||||
if((lid + i_e_ - get_warp_size()) == (num_experts - 1))
|
||||
{
|
||||
*p_total_tokens_post_pad = local_cumsum_;
|
||||
}
|
||||
@@ -1024,8 +1016,19 @@ struct MoeSortingKernel
|
||||
}
|
||||
const size_t numel = kargs.tokens * kargs.topk_mdiv.divisor;
|
||||
extern __shared__ char smem[];
|
||||
|
||||
#if MOE_SORTING_USE_EX_KERNEL
|
||||
(void)numel;
|
||||
index_t tokens_ = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.tokens;
|
||||
}
|
||||
}();
|
||||
return moe_align_block_size_kernel_ex(
|
||||
static_cast<const IndexType*>(kargs.p_topk_ids),
|
||||
static_cast<const WeightType*>(kargs.p_weights),
|
||||
@@ -1035,7 +1038,7 @@ struct MoeSortingKernel
|
||||
static_cast<IndexType*>(kargs.p_sorted_expert_ids),
|
||||
static_cast<IndexType*>(kargs.p_total_tokens_post_pad),
|
||||
kargs.num_experts,
|
||||
kargs.tokens,
|
||||
tokens_,
|
||||
kargs.unit_size_mdiv,
|
||||
kargs.topk_mdiv,
|
||||
kargs.expert_mdiv,
|
||||
@@ -1106,7 +1109,7 @@ CK_TILE_HOST_DEVICE index_t moe_sorting_mp_sem_smem_size()
|
||||
return chunk * sizeof(index_t);
|
||||
};
|
||||
|
||||
template <typename T, typename F, index_t wave_size_ = warpSize>
|
||||
template <typename T, typename F, index_t wave_size_ = get_warp_size()>
|
||||
CK_TILE_DEVICE constexpr T moe_sorting_wave_reduce(T local, F reduce_f, number<wave_size_> = {})
|
||||
{
|
||||
// constexpr int wave_size = 64;
|
||||
@@ -1260,6 +1263,7 @@ CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, long_index_t buf_by
|
||||
|
||||
} // namespace impl
|
||||
|
||||
// TODO: tokens could be from
|
||||
// prefer to run mp kernel if is not oneshot
|
||||
CK_TILE_HOST bool moe_sorting_is_oneshot(int tokens_, int num_experts_)
|
||||
{
|
||||
@@ -1366,9 +1370,11 @@ struct MoeSortingMultiPhaseKernel_P0
|
||||
|
||||
struct Kargs
|
||||
{
|
||||
const void* p_topk_ids; // [tokens, topk]
|
||||
void* p_expert_mesh; // [expert, tokens]
|
||||
index_t tokens;
|
||||
const void* p_topk_ids; // [tokens, topk]
|
||||
const void* p_local_tokens; // [1], if not nullptr, use this as actual tokens
|
||||
void* p_expert_mesh; // [expert, tokens]
|
||||
index_t tokens; // if p_local_tokens is not nullptr, this indicate the max possible tokens
|
||||
// used for ws/LDS calculation
|
||||
index_t mesh_stride; // mesh_stride for p_expert_mesh
|
||||
mdiv topk_mdiv;
|
||||
};
|
||||
@@ -1388,11 +1394,12 @@ struct MoeSortingMultiPhaseKernel_P0
|
||||
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
|
||||
{
|
||||
Kargs k;
|
||||
k.p_topk_ids = h.p_topk_ids;
|
||||
k.p_expert_mesh = h.p_ws;
|
||||
k.tokens = h.tokens;
|
||||
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
|
||||
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
|
||||
k.p_topk_ids = h.p_topk_ids;
|
||||
k.p_local_tokens = h.p_local_tokens;
|
||||
k.p_expert_mesh = h.p_ws;
|
||||
k.tokens = h.tokens;
|
||||
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
|
||||
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
|
||||
return k;
|
||||
}
|
||||
|
||||
@@ -1409,7 +1416,26 @@ struct MoeSortingMultiPhaseKernel_P0
|
||||
|
||||
const topk_id_t* p_topk_ids = reinterpret_cast<const topk_id_t*>(kargs.p_topk_ids);
|
||||
MeshType* p_expert_mesh = reinterpret_cast<MeshType*>(kargs.p_expert_mesh);
|
||||
index_t total_elem = kargs.tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile;
|
||||
index_t tokens = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.tokens;
|
||||
}
|
||||
}();
|
||||
index_t rounded_tokens = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return (tokens + Problem::SubTokenTile - 1) / Problem::SubTokenTile *
|
||||
Problem::SubTokenTile;
|
||||
}
|
||||
else
|
||||
return tokens;
|
||||
}();
|
||||
index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile;
|
||||
|
||||
#pragma unroll Problem::SubTokenTile
|
||||
for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elem;
|
||||
@@ -1420,8 +1446,15 @@ struct MoeSortingMultiPhaseKernel_P0
|
||||
IndexType eid = x[j.value]; // ext_vector_type must use int to []
|
||||
uint32_t curr_token_id, curr_topk_id;
|
||||
kargs.topk_mdiv.divmod(i * Problem::SubTokenTile + j, curr_token_id, curr_topk_id);
|
||||
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] =
|
||||
(curr_topk_id + 1) & 0xffff;
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
if(static_cast<index_t>(curr_token_id) < tokens)
|
||||
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] =
|
||||
(curr_topk_id + 1) & 0xffff;
|
||||
}
|
||||
else
|
||||
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] =
|
||||
(curr_topk_id + 1) & 0xffff;
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1471,7 +1504,7 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
// in byte
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
|
||||
{
|
||||
return BLOCK_SIZE / warpSize * sizeof(IndexType);
|
||||
return BLOCK_SIZE / get_warp_size() * sizeof(IndexType);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
@@ -1513,8 +1546,8 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum);
|
||||
}
|
||||
|
||||
index_t lane_id = threadIdx.x % warpSize;
|
||||
index_t wave_id = threadIdx.x / warpSize;
|
||||
index_t lane_id = threadIdx.x % get_warp_size();
|
||||
index_t wave_id = threadIdx.x / get_warp_size();
|
||||
|
||||
// reduce cross wave
|
||||
IndexType* s = reinterpret_cast<IndexType*>(smem);
|
||||
@@ -1527,7 +1560,7 @@ struct MoeSortingMultiPhaseKernel_P1
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
index_t c = 0;
|
||||
for(auto i = 0; i < (BLOCK_SIZE / warpSize); i++)
|
||||
for(auto i = 0; i < (BLOCK_SIZE / get_warp_size()); i++)
|
||||
{
|
||||
c += s[i];
|
||||
}
|
||||
@@ -1557,6 +1590,7 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
{
|
||||
const void* p_topk_ids; // [tokens, topk]
|
||||
const void* p_local_expert_mask; // [expert]
|
||||
const void* p_local_tokens; // [1]
|
||||
void* p_expert_mesh; // [expert, tokens]
|
||||
void* p_expert_cumsum; // [expert + 1]
|
||||
void* p_expert_sem; // [1]
|
||||
@@ -1584,6 +1618,7 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
Kargs k;
|
||||
k.p_topk_ids = h.p_topk_ids;
|
||||
k.p_local_expert_mask = h.p_local_expert_mask;
|
||||
k.p_local_tokens = h.p_local_tokens;
|
||||
k.p_expert_mesh = h.p_ws;
|
||||
k.p_expert_cumsum = reinterpret_cast<void*>(
|
||||
reinterpret_cast<char*>(h.p_ws) +
|
||||
@@ -1595,8 +1630,17 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
k.tokens = h.tokens;
|
||||
k.num_experts = h.num_experts;
|
||||
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
|
||||
k.wg_count = WGCounts(h);
|
||||
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
|
||||
k.wg_count = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return GridSize(h);
|
||||
}
|
||||
else
|
||||
{
|
||||
return WGCounts(h);
|
||||
}
|
||||
}();
|
||||
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
|
||||
return k;
|
||||
}
|
||||
|
||||
@@ -1616,19 +1660,52 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
// in byte
|
||||
CK_TILE_HOST static constexpr auto GetSmemSize()
|
||||
{
|
||||
return BLOCK_SIZE / warpSize * sizeof(IndexType);
|
||||
return BLOCK_SIZE / get_warp_size() * sizeof(IndexType);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
workgroup_barrier wb{reinterpret_cast<uint32_t*>(kargs.p_expert_sem)};
|
||||
index_t tokens = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.tokens;
|
||||
}
|
||||
}();
|
||||
index_t rounded_tokens = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return (tokens + Problem::SubTokenTile - 1) / Problem::SubTokenTile *
|
||||
Problem::SubTokenTile;
|
||||
}
|
||||
else
|
||||
return tokens;
|
||||
}();
|
||||
index_t wg_count = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
index_t total_elem = rounded_tokens * kargs.topk / Problem::SubTokenTile;
|
||||
index_t elem_cnt = (total_elem + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
|
||||
// no more than grid_size
|
||||
return min(elem_cnt, kargs.wg_count);
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.wg_count;
|
||||
}
|
||||
}();
|
||||
|
||||
{
|
||||
using topk_id_t = ext_vector_t<IndexType, Problem::SubTokenTile>;
|
||||
|
||||
const topk_id_t* p_topk_ids = reinterpret_cast<const topk_id_t*>(kargs.p_topk_ids);
|
||||
IndexType* p_expert_mesh = reinterpret_cast<IndexType*>(kargs.p_expert_mesh);
|
||||
index_t total_elem = kargs.tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile;
|
||||
index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile;
|
||||
|
||||
#pragma unroll Problem::SubTokenTile
|
||||
for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elem;
|
||||
@@ -1640,10 +1717,19 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
uint32_t curr_token_id, curr_topk_id;
|
||||
kargs.topk_mdiv.divmod(
|
||||
i * Problem::SubTokenTile + j, curr_token_id, curr_topk_id);
|
||||
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] = curr_topk_id + 1;
|
||||
// p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] = curr_topk_id + 1;
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
if(static_cast<index_t>(curr_token_id) < tokens)
|
||||
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] =
|
||||
(curr_topk_id + 1) & 0xffff;
|
||||
}
|
||||
else
|
||||
p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] =
|
||||
(curr_topk_id + 1) & 0xffff;
|
||||
});
|
||||
}
|
||||
if(static_cast<index_t>(blockIdx.x) < kargs.wg_count)
|
||||
if(static_cast<index_t>(blockIdx.x) < wg_count)
|
||||
{
|
||||
wb.inc();
|
||||
}
|
||||
@@ -1657,7 +1743,7 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
if(eid >= kargs.num_experts)
|
||||
return;
|
||||
|
||||
wb.wait_lt(kargs.wg_count);
|
||||
wb.wait_lt(wg_count);
|
||||
|
||||
for(; eid < kargs.num_experts; eid += gridDim.x)
|
||||
{
|
||||
@@ -1700,8 +1786,8 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum);
|
||||
}
|
||||
|
||||
index_t lane_id = threadIdx.x % warpSize;
|
||||
index_t wave_id = threadIdx.x / warpSize;
|
||||
index_t lane_id = threadIdx.x % get_warp_size();
|
||||
index_t wave_id = threadIdx.x / get_warp_size();
|
||||
|
||||
// reduce cross wave
|
||||
IndexType* s = reinterpret_cast<IndexType*>(smem);
|
||||
@@ -1715,7 +1801,7 @@ struct MoeSortingMultiPhaseKernel_P01
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
index_t c = 0;
|
||||
for(auto i = 0; i < (BLOCK_SIZE / warpSize); i++)
|
||||
for(auto i = 0; i < (BLOCK_SIZE / get_warp_size()); i++)
|
||||
{
|
||||
c += s[i];
|
||||
}
|
||||
@@ -1746,6 +1832,7 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
struct Kargs
|
||||
{
|
||||
const void* p_local_expert_mask; // [expert]
|
||||
const void* p_local_tokens; // [1]
|
||||
void* p_expert_mesh; // [expert, tokens]
|
||||
void* p_expert_cumsum; // [expert + 1]
|
||||
void* p_total_tokens_post_pad; // [1]
|
||||
@@ -1762,6 +1849,7 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
{
|
||||
Kargs k;
|
||||
k.p_local_expert_mask = h.p_local_expert_mask;
|
||||
k.p_local_tokens = h.p_local_tokens;
|
||||
k.p_expert_cumsum = reinterpret_cast<void*>(
|
||||
reinterpret_cast<char*>(h.p_ws) +
|
||||
impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk));
|
||||
@@ -1792,7 +1880,7 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
|
||||
{
|
||||
// return 2 * BLOCK_SIZE * sizeof(IndexType);
|
||||
return (4 + 2 * BLOCK_SIZE / warpSize) * sizeof(IndexType);
|
||||
return (4 + 2 * BLOCK_SIZE / get_warp_size()) * sizeof(IndexType);
|
||||
}
|
||||
|
||||
// reduce single pixel within a wave
|
||||
@@ -1817,8 +1905,8 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
IndexType* p_sorted_expert_ids = reinterpret_cast<IndexType*>(kargs.p_sorted_expert_ids);
|
||||
|
||||
const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
index_t wave_id = threadIdx.x / warpSize;
|
||||
index_t lane_id = threadIdx.x % warpSize;
|
||||
index_t wave_id = threadIdx.x / get_warp_size();
|
||||
index_t lane_id = threadIdx.x % get_warp_size();
|
||||
|
||||
IndexType prev_cumsum_a = 0;
|
||||
IndexType prev_cumsum_b = 0;
|
||||
@@ -1863,22 +1951,22 @@ struct MoeSortingMultiPhaseKernel_P2
|
||||
IndexType cumsum_b = b_;
|
||||
|
||||
// Note: we first cumsum local round, then add previous cumsum
|
||||
impl::moe_sorting_wave_cumsum<IndexType, warpSize>(cumsum_a);
|
||||
impl::moe_sorting_wave_cumsum<IndexType, warpSize>(cumsum_b);
|
||||
impl::moe_sorting_wave_cumsum<IndexType, get_warp_size()>(cumsum_a);
|
||||
impl::moe_sorting_wave_cumsum<IndexType, get_warp_size()>(cumsum_b);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == warpSize - 1)
|
||||
if(lane_id == get_warp_size() - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum_a;
|
||||
s[4 + wave_id + BLOCK_SIZE / warpSize] = cumsum_b;
|
||||
s[4 + wave_id] = cumsum_a;
|
||||
s[4 + wave_id + BLOCK_SIZE / get_warp_size()] = cumsum_b;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev_a = s[4 + i_w];
|
||||
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / warpSize];
|
||||
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / get_warp_size()];
|
||||
prev_a = wave_id > i_w ? prev_a : 0; // mask out
|
||||
prev_b = wave_id > i_w ? prev_b : 0; // mask out
|
||||
cumsum_a += prev_a;
|
||||
@@ -1957,6 +2045,7 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
{
|
||||
const void* p_weights;
|
||||
const void* p_local_expert_mask;
|
||||
const void* p_local_tokens;
|
||||
void* p_sorted_token_ids;
|
||||
void* p_sorted_weights;
|
||||
void* p_expert_mesh; // [token, expert]
|
||||
@@ -1973,6 +2062,7 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
Kargs k;
|
||||
k.p_weights = h.p_weights;
|
||||
k.p_local_expert_mask = h.p_local_expert_mask;
|
||||
k.p_local_tokens = h.p_local_tokens;
|
||||
k.p_sorted_token_ids = h.p_sorted_token_ids;
|
||||
k.p_sorted_weights = h.p_sorted_weights;
|
||||
k.p_expert_mesh = h.p_ws;
|
||||
@@ -1993,7 +2083,7 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
// in byte
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
|
||||
{
|
||||
return (4 + BLOCK_SIZE / warpSize) * sizeof(IndexType);
|
||||
return (4 + BLOCK_SIZE / get_warp_size()) * sizeof(IndexType);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
@@ -2009,9 +2099,19 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
const WeightType* p_weights = static_cast<const WeightType*>(kargs.p_weights);
|
||||
WeightType* p_sorted_weights = reinterpret_cast<WeightType*>(kargs.p_sorted_weights);
|
||||
|
||||
index_t tokens = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.tokens;
|
||||
}
|
||||
}();
|
||||
int eid = blockIdx.x;
|
||||
int wave_id = threadIdx.x / warpSize;
|
||||
int lane_id = threadIdx.x % warpSize;
|
||||
int wave_id = threadIdx.x / get_warp_size();
|
||||
int lane_id = threadIdx.x % get_warp_size();
|
||||
int e_start = p_expert_cumsum[eid];
|
||||
int e_end = p_expert_cumsum[eid + 1];
|
||||
if constexpr(Problem::SkipExpertsWithZeroTokens)
|
||||
@@ -2034,24 +2134,24 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
{
|
||||
int i_token = i * BLOCK_SIZE + threadIdx.x;
|
||||
IndexType x = 0;
|
||||
if(i_token < kargs.tokens)
|
||||
if(i_token < tokens)
|
||||
{
|
||||
x = p_expert_mesh[eid * kargs.mesh_stride + i_token];
|
||||
}
|
||||
int i_topk = x - 1; // topk of this token
|
||||
int i_show = x != 0 ? 1 : 0; // has this token or not
|
||||
int cumsum = i_show;
|
||||
impl::moe_sorting_wave_cumsum<int, warpSize>(cumsum);
|
||||
impl::moe_sorting_wave_cumsum<int, get_warp_size()>(cumsum);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == warpSize - 1)
|
||||
if(lane_id == get_warp_size() - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev = s[4 + i_w];
|
||||
prev = wave_id > i_w ? prev : 0; // mask out
|
||||
cumsum += prev;
|
||||
@@ -2081,7 +2181,7 @@ struct MoeSortingMultiPhaseKernel_P3
|
||||
for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += BLOCK_SIZE)
|
||||
{
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(kargs.tokens, kargs.topk_mdiv.divisor);
|
||||
p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(tokens, kargs.topk_mdiv.divisor);
|
||||
#else
|
||||
p_sorted_token_ids[i] = tokens;
|
||||
#endif
|
||||
@@ -2096,7 +2196,7 @@ CK_TILE_HOST constexpr auto moe_sorting_get_smem_size_p23(int num_experts_)
|
||||
{
|
||||
constexpr index_t BLOCK_SIZE = 256; // hardcoded 256
|
||||
const index_t expert_cumsum_elem = num_experts_ + 1;
|
||||
return (4 + 2 * BLOCK_SIZE / warpSize + expert_cumsum_elem) * sizeof(int);
|
||||
return (4 + 2 * BLOCK_SIZE / get_warp_size() + expert_cumsum_elem) * sizeof(int);
|
||||
}
|
||||
} // namespace impl
|
||||
|
||||
@@ -2120,6 +2220,7 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
{
|
||||
const void* p_weights;
|
||||
const void* p_local_expert_mask; // [expert]
|
||||
const void* p_local_tokens; // [1]
|
||||
void* p_expert_mesh; // [expert, tokens]
|
||||
void* p_expert_cumsum; // [expert + 1]
|
||||
void* p_total_tokens_post_pad; // [1]
|
||||
@@ -2142,6 +2243,7 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
Kargs k;
|
||||
k.p_weights = h.p_weights;
|
||||
k.p_local_expert_mask = h.p_local_expert_mask;
|
||||
k.p_local_tokens = h.p_local_tokens;
|
||||
k.p_expert_mesh = h.p_ws;
|
||||
k.p_expert_cumsum = reinterpret_cast<void*>(
|
||||
reinterpret_cast<char*>(h.p_ws) +
|
||||
@@ -2201,15 +2303,15 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
const IndexType* p_local_expert_mask =
|
||||
static_cast<const IndexType*>(kargs.p_local_expert_mask);
|
||||
IndexType* p_expert_cumsum = reinterpret_cast<IndexType*>(kargs.p_expert_cumsum);
|
||||
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / warpSize;
|
||||
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / get_warp_size();
|
||||
IndexType* p_total_tokens_post_pad =
|
||||
reinterpret_cast<IndexType*>(kargs.p_total_tokens_post_pad);
|
||||
IndexType* p_sorted_expert_ids =
|
||||
reinterpret_cast<IndexType*>(kargs.p_sorted_expert_ids);
|
||||
|
||||
const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
index_t wave_id = threadIdx.x / warpSize;
|
||||
index_t lane_id = threadIdx.x % warpSize;
|
||||
index_t wave_id = threadIdx.x / get_warp_size();
|
||||
index_t lane_id = threadIdx.x % get_warp_size();
|
||||
|
||||
IndexType prev_cumsum_a = 0;
|
||||
IndexType prev_cumsum_b = 0;
|
||||
@@ -2254,22 +2356,22 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
IndexType cumsum_b = b_;
|
||||
|
||||
// Note: we first cumsum local round, then add previous cumsum
|
||||
impl::moe_sorting_wave_cumsum<IndexType, warpSize>(cumsum_a);
|
||||
impl::moe_sorting_wave_cumsum<IndexType, warpSize>(cumsum_b);
|
||||
impl::moe_sorting_wave_cumsum<IndexType, get_warp_size()>(cumsum_a);
|
||||
impl::moe_sorting_wave_cumsum<IndexType, get_warp_size()>(cumsum_b);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == warpSize - 1)
|
||||
if(lane_id == get_warp_size() - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum_a;
|
||||
s[4 + wave_id + BLOCK_SIZE / warpSize] = cumsum_b;
|
||||
s[4 + wave_id] = cumsum_a;
|
||||
s[4 + wave_id + BLOCK_SIZE / get_warp_size()] = cumsum_b;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev_a = s[4 + i_w];
|
||||
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / warpSize];
|
||||
IndexType prev_b = s[4 + i_w + BLOCK_SIZE / get_warp_size()];
|
||||
prev_a = wave_id > i_w ? prev_a : 0; // mask out
|
||||
prev_b = wave_id > i_w ? prev_b : 0; // mask out
|
||||
cumsum_a += prev_a;
|
||||
@@ -2339,13 +2441,13 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
IndexType* s = reinterpret_cast<IndexType*>(smem);
|
||||
MeshType* p_expert_mesh = reinterpret_cast<MeshType*>(kargs.p_expert_mesh);
|
||||
IndexType* p_sorted_token_ids = reinterpret_cast<IndexType*>(kargs.p_sorted_token_ids);
|
||||
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / warpSize;
|
||||
IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / get_warp_size();
|
||||
const WeightType* p_weights = static_cast<const WeightType*>(kargs.p_weights);
|
||||
WeightType* p_sorted_weights = reinterpret_cast<WeightType*>(kargs.p_sorted_weights);
|
||||
|
||||
int eid = blockIdx.x;
|
||||
int wave_id = threadIdx.x / warpSize;
|
||||
int lane_id = threadIdx.x % warpSize;
|
||||
int wave_id = threadIdx.x / get_warp_size();
|
||||
int lane_id = threadIdx.x % get_warp_size();
|
||||
int e_start = p_expert_cumsum_smem[eid];
|
||||
int e_end = p_expert_cumsum_smem[eid + 1];
|
||||
if constexpr(Problem::SkipExpertsWithZeroTokens)
|
||||
@@ -2361,6 +2463,17 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
return; // skip empty expert
|
||||
}
|
||||
|
||||
index_t tokens = [&]() {
|
||||
if constexpr(Problem::LocalToken)
|
||||
{
|
||||
return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
return kargs.tokens;
|
||||
}
|
||||
}();
|
||||
|
||||
// cumsum one by one
|
||||
constexpr index_t index_pack = Problem::SubTokenTile; // always packed
|
||||
using r_t = ext_vector_t<MeshType, index_pack>; // always use int32x4
|
||||
@@ -2372,7 +2485,7 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
{
|
||||
int i_token_pack = i * BLOCK_SIZE + threadIdx.x;
|
||||
r_t x_v = 0;
|
||||
if(i_token_pack < (kargs.tokens + index_pack - 1) / index_pack)
|
||||
if(i_token_pack < (tokens + index_pack - 1) / index_pack)
|
||||
{
|
||||
x_v = reinterpret_cast<r_t*>(p_expert_mesh +
|
||||
eid * kargs.mesh_stride)[i_token_pack];
|
||||
@@ -2405,17 +2518,17 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
int i_topk = x - 1; // topk of this token
|
||||
int i_show = x != 0 ? 1 : 0; // has this token or not
|
||||
int cumsum = i_show;
|
||||
impl::moe_sorting_wave_cumsum<int, warpSize>(cumsum);
|
||||
impl::moe_sorting_wave_cumsum<int, get_warp_size()>(cumsum);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == warpSize - 1)
|
||||
if(lane_id == get_warp_size() - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev = s[4 + i_w];
|
||||
prev = wave_id > i_w ? prev : 0; // mask out
|
||||
cumsum += prev;
|
||||
@@ -2456,17 +2569,17 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
cumsum_store += i_show[j];
|
||||
});
|
||||
int cumsum = cumsum_store;
|
||||
impl::moe_sorting_wave_cumsum<int, warpSize>(cumsum);
|
||||
impl::moe_sorting_wave_cumsum<int, get_warp_size()>(cumsum);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == warpSize - 1)
|
||||
if(lane_id == get_warp_size() - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev = s[4 + i_w];
|
||||
prev = wave_id > i_w ? prev : 0; // mask out
|
||||
cumsum += prev;
|
||||
@@ -2511,17 +2624,17 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
int i_topk_1 = x1 - 1; // topk of this token
|
||||
int i_show_1 = x1 != 0 ? 1 : 0; // has this token or not
|
||||
int cumsum = i_show_0 + i_show_1;
|
||||
impl::moe_sorting_wave_cumsum<int, warpSize>(cumsum);
|
||||
impl::moe_sorting_wave_cumsum<int, get_warp_size()>(cumsum);
|
||||
|
||||
__syncthreads();
|
||||
if(lane_id == warpSize - 1)
|
||||
if(lane_id == get_warp_size() - 1)
|
||||
{
|
||||
s[4 + wave_id] = cumsum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// reduce cross wave
|
||||
static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) {
|
||||
static_for<0, BLOCK_SIZE / get_warp_size() - 1, 1>{}([&](auto i_w) {
|
||||
IndexType prev = s[4 + i_w];
|
||||
prev = wave_id > i_w ? prev : 0; // mask out
|
||||
cumsum += prev;
|
||||
@@ -2569,7 +2682,7 @@ struct MoeSortingMultiPhaseKernel_P23
|
||||
for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += BLOCK_SIZE)
|
||||
{
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(kargs.tokens, kargs.topk_mdiv.divisor);
|
||||
p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(tokens, kargs.topk_mdiv.divisor);
|
||||
#else
|
||||
p_sorted_token_ids[i] = tokens;
|
||||
#endif
|
||||
|
||||
@@ -31,6 +31,7 @@ template <typename IndexType_,
|
||||
index_t SubTokenTile_, // 1,2,4,8, or 0 in the future
|
||||
bool SubTokenOneShot_, // if we only loop over once or not
|
||||
bool LocalExpertMasking_, // used in EP case
|
||||
bool LocalToken_, // used in EP case
|
||||
bool SkipExpertsWithZeroTokens_ = true,
|
||||
index_t ExpertTile_ = 0>
|
||||
struct MoeSortingProblemEx
|
||||
@@ -44,6 +45,7 @@ struct MoeSortingProblemEx
|
||||
static constexpr index_t SubTokenTile = SubTokenTile_;
|
||||
static constexpr bool SubTokenOneShot = SubTokenOneShot_;
|
||||
static constexpr bool LocalExpertMasking = LocalExpertMasking_;
|
||||
static constexpr bool LocalToken = LocalToken_;
|
||||
static constexpr bool SkipExpertsWithZeroTokens = SkipExpertsWithZeroTokens_;
|
||||
static_assert(SubTokenTile == 1 || SubTokenTile == 2 || SubTokenTile == 4 || SubTokenTile == 8);
|
||||
static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out
|
||||
@@ -54,6 +56,7 @@ template <typename IndexType_,
|
||||
typename MeshType_,
|
||||
index_t SubTokenTile_, // 1,2,4,8
|
||||
bool LocalExpertMasking_, // used in EP case
|
||||
bool LocalToken_, // used in EP case
|
||||
bool SkipExpertsWithZeroTokens_ = true>
|
||||
struct MoeSortingProblemMp
|
||||
{
|
||||
@@ -64,6 +67,7 @@ struct MoeSortingProblemMp
|
||||
|
||||
static constexpr index_t SubTokenTile = SubTokenTile_;
|
||||
static constexpr bool LocalExpertMasking = LocalExpertMasking_;
|
||||
static constexpr bool LocalToken = LocalToken_;
|
||||
static constexpr bool SkipExpertsWithZeroTokens = SkipExpertsWithZeroTokens_;
|
||||
static_assert(SubTokenTile == 1 || SubTokenTile == 2 || SubTokenTile == 4 ||
|
||||
SubTokenTile == 8 || SubTokenTile == 16);
|
||||
|
||||
@@ -303,7 +303,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
constexpr index_t Block_M = Problem::BlockShape::Block_M0;
|
||||
constexpr index_t Block_K = Problem::BlockShape::Block_K0;
|
||||
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
|
||||
constexpr index_t warpSize = ck_tile::get_warp_size();
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
|
||||
|
||||
constexpr index_t KPack = GetSmemKPack_A<Problem>(); // LDS
|
||||
@@ -312,11 +312,11 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
|
||||
static_assert(Block_K % KVector == 0);
|
||||
constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
|
||||
if constexpr(LanesPerK >= warpSize)
|
||||
if constexpr(LanesPerK >= WarpSize)
|
||||
{
|
||||
// need multiple waves to load K
|
||||
static_assert(LanesPerK % warpSize == 0);
|
||||
constexpr index_t wavesPerK = LanesPerK / warpSize;
|
||||
static_assert(LanesPerK % WarpSize == 0);
|
||||
constexpr index_t wavesPerK = LanesPerK / WarpSize;
|
||||
if constexpr(wavesPerK > NumWarps)
|
||||
{
|
||||
// TODO: need multiple issues along K to load all data
|
||||
@@ -329,11 +329,11 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
make_tuple(number<NumIssues>{}, // m0
|
||||
number<wavesPerM>{}, // m1
|
||||
number<wavesPerK>{}, // k0
|
||||
number<warpSize>{}, // k1
|
||||
number<WarpSize>{}, // k1
|
||||
number<KVector>{}), // k2
|
||||
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
|
||||
number<wavesPerK*(warpSize * KVector + KPad)>{}, // m1
|
||||
number<warpSize * KVector + KPad>{}, // k0
|
||||
make_tuple(number<NumWarps*(WarpSize * KVector + KPad)>{}, // m0
|
||||
number<wavesPerK*(WarpSize * KVector + KPad)>{}, // m1
|
||||
number<WarpSize * KVector + KPad>{}, // k0
|
||||
number<KVector>{}, // k1
|
||||
number<1>{}), // k2
|
||||
number<KVector>{}, // lds store vector(actually no explicit store)
|
||||
@@ -344,7 +344,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<NumIssues>{}),
|
||||
make_merge_transform(make_tuple(number<wavesPerM>{}, number<wavesPerK>{})),
|
||||
make_merge_transform(make_tuple(number<warpSize>{}, number<KVector>{}))),
|
||||
make_merge_transform(make_tuple(number<WarpSize>{}, number<KVector>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
|
||||
@@ -354,8 +354,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
else
|
||||
{
|
||||
// lanes within a wave load different M but same K
|
||||
static_assert(warpSize % LanesPerK == 0);
|
||||
constexpr index_t LaneGroups = warpSize / LanesPerK; // along m
|
||||
static_assert(WarpSize % LanesPerK == 0);
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK; // along m
|
||||
constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps);
|
||||
|
||||
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
@@ -364,9 +364,9 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
number<NumWarps>{}, // m2
|
||||
number<LanesPerK>{}, // k0
|
||||
number<KVector>{}), // k1
|
||||
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
|
||||
make_tuple(number<NumWarps*(WarpSize * KVector + KPad)>{}, // m0
|
||||
number<Block_K>{}, // m1
|
||||
number<warpSize * KVector + KPad>{}, // m2
|
||||
number<WarpSize * KVector + KPad>{}, // m2
|
||||
number<KVector>{}, // k0
|
||||
number<1>{}), // k1
|
||||
number<KVector>{}, // lds store vector(actually no explicit store)
|
||||
@@ -398,7 +398,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
constexpr index_t Block_M = Problem::BlockShape::Block_M0;
|
||||
constexpr index_t Block_K = Problem::BlockShape::Block_K0;
|
||||
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
|
||||
constexpr index_t warpSize = ck_tile::get_warp_size();
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
|
||||
|
||||
constexpr index_t KPack = GetSmemKPack_A<Problem>(); // LDS
|
||||
@@ -407,11 +407,11 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
|
||||
static_assert(Block_K % KVector == 0);
|
||||
constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
|
||||
if constexpr(LanesPerK >= warpSize)
|
||||
if constexpr(LanesPerK >= WarpSize)
|
||||
{
|
||||
// need multiple waves to load K
|
||||
static_assert(LanesPerK % warpSize == 0);
|
||||
constexpr index_t wavesPerK = LanesPerK / warpSize;
|
||||
static_assert(LanesPerK % WarpSize == 0);
|
||||
constexpr index_t wavesPerK = LanesPerK / WarpSize;
|
||||
if constexpr(wavesPerK >= NumWarps)
|
||||
{
|
||||
// TODO: need multiple issues along K to load all data
|
||||
@@ -424,11 +424,11 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
make_tuple(number<NumIssues>{}, // m0
|
||||
number<wavesPerM>{}, // m1
|
||||
number<wavesPerK>{}, // k0
|
||||
number<warpSize>{}, // k1
|
||||
number<WarpSize>{}, // k1
|
||||
number<KVector>{}), // k2
|
||||
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
|
||||
number<wavesPerK*(warpSize * KVector + KPad)>{}, // m1
|
||||
number<warpSize * KVector + KPad>{}, // k0
|
||||
make_tuple(number<NumWarps*(WarpSize * KVector + KPad)>{}, // m0
|
||||
number<wavesPerK*(WarpSize * KVector + KPad)>{}, // m1
|
||||
number<WarpSize * KVector + KPad>{}, // k0
|
||||
number<KVector>{}, // k1
|
||||
number<1>{}), // k2
|
||||
number<KPack>{}, // lds load vector
|
||||
@@ -439,7 +439,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(number<NumIssues>{}, number<wavesPerM>{})),
|
||||
make_merge_transform(make_tuple(
|
||||
number<wavesPerK>{}, number<warpSize>{}, number<KVector>{}))),
|
||||
number<wavesPerK>{}, number<WarpSize>{}, number<KVector>{}))),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
@@ -449,8 +449,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
else
|
||||
{
|
||||
// lanes within a wave load different M but same K
|
||||
static_assert(warpSize % LanesPerK == 0);
|
||||
constexpr index_t LaneGroups = warpSize / LanesPerK; // along m
|
||||
static_assert(WarpSize % LanesPerK == 0);
|
||||
constexpr index_t LaneGroups = WarpSize / LanesPerK; // along m
|
||||
constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps);
|
||||
|
||||
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
@@ -459,9 +459,9 @@ struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
number<NumWarps>{}, // m2
|
||||
number<LanesPerK>{}, // k0
|
||||
number<KVector>{}), // k1
|
||||
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
|
||||
make_tuple(number<NumWarps*(WarpSize * KVector + KPad)>{}, // m0
|
||||
number<Block_K>{}, // m1
|
||||
number<warpSize * KVector + KPad>{}, // m2
|
||||
number<WarpSize * KVector + KPad>{}, // m2
|
||||
number<KVector>{}, // k0
|
||||
number<1>{}), // k1
|
||||
number<KPack>{}, // lds load vector
|
||||
|
||||
@@ -31,6 +31,8 @@
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
|
||||
|
||||
@@ -60,52 +60,105 @@ struct BlockGemmARegBRegCRegV1
|
||||
static constexpr index_t MIterPerWarp = Traits::MIterPerWarp;
|
||||
static constexpr index_t NIterPerWarp = Traits::NIterPerWarp;
|
||||
|
||||
static constexpr index_t MWarp = Traits::MWarp;
|
||||
static constexpr index_t NWarp = Traits::NWarp;
|
||||
static constexpr index_t MWarp = Traits::MWarp;
|
||||
static constexpr index_t NWarp = Traits::NWarp;
|
||||
static constexpr bool UseDefaultScheduler = (Problem::NumWaveGroups != 1);
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
|
||||
{
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
if constexpr(UseDefaultScheduler)
|
||||
{
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<>,
|
||||
tuple<>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
return a_block_dstr_encode;
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
|
||||
return a_block_dstr_encode;
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
|
||||
return a_block_dstr_encode;
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
|
||||
{
|
||||
constexpr auto b_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
if constexpr(UseDefaultScheduler)
|
||||
{
|
||||
constexpr auto b_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<>,
|
||||
tuple<>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
return b_block_dstr_encode;
|
||||
return b_block_dstr_encode;
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
return b_block_dstr_encode;
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode()
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
if constexpr(UseDefaultScheduler)
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<MWarp>,
|
||||
tuple<sequence<MIterPerWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<>,
|
||||
tuple<>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
|
||||
return c_block_dstr_encode;
|
||||
return c_block_dstr_encode;
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
|
||||
return c_block_dstr_encode;
|
||||
}
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
@@ -201,19 +254,38 @@ struct BlockGemmARegBRegCRegV1
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
if constexpr(UseDefaultScheduler)
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<MWarp>,
|
||||
tuple<sequence<MIterPerWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<>,
|
||||
tuple<>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
|
||||
@@ -12,7 +12,8 @@ template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
index_t kBlockSize_,
|
||||
typename BlockGemmShape_>
|
||||
typename BlockGemmShape_,
|
||||
index_t NumWaveGroups_ = 1>
|
||||
struct BlockGemmProblem
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
@@ -20,7 +21,8 @@ struct BlockGemmProblem
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t NumWaveGroups = NumWaveGroups_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -215,7 +215,7 @@ struct BlockUniversalGemmAsBsCr
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));
|
||||
|
||||
ALdsTile a_warp_tile_;
|
||||
ALdsTile b_warp_tile_;
|
||||
BLdsTile b_warp_tile_;
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct BatchedGemmHostArgs : public ck_tile::GemmHostArgs
|
||||
struct BatchedGemmHostArgs : public ck_tile::GemmHostArgs</*NumDTensor = 0*/>
|
||||
{
|
||||
CK_TILE_HOST BatchedGemmHostArgs() = default;
|
||||
CK_TILE_HOST BatchedGemmHostArgs(const void* a_ptr_,
|
||||
@@ -26,18 +26,28 @@ struct BatchedGemmHostArgs : public ck_tile::GemmHostArgs
|
||||
ck_tile::index_t batch_stride_B_,
|
||||
ck_tile::index_t batch_stride_C_,
|
||||
ck_tile::index_t batch_count_)
|
||||
: GemmHostArgs(
|
||||
a_ptr_, b_ptr_, c_ptr_, k_batch_, M_, N_, K_, stride_A_, stride_B_, stride_C_),
|
||||
: GemmHostArgs(a_ptr_,
|
||||
b_ptr_,
|
||||
{},
|
||||
c_ptr_,
|
||||
k_batch_,
|
||||
M_,
|
||||
N_,
|
||||
K_,
|
||||
stride_A_,
|
||||
stride_B_,
|
||||
{},
|
||||
stride_C_),
|
||||
batch_stride_A(batch_stride_A_),
|
||||
batch_stride_B(batch_stride_B_),
|
||||
batch_stride_C(batch_stride_C_),
|
||||
batch_stride_E(batch_stride_C_),
|
||||
batch_count(batch_count_)
|
||||
{
|
||||
}
|
||||
|
||||
ck_tile::index_t batch_stride_A;
|
||||
ck_tile::index_t batch_stride_B;
|
||||
ck_tile::index_t batch_stride_C;
|
||||
ck_tile::index_t batch_stride_E;
|
||||
ck_tile::index_t batch_count;
|
||||
};
|
||||
|
||||
@@ -46,18 +56,18 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
{
|
||||
using Base = GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
|
||||
|
||||
using GemmKernelArgs = typename ck_tile::GemmKernelArgs;
|
||||
using GemmKernelArgs = typename ck_tile::GemmKernelArgs<>;
|
||||
|
||||
using ADataType = typename Base::ADataType;
|
||||
using BDataType = typename Base::BDataType;
|
||||
using CDataType = typename Base::CDataType;
|
||||
using CDataType = typename Base::EDataType;
|
||||
|
||||
using TilePartitioner = typename Base::TilePartitioner;
|
||||
using GemmPipeline = typename Base::GemmPipeline;
|
||||
using EpiloguePipeline = typename Base::EpiloguePipeline;
|
||||
using ALayout = typename Base::ALayout;
|
||||
using BLayout = typename Base::BLayout;
|
||||
using CLayout = typename Base::CLayout;
|
||||
using CLayout = typename Base::ELayout;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
@@ -75,7 +85,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
{
|
||||
index_t batch_stride_A;
|
||||
index_t batch_stride_B;
|
||||
index_t batch_stride_C;
|
||||
index_t batch_stride_E;
|
||||
index_t batch_count;
|
||||
};
|
||||
|
||||
@@ -94,17 +104,19 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
{
|
||||
return BatchedGemmKernelArgs{{hostArgs.a_ptr,
|
||||
hostArgs.b_ptr,
|
||||
hostArgs.c_ptr,
|
||||
{},
|
||||
hostArgs.e_ptr,
|
||||
hostArgs.M,
|
||||
hostArgs.N,
|
||||
hostArgs.K,
|
||||
hostArgs.stride_A,
|
||||
hostArgs.stride_B,
|
||||
hostArgs.stride_C,
|
||||
{},
|
||||
hostArgs.stride_E,
|
||||
hostArgs.k_batch},
|
||||
hostArgs.batch_stride_A,
|
||||
hostArgs.batch_stride_B,
|
||||
hostArgs.batch_stride_C,
|
||||
hostArgs.batch_stride_E,
|
||||
hostArgs.batch_count};
|
||||
}
|
||||
|
||||
@@ -135,14 +147,14 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr) + batch_offset_B +
|
||||
splitk_batch_offset.b_k_split_offset;
|
||||
|
||||
const auto batch_stride_C = __builtin_amdgcn_readfirstlane(kargs.batch_stride_C);
|
||||
const auto batch_offset_C = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_C);
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr) + batch_offset_C;
|
||||
const auto batch_stride_E = __builtin_amdgcn_readfirstlane(kargs.batch_stride_E);
|
||||
const auto batch_offset_C = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_E);
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr) + batch_offset_C;
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
this->RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
|
||||
this->RunGemm(a_ptr, b_ptr, {}, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -9,74 +9,88 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/host/stream_utils.hpp"
|
||||
#include "ck_tile/core/utility/env.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/// @brief The GEMM problem definition.
|
||||
///
|
||||
/// @par Overview
|
||||
/// This structure defines the GEMM problem configuration by stating all required information
|
||||
/// like M,N,K sizes and respective strides.
|
||||
struct GemmProblem
|
||||
{
|
||||
CK_TILE_HOST GemmProblem() = default;
|
||||
CK_TILE_HOST GemmProblem(
|
||||
index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_)
|
||||
: M(M_), N(N_), K(K_), stride_A(stride_A_), stride_B(stride_B_), stride_C(stride_C_)
|
||||
{
|
||||
}
|
||||
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t stride_A;
|
||||
index_t stride_B;
|
||||
index_t stride_C;
|
||||
};
|
||||
|
||||
/// @brief The GEMM kernel host arguments.
|
||||
///
|
||||
/// @par Overview
|
||||
/// This structure is passed to @ref GemmKernel "GemmKernel" when creating kernel arguments
|
||||
/// object. It contain all necessary information required to build proper kernel argument
|
||||
/// and launch kernel on GPU.
|
||||
struct GemmHostArgs : public GemmProblem
|
||||
/// This structure defines the GEMM problem configuration by stating all required information
|
||||
/// like M,N,K sizes and respective strides.
|
||||
/// NumDTensor describes the number of D tensors.
|
||||
template <index_t NumDTensor = 0>
|
||||
struct GemmHostArgs
|
||||
{
|
||||
CK_TILE_HOST GemmHostArgs() = default;
|
||||
CK_TILE_HOST GemmHostArgs(const void* a_ptr_,
|
||||
const void* b_ptr_,
|
||||
void* c_ptr_,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr_,
|
||||
void* e_ptr_,
|
||||
index_t k_batch_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t stride_A_,
|
||||
index_t stride_B_,
|
||||
index_t stride_C_)
|
||||
: GemmProblem(M_, N_, K_, stride_A_, stride_B_, stride_C_),
|
||||
a_ptr(a_ptr_),
|
||||
const std::array<index_t, NumDTensor>& stride_Ds_,
|
||||
index_t stride_E_)
|
||||
: a_ptr(a_ptr_),
|
||||
b_ptr(b_ptr_),
|
||||
c_ptr(c_ptr_),
|
||||
ds_ptr(ds_ptr_),
|
||||
e_ptr(e_ptr_),
|
||||
M(M_),
|
||||
N(N_),
|
||||
K(K_),
|
||||
stride_A(stride_A_),
|
||||
stride_B(stride_B_),
|
||||
stride_Ds(stride_Ds_),
|
||||
stride_E(stride_E_),
|
||||
k_batch(k_batch_)
|
||||
{
|
||||
}
|
||||
|
||||
const void* a_ptr;
|
||||
const void* b_ptr;
|
||||
void* c_ptr;
|
||||
const std::array<const void*, NumDTensor> ds_ptr;
|
||||
union
|
||||
{
|
||||
void* e_ptr;
|
||||
void* c_ptr;
|
||||
};
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t stride_A;
|
||||
index_t stride_B;
|
||||
const std::array<index_t, NumDTensor> stride_Ds;
|
||||
union
|
||||
{
|
||||
index_t stride_E;
|
||||
index_t stride_C;
|
||||
};
|
||||
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
/// @brief The GEMM kernel device arguments.
|
||||
template <index_t NumDTensor = 0>
|
||||
struct GemmKernelArgs
|
||||
{
|
||||
/// @brief The A input tensor's pointer to device memory.
|
||||
const void* a_ptr;
|
||||
/// @brief The B input tensor's pointer to device memory.
|
||||
const void* b_ptr;
|
||||
/// @brief The C output tensor's pointer to device memory.
|
||||
void* c_ptr;
|
||||
/// @brief The Ds input tensor's pointer to device memory.
|
||||
const std::array<const void*, NumDTensor> ds_ptr;
|
||||
/// @brief The E output tensor's pointer to device memory.
|
||||
void* e_ptr;
|
||||
/// @brief GEMM's M dimension size.
|
||||
index_t M;
|
||||
/// @brief GEMM's N dimension size.
|
||||
@@ -90,8 +104,11 @@ struct GemmKernelArgs
|
||||
/// (in memory) of B tensor.
|
||||
index_t stride_B;
|
||||
/// @brief The distance between consecutive elements of non-contiguous dimension
|
||||
/// (in memory) of C tensor.
|
||||
index_t stride_C;
|
||||
/// (in memory) of Ds tensor.
|
||||
std::array<index_t, NumDTensor> stride_Ds;
|
||||
/// @brief The distance between consecutive elements of non-contiguous dimension
|
||||
/// (in memory) of E tensor.
|
||||
index_t stride_E;
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
@@ -130,26 +147,51 @@ struct GemmKernelArgs
|
||||
/// @tparam EpiloguePipeline_ The type of class providing the final part of matrix
|
||||
/// multiplication implementation. It is responsible for storing
|
||||
/// results calculated by @ref GemmPipeline_ "GemmPipeline" to
|
||||
/// the output C tensor in global memory.
|
||||
/// the output E tensor in global memory.
|
||||
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
|
||||
struct GemmKernel
|
||||
{
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
||||
// TODO: GemmPipeline::CLayout -> GemmPipeline::ELayout will be changed for multi-ABD
|
||||
using ELayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
|
||||
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
|
||||
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
|
||||
|
||||
// Get the persistent kernel if the pipeline has it available
|
||||
struct has_persistent_kernel
|
||||
{
|
||||
template <typename T>
|
||||
using has_persistent_type = decltype(T::UsePersistentKernel);
|
||||
|
||||
static constexpr bool value = []() {
|
||||
if constexpr(is_detected<has_persistent_type, GemmPipeline>{})
|
||||
return GemmPipeline::UsePersistentKernel;
|
||||
else
|
||||
return false;
|
||||
}();
|
||||
};
|
||||
static constexpr bool PersistentKernel = has_persistent_kernel::value;
|
||||
|
||||
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
|
||||
// Below type is actually accumulation data type - the output of block GEMM.
|
||||
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
static constexpr auto I2 = number<2>();
|
||||
static constexpr auto I3 = number<3>{};
|
||||
|
||||
static_assert(DsLayout::size() == DsDataType::size(),
|
||||
"The size of DsLayout and DsDataType should be the same");
|
||||
using KernelArgs = GemmKernelArgs<DsLayout::size()>;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
@@ -163,20 +205,41 @@ struct GemmKernel
|
||||
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the maximum occupancy grid size for the persistent kernel on the current device.
|
||||
* @return The maximum occupancy grid size.
|
||||
* @note This function queries the maximum occupancy of the kernel using
|
||||
* `hipOccupancyMaxActiveBlocksPerMultiprocessor`.
|
||||
*/
|
||||
CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
|
||||
{
|
||||
using Kernel = GemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
|
||||
const auto kernel = kentry<KernelBlockSize, 1, Kernel, KernelArgs>;
|
||||
int occupancy;
|
||||
hip_check_error(
|
||||
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, KernelBlockSize, 0));
|
||||
const int grid_size = get_available_compute_units(s) * occupancy;
|
||||
return dim3(grid_size, 1, 1);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
|
||||
|
||||
CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const GemmHostArgs& hostArgs)
|
||||
CK_TILE_HOST static constexpr KernelArgs
|
||||
MakeKernelArgs(const GemmHostArgs<NumDTensor>& hostArgs)
|
||||
{
|
||||
return GemmKernelArgs{hostArgs.a_ptr,
|
||||
hostArgs.b_ptr,
|
||||
hostArgs.c_ptr,
|
||||
hostArgs.M,
|
||||
hostArgs.N,
|
||||
hostArgs.K,
|
||||
hostArgs.stride_A,
|
||||
hostArgs.stride_B,
|
||||
hostArgs.stride_C,
|
||||
hostArgs.k_batch};
|
||||
|
||||
return KernelArgs{hostArgs.a_ptr,
|
||||
hostArgs.b_ptr,
|
||||
hostArgs.ds_ptr,
|
||||
hostArgs.e_ptr,
|
||||
hostArgs.M,
|
||||
hostArgs.N,
|
||||
hostArgs.K,
|
||||
hostArgs.stride_A,
|
||||
hostArgs.stride_B,
|
||||
hostArgs.stride_Ds,
|
||||
hostArgs.stride_E,
|
||||
hostArgs.k_batch};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
@@ -186,8 +249,7 @@ struct GemmKernel
|
||||
|
||||
struct SplitKBatchOffset
|
||||
{
|
||||
__device__ SplitKBatchOffset(const GemmKernelArgs& kargs,
|
||||
const std::size_t k_id = blockIdx.z)
|
||||
__device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z)
|
||||
{
|
||||
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
|
||||
const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1);
|
||||
@@ -226,10 +288,10 @@ struct GemmKernel
|
||||
index_t splitted_k;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs)
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs)
|
||||
{
|
||||
if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<CDataType, fp16_t, bf16_t>::value)
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value)
|
||||
{
|
||||
if(kargs.k_batch != 1)
|
||||
{
|
||||
@@ -325,7 +387,56 @@ struct GemmKernel
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
bool DTesnorIsValid = {true};
|
||||
static_for<0, NumDTensor, 1>{}([&](auto index) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
|
||||
if(std::is_same_v<DiLayout, ELayout> == false)
|
||||
{
|
||||
DTesnorIsValid = false;
|
||||
}
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Can't support N for tensor D that is not a multiple of "
|
||||
"NPerBlock without padding!");
|
||||
}
|
||||
DTesnorIsValid = false;
|
||||
}
|
||||
if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("N is not a multiple of vector load size for D tensor!");
|
||||
}
|
||||
DTesnorIsValid = false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Can't support M for tensor D that is not a multiple of "
|
||||
"MPerBlock without padding!");
|
||||
}
|
||||
DTesnorIsValid = false;
|
||||
}
|
||||
if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("M is not a multiple of vector load size for D tensor!");
|
||||
}
|
||||
DTesnorIsValid = false;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
|
||||
{
|
||||
@@ -365,15 +476,17 @@ struct GemmKernel
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
return DTesnorIsValid;
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
CDataType* c_ptr,
|
||||
const GemmKernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset)
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTensorViews(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset)
|
||||
{
|
||||
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
const auto& a_tensor_view = [&]() {
|
||||
@@ -460,29 +573,54 @@ struct GemmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& ds_tensor_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const DDataType_*>(ds_ptr[i]),
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.stride_Ds[i], 1),
|
||||
number<EpiloguePipeline::GetVectorSizeD(i)>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const DDataType_*>(ds_ptr[i]),
|
||||
make_tuple(kargs.N, kargs.M),
|
||||
make_tuple(kargs.stride_Ds[i], 1),
|
||||
number<EpiloguePipeline::GetVectorSizeD(i)>{},
|
||||
number<1>{});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
// TODO: enable vector write for C in ColMajor
|
||||
const auto& c_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
const auto& e_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
c_ptr,
|
||||
e_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.stride_C, 1),
|
||||
make_tuple(kargs.stride_E, 1),
|
||||
number<EpiloguePipeline::GetVectorSizeC()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
c_ptr,
|
||||
e_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(1, kargs.stride_C),
|
||||
make_tuple(1, kargs.stride_E),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(a_tensor_view, b_tensor_view, c_tensor_view);
|
||||
return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, e_tensor_view);
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
@@ -524,35 +662,57 @@ struct GemmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& ds_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
const auto& d_tensor_view = views.at(I2);
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(d_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(d_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadM>{});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
// TODO vector write in for C in ColMajor
|
||||
const auto& c_pad_view = [&]() {
|
||||
const auto& c_tensor_view = views.at(I2);
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
const auto& e_pad_view = [&]() {
|
||||
const auto& e_tensor_view = views.at(I3);
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(c_tensor_view,
|
||||
return pad_tensor_view(e_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(c_tensor_view,
|
||||
return pad_tensor_view(e_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<GemmPipeline::kPadM, false>{});
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(a_pad_view, b_pad_view, c_pad_view);
|
||||
return make_tuple(a_pad_view, b_pad_view, ds_pad_view, e_pad_view);
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
|
||||
{
|
||||
const auto& a_pad_view = views.at(I0);
|
||||
const auto& b_pad_view = views.at(I1);
|
||||
const auto& c_pad_view = views.at(I2);
|
||||
const auto& a_pad_view = views.at(I0);
|
||||
const auto& b_pad_view = views.at(I1);
|
||||
const auto& ds_pad_view = views.at(I2);
|
||||
const auto& e_pad_view = views.at(I3);
|
||||
|
||||
const auto& a_block_window = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
@@ -588,12 +748,32 @@ struct GemmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
auto c_block_window = make_tile_window(
|
||||
c_pad_view,
|
||||
const auto ds_block_window = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{i_n, i_m});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
auto e_block_window = make_tile_window(
|
||||
e_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
|
||||
return make_tuple(a_block_window, b_block_window, c_block_window);
|
||||
return make_tuple(a_block_window, b_block_window, ds_block_window, e_block_window);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -601,7 +781,8 @@ struct GemmKernel
|
||||
*
|
||||
* @param a_ptr input A pointer
|
||||
* @param b_ptr input B pointer
|
||||
* @param c_ptr output C pointer
|
||||
* @param ds_ptr input Ds pointer
|
||||
* @param e_ptr output E pointer
|
||||
* @param smem_ptr_0 The start memory pointer of the shared memory block.
|
||||
* @param kargs GEMM kernel arguments
|
||||
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch.
|
||||
@@ -609,11 +790,13 @@ struct GemmKernel
|
||||
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
|
||||
*
|
||||
*/
|
||||
template <bool UseDefaultScheduler = true>
|
||||
CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
CDataType* c_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
void* smem_ptr_0,
|
||||
const GemmKernelArgs& kargs,
|
||||
const KernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
@@ -621,7 +804,7 @@ struct GemmKernel
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset);
|
||||
a_ptr, b_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
|
||||
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
@@ -632,15 +815,20 @@ struct GemmKernel
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, smem_ptr_0);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I2);
|
||||
if(UseDefaultScheduler || (get_warp_id() == 0))
|
||||
{
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, smem_ptr_0);
|
||||
EpiloguePipeline{}.template
|
||||
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -650,7 +838,8 @@ struct GemmKernel
|
||||
*
|
||||
* @param a_ptr input A pointer
|
||||
* @param b_ptr input B pointer
|
||||
* @param c_ptr output C pointer
|
||||
* @param ds_ptr input Ds pointer
|
||||
* @param e_ptr output E pointer
|
||||
* @param smem_ptr_0 The starting pointer of 1st shared memory block.
|
||||
* @param smem_ptr_1 The starting pointer of 2nd shared memory block.
|
||||
* @param kargs GEMM kernel arguments
|
||||
@@ -661,10 +850,11 @@ struct GemmKernel
|
||||
*/
|
||||
CK_TILE_DEVICE static void RunGemm2LDS(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
CDataType* c_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
void* __restrict__ smem_ptr_0,
|
||||
void* __restrict__ smem_ptr_1,
|
||||
const GemmKernelArgs& kargs,
|
||||
const KernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
@@ -672,7 +862,8 @@ struct GemmKernel
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset);
|
||||
a_ptr, b_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
|
||||
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
@@ -682,18 +873,22 @@ struct GemmKernel
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I2);
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, smem_ptr_0);
|
||||
EpiloguePipeline{}.template
|
||||
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const
|
||||
// Non-persistent kernel entry point
|
||||
template <bool U = !PersistentKernel, typename = std::enable_if_t<U>>
|
||||
CK_TILE_DEVICE void operator()(KernelArgs kargs) const
|
||||
{
|
||||
const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x);
|
||||
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
|
||||
@@ -701,12 +896,14 @@ struct GemmKernel
|
||||
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
const SplitKBatchOffset splitk_batch_offset(kargs);
|
||||
|
||||
// options
|
||||
const ADataType* a_ptr =
|
||||
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
|
||||
const BDataType* b_ptr =
|
||||
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
|
||||
|
||||
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
@@ -716,11 +913,12 @@ struct GemmKernel
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<CDataType, fp16_t, bf16_t>::value))
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm2LDS(a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1,
|
||||
kargs,
|
||||
@@ -733,9 +931,95 @@ struct GemmKernel
|
||||
{
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<CDataType, fp16_t, bf16_t>::value))
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
|
||||
constexpr auto scheduler_type = (GemmPipeline::NumWaveGroups == 1);
|
||||
RunGemm<scheduler_type>(a_ptr,
|
||||
b_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr_0,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Persistent kernel entry point
|
||||
template <bool U = PersistentKernel, typename = std::enable_if_t<U>, typename = void>
|
||||
CK_TILE_DEVICE void operator()(KernelArgs kargs) const
|
||||
{
|
||||
const auto grid_size = __builtin_amdgcn_readfirstlane(get_grid_size());
|
||||
const auto num_tiles =
|
||||
__builtin_amdgcn_readfirstlane(TilePartitioner::GridSize(kargs.M, kargs.N));
|
||||
const auto num_work = __builtin_amdgcn_readfirstlane(num_tiles * kargs.k_batch);
|
||||
auto block_id = __builtin_amdgcn_readfirstlane(get_block_id());
|
||||
|
||||
while(block_id < num_work)
|
||||
{
|
||||
// Get the tile index for this block
|
||||
const auto tile_idx = __builtin_amdgcn_readfirstlane(block_id % num_tiles);
|
||||
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(tile_idx);
|
||||
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
// Get the SplitK offset for this block
|
||||
const auto k_batch = __builtin_amdgcn_readfirstlane(block_id / num_tiles);
|
||||
const SplitKBatchOffset splitk_batch_offset(kargs, k_batch);
|
||||
const ADataType* a_ptr =
|
||||
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
|
||||
const BDataType* b_ptr =
|
||||
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
|
||||
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
// Run the GEMM
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation ==
|
||||
memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm2LDS(a_ptr,
|
||||
b_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation ==
|
||||
memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm(a_ptr,
|
||||
b_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr_0,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
}
|
||||
// Advance to the next work item
|
||||
block_id += grid_size;
|
||||
if(block_id >= num_work)
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -195,6 +195,22 @@ struct OffsettedTile1DPartitioner
|
||||
const auto [iM, iN] = TilePartitioner{M, N}.GetOutputTileIndex(blockIdx.x - block_start);
|
||||
return make_tuple(iM, iN);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief The function subtracts the block's start (offset) from a given block index.
|
||||
* @param [in] block_start Workgroup offset.
|
||||
* @param [in] M Gemm's M dimension.
|
||||
* @param [in] N Gemm's N dimension.
|
||||
* @param [in] block_idx Current block index of the workgroup.
|
||||
* @return Returns a `tuple` [Im, In] with shifted index.
|
||||
*/
|
||||
[[nodiscard]] CK_TILE_DEVICE static auto
|
||||
GetOffsetedTileIndex(index_t block_start, index_t M, index_t N, index_t block_idx) noexcept
|
||||
-> const tuple<index_t, index_t>
|
||||
{
|
||||
const auto [iM, iN] = TilePartitioner{M, N}.GetOutputTileIndex(block_idx - block_start);
|
||||
return make_tuple(iM, iN);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -230,7 +246,7 @@ struct GemmSpatiallyLocalTilePartitioner
|
||||
* @param N GEMM's N dimension.
|
||||
* @return index_t A total number of workgroups.
|
||||
*/
|
||||
CK_TILE_HOST static auto
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> index_t
|
||||
{
|
||||
const index_t GridDimX = integer_divide_ceil(M, MPerBlock);
|
||||
|
||||
@@ -5,23 +5,30 @@
|
||||
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/literals.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/host/stream_utils.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct GemmTransKernelArg
|
||||
{
|
||||
GemmKernelArgs group_karg;
|
||||
GemmKernelArgs<> group_karg;
|
||||
ck_tile::index_t block_start;
|
||||
ck_tile::index_t block_end;
|
||||
|
||||
GemmTransKernelArg() = default;
|
||||
GemmTransKernelArg(GemmKernelArgs&& karg, index_t bl_start, index_t bl_end)
|
||||
GemmTransKernelArg() = delete;
|
||||
GemmTransKernelArg(GemmKernelArgs<>&& karg, index_t bl_start, index_t bl_end)
|
||||
: group_karg{karg}, block_start{bl_start}, block_end{bl_end}
|
||||
{
|
||||
}
|
||||
|
||||
GemmTransKernelArg(GemmKernelArgs<>&& karg) : group_karg{karg}, block_start{0}, block_end{0} {}
|
||||
};
|
||||
|
||||
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
|
||||
@@ -32,7 +39,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
using ELayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
|
||||
@@ -40,8 +47,10 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
|
||||
using OffsetTile1DPartitioner = OffsettedTile1DPartitioner<TilePartitioner>;
|
||||
using Base = GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
|
||||
using Kernel = GroupedGemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
|
||||
|
||||
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
|
||||
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
|
||||
static constexpr bool UsePersistentKernel = GemmPipeline::UsePersistentKernel;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
@@ -51,19 +60,43 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
return concat('_', "gemm_grouped", gemm_prec_str<ADataType, BDataType>,
|
||||
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
|
||||
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
|
||||
concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
|
||||
concat('x', P_::kPadM, P_::kPadN, P_::kPadK),
|
||||
(UsePersistentKernel ? "Persistent" : "NonPersistent"));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
__host__ static auto GetWorkSpaceSize(const std::vector<GemmHostArgs>& gemm_descs)
|
||||
-> std::size_t
|
||||
CK_TILE_HOST static auto
|
||||
GetWorkSpaceSize(const std::vector<GemmHostArgs</*NumDTensor = 0*/>>& gemm_descs) -> std::size_t
|
||||
{
|
||||
return gemm_descs.size() * sizeof(GemmTransKernelArg);
|
||||
}
|
||||
|
||||
__host__ static constexpr auto BlockSize() -> dim3 { return dim3(KernelBlockSize); }
|
||||
CK_TILE_HOST static auto GetWorkSpaceSize(index_t group_count) -> std::size_t
|
||||
{
|
||||
return group_count * sizeof(GemmTransKernelArg);
|
||||
}
|
||||
|
||||
__host__ static constexpr auto GridSize(const std::vector<GemmHostArgs>& gemm_descs)
|
||||
CK_TILE_HOST static constexpr auto BlockSize() -> dim3 { return dim3(KernelBlockSize); }
|
||||
|
||||
/**
|
||||
* @brief Get the maximum occupancy grid size for the persistent kernel on the current device.
|
||||
* @return The maximum occupancy grid size.
|
||||
* @note This function queries the maximum occupancy of the kernel using
|
||||
* `hipOccupancyMaxActiveBlocksPerMultiprocessor`.
|
||||
*/
|
||||
CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
|
||||
{
|
||||
using ConstantPointer = const void CK_CONSTANT_ADDRESS_SPACE*;
|
||||
const auto kernel = kentry<KernelBlockSize, 1, Kernel, ConstantPointer, index_t>;
|
||||
int occupancy;
|
||||
HIP_CHECK_ERROR(
|
||||
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, KernelBlockSize, 0));
|
||||
const int grid_size = get_available_compute_units(s) * occupancy;
|
||||
return dim3(grid_size, 1, 1);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto
|
||||
GridSize(const std::vector<GemmHostArgs</*NumDTensor = 0*/>>& gemm_descs)
|
||||
{
|
||||
index_t grid_size = 0;
|
||||
for(const auto& it_desc : gemm_descs)
|
||||
@@ -74,7 +107,8 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
return dim3(grid_size, 1, 1);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static auto MakeKargs(const std::vector<GemmHostArgs>& gemm_descs)
|
||||
CK_TILE_HOST static auto
|
||||
MakeKargs(const std::vector<GemmHostArgs</*NumDTensor = 0*/>>& gemm_descs)
|
||||
-> std::vector<GemmTransKernelArg>
|
||||
{
|
||||
std::vector<GemmTransKernelArg> gemm_kernel_args_;
|
||||
@@ -95,7 +129,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
|
||||
const index_t stride_a = gemm_descs[i].stride_A;
|
||||
const index_t stride_b = gemm_descs[i].stride_B;
|
||||
const index_t stride_c = gemm_descs[i].stride_C;
|
||||
const index_t stride_e = gemm_descs[i].stride_E;
|
||||
|
||||
const index_t grid_size_grp = TilePartitioner::GridSize(M, N) * gemm_descs[i].k_batch;
|
||||
|
||||
@@ -104,16 +138,18 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
|
||||
grid_size += grid_size_grp;
|
||||
|
||||
auto karg = GemmKernelArgs{type_convert<const ADataType*>(gemm_descs[i].a_ptr),
|
||||
type_convert<const BDataType*>(gemm_descs[i].b_ptr),
|
||||
type_convert<CDataType*>(gemm_descs[i].c_ptr),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
gemm_descs[i].k_batch};
|
||||
auto karg = GemmKernelArgs<>{type_convert<const ADataType*>(gemm_descs[i].a_ptr),
|
||||
type_convert<const BDataType*>(gemm_descs[i].b_ptr),
|
||||
{},
|
||||
type_convert<CDataType*>(gemm_descs[i].e_ptr),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_a,
|
||||
stride_b,
|
||||
{},
|
||||
stride_e,
|
||||
gemm_descs[i].k_batch};
|
||||
|
||||
gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end);
|
||||
}
|
||||
@@ -121,39 +157,120 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
return gemm_kernel_args_;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const std::vector<GemmTransKernelArg>& kargs)
|
||||
{
|
||||
for(const auto& karg : kargs)
|
||||
{
|
||||
if(!Base::IsSupportedArgument(karg.group_karg))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() -> index_t
|
||||
{
|
||||
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void Run(const GemmTransKernelArg& kargs) const
|
||||
CK_TILE_DEVICE void Run(const GemmTransKernelArg& kargs,
|
||||
const tuple<index_t, index_t>& block_idx_2d,
|
||||
const index_t block_idx_z) const
|
||||
{
|
||||
const auto [iM, iN] = OffsetTile1DPartitioner::GetOffsetedTileIndex(
|
||||
kargs.block_start, kargs.group_karg.M, kargs.group_karg.N);
|
||||
Run(kargs.group_karg, block_idx_2d, block_idx_z);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void Run(const GemmKernelArgs<>& kargs,
|
||||
const tuple<index_t, index_t>& block_idx_2d,
|
||||
const index_t block_idx_z) const
|
||||
{
|
||||
const auto [iM, iN] = block_idx_2d;
|
||||
|
||||
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
const typename Base::SplitKBatchOffset splitk_batch_offset(kargs.group_karg, blockIdx.z);
|
||||
const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, block_idx_z);
|
||||
|
||||
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.group_karg.a_ptr);
|
||||
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.group_karg.b_ptr);
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.group_karg.c_ptr);
|
||||
const ADataType* a_ptr =
|
||||
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
|
||||
const BDataType* b_ptr =
|
||||
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
this->RunGemm(
|
||||
a_ptr, b_ptr, c_ptr, smem_ptr, kargs.group_karg, splitk_batch_offset, i_m, i_n);
|
||||
if constexpr(UsePersistentKernel)
|
||||
{
|
||||
RunGemmWithPipelineSelection(
|
||||
a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
this->RunGemm(a_ptr, b_ptr, {}, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
|
||||
index_t group_count) const
|
||||
/**
|
||||
* @brief Runs single GEMM problem cooperatively by whole workgroup.
|
||||
*
|
||||
* @note The GEMM pipeline is selected in-kernel based on the number of K-loops
|
||||
* and the tail-number. This is needed for the persistent tile-loop when
|
||||
* we didn't have access to the K dimension on the host.
|
||||
*
|
||||
* @param a_ptr input A pointer
|
||||
* @param b_ptr input B pointer
|
||||
* @param c_ptr output C pointer
|
||||
* @param smem_ptr_0 The start memory pointer of the shared memory block.
|
||||
* @param kargs GEMM kernel arguments
|
||||
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch.
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
|
||||
*
|
||||
*/
|
||||
CK_TILE_DEVICE static void
|
||||
RunGemmWithPipelineSelection(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
CDataType* c_ptr,
|
||||
void* smem_ptr_0,
|
||||
const GemmKernelArgs<>& kargs,
|
||||
const typename Base::SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
const index_t block_id = ck_tile::get_block_1d_id();
|
||||
const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg*>(
|
||||
cast_pointer_to_generic_address_space(gemm_descs_const));
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, {}, c_ptr, kargs, splitk_batch_offset);
|
||||
|
||||
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
const auto& a_block_window = gemm_tile_windows.at(Base::I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(Base::I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(Base::I2);
|
||||
|
||||
// Get hot-loop and tail configuration
|
||||
const index_t num_loop = __builtin_amdgcn_readfirstlane(
|
||||
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
// Run GEMM pipeline
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(Base::I3);
|
||||
EpiloguePipeline{}.template
|
||||
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE index_t FindGroupId(const GemmTransKernelArg* gemm_desc_ptr,
|
||||
index_t block_id,
|
||||
index_t group_count) const
|
||||
{
|
||||
index_t left = 0;
|
||||
index_t right = group_count;
|
||||
index_t group_id = index_t((left + right) >> 1);
|
||||
@@ -173,7 +290,61 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
group_id = index_t((left + right) >> 1);
|
||||
}
|
||||
|
||||
Run(gemm_desc_ptr[group_id]);
|
||||
return group_id;
|
||||
}
|
||||
|
||||
// For non-persistent kernels
|
||||
template <bool U = UsePersistentKernel, typename = std::enable_if_t<!U>>
|
||||
CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
|
||||
index_t group_count) const
|
||||
{
|
||||
const index_t block_id = ck_tile::get_block_1d_id();
|
||||
const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg*>(
|
||||
cast_pointer_to_generic_address_space(gemm_descs_const));
|
||||
|
||||
const index_t group_id = FindGroupId(gemm_desc_ptr, block_id, group_count);
|
||||
const auto& kargs = gemm_desc_ptr[group_id];
|
||||
const auto grid_size_2d = TilePartitioner::GridSize(kargs.group_karg.M, kargs.group_karg.N);
|
||||
const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex(
|
||||
0,
|
||||
kargs.group_karg.M,
|
||||
kargs.group_karg.N,
|
||||
(block_id - kargs.block_start) % grid_size_2d);
|
||||
Run(kargs, block_idx_2d, (block_id - kargs.block_start) / grid_size_2d);
|
||||
}
|
||||
|
||||
// For persistent kernels
|
||||
template <bool U = UsePersistentKernel,
|
||||
typename = std::enable_if_t<U>,
|
||||
typename = void> // extra template parameter to avoid redefinition
|
||||
CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
|
||||
const index_t group_count) const
|
||||
{
|
||||
const index_t grid_size = ck_tile::get_grid_size();
|
||||
const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg*>(
|
||||
cast_pointer_to_generic_address_space(gemm_descs_const));
|
||||
index_t block_id = ck_tile::get_block_1d_id(); // initial block_id
|
||||
index_t cum_grid_size = 0;
|
||||
for(index_t group_id = 0; group_id < group_count; ++group_id)
|
||||
{
|
||||
const auto& kargs = gemm_desc_ptr[group_id].group_karg;
|
||||
const auto& k_batch = kargs.k_batch;
|
||||
const auto block_start = cum_grid_size;
|
||||
cum_grid_size += TilePartitioner::GridSize(kargs.M, kargs.N) * k_batch;
|
||||
while(block_id < cum_grid_size)
|
||||
{
|
||||
const auto grid_size_2d = TilePartitioner::GridSize(kargs.M, kargs.N);
|
||||
const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex(
|
||||
0, kargs.M, kargs.N, (block_id - block_start) % grid_size_2d);
|
||||
Run(kargs, block_idx_2d, (block_id - block_start) / grid_size_2d);
|
||||
block_id = block_id + grid_size; // advance to next block
|
||||
// NOTE: this check is redundant but helps the compiler avoid spilling some VGPR
|
||||
if(block_id >= cum_grid_size)
|
||||
{
|
||||
break; // exit the loop if all blocks are processed
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -80,7 +80,8 @@ struct GemmPipelineAgBgCrImplBase
|
||||
template <typename ADramBlockWindowTmp, typename ALdsTensorView, typename ALdsLoadTileDistr>
|
||||
CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const ALdsTensorView& a_lds_block_view,
|
||||
const ALdsLoadTileDistr&) const
|
||||
const ALdsLoadTileDistr&,
|
||||
const array<index_t, 2>& offset = {0, 0}) const
|
||||
{
|
||||
constexpr bool is_col_major = std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
|
||||
@@ -91,7 +92,7 @@ struct GemmPipelineAgBgCrImplBase
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(YPerTile{}, XPerTile{}),
|
||||
a_dram_block_window_tmp.get_window_origin(),
|
||||
a_dram_block_window_tmp.get_window_origin() + offset,
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
// A LDS tile window for store
|
||||
@@ -112,7 +113,8 @@ struct GemmPipelineAgBgCrImplBase
|
||||
template <typename BDramBlockWindowTmp, typename BLdsTensorView, typename BLdsLoadTileDistr>
|
||||
CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BLdsTensorView& b_lds_block_view,
|
||||
const BLdsLoadTileDistr&) const
|
||||
const BLdsLoadTileDistr&,
|
||||
const array<index_t, 2>& offset = {0, 0}) const
|
||||
{
|
||||
constexpr bool is_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
@@ -122,7 +124,7 @@ struct GemmPipelineAgBgCrImplBase
|
||||
auto b_copy_dram_window =
|
||||
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(YPerTile{}, XPerTile{}),
|
||||
b_dram_block_window_tmp.get_window_origin(),
|
||||
b_dram_block_window_tmp.get_window_origin() + offset,
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
|
||||
// TODO: Do we really need those two tile windows???
|
||||
|
||||
@@ -20,18 +20,19 @@ namespace ck_tile {
|
||||
template <typename Problem>
|
||||
struct BaseGemmPipelineAgBgCrCompV3
|
||||
{
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
|
||||
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
if(BlockHasHotloop(num_loop))
|
||||
{
|
||||
@@ -49,6 +50,50 @@ struct BaseGemmPipelineAgBgCrCompV3
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
|
||||
{
|
||||
// Handle all the valid cases.
|
||||
if(has_hot_loop)
|
||||
{
|
||||
if(tail_number == TailNumber::Full)
|
||||
{
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Full>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(tail_number == TailNumber::Odd)
|
||||
{
|
||||
return run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_number == TailNumber::Even)
|
||||
{
|
||||
return run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::Even>{});
|
||||
}
|
||||
}
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
// This path should be unreachable in device code if tail_number is valid.
|
||||
__builtin_unreachable();
|
||||
#else
|
||||
// If execution reaches here, it's an invalid combination of arguments.
|
||||
if(has_hot_loop)
|
||||
{
|
||||
throw std::logic_error("Invalid TailNumber: If has_hot_loop is true, tail_number must "
|
||||
"be TailNumber::Full.");
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::logic_error("Invalid TailNumber: If has_hot_loop is false, tail_number must "
|
||||
"be TailNumber::Odd or TailNumber::Even.");
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
// Compute optimized pipeline
|
||||
@@ -98,12 +143,14 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
using Base::PrefetchStages;
|
||||
using Base::UsePersistentKernel;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
@@ -556,6 +603,42 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
p_smem);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief This function runs the pipeline by wrapping it with the tail handler.
|
||||
*
|
||||
* @note This is used by the persistent gemm kernel variants that don't determine
|
||||
* hot loop and tail number on the host side, e.g. grouped gemm kernel.
|
||||
*/
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
bool has_hot_loop,
|
||||
TailNumber tail_number,
|
||||
void* p_smem) const
|
||||
{
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
constexpr bool hot_loop = hot_loop_.value;
|
||||
constexpr auto tail_num = tail_num_.value;
|
||||
constexpr auto PassThrough = [](const auto& x) { return x; };
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop, tail_num>(
|
||||
a_dram_block_window_tmp,
|
||||
PassThrough,
|
||||
b_dram_block_window_tmp,
|
||||
PassThrough,
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief This function runs the pipeline using compile-time known hot loop and tail number.
|
||||
* @param num_loop The number of loop iterations. This is determined at runtime due to e.g.
|
||||
* SplitK.
|
||||
* @note This is used by the kernel variants that are able to determine
|
||||
* hot loop and tail number on the host side, e.g. non-persistent gemm kernel.
|
||||
*/
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
|
||||
@@ -34,6 +34,46 @@ struct BaseGemmPipelineAgBgCrCompV4
|
||||
return TailNumber::Two;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
|
||||
{
|
||||
// Handle all the valid cases.
|
||||
if(has_hot_loop)
|
||||
{
|
||||
if(tail_number == TailNumber::Three)
|
||||
{
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Three>{});
|
||||
}
|
||||
else if(tail_number == TailNumber::Two)
|
||||
{
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Two>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(tail_number == TailNumber::Three)
|
||||
{
|
||||
return run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::Three>{});
|
||||
}
|
||||
else if(tail_number == TailNumber::Two)
|
||||
{
|
||||
return run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::Two>{});
|
||||
}
|
||||
}
|
||||
// If execution reaches here, it's an invalid tail_number because it wasn't handled above.
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
__builtin_unreachable();
|
||||
#else
|
||||
throw std::logic_error("Invalid TailNumber: Only TailNumber::Full and smaller than "
|
||||
"PrefetchStages are supported.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -94,6 +134,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
@@ -572,5 +613,30 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
p_smem_0,
|
||||
p_smem_1);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
bool has_hot_loop,
|
||||
TailNumber tail_number,
|
||||
void* __restrict__ p_smem_0,
|
||||
void* __restrict__ p_smem_1) const
|
||||
{
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
constexpr bool hot_loop = hot_loop_.value;
|
||||
constexpr auto tail_num = tail_num_.value;
|
||||
constexpr auto PassThrough = [](const auto& x) { return x; };
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop, tail_num>(
|
||||
a_dram_block_window_tmp,
|
||||
PassThrough,
|
||||
b_dram_block_window_tmp,
|
||||
PassThrough,
|
||||
num_loop,
|
||||
p_smem_0,
|
||||
p_smem_1);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -0,0 +1,379 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
// A Tile Window: global memory
|
||||
// B Tile Window: global memory
|
||||
// C Distributed Tensor: register
|
||||
|
||||
template <typename Problem>
|
||||
struct BaseGemmPipelineAgBgCrCompV5
|
||||
{
|
||||
static constexpr index_t PrefetchStages = 1;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t) { return true; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t)
|
||||
{
|
||||
return TailNumber::Empty;
|
||||
}
|
||||
|
||||
template <typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber)
|
||||
{
|
||||
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem, typename Policy = GemmPipelineAgBgCrCompV5DefaultPolicy>
|
||||
struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
{
|
||||
using Base = BaseGemmPipelineAgBgCrCompV5<Problem>;
|
||||
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
using I2 = number<2>;
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
|
||||
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
static constexpr index_t NumWarps = BlockGemmShape::NumWarps;
|
||||
static constexpr index_t KTileSize = BlockGemmShape::WarpTile::at(I2{});
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "pipeline_AgBgCrCompV5", BlockSize,
|
||||
concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()),
|
||||
concat('x', kPadM, kPadN, kPadK));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
|
||||
{
|
||||
return Policy::template IsTransposeC<Problem>();
|
||||
}
|
||||
|
||||
template <GemmPipelineScheduler Scheduler>
|
||||
struct PipelineImpl : public PipelineImplBase
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
|
||||
{
|
||||
using Base = PipelineImplBase;
|
||||
|
||||
template <bool HasHotLoop,
|
||||
TailNumber TailNum,
|
||||
typename ADramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename BElementFunction>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* __restrict__ p_smem_0) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType,
|
||||
remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
"Data Type conflict on A and B matrix input data type.");
|
||||
|
||||
static_assert(
|
||||
KPerBlock % ((NumWarps / 2) * KTileSize) == 0,
|
||||
"Ping Pong Warps, TileSize and Block Size for K dimensions does not match.");
|
||||
|
||||
constexpr bool is_a_col_major =
|
||||
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
static_assert(is_a_col_major
|
||||
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
: (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"A block window has incorrect lengths for defined ALayout!");
|
||||
static_assert(is_b_row_major
|
||||
? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"B block window has incorrect lengths for defined BLayout!");
|
||||
|
||||
index_t warp_id = get_warp_id();
|
||||
index_t operation_id =
|
||||
__builtin_amdgcn_readfirstlane(get_warp_id()); // 0 - Memory read, 1 - block-gemm
|
||||
|
||||
auto a_offset = (warp_id == 0) ? make_array(0, 0) : make_array(0, KPerBlock);
|
||||
auto b_offset = (warp_id == 0) ? make_array(0, 0) : make_array(0, KPerBlock);
|
||||
|
||||
auto tensor_views =
|
||||
Base::GetABLdsTensorViews(static_cast<void*>(static_cast<char*>(p_smem_0)));
|
||||
auto& a_lds_block = tensor_views.get(number<0>{});
|
||||
auto& b_lds_block = tensor_views.get(number<1>{});
|
||||
|
||||
constexpr auto a_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
|
||||
constexpr auto b_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
|
||||
|
||||
auto a_windows = Base::GetAWindows(
|
||||
a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr, a_offset);
|
||||
auto& a_copy_dram_window = a_windows.get(number<0>{});
|
||||
auto& a_copy_lds_window = a_windows.get(number<1>{});
|
||||
auto& a_lds_window = a_windows.get(number<2>{});
|
||||
|
||||
auto b_windows = Base::GetBWindows(
|
||||
b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr, b_offset);
|
||||
auto& b_copy_dram_window = b_windows.get(number<0>{});
|
||||
auto& b_copy_lds_window = b_windows.get(number<1>{});
|
||||
auto& b_lds_window = b_windows.get(number<2>{});
|
||||
|
||||
// DRAM window steps.
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step =
|
||||
is_a_col_major ? make_array(KPerBlock * NumWarps, 0)
|
||||
: make_array(0, KPerBlock * NumWarps);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step =
|
||||
is_b_row_major ? make_array(KPerBlock * NumWarps, 0)
|
||||
: make_array(0, KPerBlock * NumWarps);
|
||||
|
||||
constexpr auto AGemmTileDistr = decltype(make_static_tile_distribution(
|
||||
BlockGemm::MakeABlockDistributionEncode())){};
|
||||
constexpr auto BGemmTileDistr = decltype(make_static_tile_distribution(
|
||||
BlockGemm::MakeBBlockDistributionEncode())){};
|
||||
|
||||
using AGemmTile = decltype(make_static_distributed_tensor<ADataType>(AGemmTileDistr));
|
||||
using BGemmTile = decltype(make_static_distributed_tensor<BDataType>(BGemmTileDistr));
|
||||
AGemmTile a_tile_0, a_tile_1;
|
||||
BGemmTile b_tile_0, b_tile_1;
|
||||
|
||||
// Register tile for A and B.
|
||||
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
|
||||
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
|
||||
using ABlockTile =
|
||||
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
|
||||
using BBlockTile =
|
||||
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
|
||||
ABlockTile a_global_load_tile;
|
||||
BBlockTile b_global_load_tile;
|
||||
|
||||
// Block GEMM
|
||||
auto block_gemm = BlockGemm();
|
||||
auto c_block_tile_0 = block_gemm.MakeCBlockTile();
|
||||
auto c_block_tile_1 = block_gemm.MakeCBlockTile();
|
||||
|
||||
CDataType* __restrict__ p_c_lds = static_cast<CDataType*>(p_smem_0);
|
||||
auto c_lds_block_0 =
|
||||
make_naive_tensor_view<address_space_enum::lds>(p_c_lds,
|
||||
make_tuple(MPerBlock, NPerBlock),
|
||||
make_tuple(NPerBlock, 1),
|
||||
number<BlockGemm::Traits::KPack>{},
|
||||
number<1>{});
|
||||
auto c_window_0 = make_tile_window(c_lds_block_0,
|
||||
make_tuple(number<MPerBlock>{}, number<NPerBlock>{}),
|
||||
{0, 0},
|
||||
c_block_tile_1.get_tile_distribution());
|
||||
|
||||
// initialize C
|
||||
if(warp_id == 0)
|
||||
{
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile_1);
|
||||
}
|
||||
|
||||
// define ping, pong steps here as lambda functions.
|
||||
auto MemoryOpsStep = [&](auto idx) {
|
||||
// Memory read half here.
|
||||
Base::GlobalPrefetch(
|
||||
a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(
|
||||
b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_global_load_tile, a_element_func);
|
||||
}
|
||||
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window, b_global_load_tile, b_element_func);
|
||||
}
|
||||
|
||||
if(idx == 0)
|
||||
{
|
||||
Base::LocalPrefetch(a_tile_0, a_lds_window);
|
||||
Base::LocalPrefetch(b_tile_0, b_lds_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefetch(a_tile_1, a_lds_window);
|
||||
Base::LocalPrefetch(b_tile_1, b_lds_window);
|
||||
}
|
||||
};
|
||||
|
||||
auto ComputeStep = [&](auto idx) {
|
||||
if(idx == 0)
|
||||
{
|
||||
block_gemm(c_block_tile_0, a_tile_0, b_tile_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
block_gemm(c_block_tile_1, a_tile_1, b_tile_1);
|
||||
}
|
||||
};
|
||||
|
||||
if(operation_id == 0)
|
||||
{
|
||||
MemoryOpsStep(warp_id);
|
||||
}
|
||||
|
||||
index_t num_compute_steps = __builtin_amdgcn_readfirstlane(num_loop);
|
||||
while(num_compute_steps > 1)
|
||||
{
|
||||
block_sync_lds();
|
||||
operation_id = (operation_id + 1) % NumWaveGroups;
|
||||
|
||||
if(operation_id == 0)
|
||||
{
|
||||
MemoryOpsStep(warp_id);
|
||||
}
|
||||
else
|
||||
{
|
||||
ComputeStep(warp_id);
|
||||
}
|
||||
num_compute_steps -= 1;
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
if(operation_id == 0)
|
||||
{
|
||||
ComputeStep(warp_id);
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
if(warp_id == 1)
|
||||
{
|
||||
store_tile(c_window_0, c_block_tile_1);
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
if(warp_id == 0)
|
||||
{
|
||||
load_tile(c_block_tile_1, c_window_0);
|
||||
|
||||
constexpr auto s_spans = decltype(c_block_tile_0)::get_distributed_spans();
|
||||
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
|
||||
auto idx2 = make_tuple(idx0, idx1);
|
||||
c_block_tile_0(idx2) += c_block_tile_1(idx2);
|
||||
});
|
||||
});
|
||||
}
|
||||
return c_block_tile_0;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem_0) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
a_element_func,
|
||||
b_dram_block_window_tmp,
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem_0);
|
||||
}
|
||||
|
||||
public:
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const index_t num_loop,
|
||||
void* __restrict__ p_smem_0) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
num_loop,
|
||||
p_smem_0);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,63 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
// Default policy for GemmPipelineAGmemBGmemCregComputeV5, except the block gemm method, it shares
|
||||
// the same vector size implementation, SmemSize, Global memory tile distiribution as the
|
||||
// UniversalGemm Pipeline Policy.
|
||||
// Default policy class should not be templated, put template on
|
||||
// member functions instead.
|
||||
struct GemmPipelineAgBgCrCompV5DefaultPolicy
|
||||
: public UniversalGemmBasePolicy<GemmPipelineAgBgCrCompV5DefaultPolicy>
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
using AccDataType = float;
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
AccDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
BlockWarps,
|
||||
WarpGemm>;
|
||||
|
||||
return BlockGemmARegBRegCRegV1<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr index_t GetSmemSizeC()
|
||||
{
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
|
||||
return integer_least_multiple(sizeof(typename Problem::CDataType) * MPerBlock * NPerBlock,
|
||||
16);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
|
||||
constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
|
||||
constexpr index_t smem_size_c = GetSmemSizeC<Problem>();
|
||||
|
||||
return smem_size_a + smem_size_b >= smem_size_c ? (smem_size_a + smem_size_b)
|
||||
: (smem_size_c);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -52,13 +52,14 @@ struct BaseGemmPipelineAgBgCrMem
|
||||
|
||||
static constexpr index_t LocalPrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = PrefetchStages;
|
||||
static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel;
|
||||
|
||||
CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
|
||||
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
if(num_loop % PrefetchStages == 1)
|
||||
{
|
||||
@@ -93,6 +94,56 @@ struct BaseGemmPipelineAgBgCrMem
|
||||
return TailNumber::Full;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
|
||||
{
|
||||
// Wrap the hot_loop dispatch first.
|
||||
auto tail_dispatch = [&](auto tail_num_constant) {
|
||||
if(has_hot_loop)
|
||||
{
|
||||
return run_func(bool_constant<true>{}, tail_num_constant);
|
||||
}
|
||||
else
|
||||
{
|
||||
return run_func(bool_constant<false>{}, tail_num_constant);
|
||||
}
|
||||
};
|
||||
|
||||
#define CHECK_TAIL_NUMBER(TAIL_NUMBER, PREFETCH_VALUE) \
|
||||
else if(tail_number == TailNumber::TAIL_NUMBER) \
|
||||
{ \
|
||||
if constexpr(PrefetchStages > PREFETCH_VALUE) \
|
||||
{ \
|
||||
return tail_dispatch(integral_constant<TailNumber, TailNumber::TAIL_NUMBER>{}); \
|
||||
} \
|
||||
}
|
||||
// Handle all the valid cases.
|
||||
if(tail_number == TailNumber::One)
|
||||
{
|
||||
return tail_dispatch(integral_constant<TailNumber, TailNumber::One>{});
|
||||
}
|
||||
else if(tail_number == TailNumber::Full)
|
||||
{
|
||||
return tail_dispatch(integral_constant<TailNumber, TailNumber::Full>{});
|
||||
}
|
||||
CHECK_TAIL_NUMBER(Two, 2)
|
||||
CHECK_TAIL_NUMBER(Three, 3)
|
||||
CHECK_TAIL_NUMBER(Four, 4)
|
||||
CHECK_TAIL_NUMBER(Five, 5)
|
||||
CHECK_TAIL_NUMBER(Six, 6)
|
||||
CHECK_TAIL_NUMBER(Seven, 7)
|
||||
#undef CHECK_TAIL_NUMBER
|
||||
|
||||
// We shouldn't get here unless we have a tail number larger than the prefetch stages.
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
__builtin_unreachable();
|
||||
#else
|
||||
throw std::logic_error("Invalid TailNumber: Only TailNumber::Full and smaller than "
|
||||
"PrefetchStages are supported.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
// Maximum Global Memory throughput pipeline with >=32KB data in fly
|
||||
@@ -137,6 +188,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
// Where is the right place for HasHotLoop and TailNum ???
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
@@ -749,6 +801,29 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
p_smem);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
bool has_hot_loop,
|
||||
TailNumber tail_number,
|
||||
void* p_smem) const
|
||||
{
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
constexpr bool hot_loop = hot_loop_.value;
|
||||
constexpr auto tail_num = tail_num_.value;
|
||||
constexpr auto PassThrough = [](const auto& x) { return x; };
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop, tail_num>(
|
||||
a_dram_block_window_tmp,
|
||||
PassThrough,
|
||||
b_dram_block_window_tmp,
|
||||
PassThrough,
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
|
||||
@@ -47,6 +47,8 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
static constexpr index_t kLdsAlignmentInBytes = 16;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
|
||||
4
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
Executable file → Normal file
4
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
Executable file → Normal file
@@ -121,7 +121,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
constexpr index_t M1 = Problem::VectorSizeA;
|
||||
constexpr index_t M0 = MPerBlock / M1;
|
||||
constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize;
|
||||
static_assert(total_pixels % M1 == 0);
|
||||
@@ -211,7 +211,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType);
|
||||
constexpr index_t N1 = Problem::VectorSizeB;
|
||||
constexpr index_t N0 = NPerBlock / N1;
|
||||
constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize;
|
||||
static_assert(total_pixels % N1 == 0);
|
||||
|
||||
@@ -14,7 +14,10 @@ template <typename ADataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
typename ComputeDataType_ = ADataType_>
|
||||
typename ComputeDataType_ = ADataType_,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeA_ = 1,
|
||||
index_t VectorSizeB_ = 1>
|
||||
struct GemmPipelineProblemBase
|
||||
{
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
@@ -24,6 +27,8 @@ struct GemmPipelineProblemBase
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
|
||||
static constexpr bool FixedVectorSize = FixedVectorSize_;
|
||||
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Traits::ALayout>;
|
||||
@@ -32,6 +37,8 @@ struct GemmPipelineProblemBase
|
||||
|
||||
static constexpr bool TransposeC = Traits::TransposeC;
|
||||
|
||||
static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;
|
||||
|
||||
static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity;
|
||||
|
||||
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
|
||||
@@ -40,8 +47,7 @@ struct GemmPipelineProblemBase
|
||||
static constexpr bool kPadN = Traits::kPadN;
|
||||
static constexpr bool kPadK = Traits::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
|
||||
static constexpr auto Scheduler = GemmPipelineScheduler::Default;
|
||||
static constexpr index_t VectorLoadSize = Traits::_VectorSize;
|
||||
|
||||
@@ -114,7 +120,11 @@ struct GemmPipelineProblemBase
|
||||
}
|
||||
|
||||
static constexpr index_t VectorSizeA = []() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(FixedVectorSize)
|
||||
{
|
||||
return VectorSizeA_;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return kPadK ? 1 : GetAlignmentA();
|
||||
}
|
||||
@@ -125,7 +135,11 @@ struct GemmPipelineProblemBase
|
||||
}();
|
||||
|
||||
static constexpr index_t VectorSizeB = []() {
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
if constexpr(FixedVectorSize)
|
||||
{
|
||||
return VectorSizeB_;
|
||||
}
|
||||
else if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return kPadN ? 1 : GetAlignmentB();
|
||||
}
|
||||
@@ -152,13 +166,19 @@ template <typename ADataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
typename ComputeDataType_ = ADataType_>
|
||||
typename ComputeDataType_ = ADataType_,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeA_ = 1,
|
||||
index_t VectorSizeB_ = 1>
|
||||
using GemmPipelineProblem = GemmPipelineProblemBase<ADataType_,
|
||||
BDataType_,
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
ComputeDataType_>;
|
||||
ComputeDataType_,
|
||||
FixedVectorSize_,
|
||||
VectorSizeA_,
|
||||
VectorSizeB_>;
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
@@ -168,7 +188,10 @@ template <typename ADataType_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full,
|
||||
typename ComputeDataType_ = ADataType_>
|
||||
typename ComputeDataType_ = ADataType_,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeA_ = 1,
|
||||
index_t VectorSizeB_ = 1>
|
||||
struct UniversalGemmPipelineProblem
|
||||
{
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
@@ -178,6 +201,10 @@ struct UniversalGemmPipelineProblem
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
|
||||
static constexpr bool FixedVectorSize = FixedVectorSize_;
|
||||
static constexpr index_t VectorSizeA = VectorSizeA_;
|
||||
static constexpr index_t VectorSizeB = VectorSizeB_;
|
||||
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Traits::ALayout>;
|
||||
@@ -198,6 +225,8 @@ struct UniversalGemmPipelineProblem
|
||||
|
||||
static constexpr bool TransposeC = Traits::TransposeC;
|
||||
static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity;
|
||||
|
||||
static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -464,10 +464,12 @@ struct UniversalGemmBasePolicy
|
||||
{
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t VecLoadSize =
|
||||
Problem::FixedVectorSize ? Problem::VectorSizeA : GetVectorSizeA<Problem>();
|
||||
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
// Tile: MPerBlock X KPerBlock
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
@@ -476,7 +478,8 @@ struct UniversalGemmBasePolicy
|
||||
MPerBlock,
|
||||
KPerBlock,
|
||||
VecLoadSize,
|
||||
getATileAccessPattern()>;
|
||||
ATileAccessPattern,
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
}
|
||||
// Tile: KPerBlock X MPerBlock
|
||||
@@ -486,20 +489,21 @@ struct UniversalGemmBasePolicy
|
||||
KPerBlock,
|
||||
MPerBlock,
|
||||
VecLoadSize,
|
||||
getATileAccessPattern()>;
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
ATileAccessPattern,
|
||||
NumWaveGroups>;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
|
||||
{
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t VecLoadSize =
|
||||
Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB<Problem>();
|
||||
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
// Tile: KPerBlock X NPerBlock
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
@@ -508,7 +512,8 @@ struct UniversalGemmBasePolicy
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
VecLoadSize,
|
||||
getBTileAccessPattern()>;
|
||||
BTileAccessPattern,
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
}
|
||||
// Tile: NPerBlock X KPerBlock
|
||||
@@ -518,26 +523,26 @@ struct UniversalGemmBasePolicy
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
VecLoadSize,
|
||||
getBTileAccessPattern()>;
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
BTileAccessPattern,
|
||||
NumWaveGroups>;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegTileDistribution()
|
||||
{
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
|
||||
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
|
||||
KPerBlock,
|
||||
MPerBlock,
|
||||
VecLoadSize,
|
||||
getATileAccessPattern()>;
|
||||
ATileAccessPattern,
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::MakeShuffled2DStaticTileDistribution();
|
||||
}
|
||||
|
||||
@@ -546,16 +551,18 @@ struct UniversalGemmBasePolicy
|
||||
{
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
|
||||
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
VecLoadSize,
|
||||
getBTileAccessPattern()>;
|
||||
BTileAccessPattern,
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::MakeShuffled2DStaticTileDistribution();
|
||||
}
|
||||
|
||||
|
||||
@@ -12,7 +12,8 @@ template <bool kPadM_,
|
||||
bool kPadK_,
|
||||
typename ALayout_,
|
||||
typename BLayout_,
|
||||
typename CLayout_>
|
||||
typename CLayout_,
|
||||
index_t NumWaveGroups_ = 1>
|
||||
struct TileGemmTraits
|
||||
{
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
@@ -28,6 +29,7 @@ struct TileGemmTraits
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
static constexpr index_t NumWaveGroups = NumWaveGroups_;
|
||||
};
|
||||
|
||||
template <bool kPadM_,
|
||||
@@ -38,7 +40,9 @@ template <bool kPadM_,
|
||||
typename BLayout_,
|
||||
typename CLayout_,
|
||||
bool TransposeC_ = false,
|
||||
bool UseStructuredSparsity_ = false>
|
||||
bool UseStructuredSparsity_ = false,
|
||||
bool UsePersistentKernel_ = false,
|
||||
index_t NumWaveGroups_ = 1>
|
||||
struct TileGemmUniversalTraits
|
||||
{
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
@@ -53,6 +57,28 @@ struct TileGemmUniversalTraits
|
||||
|
||||
static constexpr bool TransposeC = TransposeC_;
|
||||
static constexpr bool UseStructuredSparsity = UseStructuredSparsity_;
|
||||
static constexpr bool UsePersistentKernel = UsePersistentKernel_;
|
||||
static constexpr index_t NumWaveGroups = NumWaveGroups_;
|
||||
};
|
||||
|
||||
template <bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool kPadK_,
|
||||
bool DoubleSmemBuffer_,
|
||||
typename ALayout_,
|
||||
typename BLayout_,
|
||||
typename CLayout_,
|
||||
bool TransposeC_ = false,
|
||||
bool UseStructuredSparsity_ = false>
|
||||
using PersistentTileGemmUniversalTraits = TileGemmUniversalTraits<kPadM_,
|
||||
kPadN_,
|
||||
kPadK_,
|
||||
DoubleSmemBuffer_,
|
||||
ALayout_,
|
||||
BLayout_,
|
||||
CLayout_,
|
||||
TransposeC_,
|
||||
UseStructuredSparsity_,
|
||||
true>;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -172,7 +172,7 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution =
|
||||
#if defined(__gfx950__)
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>>>;
|
||||
#else
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
|
||||
@@ -204,14 +204,6 @@ using WarpGemmMfmaBf16Bf16F32M64N4K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterat
|
||||
using WarpGemmMfma_f32_32x32x16_fp8_fp8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x16_fp8_bf8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
@@ -221,20 +213,28 @@ using WarpGemmMfma_f32_32x32x16_bf8_fp8 = WarpGemmImpl<
|
||||
using WarpGemmMfma_f32_32x32x16_bf8_bf8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_16x16x64_fp8_fp8 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>,
|
||||
using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfma_f32_16x16x32_fp8_fp8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_16x16x32_bf8_bf8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_16x16x64_fp8_fp8 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfma_f32_16x16x32_bf8_bf8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_16x16x128_fp8_fp8 = WarpGemmImpl<WarpGemmAtrributeMfma<
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_fp8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
@@ -282,4 +282,19 @@ using WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution =
|
||||
2,
|
||||
swizzle_factor>>;
|
||||
|
||||
// int8
|
||||
using WarpGemmMfma_i32_32x32x16_i8_i8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_i32_32x32x16_i8_i8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_i32_16x16x32_i8_i8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_i32_16x16x32_i8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_i32_16x16x32_i8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -1092,7 +1092,7 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base
|
||||
}
|
||||
else
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
#if defined(__gfx94__) or defined(__gfx95__)
|
||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
@@ -1116,7 +1116,7 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
#if defined(__gfx94__) or defined(__gfx95__)
|
||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
@@ -1127,7 +1127,7 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_316x16x32_bf8_bf8(
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
@@ -1251,7 +1251,7 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
|
||||
}
|
||||
else
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
#if defined(__gfx94__) or defined(__gfx95__)
|
||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
@@ -1286,7 +1286,7 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
#if defined(__gfx94__) or defined(__gfx95__)
|
||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
@@ -1578,8 +1578,8 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
|
||||
DISPATCH_MFMA_CTRL_("v_mfma_i32_32x32x16_i8", Ctrl)
|
||||
else
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
c_vec = __builtin_amdgcn_mfma_i32_32x32x8i8(
|
||||
#if defined(__gfx94__) or defined(__gfx95__)
|
||||
c_vec = __builtin_amdgcn_mfma_i32_32x32x16_i8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
#elif defined(__gfx908__) || defined(__gfx90a__)
|
||||
static_for<0, 8, 1>{}([&](auto k) {
|
||||
@@ -1609,6 +1609,183 @@ struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
|
||||
}
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaImpl_i32_16x16x32_i8
|
||||
{
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = int8_t;
|
||||
using BDataType = int8_t;
|
||||
using CDataType = int32_t;
|
||||
|
||||
using AVecType = ext_vector_t<ADataType, 8>;
|
||||
using BVecType = ext_vector_t<BDataType, 8>;
|
||||
using CVecType = ext_vector_t<CDataType, 4>;
|
||||
|
||||
static constexpr index_t kM = 16;
|
||||
static constexpr index_t kN = 16;
|
||||
static constexpr index_t kK = 32;
|
||||
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 4;
|
||||
static constexpr index_t kABKPerLane = 8;
|
||||
|
||||
static constexpr index_t kCMLane = 4;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 4; // write to 4x AccVGPRs
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
DISPATCH_MFMA_CTRL_("v_mfma_i32_16x16x32_i8", Ctrl)
|
||||
else
|
||||
{
|
||||
#if defined(__gfx94__) or defined(__gfx95__)
|
||||
c_vec = __builtin_amdgcn_mfma_i32_16x16x32_i8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
CVecType c_vec{0};
|
||||
operator()(c_vec, a_vec, b_vec);
|
||||
return c_vec;
|
||||
}
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaImpl_i32_16x16x64_i8
|
||||
{
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = int8_t;
|
||||
using BDataType = int8_t;
|
||||
using CDataType = int32_t;
|
||||
|
||||
using AVecType = ext_vector_t<ADataType, 16>;
|
||||
using BVecType = ext_vector_t<BDataType, 16>;
|
||||
using CVecType = ext_vector_t<CDataType, 4>;
|
||||
|
||||
static constexpr index_t kM = 16;
|
||||
static constexpr index_t kN = 16;
|
||||
static constexpr index_t kK = 64;
|
||||
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 4;
|
||||
static constexpr index_t kABKPerLane = 16;
|
||||
|
||||
static constexpr index_t kCMLane = 4;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 4; // write to 4x AccVGPRs
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
DISPATCH_MFMA_CTRL_("v_mfma_i32_16x16x64_i8", Ctrl)
|
||||
else
|
||||
{
|
||||
#if defined(__gfx95__)
|
||||
c_vec = __builtin_amdgcn_mfma_i32_16x16x64_i8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
CVecType c_vec{0};
|
||||
operator()(c_vec, a_vec, b_vec);
|
||||
return c_vec;
|
||||
}
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaImpl_i32_32x32x32_i8
|
||||
{
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = int8_t;
|
||||
using BDataType = int8_t;
|
||||
using CDataType = int32_t;
|
||||
|
||||
using AVecType = ext_vector_t<ADataType, 16>;
|
||||
using BVecType = ext_vector_t<BDataType, 16>;
|
||||
using CVecType = ext_vector_t<CDataType, 16>;
|
||||
|
||||
static constexpr index_t kM = 32;
|
||||
static constexpr index_t kN = 32;
|
||||
static constexpr index_t kK = 32;
|
||||
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
|
||||
static constexpr index_t kAMLane = 32;
|
||||
static constexpr index_t kBNLane = 32;
|
||||
static constexpr index_t kABKLane = 2;
|
||||
static constexpr index_t kABKPerLane = 16;
|
||||
|
||||
static constexpr index_t kCMLane = 2;
|
||||
static constexpr index_t kCNLane = 32;
|
||||
static constexpr index_t kCM0PerLane = 4;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
DISPATCH_MFMA_CTRL_("v_mfma_i32_32x32x32_i8", Ctrl)
|
||||
else
|
||||
{
|
||||
#if defined(__gfx95__)
|
||||
c_vec =
|
||||
__builtin_amdgcn_mfma_i32_32x32x32_i8(a_vec, bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
CVecType c_vec{0};
|
||||
operator()(c_vec, a_vec, b_vec);
|
||||
return c_vec;
|
||||
}
|
||||
};
|
||||
|
||||
#undef DISPATCH_MFMA_
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -11,7 +11,7 @@ namespace ck_tile {
|
||||
namespace impl {
|
||||
template <typename AType,
|
||||
typename BType,
|
||||
typename CType,
|
||||
typename AccType,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t KPerWave,
|
||||
@@ -22,6 +22,7 @@ struct WarpGemmMfmaDispatcher;
|
||||
|
||||
// clang-format off
|
||||
// fp16
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF16F16F32M32N32K8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaF16F16F32M32N32K16; };
|
||||
@@ -37,10 +38,12 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; };
|
||||
|
||||
// fp16 2:4 structural sparsity
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, false, true> { using Type = WarpGemmSmfmacF16F16F32M32N32K16; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false, false, true> { using Type = WarpGemmSmfmacF16F16F32M16N16K32; };
|
||||
|
||||
// bf16
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; };
|
||||
@@ -56,6 +59,7 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; };
|
||||
|
||||
// fp8
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 32, false> { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8; };
|
||||
@@ -81,12 +85,19 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float,
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8; };
|
||||
|
||||
// int8
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 32, 32, 16, false> { using Type = WarpGemmMfma_i32_32x32x16_i8_i8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 32, 32, 16, true> { using Type = WarpGemmMfma_i32_32x32x16_i8_i8_CTransposed; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 16, 16, 32, false> { using Type = WarpGemmMfma_i32_16x16x32_i8_i8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 16, 16, 32, true> { using Type = WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed; };
|
||||
|
||||
// clang-format on
|
||||
} // namespace impl
|
||||
|
||||
template <typename AType,
|
||||
typename BType,
|
||||
typename CType,
|
||||
typename AccType,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t KPerWave,
|
||||
@@ -95,7 +106,7 @@ template <typename AType,
|
||||
bool UseStructuredSparsity = false>
|
||||
using WarpGemmMfmaDispatcher = typename impl::WarpGemmMfmaDispatcher<AType,
|
||||
BType,
|
||||
CType,
|
||||
AccType,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
KPerWave,
|
||||
|
||||
12
include/ck_tile/ops/grouped_convolution.hpp
Normal file
12
include/ck_tile/ops/grouped_convolution.hpp
Normal file
@@ -0,0 +1,12 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
@@ -0,0 +1,800 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
#include "ck_tile/core/utility/env.hpp"
|
||||
#include "ck_tile/host/convolution_parameter.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/// @brief The Grouped Convolution kernel device arguments.
|
||||
template <typename GroupedConvTraitsType>
|
||||
struct GroupedConvFwdKernelArgs
|
||||
{
|
||||
|
||||
using ConvToGemmFwdTransformer =
|
||||
TransformConvFwdToGemm<GroupedConvTraitsType::NDimSpatial,
|
||||
GroupedConvTraitsType::ConvSpecialization>;
|
||||
static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor;
|
||||
|
||||
template <
|
||||
typename InLay = typename GroupedConvTraitsType::InLayout,
|
||||
typename WeiLay = typename GroupedConvTraitsType::WeiLayout,
|
||||
typename OutLay = typename GroupedConvTraitsType::OutLayout,
|
||||
typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NWGC> &&
|
||||
std::is_same_v<WeiLay, tensor_layout::convolution::GKXC> &&
|
||||
std::is_same_v<OutLay, tensor_layout::convolution::NWGK>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvHostArgs& args)
|
||||
{
|
||||
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0])};
|
||||
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0])};
|
||||
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.output_spatial_lengths_[0])};
|
||||
|
||||
conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0])};
|
||||
conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0])};
|
||||
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0])};
|
||||
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0])};
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
GemmM = args.N_ * args.output_spatial_lengths_[0];
|
||||
GemmN = args.K_;
|
||||
GemmK = args.C_ * args.filter_spatial_lengths_[0];
|
||||
|
||||
in_ptr = args.in_ptr;
|
||||
wei_ptr = args.wei_ptr;
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
{
|
||||
ds_ptr[d] = args.ds_ptr[d];
|
||||
}
|
||||
out_ptr = args.out_ptr;
|
||||
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
|
||||
wei_g_k_c_xs_lengths,
|
||||
out_g_n_k_wos_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
|
||||
a_grid_desc_m_k =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeADescriptor_M_K<typename GroupedConvTraitsType::InLayout>();
|
||||
b_grid_desc_n_k =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeBDescriptor_N_K<typename GroupedConvTraitsType::WeiLayout>();
|
||||
c_grid_desc_m_n =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeCDescriptor_M_N<typename GroupedConvTraitsType::OutLayout>();
|
||||
|
||||
group_stride_a = args.C_;
|
||||
group_stride_b = args.K_ * args.C_ *
|
||||
std::accumulate(args.filter_spatial_lengths_.begin(),
|
||||
args.filter_spatial_lengths_.end(),
|
||||
1,
|
||||
std::multiplies<index_t>());
|
||||
group_stride_c = args.K_;
|
||||
}
|
||||
|
||||
template <
|
||||
typename InLay = typename GroupedConvTraitsType::InLayout,
|
||||
typename WeiLay = typename GroupedConvTraitsType::WeiLayout,
|
||||
typename OutLay = typename GroupedConvTraitsType::OutLayout,
|
||||
typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NHWGC> &&
|
||||
std::is_same_v<WeiLay, tensor_layout::convolution::GKYXC> &&
|
||||
std::is_same_v<OutLay, tensor_layout::convolution::NHWGK>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvHostArgs& args)
|
||||
{
|
||||
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[1])};
|
||||
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[1])};
|
||||
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.output_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.output_spatial_lengths_[1])};
|
||||
|
||||
conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
|
||||
static_cast<index_t>(args.conv_filter_strides_[1])};
|
||||
conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
|
||||
static_cast<index_t>(args.conv_filter_dilations_[1])};
|
||||
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
|
||||
static_cast<index_t>(args.input_left_pads_[1])};
|
||||
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
|
||||
static_cast<index_t>(args.input_right_pads_[1])};
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
GemmM = args.N_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
|
||||
GemmN = args.K_;
|
||||
GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1];
|
||||
|
||||
in_ptr = args.in_ptr;
|
||||
wei_ptr = args.wei_ptr;
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
{
|
||||
ds_ptr[d] = args.ds_ptr[d];
|
||||
}
|
||||
out_ptr = args.out_ptr;
|
||||
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
|
||||
wei_g_k_c_xs_lengths,
|
||||
out_g_n_k_wos_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
|
||||
a_grid_desc_m_k =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeADescriptor_M_K<typename GroupedConvTraitsType::InLayout>();
|
||||
b_grid_desc_n_k =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeBDescriptor_N_K<typename GroupedConvTraitsType::WeiLayout>();
|
||||
c_grid_desc_m_n =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeCDescriptor_M_N<typename GroupedConvTraitsType::OutLayout>();
|
||||
|
||||
group_stride_a = args.C_;
|
||||
group_stride_b = args.K_ * args.C_ *
|
||||
std::accumulate(args.filter_spatial_lengths_.begin(),
|
||||
args.filter_spatial_lengths_.end(),
|
||||
1,
|
||||
std::multiplies<index_t>());
|
||||
group_stride_c = args.K_;
|
||||
}
|
||||
|
||||
template <
|
||||
typename InLay = typename GroupedConvTraitsType::InLayout,
|
||||
typename WeiLay = typename GroupedConvTraitsType::WeiLayout,
|
||||
typename OutLay = typename GroupedConvTraitsType::OutLayout,
|
||||
typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NDHWGC> &&
|
||||
std::is_same_v<WeiLay, tensor_layout::convolution::GKZYXC> &&
|
||||
std::is_same_v<OutLay, tensor_layout::convolution::NDHWGK>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvHostArgs& args)
|
||||
{
|
||||
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[1]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[2])};
|
||||
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[1]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[2])};
|
||||
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.output_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.output_spatial_lengths_[1]),
|
||||
static_cast<index_t>(args.output_spatial_lengths_[2])};
|
||||
|
||||
conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
|
||||
static_cast<index_t>(args.conv_filter_strides_[1]),
|
||||
static_cast<index_t>(args.conv_filter_strides_[2])};
|
||||
conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
|
||||
static_cast<index_t>(args.conv_filter_dilations_[1]),
|
||||
static_cast<index_t>(args.conv_filter_dilations_[2])};
|
||||
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
|
||||
static_cast<index_t>(args.input_left_pads_[1]),
|
||||
static_cast<index_t>(args.input_left_pads_[2])};
|
||||
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
|
||||
static_cast<index_t>(args.input_right_pads_[1]),
|
||||
static_cast<index_t>(args.input_right_pads_[2])};
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
GemmM = args.N_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1] *
|
||||
args.output_spatial_lengths_[2];
|
||||
GemmN = args.K_;
|
||||
GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1] *
|
||||
args.filter_spatial_lengths_[2];
|
||||
|
||||
in_ptr = args.in_ptr;
|
||||
wei_ptr = args.wei_ptr;
|
||||
for(index_t d = 0; d < NumDTensor; d++)
|
||||
{
|
||||
ds_ptr[d] = args.ds_ptr[d];
|
||||
}
|
||||
out_ptr = args.out_ptr;
|
||||
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
|
||||
wei_g_k_c_xs_lengths,
|
||||
out_g_n_k_wos_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
|
||||
a_grid_desc_m_k =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeADescriptor_M_K<typename GroupedConvTraitsType::InLayout>();
|
||||
b_grid_desc_n_k =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeBDescriptor_N_K<typename GroupedConvTraitsType::WeiLayout>();
|
||||
c_grid_desc_m_n =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeCDescriptor_M_N<typename GroupedConvTraitsType::OutLayout>();
|
||||
|
||||
group_stride_a = args.C_;
|
||||
group_stride_b = args.K_ * args.C_ *
|
||||
std::accumulate(args.filter_spatial_lengths_.begin(),
|
||||
args.filter_spatial_lengths_.end(),
|
||||
1,
|
||||
std::multiplies<index_t>());
|
||||
group_stride_c = args.K_;
|
||||
}
|
||||
|
||||
using AGridDescMK = remove_cvref_t<decltype(
|
||||
ConvToGemmFwdTransformer{}
|
||||
.template MakeADescriptor_M_K<typename GroupedConvTraitsType::InLayout>())>;
|
||||
using BGridDescNK = remove_cvref_t<decltype(
|
||||
ConvToGemmFwdTransformer{}
|
||||
.template MakeBDescriptor_N_K<typename GroupedConvTraitsType::WeiLayout>())>;
|
||||
using CGridDescMN = remove_cvref_t<decltype(
|
||||
ConvToGemmFwdTransformer{}
|
||||
.template MakeCDescriptor_M_N<typename GroupedConvTraitsType::OutLayout>())>;
|
||||
|
||||
static constexpr index_t NonSpatialDims = 3;
|
||||
array<index_t, NonSpatialDims + GroupedConvTraitsType::NDimSpatial> in_g_n_c_wis_lengths;
|
||||
array<index_t, NonSpatialDims + GroupedConvTraitsType::NDimSpatial> wei_g_k_c_xs_lengths;
|
||||
array<index_t, NonSpatialDims + GroupedConvTraitsType::NDimSpatial> out_g_n_k_wos_lengths;
|
||||
|
||||
array<index_t, GroupedConvTraitsType::NDimSpatial> conv_filter_strides;
|
||||
array<index_t, GroupedConvTraitsType::NDimSpatial> conv_filter_dilations;
|
||||
array<index_t, GroupedConvTraitsType::NDimSpatial> input_left_pads;
|
||||
array<index_t, GroupedConvTraitsType::NDimSpatial> input_right_pads;
|
||||
|
||||
index_t k_batch;
|
||||
index_t GemmM;
|
||||
index_t GemmN;
|
||||
index_t GemmK;
|
||||
|
||||
const void* in_ptr;
|
||||
const void* wei_ptr;
|
||||
std::array<const void*, NumDTensor> ds_ptr;
|
||||
void* out_ptr;
|
||||
|
||||
AGridDescMK a_grid_desc_m_k;
|
||||
BGridDescNK b_grid_desc_n_k;
|
||||
CGridDescMN c_grid_desc_m_n;
|
||||
|
||||
long_index_t group_stride_a;
|
||||
long_index_t group_stride_b;
|
||||
long_index_t group_stride_c;
|
||||
};
|
||||
|
||||
/// @brief The Grouped Convolution Forward kernel template.
|
||||
///
|
||||
/// @paragraph Overview Overview
|
||||
/// This class provides the grouped convolution forward kernel template. By semantic
|
||||
/// division of Implicit GEMM algorithm into following parts we achieve flexible,
|
||||
/// versatile and robust kernel implementation.
|
||||
///
|
||||
/// @li @b Prolog - The start of GEMM kernel implementation in @ref operator()
|
||||
/// function call operator" which determines the work scope of each workgroup.
|
||||
/// @li @b GemmPipeline - The core part @a "heart" of matrix multiplication algorithm.
|
||||
/// This is the place where each workgroup is loading data from global memory and
|
||||
/// carrying out dot products.
|
||||
/// @li @b Epilogue - The @a "final" part of matrix multiplication implementation
|
||||
/// responsible for storing results to global memory. This is also the place where
|
||||
/// any additional operator fusion may take place.
|
||||
///
|
||||
/// Additionally both @ref GemmPipeline_ "GemmPipeline" and @ref EpiloguePipeline_
|
||||
/// "EpiloguePipeline" are parameterized with so called @a Policy which determines all
|
||||
/// internal details of those functional parts. You can think of it like both gemm and
|
||||
/// epilogue pipelines provides the control-flow logic controlled by policies. Moreover
|
||||
/// the policy is responsible for definition of all necessary data layouts and thread's
|
||||
/// work distribution.
|
||||
///
|
||||
/// @tparam GroupedConvTraitsType The type of class providing traits for grouped convolution.
|
||||
/// @tparam TilePartitioner_ The type of class providing mapping of workgroup index into
|
||||
/// the
|
||||
/// output data tile to be calculated. It determines the
|
||||
/// workgroup to data relationship (or in other words - which
|
||||
/// data would be processed and calculated by which workgroup).
|
||||
/// @tparam GemmPipeline_ The type of class which provides the core part of matrix
|
||||
/// multiplication. This class should provide implementation of
|
||||
/// data loading from global memory and performing block-wise
|
||||
/// matrix multiplication. You can think of it as a work done by
|
||||
/// single workgroup point of view.
|
||||
/// @tparam EpiloguePipeline_ The type of class providing the final part of matrix
|
||||
/// multiplication implementation. It is responsible for storing
|
||||
/// results calculated by @ref GemmPipeline_ "GemmPipeline" to
|
||||
/// the output C tensor in global memory.
|
||||
template <typename GroupedConvTraitsType,
|
||||
typename TilePartitioner_,
|
||||
typename GemmPipeline_,
|
||||
typename EpiloguePipeline_>
|
||||
struct GroupedConvolutionForwardKernel
|
||||
{
|
||||
static constexpr index_t NDimSpatial = GroupedConvTraitsType::NDimSpatial;
|
||||
static constexpr ConvolutionSpecialization ConvSpecialization =
|
||||
GroupedConvTraitsType::ConvSpecialization;
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using GemmALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
||||
using GemmBLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
||||
using GemmCLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
|
||||
using InLayout = remove_cvref_t<typename GroupedConvTraitsType::InLayout>;
|
||||
using WeiLayout = remove_cvref_t<typename GroupedConvTraitsType::WeiLayout>;
|
||||
using OutLayout = remove_cvref_t<typename GroupedConvTraitsType::OutLayout>;
|
||||
using DsLayout = remove_cvref_t<typename GroupedConvTraitsType::DsLayout>;
|
||||
|
||||
using GemmDsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
|
||||
|
||||
static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor;
|
||||
|
||||
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
|
||||
|
||||
using InDataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
using WeiDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
|
||||
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
|
||||
// Below type is actually accumulation data type - the output of block GEMM.
|
||||
using OutDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
using GroupedConvFwdKernelArgsSpecialized = GroupedConvFwdKernelArgs<GroupedConvTraitsType>;
|
||||
|
||||
// TODO: Enable this
|
||||
static constexpr bool IsSplitKSupported = false;
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
static constexpr auto I2 = number<2>();
|
||||
static constexpr auto I3 = number<3>();
|
||||
|
||||
static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK,
|
||||
"Not supported!");
|
||||
static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>, "Not supported!");
|
||||
static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
|
||||
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "grouped_convolution_forward", gemm_prec_str<InDataType, WeiDataType>, GemmPipeline::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const GroupedConvHostArgs& args)
|
||||
{
|
||||
const index_t GemmM = args.N_ * std::accumulate(args.output_spatial_lengths_.begin(),
|
||||
args.output_spatial_lengths_.end(),
|
||||
1,
|
||||
std::multiplies<index_t>());
|
||||
const index_t GemmN = args.K_;
|
||||
return dim3(TilePartitioner::GridSize(GemmM, GemmN), args.G_, args.k_batch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
|
||||
|
||||
CK_TILE_HOST static constexpr GroupedConvFwdKernelArgsSpecialized
|
||||
MakeKernelArgs(const GroupedConvHostArgs& hostArgs)
|
||||
{
|
||||
return GroupedConvFwdKernelArgsSpecialized(hostArgs);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
|
||||
}
|
||||
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const GroupedConvFwdKernelArgsSpecialized& kargs)
|
||||
{
|
||||
if constexpr((EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value) ||
|
||||
!IsSplitKSupported)
|
||||
{
|
||||
if(kargs.k_batch != 1)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
|
||||
const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];
|
||||
|
||||
// check ConvolutionSpecialization
|
||||
if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// check if it's 1x1, stride=1 conv
|
||||
for(index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
|
||||
const index_t ConvStride = kargs.conv_filter_strides[i];
|
||||
const index_t LeftPad = kargs.input_left_pads[i];
|
||||
const index_t RightPad = kargs.input_right_pads[i];
|
||||
|
||||
if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
// check if it's 1x1 conv
|
||||
for(index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
|
||||
const index_t LeftPad = kargs.input_left_pads[i];
|
||||
const index_t RightPad = kargs.input_right_pads[i];
|
||||
|
||||
if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter3x3)
|
||||
{
|
||||
if(ConvC != 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
for(index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
const index_t filter_spatial_dim = kargs.wei_g_k_c_xs_lengths[i + I3];
|
||||
|
||||
if(filter_spatial_dim != I3)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
|
||||
if constexpr(std::is_same_v<InLayout, ctc::NWGC> || std::is_same_v<InLayout, ctc::NHWGC> ||
|
||||
std::is_same_v<InLayout, ctc::NDHWGC>)
|
||||
{
|
||||
// Check access per C
|
||||
if(ConvC % GemmPipeline::GetVectorSizeA() != 0)
|
||||
{
|
||||
CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
CK_TILE_ERROR("Not supported input layout!");
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector access of B
|
||||
// FIXME: layout
|
||||
if constexpr(std::is_same_v<WeiLayout, ctc::GKXC> ||
|
||||
std::is_same_v<WeiLayout, ctc::GKYXC> ||
|
||||
std::is_same_v<WeiLayout, ctc::GKZYXC>)
|
||||
{
|
||||
if(ConvC % GemmPipeline::GetVectorSizeB() != 0)
|
||||
{
|
||||
CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
CK_TILE_ERROR("Not supported weight layout!");
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector access of E
|
||||
if constexpr(std::is_same_v<OutLayout, ctc::NWGK> ||
|
||||
std::is_same_v<OutLayout, ctc::NHWGK> ||
|
||||
std::is_same_v<OutLayout, ctc::NDHWGK>)
|
||||
{
|
||||
if(ConvK % EpiloguePipeline::GetVectorSizeC() != 0)
|
||||
{
|
||||
CK_TILE_ERROR("Conv K is not a multiple of vector store size for output image!");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
CK_TILE_ERROR("Not supported output layout!");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTensorViews(const InDataType* a_ptr,
|
||||
const WeiDataType* b_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
OutDataType* c_ptr,
|
||||
const GroupedConvFwdKernelArgsSpecialized& kargs)
|
||||
{
|
||||
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!");
|
||||
const auto& a_tensor_view = [&]() {
|
||||
return make_tensor_view<address_space_enum::global>(a_ptr, kargs.a_grid_desc_m_k);
|
||||
}();
|
||||
|
||||
const auto& b_tensor_view = [&]() {
|
||||
return make_tensor_view<address_space_enum::global>(b_ptr, kargs.b_grid_desc_n_k);
|
||||
}();
|
||||
|
||||
// TODO: enable vector write for C in ColMajor
|
||||
const auto& c_tensor_view = [&]() {
|
||||
return make_tensor_view<address_space_enum::global>(c_ptr, kargs.c_grid_desc_m_n);
|
||||
}();
|
||||
|
||||
const auto& ds_tensor_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
static_assert(std::is_same_v<std::tuple_element_t<i, DsLayout>, OutLayout>,
|
||||
"Not supported!");
|
||||
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
|
||||
"Not supported!");
|
||||
static_assert(std::is_same_v<std::tuple_element_t<i, DsDataType>, OutDataType>,
|
||||
"Not supported!");
|
||||
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
static_cast<OutDataType*>(ds_ptr[i]), kargs.c_grid_desc_m_n);
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view);
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
|
||||
{
|
||||
const auto& a_pad_view = [&]() {
|
||||
const auto& a_tensor_view = views.at(I0);
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
|
||||
const auto& b_pad_view = [&]() {
|
||||
const auto& b_tensor_view = views.at(I1);
|
||||
return pad_tensor_view(b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
|
||||
const auto& ds_tensor_view = views.at(I2);
|
||||
const auto& ds_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
return pad_tensor_view(ds_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
const auto& c_pad_view = [&]() {
|
||||
const auto& c_tensor_view = views.at(I3);
|
||||
return pad_tensor_view(c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
|
||||
return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view);
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
|
||||
{
|
||||
const auto& a_pad_view = views.at(I0);
|
||||
const auto& b_pad_view = views.at(I1);
|
||||
const auto& ds_pad_view = views.at(I2);
|
||||
const auto& c_pad_view = views.at(I3);
|
||||
|
||||
const auto& a_block_window = [&]() {
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, 0});
|
||||
}();
|
||||
|
||||
const auto& b_block_window = [&]() {
|
||||
return make_tile_window(b_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_n, 0});
|
||||
}();
|
||||
|
||||
const auto ds_block_window = generate_tuple(
|
||||
[&](auto i) {
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
auto c_block_window = make_tile_window(
|
||||
c_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
|
||||
return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Runs single GEMM problem cooperatively by whole workgroup.
|
||||
*
|
||||
* @param a_ptr input A pointer
|
||||
* @param b_ptr input B pointer
|
||||
* @param c_ptr output C pointer
|
||||
* @param smem_ptr_0 The start memory pointer of the shared memory block.
|
||||
* @param kargs Grouped Convolution Forward kernel arguments
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
|
||||
*
|
||||
*/
|
||||
CK_TILE_DEVICE static void RunGemm(const InDataType* a_ptr,
|
||||
const WeiDataType* b_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
OutDataType* c_ptr,
|
||||
void* smem_ptr_0,
|
||||
const GroupedConvFwdKernelArgsSpecialized& kargs,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
|
||||
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop =
|
||||
__builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(kargs.GemmK));
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, smem_ptr_0);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Runs single GEMM problem cooperatively by whole workgroup.
|
||||
*
|
||||
* @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism.
|
||||
*
|
||||
* @param a_ptr input A pointer
|
||||
* @param b_ptr input B pointer
|
||||
* @param c_ptr output C pointer
|
||||
* @param smem_ptr_0 The starting pointer of 1st shared memory block.
|
||||
* @param smem_ptr_1 The starting pointer of 2nd shared memory block.
|
||||
* @param kargs Grouped Convolution Forward kernel arguments
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
|
||||
*
|
||||
*/
|
||||
CK_TILE_DEVICE static void RunGemm2LDS(const InDataType* a_ptr,
|
||||
const WeiDataType* b_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
OutDataType* c_ptr,
|
||||
void* __restrict__ smem_ptr_0,
|
||||
void* __restrict__ smem_ptr_1,
|
||||
const GroupedConvFwdKernelArgsSpecialized& kargs,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop =
|
||||
__builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(kargs.GemmK));
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0, smem_ptr_1);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(GroupedConvFwdKernelArgsSpecialized kargs) const
|
||||
{
|
||||
const auto blockIdX = __builtin_amdgcn_readfirstlane(blockIdx.x);
|
||||
const auto [iM, iN] =
|
||||
TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(blockIdX);
|
||||
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
const auto blockIdY = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
const auto group_offset_a = __builtin_amdgcn_readfirstlane(kargs.group_stride_a * blockIdY);
|
||||
const auto group_offset_b = __builtin_amdgcn_readfirstlane(kargs.group_stride_b * blockIdY);
|
||||
const auto group_offset_c = __builtin_amdgcn_readfirstlane(kargs.group_stride_c * blockIdY);
|
||||
|
||||
// options
|
||||
const InDataType* a_ptr = static_cast<const InDataType*>(kargs.in_ptr) + group_offset_a;
|
||||
const WeiDataType* b_ptr = static_cast<const WeiDataType*>(kargs.wei_ptr) + group_offset_b;
|
||||
OutDataType* c_ptr = static_cast<OutDataType*>(kargs.out_ptr) + group_offset_c;
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm2LDS(
|
||||
a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, smem_ptr_1, kargs, i_m, i_n);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, i_m, i_n);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,30 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
enum struct ConvolutionSpecialization
|
||||
{
|
||||
Default,
|
||||
Filter1x1Pad0,
|
||||
Filter1x1Stride1Pad0,
|
||||
Filter3x3,
|
||||
};
|
||||
|
||||
CK_TILE_HOST std::string getConvSpecializationString(const ConvolutionSpecialization& s)
|
||||
{
|
||||
switch(s)
|
||||
{
|
||||
case ConvolutionSpecialization::Default: return "Default";
|
||||
case ConvolutionSpecialization::Filter1x1Pad0: return "Filter1x1Pad0";
|
||||
case ConvolutionSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0";
|
||||
case ConvolutionSpecialization::Filter3x3: return "Filter3x3";
|
||||
default: return "Unrecognized specialization!";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,74 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/convolution_parameter.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/// @brief The Grouped Conv kernel host arguments.
|
||||
///
|
||||
/// @par Overview
|
||||
/// This structure is passed to Grouped Convolution Kernels when creating kernel
|
||||
/// arguments object. It contain all necessary information required to
|
||||
/// build proper kernel argument and launch kernel on GPU.
|
||||
struct GroupedConvHostArgs : public conv::ConvParam
|
||||
{
|
||||
CK_TILE_HOST GroupedConvHostArgs() = delete;
|
||||
CK_TILE_HOST GroupedConvHostArgs(ConvParam conv_param,
|
||||
const void* in_ptr_,
|
||||
const void* wei_ptr_,
|
||||
const std::vector<const void*> ds_ptr_,
|
||||
void* out_ptr_,
|
||||
index_t k_batch_)
|
||||
: conv::ConvParam(conv_param),
|
||||
in_ptr(in_ptr_),
|
||||
wei_ptr(wei_ptr_),
|
||||
ds_ptr(ds_ptr_),
|
||||
out_ptr(out_ptr_),
|
||||
k_batch(k_batch_)
|
||||
{
|
||||
}
|
||||
|
||||
const void* in_ptr;
|
||||
const void* wei_ptr;
|
||||
const std::vector<const void*> ds_ptr;
|
||||
void* out_ptr;
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
template <index_t NDimSpatial_,
|
||||
ConvolutionSpecialization ConvSpecialization_,
|
||||
typename InLayout_,
|
||||
typename WeiLayout_,
|
||||
typename DsLayout_,
|
||||
typename OutLayout_>
|
||||
struct GroupedConvTraits
|
||||
{
|
||||
private:
|
||||
static constexpr auto generate_implicit_gemm_layout()
|
||||
{
|
||||
return generate_tuple([](auto) { return ck_tile::tensor_layout::gemm::RowMajor{}; },
|
||||
number<DsLayout_::size()>{});
|
||||
}
|
||||
|
||||
public:
|
||||
static constexpr index_t NDimSpatial = NDimSpatial_;
|
||||
static constexpr ConvolutionSpecialization ConvSpecialization = ConvSpecialization_;
|
||||
using InLayout = InLayout_;
|
||||
using WeiLayout = WeiLayout_;
|
||||
using DsLayout = DsLayout_;
|
||||
using OutLayout = OutLayout_;
|
||||
using GroupedConvImplicitGemmTraits = TileGemmTraits<true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::ColumnMajor,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>;
|
||||
static constexpr index_t NumDTensor = DsLayout::size();
|
||||
using ImplicitGemmDsLayout = decltype(generate_implicit_gemm_layout());
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
File diff suppressed because it is too large
Load Diff
@@ -26,7 +26,7 @@ struct TileImageToColumnShape
|
||||
static constexpr index_t kMWarpPerBlock = kMPerBlock / kMPerWarp;
|
||||
static constexpr index_t kKWarpPerBlock = kKPerBlock / kKPerWarp;
|
||||
|
||||
static constexpr index_t kBlockSize = warpSize * kMWarpPerBlock * kKWarpPerBlock;
|
||||
static constexpr index_t kBlockSize = get_warp_size() * kMWarpPerBlock * kKWarpPerBlock;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -250,7 +250,7 @@ struct BlockNormReduceCrossWarpSync
|
||||
// | w0 | w1 | w2 | w3 | -----> | w0123 |
|
||||
//
|
||||
// -> also store data from every wave into LDS
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / warpSize;
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
|
||||
return num_warps * 4 * thread_buf_size * sizeof(float);
|
||||
}
|
||||
|
||||
@@ -276,7 +276,7 @@ struct BlockNormReduceCrossWarpSync
|
||||
const index_t lane_id = get_lane_id();
|
||||
const index_t warp_id = get_warp_id();
|
||||
constexpr auto num_reduce_warps = GetReduceWarps<MeanDistributedTensor_>();
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / warpSize;
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
|
||||
const index_t smem_offset = warp_id;
|
||||
|
||||
// skip if nonthing to do
|
||||
|
||||
@@ -210,7 +210,7 @@ struct BlockReduce2dCrossWarpSync
|
||||
// | w0 | w1 | w2 | w3 | -----> | w0123 |
|
||||
//
|
||||
// -> also store data from every wave into LDS
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / warpSize;
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
|
||||
return num_warps * thread_buf_size * sizeof(DataType);
|
||||
}
|
||||
|
||||
@@ -226,7 +226,7 @@ struct BlockReduce2dCrossWarpSync
|
||||
const index_t lane_id = get_lane_id();
|
||||
const index_t warp_id = get_warp_id();
|
||||
constexpr auto num_reduce_warps = GetReduceWarps<YDistributedTensor_>();
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / warpSize;
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
|
||||
const index_t smem_offset = warp_id;
|
||||
|
||||
// skip if nonthing to do
|
||||
|
||||
Reference in New Issue
Block a user