mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#4294 (commit 6601702)
Cleanup and refactoring related to tile loading MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Proposed changes Cleanup and refactoring done while implementing mixed precision for fp16/bf16 x fp8 Key changes: - Renamed load_interleaved_pk_type.hpp to load_and_convert_tile.hpp and refactored the API to use consistent naming conventions - Updated load_tile_transpose functions to use output parameters instead of return values for consistency - Removed unused variable declarations and simplified type deduction logic - Define load_tile_with_elementwise to use tuple types explicitly for clarity ## Checklist Please put an `x` into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask. - [ ] I have added tests relevant to the introduced functionality, and the unit tests are passing locally - [ ] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more than 30 seconds to run. - [x] I have added inline documentation which enables the maintainers with understanding the motivation - [ ] I have removed the stale documentation which is no longer relevant after this pull request - [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request - [X] I have run `clang-format` on all changed files - [ ] Any dependent changes have been merged ## Discussion If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered
This commit is contained in:
committed by
assistant-librarian[bot]
parent
0438ab1b79
commit
95dc496d30
@@ -48,19 +48,19 @@ CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window,
|
||||
* and an elementwise function. For each A = A0, A1… AN, the elementwise function
|
||||
* is additionally applied during a single read.
|
||||
*/
|
||||
template <typename TileWindow_,
|
||||
template <typename... TileWindow_,
|
||||
typename ElementWise_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_& tile_window,
|
||||
CK_TILE_DEVICE auto load_tile_with_elementwise(const ck_tile::tuple<TileWindow_...>& tile_windows,
|
||||
ElementWise_ elementwise,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
// TODO: Tile windows should works with unknow number of params
|
||||
// Load element_wise API works only when the input typle is a tuple-tyupe
|
||||
return tile_window[number<0>{}].load(
|
||||
tile_window, elementwise, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
// TODO: Tile windows should work with unknown number of params
|
||||
// Load element_wise API works only when the input type is a tuple-type
|
||||
return tile_windows[number<0>{}].load(
|
||||
tile_windows, elementwise, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
// Per-lane read-offset tweaks allow swizzling patterns not representable by tile_distribution.
|
||||
@@ -85,12 +85,12 @@ template <typename DistributedTensor_,
|
||||
typename TileWindow_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
|
||||
CK_TILE_DEVICE void load_tile(DistributedTensor_& dst_tile,
|
||||
const TileWindow_& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
return tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -131,7 +131,7 @@ template <typename T,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE auto load_tile_raw(T& tile,
|
||||
CK_TILE_DEVICE void load_tile_raw(T& tile,
|
||||
const tile_window_linear<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
|
||||
@@ -374,6 +374,7 @@ CK_TILE_HOST_DEVICE constexpr auto InputTileDistributionEncoding()
|
||||
* element space size and vector length remain consistent between the input and output
|
||||
* distributions.
|
||||
*
|
||||
* @tparam DistributedTensor_ The type of the tensor containing the transposed tile data.
|
||||
* @tparam BottomTensorView_ The type of the bottom tensor view.
|
||||
* @tparam WindowLengths_ The type representing the window lengths.
|
||||
* @tparam TileDistribution_ The type representing the tile distribution.
|
||||
@@ -381,18 +382,19 @@ CK_TILE_HOST_DEVICE constexpr auto InputTileDistributionEncoding()
|
||||
* @tparam Policy The transpose policy to use (defaults to DefaultTranspose).
|
||||
* the last is SFINAE to ensure the tile distribution encoding is valid.
|
||||
*
|
||||
* @param out_tensor A statically distributed tensor containing the transposed tile
|
||||
* data.
|
||||
* @param tile_window The tile window with static distribution to load and transpose.
|
||||
* @param offset The offset (in elements) added to the base address before
|
||||
* indexing.
|
||||
*
|
||||
* @return A statically distributed tensor containing the transposed tile data.
|
||||
*
|
||||
* @note
|
||||
* - The function uses compile-time checks to ensure the input and output tile distributions
|
||||
* are compatible in terms of element space size and vector length.
|
||||
* - The transpose operation is performed according to the specified Policy.
|
||||
*/
|
||||
template <
|
||||
typename DistributedTensor_,
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
@@ -402,21 +404,28 @@ template <
|
||||
typename BottomTensorView_::DataType,
|
||||
Policy>::distr_encoding_valid,
|
||||
Policy>>
|
||||
CK_TILE_DEVICE auto load_tile_transpose_with_offset(
|
||||
CK_TILE_DEVICE void load_tile_transpose_with_offset(
|
||||
DistributedTensor_& out_tensor,
|
||||
const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& __restrict__ tile_window,
|
||||
index_t offset)
|
||||
{
|
||||
auto trans_tensor = tile_window.template load_transpose_with_offset<Policy>(offset);
|
||||
constexpr auto input_distr = TileDistribution_{};
|
||||
constexpr auto output_distr = typename DistributedTensor_::StaticTileDistribution{};
|
||||
|
||||
// Check that the tile distribution of out_tensor is the expected one for transposed loads.
|
||||
using OutTileDstrEncode = typename OutputTileDistributionTraits<
|
||||
typename TileDistribution_::DstrEncode,
|
||||
typename BottomTensorView_::DataType>::TransposedDstrEncode;
|
||||
auto out_tensor = make_static_distributed_tensor<typename BottomTensorView_::DataType>(
|
||||
make_static_tile_distribution(OutTileDstrEncode{}));
|
||||
auto trans_tensor = tile_window.template load_transpose_with_offset<Policy>(offset);
|
||||
constexpr auto input_distr = TileDistribution_{};
|
||||
constexpr auto output_distr = make_static_tile_distribution(OutTileDstrEncode{});
|
||||
static_assert(std::is_same_v<decltype(make_static_tile_distribution(OutTileDstrEncode{})),
|
||||
remove_cvref_t<decltype(output_distr)>>);
|
||||
|
||||
// Check that the datatype of out_tensor matches that of the bottom tensor view.
|
||||
static_assert(std::is_same_v<typename DistributedTensor_::DataType,
|
||||
typename BottomTensorView_::DataType>);
|
||||
|
||||
constexpr auto y_in_desc = input_distr.get_ys_to_d_descriptor();
|
||||
constexpr auto y_out_desc = output_distr.get_ys_to_d_descriptor();
|
||||
@@ -443,8 +452,6 @@ CK_TILE_DEVICE auto load_tile_transpose_with_offset(
|
||||
number<iAccess>{},
|
||||
trans_tensor.get_thread_buffer().template get_as<DataVec>(number<iAccess>{}));
|
||||
});
|
||||
|
||||
return out_tensor;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -456,6 +463,7 @@ CK_TILE_DEVICE auto load_tile_transpose_with_offset(
|
||||
* element space size and vector length remain consistent between the input and output
|
||||
* distributions.
|
||||
*
|
||||
* @tparam DistributedTensor_ The type of the tensor containing the transposed tile data.
|
||||
* @tparam BottomTensorView_ The type of the bottom tensor view.
|
||||
* @tparam WindowLengths_ The type representing the window lengths.
|
||||
* @tparam TileDistribution_ The type representing the tile distribution.
|
||||
@@ -463,16 +471,37 @@ CK_TILE_DEVICE auto load_tile_transpose_with_offset(
|
||||
* @tparam Policy The transpose policy to use (defaults to DefaultTranspose).
|
||||
* the last is SFINAE to ensure the tile distribution encoding is valid.
|
||||
*
|
||||
* @param out_tensor A statically distributed tensor containing the transposed tile
|
||||
* data.
|
||||
* @param tile_window The tile window with static distribution to load and transpose.
|
||||
* indexing.
|
||||
*
|
||||
* @return A statically distributed tensor containing the transposed tile data.
|
||||
*
|
||||
* @note
|
||||
* - The function uses compile-time checks to ensure the input and output tile distributions
|
||||
* are compatible in terms of element space size and vector length.
|
||||
* - The transpose operation is performed according to the specified Policy.
|
||||
*/
|
||||
template <
|
||||
typename DistributedTensor_,
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
typename Policy = DefaultTranspose<typename BottomTensorView_::DataType>,
|
||||
typename = std::enable_if_t<TransposeTileDistrChecker<TileDistribution_,
|
||||
typename BottomTensorView_::DataType,
|
||||
Policy>::distr_encoding_valid,
|
||||
Policy>>
|
||||
CK_TILE_DEVICE void
|
||||
load_tile_transpose(DistributedTensor_& out_tensor,
|
||||
const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& __restrict__ tile_window)
|
||||
{
|
||||
load_tile_transpose_with_offset(out_tensor, tile_window, 0);
|
||||
}
|
||||
|
||||
template <
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
@@ -489,7 +518,15 @@ load_tile_transpose(const tile_window_with_static_distribution<BottomTensorView_
|
||||
TileDistribution_,
|
||||
NumCoord>& __restrict__ tile_window)
|
||||
{
|
||||
return load_tile_transpose_with_offset(tile_window, 0);
|
||||
using OutTileDstrEncode = typename OutputTileDistributionTraits<
|
||||
typename TileDistribution_::DstrEncode,
|
||||
typename BottomTensorView_::DataType>::TransposedDstrEncode;
|
||||
auto out_tensor = make_static_distributed_tensor<typename BottomTensorView_::DataType>(
|
||||
make_static_tile_distribution(OutTileDstrEncode{}));
|
||||
|
||||
load_tile_transpose_with_offset(out_tensor, tile_window, 0);
|
||||
|
||||
return out_tensor;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -190,11 +190,11 @@ struct tile_window_with_static_distribution
|
||||
* The same thread, during vectorized reading, accesses the same set of
|
||||
* data from A0, A1, A2, … AN.
|
||||
*/
|
||||
template <typename TileWindow_,
|
||||
template <typename... TileWindow_,
|
||||
typename ElementWise_,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load(const TileWindow_& tile_window,
|
||||
CK_TILE_DEVICE auto load(const ck_tile::tuple<TileWindow_...>& tile_windows,
|
||||
ElementWise_ elementwise,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
@@ -202,7 +202,7 @@ struct tile_window_with_static_distribution
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
|
||||
load(dst_tensor,
|
||||
tile_window,
|
||||
tile_windows,
|
||||
elementwise,
|
||||
number<i_access_unsupport_>{},
|
||||
bool_constant<oob_conditional_check>{});
|
||||
@@ -210,12 +210,12 @@ struct tile_window_with_static_distribution
|
||||
}
|
||||
|
||||
template <typename DistributedTensor,
|
||||
typename TileWindow_,
|
||||
typename... TileWindow_,
|
||||
typename ElementWise_,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void load(DistributedTensor& dst_tensor,
|
||||
const TileWindow_& tile_window,
|
||||
const ck_tile::tuple<TileWindow_...>& tile_windows,
|
||||
ElementWise_ elementwise,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
@@ -226,14 +226,14 @@ struct tile_window_with_static_distribution
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
constexpr auto sizeOfTuple = TileWindow_::size();
|
||||
constexpr auto sizeOfTuple = remove_cvref_t<decltype(tile_windows)>::size();
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
/// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
auto window_adaptor_thread_coord =
|
||||
tile_window[number<0>{}].pre_computed_coords_[iCoord][I0];
|
||||
tile_windows[number<0>{}].pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord =
|
||||
tile_window[number<0>{}].pre_computed_coords_[iCoord][I1];
|
||||
tile_windows[number<0>{}].pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
@@ -244,7 +244,7 @@ struct tile_window_with_static_distribution
|
||||
// read from bottom tensor
|
||||
const auto idx_vec_value = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return tile_window[number<jj>{}]
|
||||
return tile_windows[number<jj>{}]
|
||||
.get_bottom_tensor_view()
|
||||
.template get_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
#include "ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp"
|
||||
#include "ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp"
|
||||
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -5,22 +5,20 @@
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename SrcDataType, typename DstDataType, index_t UnaryOpSize>
|
||||
struct InterleavedPKTypeLoader
|
||||
struct ConverterLoader
|
||||
{
|
||||
template <typename WarpWindow, typename WarpTile>
|
||||
CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile,
|
||||
const WarpWindow& warp_window)
|
||||
CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& dst, const WarpWindow& src)
|
||||
{
|
||||
const element_wise::PassThroughPack8 elementwise_op{};
|
||||
|
||||
static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0);
|
||||
constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize;
|
||||
const auto in_dstr_tensors = load_tile(warp_window);
|
||||
const auto tmp = load_tile(src);
|
||||
|
||||
// NOTE: we rely on types packing neatly here
|
||||
using RawSrcType = typename SrcDataType::type;
|
||||
@@ -29,29 +27,28 @@ struct InterleavedPKTypeLoader
|
||||
using SrcVectorType = ext_vector_t<RawSrcType, UnaryOpSize / PackedSize>;
|
||||
using DstVectorType = ext_vector_t<DstDataType, UnaryOpSize>;
|
||||
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
|
||||
elementwise_op(warp_tile.get_thread_buffer().template get_as<DstVectorType>()(i),
|
||||
in_dstr_tensors.get_thread_buffer().template get_as<SrcVectorType>()[i]);
|
||||
const element_wise::PassThroughPack8 elementwise_op{};
|
||||
|
||||
elementwise_op(dst.get_thread_buffer().template get_as<DstVectorType>()(i),
|
||||
tmp.get_thread_buffer().template get_as<SrcVectorType>()[i]);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename SrcDataType,
|
||||
typename DstDataType,
|
||||
index_t UnaryOpSize,
|
||||
bool LoadTranspose = false,
|
||||
typename WarpTile,
|
||||
typename WarpWindow>
|
||||
CK_TILE_DEVICE void load_int4_tile(WarpTile& dst, const WarpWindow& src)
|
||||
template <index_t UnaryOpSize, bool LoadTranspose = false, typename WarpTile, typename WarpWindow>
|
||||
CK_TILE_DEVICE void load_and_convert_tile(WarpTile& dst, const WarpWindow& src)
|
||||
{
|
||||
if constexpr(is_packed_type_v<SrcDataType>)
|
||||
using SrcDataType = typename WarpWindow::Base::DataType;
|
||||
using DstDataType = typename WarpTile::DataType;
|
||||
|
||||
if constexpr(is_packed_type_v<SrcDataType> && !is_packed_type_v<DstDataType>)
|
||||
{
|
||||
static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t or pk_fp4_t");
|
||||
InterleavedPKTypeLoader<SrcDataType, DstDataType, UnaryOpSize>::load_interleaved_pk_type(
|
||||
dst, src);
|
||||
ConverterLoader<SrcDataType, DstDataType, UnaryOpSize>::load_interleaved_pk_type(dst, src);
|
||||
}
|
||||
else if constexpr(LoadTranspose)
|
||||
{
|
||||
dst = load_tile_transpose(src);
|
||||
load_tile_transpose(dst, src);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -9,7 +9,7 @@
|
||||
#include "ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
|
||||
#include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -22,7 +22,7 @@
|
||||
#include "ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
|
||||
#include "ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -61,7 +61,7 @@
|
||||
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -530,7 +530,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
s_acc = gemm_0(q_reg_tensor, k_reg_tensor);
|
||||
|
||||
dot_lds_read_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_curr);
|
||||
dot_reg_tensor = load_tile_transpose(dot_lds_read_window);
|
||||
load_tile_transpose(dot_reg_tensor, dot_lds_read_window);
|
||||
}
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
@@ -634,7 +634,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
dp_acc = gemm_2(do_reg_tensor, v_reg_tensor);
|
||||
|
||||
qt_lds_read_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_curr);
|
||||
qt_reg_tensor = load_tile_transpose(qt_lds_read_window);
|
||||
load_tile_transpose(qt_reg_tensor, qt_lds_read_window);
|
||||
|
||||
// STAGE 3, P^T@OGrad^T Gemm1
|
||||
auto pt_reg_tensor = make_static_distributed_tensor<GemmDataType>(
|
||||
@@ -715,7 +715,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
}
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
ds_reg_tensor = load_tile_transpose(ds_lds_read_window);
|
||||
load_tile_transpose(ds_reg_tensor, ds_lds_read_window);
|
||||
move_tile_window(ds_lds_read_window, {kK4, 0});
|
||||
}
|
||||
if constexpr(is_main_body)
|
||||
@@ -728,7 +728,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
|
||||
if constexpr(i_k4 < k4_loops - 1)
|
||||
{
|
||||
ds_reg_tensor_next = load_tile_transpose(ds_lds_read_window);
|
||||
load_tile_transpose(ds_reg_tensor_next, ds_lds_read_window);
|
||||
move_tile_window(ds_lds_read_window, {kK4, 0});
|
||||
}
|
||||
auto kt_reg_tensor_slice = get_slice_tile( //
|
||||
|
||||
@@ -455,10 +455,10 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
|
||||
async_load_tile(q_lds_write_window, q_dram_window);
|
||||
async_load_tile(do_lds_write_window, do_dram_window);
|
||||
__builtin_amdgcn_s_waitcnt(0);
|
||||
qt_reg_tensor = load_tile_transpose(qt_lds_read_window);
|
||||
q_reg_tensor = load_tile(q_lds_read_window);
|
||||
dot_reg_tensor = load_tile_transpose(dot_lds_read_window);
|
||||
do_reg_tensor = load_tile(do_lds_read_window);
|
||||
load_tile_transpose(qt_reg_tensor, qt_lds_read_window);
|
||||
q_reg_tensor = load_tile(q_lds_read_window);
|
||||
load_tile_transpose(dot_reg_tensor, dot_lds_read_window);
|
||||
do_reg_tensor = load_tile(do_lds_read_window);
|
||||
|
||||
lse_block_tile = load_tile(lse_dram_window);
|
||||
d_block_tile = load_tile(d_dram_window);
|
||||
@@ -490,9 +490,9 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
|
||||
async_load_tile(v_lds_write_window, v_dram_window);
|
||||
move_tile_window(v_dram_window, {kN0, 0});
|
||||
s_waitcnt</*vmcnt=*/0>();
|
||||
k_reg_tensor = load_tile(k_lds_read_window);
|
||||
v_reg_tensor = load_tile(v_lds_read_window);
|
||||
kt_reg_tensor = load_tile_transpose(kt_lds_read_window);
|
||||
k_reg_tensor = load_tile(k_lds_read_window);
|
||||
v_reg_tensor = load_tile(v_lds_read_window);
|
||||
load_tile_transpose(kt_reg_tensor, kt_lds_read_window);
|
||||
}
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
@@ -668,7 +668,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
|
||||
block_sync_lds();
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
ds_reg_tensor = load_tile_transpose(ds_lds_read_window);
|
||||
load_tile_transpose(ds_reg_tensor, ds_lds_read_window);
|
||||
move_tile_window(ds_lds_read_window, {kK4, 0});
|
||||
}
|
||||
if constexpr(is_main_body)
|
||||
@@ -680,7 +680,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
|
||||
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
|
||||
if constexpr(i_k4 < k4_loops - 1)
|
||||
{
|
||||
ds_reg_tensor_next = load_tile_transpose(ds_lds_read_window);
|
||||
load_tile_transpose(ds_reg_tensor_next, ds_lds_read_window);
|
||||
move_tile_window(ds_lds_read_window, {kK4, 0});
|
||||
}
|
||||
auto kt_reg_tensor_slice = get_slice_tile( //
|
||||
|
||||
@@ -718,7 +718,7 @@ struct BlockFmhaFwdV3Pipeline
|
||||
};
|
||||
|
||||
auto V_lds_load = [&](auto v_lds_read_idx) {
|
||||
kv_tile.v_tile = load_tile_transpose(v_lds_window_load(v_lds_read_idx));
|
||||
load_tile_transpose(kv_tile.v_tile, v_lds_window_load(v_lds_read_idx));
|
||||
};
|
||||
|
||||
decltype(m) m_old;
|
||||
|
||||
@@ -582,7 +582,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
|
||||
// loop over along the [V]alue Sequence length
|
||||
move_tile_window(v_lds_read_window, {kK1, 0});
|
||||
v_tile = load_tile_transpose(v_lds_read_window);
|
||||
load_tile_transpose(v_tile, v_lds_read_window);
|
||||
});
|
||||
// move back to the origin
|
||||
move_tile_window(v_lds_read_window, {-kK1 * (k1_loops - 1), 0});
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -79,7 +79,7 @@
|
||||
#include "ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
@@ -218,10 +218,8 @@ struct BlockUniversalGemmAsBsCr
|
||||
bool_constant<ALoadTranspose> = {},
|
||||
bool_constant<BLoadTranspose> = {})
|
||||
{
|
||||
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_, ALoadTranspose>(a_warp_tile_,
|
||||
a_block_window);
|
||||
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_, BLoadTranspose>(b_warp_tile_,
|
||||
b_block_window);
|
||||
load_and_convert_tile<UnaryOpSize_, ALoadTranspose>(a_warp_tile_, a_block_window);
|
||||
load_and_convert_tile<UnaryOpSize_, BLoadTranspose>(b_warp_tile_, b_block_window);
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
@@ -290,9 +288,9 @@ struct BlockUniversalGemmAsBsCr
|
||||
static constexpr index_t KInnerLoopIter = KPerInnerLoop / WarpGemm::kKPerThread;
|
||||
|
||||
static constexpr auto ALdsTileDistr =
|
||||
make_static_tile_distribution(MakeABlockDistributionEncode());
|
||||
decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){};
|
||||
static constexpr auto BLdsTileDistr =
|
||||
make_static_tile_distribution(MakeBBlockDistributionEncode());
|
||||
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
|
||||
|
||||
using ALdsTile = decltype(make_static_distributed_tensor<ATypeToUse>(ALdsTileDistr));
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<BTypeToUse>(BLdsTileDistr));
|
||||
@@ -349,10 +347,8 @@ struct BlockUniversalGemmAsBsCr
|
||||
auto b_lds_gemm_window = make_tile_window(
|
||||
b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr);
|
||||
|
||||
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_, ALoadTranspose>(a_warp_tile_,
|
||||
a_lds_gemm_window);
|
||||
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_, BLoadTranspose>(b_warp_tile_,
|
||||
b_lds_gemm_window);
|
||||
load_and_convert_tile<UnaryOpSize_, ALoadTranspose>(a_warp_tile_, a_lds_gemm_window);
|
||||
load_and_convert_tile<UnaryOpSize_, BLoadTranspose>(b_warp_tile_, b_lds_gemm_window);
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
|
||||
@@ -79,9 +79,7 @@ struct GemmPipelineAgBgCrImplBase
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
template <typename SrcDataType = void,
|
||||
typename DstDataType = void,
|
||||
index_t UnaryOpSize = 8,
|
||||
template <index_t UnaryOpSize = 8,
|
||||
typename DstBlockTile,
|
||||
typename SrcTileWindow,
|
||||
typename DramTileWindowStep>
|
||||
@@ -89,7 +87,7 @@ struct GemmPipelineAgBgCrImplBase
|
||||
SrcTileWindow& dram_tile_window,
|
||||
const DramTileWindowStep& dram_tile_window_step) const
|
||||
{
|
||||
load_int4_tile<SrcDataType, DstDataType, UnaryOpSize>(dst_block_tile, dram_tile_window);
|
||||
load_and_convert_tile<UnaryOpSize>(dst_block_tile, dram_tile_window);
|
||||
move_tile_window(dram_tile_window, dram_tile_window_step);
|
||||
}
|
||||
|
||||
@@ -124,7 +122,7 @@ struct GemmPipelineAgBgCrImplBase
|
||||
bool_constant<LoadTranspose> = {}) const
|
||||
{
|
||||
if constexpr(LoadTranspose)
|
||||
dst_block_tile = load_tile_transpose(lds_tile_window);
|
||||
load_tile_transpose(dst_block_tile, lds_tile_window);
|
||||
else
|
||||
load_tile(dst_block_tile, lds_tile_window);
|
||||
}
|
||||
@@ -241,12 +239,6 @@ struct GemmPipelineAgBgCrImplBase
|
||||
CK_TILE_DEVICE constexpr auto MakeALdsWindows(const ALdsTensorView& a_lds_block_view,
|
||||
const ALdsLoadTileDistr&) const
|
||||
{
|
||||
// with pk_int4_t load transpose the LDS type is always BDataType
|
||||
using ADataTypeLDS =
|
||||
std::conditional_t<std::is_same_v<typename Problem::ADataType, pk_int4_t>,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::ADataType>;
|
||||
|
||||
auto a_lds_shape = []() {
|
||||
if constexpr(is_a_load_tr)
|
||||
return make_tuple(number<KPerBlock>{}, number<MPerBlock>{});
|
||||
@@ -258,11 +250,16 @@ struct GemmPipelineAgBgCrImplBase
|
||||
|
||||
auto a_lds_load_tile_distr = []() {
|
||||
if constexpr(is_a_load_tr)
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
typename InputTileDistributionTraits<typename ALdsLoadTileDistr::DstrEncode,
|
||||
ADataTypeLDS>::TransposedDstrEncode{});
|
||||
typename InputTileDistributionTraits<
|
||||
typename ALdsLoadTileDistr::DstrEncode,
|
||||
typename ALdsTensorView::DataType>::TransposedDstrEncode{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return ALdsLoadTileDistr{};
|
||||
}
|
||||
}();
|
||||
|
||||
auto a_lds_gemm_window =
|
||||
@@ -333,18 +330,18 @@ struct GemmPipelineAgBgCrImplBase
|
||||
|
||||
auto b_copy_lds_window = make_tile_window(b_lds_block_view, b_lds_shape, {0, 0});
|
||||
|
||||
using BLdsDataType = std::conditional_t<IsBCastPolicyBeforeLDSWrite,
|
||||
typename Problem::ADataType,
|
||||
typename Problem::BDataType>;
|
||||
|
||||
auto b_lds_load_tile_distr = []() {
|
||||
if constexpr(is_b_load_tr)
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
typename InputTileDistributionTraits<typename BLdsLoadTileDistr::DstrEncode,
|
||||
BLdsDataType>::TransposedDstrEncode{});
|
||||
|
||||
typename InputTileDistributionTraits<
|
||||
typename BLdsLoadTileDistr::DstrEncode,
|
||||
typename BLdsTensorView::DataType>::TransposedDstrEncode{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return BLdsLoadTileDistr{};
|
||||
}
|
||||
}();
|
||||
|
||||
auto b_lds_gemm_window =
|
||||
|
||||
@@ -127,7 +127,6 @@ struct UniversalGemmBasePolicy
|
||||
using ADataType = OverrideADataType;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPack = Derived::template GetSmemPackA<Problem>();
|
||||
|
||||
if constexpr(is_a_load_tr<Problem>)
|
||||
{
|
||||
@@ -261,6 +260,7 @@ struct UniversalGemmBasePolicy
|
||||
}
|
||||
else // A is in RowMajor
|
||||
{
|
||||
constexpr index_t KPack = Derived::template GetSmemPackA<Problem>();
|
||||
constexpr auto DataTypeSize = sizeof(ADataType);
|
||||
constexpr uint64_t MinLdsLayer = 1ULL;
|
||||
constexpr auto MLdsLayer =
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp"
|
||||
|
||||
@@ -627,8 +627,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
// // Prefetch A0
|
||||
Base::GlobalPrefetch(a_global_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
Base::template GlobalPrefetch<BDataType, BTypeToUse, UnaryOpSize_>(
|
||||
b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step);
|
||||
|
||||
// Prefill A0
|
||||
Base::LocalPrefill(a_copy_lds_windows[I0], a_global_tile);
|
||||
@@ -652,7 +651,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
do
|
||||
{
|
||||
{
|
||||
Base::template GlobalPrefetch<BDataType, BTypeToUse, UnaryOpSize_>(
|
||||
Base::GlobalPrefetch(
|
||||
b_global_tile[1], b_flat_dram_window, b_dram_tile_window_step);
|
||||
Base::LocalPrefill(a_copy_lds_windows[I1], a_global_tile);
|
||||
Base::GlobalPrefetch(
|
||||
@@ -666,7 +665,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
HotLoopScheduler();
|
||||
}
|
||||
{
|
||||
Base::template GlobalPrefetch<BDataType, BTypeToUse, UnaryOpSize_>(
|
||||
Base::GlobalPrefetch(
|
||||
b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step);
|
||||
Base::LocalPrefill(a_copy_lds_windows[I0], a_global_tile);
|
||||
Base::GlobalPrefetch(
|
||||
@@ -687,7 +686,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
{
|
||||
Base::template GlobalPrefetch<BDataType, BTypeToUse, UnaryOpSize_>(
|
||||
Base::GlobalPrefetch(
|
||||
b_global_tile[1], b_flat_dram_window, b_dram_tile_window_step);
|
||||
Base::LocalPrefill(a_copy_lds_windows[I1], a_global_tile);
|
||||
block_weight_preshuffle(
|
||||
|
||||
@@ -34,7 +34,7 @@
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -238,7 +238,7 @@ struct BlockGemmWeightPreshuffleABQuantARegBRegCReg : public BlockGemmQuantBase
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
|
||||
load_int4_tile<ADataType, ComputeDataType, UnaryOpSize>(
|
||||
load_and_convert_tile<UnaryOpSize>(
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
a_warp_windows(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
@@ -268,10 +268,8 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase
|
||||
bool_constant<BLoadTranspose> = {})
|
||||
{
|
||||
// If A/B datatype were pkint4/pkfp4 it would be converted prior to storing in LDS
|
||||
load_int4_tile<OverrideADataType, ComputeDataType, UnaryOpSize_, ALoadTranspose>(
|
||||
a_warp_tile_, a_block_window);
|
||||
load_int4_tile<OverrideBDataType, ComputeDataType, UnaryOpSize_, BLoadTranspose>(
|
||||
b_warp_tile_, b_block_window);
|
||||
load_and_convert_tile<UnaryOpSize_, ALoadTranspose>(a_warp_tile_, a_block_window);
|
||||
load_and_convert_tile<UnaryOpSize_, BLoadTranspose>(b_warp_tile_, b_block_window);
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
|
||||
@@ -248,10 +248,8 @@ struct AQuantBlockUniversalGemmAsBsCr
|
||||
// while ADatatype might not be the same as BDataType at the time of problem
|
||||
// initialization, we can safely use BDataType here because when A would be int4 we will
|
||||
// ensure A is converted to BDataType prior to loading
|
||||
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_, ALoadTranspose>(
|
||||
a_warp_tile_, a_block_window);
|
||||
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_, BLoadTranspose>(
|
||||
b_warp_tile_, b_block_window);
|
||||
load_and_convert_tile<UnaryOpSize_, ALoadTranspose>(a_warp_tile_, a_block_window);
|
||||
load_and_convert_tile<UnaryOpSize_, BLoadTranspose>(b_warp_tile_, b_block_window);
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
@@ -395,10 +393,8 @@ struct AQuantBlockUniversalGemmAsBsCr
|
||||
auto b_lds_gemm_window = make_tile_window(
|
||||
b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr);
|
||||
|
||||
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_, ALoadTranspose>(
|
||||
a_warp_tile_, a_lds_gemm_window);
|
||||
load_int4_tile<BDataType, ComputeDataType, UnaryOpSize_, BLoadTranspose>(
|
||||
b_warp_tile_, b_lds_gemm_window);
|
||||
load_and_convert_tile<UnaryOpSize_, ALoadTranspose>(a_warp_tile_, a_lds_gemm_window);
|
||||
load_and_convert_tile<UnaryOpSize_, BLoadTranspose>(b_warp_tile_, b_lds_gemm_window);
|
||||
}
|
||||
|
||||
// C += A * B with quantization support
|
||||
|
||||
@@ -239,11 +239,9 @@ struct BQuantBlockUniversalGemmAsBsCr
|
||||
bool_constant<ALoadTranspose> = {},
|
||||
bool_constant<BLoadTranspose> = {})
|
||||
{
|
||||
load_int4_tile<ADataType, ComputeDataType, UnaryOpSize_, ALoadTranspose>(
|
||||
a_warp_tile_, a_block_window);
|
||||
load_and_convert_tile<UnaryOpSize_, ALoadTranspose>(a_warp_tile_, a_block_window);
|
||||
// If B datatype were pkint4 it would be converted prior to storing in LDS
|
||||
load_int4_tile<OverrideBDataType, ComputeDataType, UnaryOpSize_, BLoadTranspose>(
|
||||
b_warp_tile_, b_block_window);
|
||||
load_and_convert_tile<UnaryOpSize_, BLoadTranspose>(b_warp_tile_, b_block_window);
|
||||
}
|
||||
|
||||
// Load from LDS and scale (then the tile can directly be consumed in the block gemm)
|
||||
|
||||
@@ -202,20 +202,16 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
|
||||
CK_TILE_DEVICE static void LoadAndConvertATile(ABlockTile_& a_block_tile,
|
||||
const ADramWindow& a_dram_window)
|
||||
{
|
||||
using DestDataType = typename ABlockTile_::DataType;
|
||||
using SrcDataType = typename ADramWindow::Base::TileWindowBase::DataType;
|
||||
constexpr index_t UnaryOpSize = 8;
|
||||
load_int4_tile<SrcDataType, DestDataType, UnaryOpSize>(a_block_tile, a_dram_window);
|
||||
load_and_convert_tile<UnaryOpSize>(a_block_tile, a_dram_window);
|
||||
}
|
||||
|
||||
template <typename BDramWindow, typename BBlockTile_>
|
||||
CK_TILE_DEVICE static void LoadAndConvertBTile(BBlockTile_& b_block_tile,
|
||||
const BDramWindow& b_dram_window)
|
||||
{
|
||||
using DestDataType = typename BBlockTile_::DataType;
|
||||
using SrcDataType = typename BDramWindow::Base::TileWindowBase::DataType;
|
||||
constexpr index_t UnaryOpSize = 8;
|
||||
load_int4_tile<SrcDataType, DestDataType, UnaryOpSize>(b_block_tile, b_dram_window);
|
||||
load_and_convert_tile<UnaryOpSize>(b_block_tile, b_dram_window);
|
||||
}
|
||||
|
||||
template <bool HasHotLoop,
|
||||
|
||||
@@ -178,10 +178,8 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
ADramWindow& a_dram_window,
|
||||
const DramTileWindowStep& dram_tile_window_step)
|
||||
{
|
||||
using DestDataType = typename ABlockTile_::DataType;
|
||||
using SrcDataType = typename ADramWindow::Base::TileWindowBase::DataType;
|
||||
constexpr index_t UnaryOpSize = 8;
|
||||
load_int4_tile<SrcDataType, DestDataType, UnaryOpSize>(a_block_tile, a_dram_window);
|
||||
load_and_convert_tile<UnaryOpSize>(a_block_tile, a_dram_window);
|
||||
move_tile_window(a_dram_window, dram_tile_window_step);
|
||||
}
|
||||
|
||||
|
||||
@@ -174,10 +174,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
ADramWindow& a_dram_window,
|
||||
const DramTileWindowStep& dram_tile_window_step)
|
||||
{
|
||||
using DestDataType = typename ABlockTile_::DataType;
|
||||
using SrcDataType = typename ADramWindow::Base::TileWindowBase::DataType;
|
||||
constexpr index_t UnaryOpSize = 8;
|
||||
load_int4_tile<SrcDataType, DestDataType, UnaryOpSize>(a_block_tile, a_dram_window);
|
||||
load_and_convert_tile<UnaryOpSize>(a_block_tile, a_dram_window);
|
||||
move_tile_window(a_dram_window, dram_tile_window_step);
|
||||
}
|
||||
|
||||
|
||||
@@ -185,10 +185,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
|
||||
CK_TILE_DEVICE static void LoadAndConvertBTile(BBlockTile_& b_block_tile,
|
||||
const BDramWindow& b_dram_window)
|
||||
{
|
||||
using DestDataType = typename BBlockTile_::DataType;
|
||||
using SrcDataType = typename BDramWindow::Base::TileWindowBase::DataType;
|
||||
constexpr index_t UnaryOpSize = 8;
|
||||
load_int4_tile<SrcDataType, DestDataType, UnaryOpSize>(b_block_tile, b_dram_window);
|
||||
load_and_convert_tile<UnaryOpSize>(b_block_tile, b_dram_window);
|
||||
}
|
||||
|
||||
template <typename BBlockTile_, typename BDramWindow, typename BDramTileWindowStep>
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp"
|
||||
@@ -373,8 +373,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * flatNPerWarp, kIter * flatKPerWarp});
|
||||
|
||||
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_>(
|
||||
b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
|
||||
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_ping(nIter)(kIter),
|
||||
b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
// move B window to next flat K
|
||||
@@ -413,8 +413,6 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
|
||||
block_sync_lds();
|
||||
|
||||
// preload A00,A10 from lds
|
||||
using ATypeToUse =
|
||||
mixed_prec_compute_type_from_input_t<ADataType, BDataType, ComputeDataType>;
|
||||
using ATileType =
|
||||
decltype(make_static_distributed_tensor<BTypeToUse>(a_warp_tile_distribution));
|
||||
statically_indexed_array<ATileType, m_preload> a_warp_tensor;
|
||||
@@ -422,7 +420,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_>(
|
||||
load_and_convert_tile<UnaryOpSize_>(
|
||||
a_warp_tensor(loadIter), a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
@@ -456,8 +454,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * flatNPerWarp, kIter * flatKPerWarp});
|
||||
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_>(
|
||||
b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
|
||||
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_pong(nIter)(kIter),
|
||||
b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
@@ -468,7 +466,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_>(
|
||||
load_and_convert_tile<UnaryOpSize_>(
|
||||
a_warp_tensor(loadIter), a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
|
||||
});
|
||||
|
||||
@@ -481,8 +479,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * flatNPerWarp, kIter * flatKPerWarp});
|
||||
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_>(
|
||||
b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
|
||||
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_ping(nIter)(kIter),
|
||||
b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
@@ -511,7 +509,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_>(
|
||||
load_and_convert_tile<UnaryOpSize_>(
|
||||
a_warp_tensor(loadIter), a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
|
||||
});
|
||||
iCounter--;
|
||||
@@ -529,8 +527,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * flatNPerWarp, kIter * flatKPerWarp});
|
||||
|
||||
load_int4_tile<BDataType, BTypeToUse, UnaryOpSize_>(
|
||||
b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
|
||||
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_pong(nIter)(kIter),
|
||||
b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
aq_block_tile_2 = load_tile(aq_copy_dram_window);
|
||||
@@ -551,7 +549,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
load_int4_tile<ADataType, ATypeToUse, UnaryOpSize_>(
|
||||
load_and_convert_tile<UnaryOpSize_>(
|
||||
a_warp_tensor(loadIter), a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
|
||||
});
|
||||
|
||||
|
||||
@@ -344,8 +344,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * flatNPerWarp, kIter * flatKPerWarp});
|
||||
|
||||
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
|
||||
b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
|
||||
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_ping(nIter)(kIter),
|
||||
b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
// move B window to next flat K
|
||||
@@ -430,8 +430,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * flatNPerWarp, kIter * flatKPerWarp});
|
||||
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
|
||||
b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
|
||||
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_pong(nIter)(kIter),
|
||||
b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
@@ -467,8 +467,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * flatNPerWarp, kIter * flatKPerWarp});
|
||||
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
|
||||
b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
|
||||
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_ping(nIter)(kIter),
|
||||
b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
@@ -525,8 +525,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * flatNPerWarp, kIter * flatKPerWarp});
|
||||
|
||||
load_int4_tile<BDataType, ADataType, UnaryOpSize_>(
|
||||
b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
|
||||
load_and_convert_tile<UnaryOpSize_>(b_warp_tensor_pong(nIter)(kIter),
|
||||
b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
bq_block_tile_2 = load_tile(bq_copy_dram_window);
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
#include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
#include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp"
|
||||
#include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp"
|
||||
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
#include "ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp"
|
||||
#include "ck_tile/ops/norm_reduce/thread/thread_welford.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#include "ck_tile/ops/permute/kernel/generic_permute_kernel.hpp"
|
||||
#include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
#include "ck_tile/ops/pooling/pipeline/pool_problem.hpp"
|
||||
#include "ck_tile/ops/pooling/pipeline/pool_shape.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
#include "ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp"
|
||||
#include "ck_tile/ops/reduce/pipeline/reduce2d_shape.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp"
|
||||
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
#include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#include "ck_tile/ops/softmax/block/block_softmax_2d.hpp"
|
||||
#include "ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
#include "ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp"
|
||||
#include "ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#include "ck_tile/ops/topk/block/block_topk_stream_2d.hpp"
|
||||
#include "ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp"
|
||||
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
Reference in New Issue
Block a user