fix some compile errors

This commit is contained in:
joye
2025-04-23 05:27:52 -05:00
parent e14a16359f
commit acce2df3bf
5 changed files with 19 additions and 18 deletions

View File

@@ -43,7 +43,7 @@ float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_con
}
// Param Comb: type_size, block_x & y, warp_x & y, thread_x & y
#define FOREACH_TRANSPOSE_PARAM(F) F(fp16, ck_tile::fp16_t, 16, 32, 16, 32)
#define FOREACH_TRANSPOSE_PARAM(F) F(fp16, ck_tile::fp16_t, 16, 16, 16, 16)
// Macro that defines one static function per line
#define GEN_TRANSPOSE_FN(SHORT_NAME, REAL_TYPE, BX, BY, WX, WY) \
@@ -61,7 +61,7 @@ float batched_transpose(batched_transpose_trait t,
{
if(t.type == "fp16")
{
return transpose_fn_fp16_16_32_16_32(a, s);
return transpose_fn_fp16_16_16_16_16(a, s);
}
return -1;
}

View File

@@ -19,3 +19,4 @@ add_subdirectory(16_batched_gemm)
add_subdirectory(17_grouped_gemm)
add_subdirectory(18_flatmm)
add_subdirectory(35_batched_transpose)
add_subdirectory(36_transpose)

View File

@@ -31,7 +31,7 @@ struct BatchedTransposeKernel
using Pipeline = remove_cvref_t<Pipeline_>;
using Problem = remove_cvref_t<typename Pipeline::Problem>;
using Type = typename Problem::InputType;
using Type = typename Problem::DataType;
struct BatchedTransposeKargs
{

View File

@@ -13,7 +13,7 @@ struct TransposeTraits
{
static constexpr index_t kLeadDim = kCol;
static constexpr index_t kSecondDim = kRow;
}
};
template <index_t kRow, index_t kCol>
struct TransposeTraits<tensor_layout::gemm::ColumnMajor, kRow, kCol>
@@ -46,13 +46,13 @@ struct TransposePipelineProblem
static constexpr index_t kSecondDimWarps =
TransposeTraits<Layout, kRowWarps_, kColWarps_>::kSecondDim;
static constexpr index_t kLeadDimPerBlock =
TransposeTraits<Layout, kRowPerBlock, kColPerBlock>::kLeadDim;
TransposeTraits<Layout, kRowPerBlock_, kColPerBlock_>::kLeadDim;
static constexpr index_t kSecondDimPerBlock =
TransposeTraits<Layout, kRowPerBlock, kColPerBlock>::kSecondDim;
TransposeTraits<Layout, kRowPerBlock_, kColPerBlock_>::kSecondDim;
static constexpr index_t kLeadDimPerWarp =
TransposeTraits<Layout, kRowPerWarp, kColPerWarp>::kLeadDim;
TransposeTraits<Layout, kRowPerWarp_, kColPerWarp_>::kLeadDim;
static constexpr index_t kSecondDimPerWarp =
TransposeTraits<Layout, kRowPerWarp, kColPerWarp>::kSecondDim;
TransposeTraits<Layout, kRowPerWarp_, kColPerWarp_>::kSecondDim;
};
template <typename Problem_, typename Policy_ = TransposePolicy>
@@ -89,10 +89,10 @@ struct BlockTranspose
static_assert(get_warp_size() == 64, "the warp size is not correct!");
static_assert(kBlockSize == kNumWarpInLeadDim * kNumWarpInSecondDim * get_warp_size(),
"the block size is not correct!");
static_assert(kLeadDimPerWarpInQuadrant * kSecondDimPerWarpInQuadrant * 4 == get_warp_size(),
"the warp size is not correct!");
//static_assert(kLeadDimPerWarpInQuadrant * kSecondDimPerWarpInQuadrant * 4 == get_warp_size(),
// "the warp size is not correct!");
static constexpr index_t GetVectorSize() { return Policy::template GetVectorSize(); }
static constexpr index_t GetVectorSize() { return Policy::template GetVectorSize<Problem>(); }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
@@ -123,17 +123,17 @@ struct BlockTranspose
{0, 0});
auto load_from_lds_window =
make_tile_window(out_lds_block,
make_tile_window(output_lds_block,
make_tuple(number<kSecondDimPerBlock>{}, number<kLeadDimPerBlock>{}),
{0, 0},
Policy::MakeLdsLoadTileDistribution<Problem>());
Policy::template MakeLdsLoadTileDistribution<Problem>());
auto x = load_tile(input_tile_window);
store_tile(copy_to_lds_window, x);
block_sync_lds();
auto y = load_tile_transpose(lds_tile_window);
auto y = load_tile_transpose(load_from_lds_window);
store_tile(output_tile_window, y);
}
};

View File

@@ -51,7 +51,7 @@ struct TransposePolicy
{
return integer_least_multiple(
sizeof(typename Problem::DataType) *
MakeLdsStoreBlockDescriptor<Problem>::get_element_space_size(),
MakeLdsStoreBlockDescriptor<Problem>().get_element_space_size(),
16);
}
@@ -90,7 +90,7 @@ struct TransposePolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreBlockDescriptor()
{
using Layout = remove_cvref_t<typename Problem::Layout>;
//using Layout = remove_cvref_t<typename Problem::Layout>;
constexpr index_t kLeadDimPerBlock = Problem::kLeadDimPerBlock;
constexpr index_t kSecondDimPerBlock = Problem::kSecondDimPerBlock;
constexpr index_t kVectorSize = 16 / sizeof(typename Problem::DataType);
@@ -123,9 +123,9 @@ struct TransposePolicy
template <typename Problem, typename WarpLevelOuterDistribution_>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadTileDistribution()
{
using Layout = remove_cvref_t<typename Problem::Layout>;
//using Layout = remove_cvref_t<typename Problem::Layout>;
using QuartTransposeTileDistribution =
QuartTransposeTraits<typename Problem::DataType>::TileDistribution;
typename QuartTransposeTraits<typename Problem::DataType>::TileDistribution;
using WarpTransposeTileDistribution =
decltype(detail::make_embed_tile_distribution_encoding(
WarpLevelOuterDistribution_{}, QuartTransposeTileDistribution{}));