[rocm-libraries] ROCm/rocm-libraries#6089 (commit c876d18)

[CK Tile] Extend type support EightWave pipeline

## Motivation

EightWave pipeline was designed for 8 bit types. This PR extend support
for any FP type

## Technical Details

 - Generalize policy to support any FP type
- Change LDS layout to fix bank conflicts. This removes all bank
conflicts in the pipeline (checked for all supported types). Remaining
bank conflicts are related to Cshuffle epilogue.

## Test Plan

Added GEMM tests with new supported types. Note that FP6 is also
supported for MX GEMM but the PR was reverted so no tests were added for
it.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Enrico Degregori
2026-06-05 23:54:40 +00:00
committed by assistant-librarian[bot]
parent 054436ca4a
commit 1b4fbd95fd
6 changed files with 167 additions and 101 deletions

View File

@@ -62,7 +62,7 @@ struct GemmPipelineAgBgCrCompAsyncEightWaves : public BaseGemmPipelineAgBgCrComp
static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(I1);
static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(I2);
static constexpr index_t kflatKPerBlock = BlockGemmShape::flatKPerBlock;
static constexpr index_t kflatKPerWarp = BlockGemmShape::flatKPerWarp;
static constexpr index_t MIterPerWarp = MPerBlock / (MWarps * WarpGemm::kM);
static constexpr index_t NIterPerWarp = NPerBlock / (NWarps * WarpGemm::kN);
@@ -170,9 +170,9 @@ struct GemmPipelineAgBgCrCompAsyncEightWaves : public BaseGemmPipelineAgBgCrComp
static_assert((MPerBlock == AsDramBlockWindowTmp{}.get_window_lengths()[I0] &&
KPerBlock == AsDramBlockWindowTmp{}.get_window_lengths()[I1]),
"A block window has incorrect lengths for defined ALayout!");
static_assert(Preshuffle //
static_assert(Preshuffle
? (NWarps == BsDramBlockWindowTmp{}.get_window_lengths()[I0] &&
kflatKPerBlock == BsDramBlockWindowTmp{}.get_window_lengths()[I1])
kflatKPerWarp == BsDramBlockWindowTmp{}.get_window_lengths()[I1])
: (NPerBlock == BsDramBlockWindowTmp{}.get_window_lengths()[I0] &&
KPerBlock == BsDramBlockWindowTmp{}.get_window_lengths()[I1]),
"B block window has incorrect lengths for defined BLayout!");

View File

@@ -21,6 +21,8 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr index_t kDramLoadPackBytes = 128;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
@@ -33,14 +35,17 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
"ALayout must be RowMajor!");
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::ColumnMajor>,
"BLayout must be ColumnMajor!");
static_assert(is_any_of<AComputeDataType, fp8_t, bf8_t, pk_fp4_t>::value);
static_assert(is_any_of<BComputeDataType, fp8_t, bf8_t, pk_fp4_t>::value);
static_assert(
is_any_of<AComputeDataType, fp8_t, bf8_t, pk_fp4_t, pk_fp6x16_t, fp16_t, bf16_t>::value);
static_assert(
is_any_of<BComputeDataType, fp8_t, bf8_t, pk_fp4_t, pk_fp6x16_t, fp16_t, bf16_t>::value);
static_assert(std::is_same_v<AComputeDataType, BComputeDataType>);
static_assert(std::is_same_v<CDataType, float>);
static constexpr auto WGAccess = std::is_same_v<ComputeDataType, fp8_t>
? WGAttrNumAccessEnum::Double
: WGAttrNumAccessEnum::Single;
static constexpr auto WGAccess =
std::is_same_v<ComputeDataType, fp8_t> || std::is_same_v<ComputeDataType, bf8_t>
? WGAttrNumAccessEnum::Double
: WGAttrNumAccessEnum::Single;
static constexpr auto PackedSize = numeric_traits<ComputeDataType>::PackedSize;
using BlockGemmShape = typename Problem::BlockGemmShape;
@@ -92,10 +97,8 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
static constexpr index_t KWarps = BlockWarps::at(I2);
static constexpr index_t MIterPerWarp = MWarpTiles / MWarps;
static constexpr index_t NIterPerWarp = NWarpTiles / NWarps;
static constexpr index_t KPerWarp = KPerBlock / KWarps;
static constexpr index_t NPerWarp = NPerBlock / NWarps;
static_assert(NWarps == 2, "NWarps == 2 for ping-pong!");
static_assert(KWarpTiles == KWarps, "Wrong!");
static constexpr index_t warp_size = get_warp_size();
static constexpr index_t warp_num = BlockSize / warp_size;
@@ -103,27 +106,35 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
static_assert(warp_num * warp_size == BlockSize, "Wrong!");
static_assert(sizeof(ADataType) == sizeof(BDataType), "Wrong!");
static constexpr index_t ElementSize = sizeof(ADataType);
static constexpr index_t K2 = Problem::VectorLoadSize / ElementSize * PackedSize; // 16
static constexpr index_t K1 = WarpTile::at(I2) / K2; // 8
static constexpr index_t K0 = KPerWarp / (K1 * K2);
static_assert(K0 * K1 * K2 == KPerWarp, "Wrong!");
static_assert(K0 == 1, "Wrong!");
static_assert(sizeof(ADataType) == sizeof(ComputeDataType));
static constexpr index_t ElementSize = sizeof(ComputeDataType);
static constexpr index_t K2 = Problem::VectorLoadSize / ElementSize * PackedSize;
// We define kDramLoadPackElems as 128 for fp6 because K2 == 16, so in this way we have correct
// values for K1 (number of contiguous lanes in a row)
static constexpr index_t kDramLoadPackElems =
std::is_same_v<ComputeDataType, pk_fp6x16_t>
? kDramLoadPackBytes
: kDramLoadPackBytes / sizeof(ComputeDataType) * PackedSize;
static constexpr index_t PacksPerLdsRow = std::min(kDramLoadPackElems, KPerBlock) / K2;
static constexpr index_t K1 = PacksPerLdsRow;
static constexpr index_t K0 = KPerBlock / (K1 * K2);
static_assert(K0 * K1 * K2 == KPerBlock, "Wrong!");
static constexpr index_t SwizzleFactor = WarpTileK / static_cast<index_t>(WGAccess) / K2;
CK_TILE_DEVICE static constexpr bool IsPreshuffle() { return Preshuffle; }
CK_TILE_DEVICE static constexpr auto MakeADramTileDistribution()
{
constexpr index_t M2 = warp_size / K1; // 8
constexpr index_t M1 = warp_num; // 8
constexpr index_t M2 = warp_size / K1;
constexpr index_t M1 = warp_num;
constexpr index_t M0 = MPerBlock / M1 / M2;
static_assert(M0 * M1 * M2 == MPerBlock, "wrong!");
return make_static_tile_distribution(
ck_tile::tile_distribution_encoding<
ck_tile::sequence<>,
ck_tile::tuple<ck_tile::sequence<M0, M1, M2>, // [123] 8 8
ck_tile::sequence<K0, K1, K2>>, // 1 8 16
ck_tile::tuple<ck_tile::sequence<M0, M1, M2>, ck_tile::sequence<K0, K1, K2>>,
ck_tile::tuple<ck_tile::sequence<1>, ck_tile::sequence<1, 2>>, // M0 M2,K1
ck_tile::tuple<ck_tile::sequence<1>, ck_tile::sequence<2, 1>>,
ck_tile::sequence<1, 2, 2>, // M0,K0,K2
@@ -134,36 +145,33 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
{
if constexpr(Preshuffle)
{
constexpr index_t K1_ = warp_size; // 64
constexpr index_t K0_ = KPerBlock * WarpTileN / K1_ / K2; // 2
constexpr index_t K1_ = warp_size;
constexpr index_t K0_ = KPerBlock * WarpTileN / K1_ / K2;
static_assert(K0_ * K1_ * K2 == KPerBlock * WarpTileN, "wrong!");
constexpr index_t N1 = warp_num / NWarps / K0_; // 2
constexpr index_t N0 = NPerBlock / WarpTileN / N1 / NWarps; // 4
constexpr index_t N1 = warp_num / NWarps / K0_;
constexpr index_t N0 = NPerBlock / WarpTileN / N1 / NWarps;
static_assert(NWarps * N0 * N1 == NPerBlock / WarpTileN, "wrong!");
return make_static_tile_distribution(
tile_distribution_encoding< //
sequence<>,
tuple<sequence<NWarps, N0, N1>, // 2 [4] 2
sequence<K0_, K1_, K2>>, // 2 64 16
tuple<sequence<1, 1, 2>, sequence<2>>, // NWarps,N1,K0 K1
tuple<sequence<0, 2, 0>, sequence<1>>,
sequence<1, 2>, // N0,K2
sequence<1, 2>>{});
tile_distribution_encoding<sequence<>,
tuple<sequence<NWarps, N0, N1>, sequence<K0_, K1_, K2>>,
tuple<sequence<1, 1, 2>, sequence<2>>, // NWarps,N1,K0 K1
tuple<sequence<0, 2, 0>, sequence<1>>,
sequence<1, 2>, // N0,K2
sequence<1, 2>>{});
}
else
{
constexpr index_t N2 = warp_size / K1; // 8
constexpr index_t N1 = warp_num / NWarps; // 4
constexpr index_t N0 = NPerBlock / N1 / N2 / NWarps; // 4
constexpr index_t N2 = warp_size / K1;
constexpr index_t N1 = warp_num / NWarps;
constexpr index_t N0 = NPerBlock / N1 / N2 / NWarps;
static_assert(NWarps * N0 * N1 * N2 == NPerBlock, "wrong!");
return make_static_tile_distribution(
tile_distribution_encoding< //
tile_distribution_encoding<
sequence<>,
tuple<sequence<NWarps, N0, N1, N2>, // 2 [4] 4 8
sequence<K0, K1, K2>>, // 1 8 16
tuple<sequence<NWarps, N0, N1, N2>, sequence<K0, K1, K2>>,
tuple<sequence<1, 1>, sequence<1, 2>>, // NWarps,N1 N2,K1
tuple<sequence<0, 2>, sequence<3, 1>>,
sequence<1, 2, 2>, // N0,K0,K2
@@ -182,7 +190,7 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
const index_t k_tiles = cols / (KWarps * K1 * K2);
const auto col_lens = make_tuple(k_tiles, number<KWarps>{}, number<K1>{}, number<K2>{});
constexpr index_t M1 = warp_size / static_cast<index_t>(WGAccess) / K1; // 4
constexpr index_t M1 = SwizzleFactor;
const index_t M0 = integer_divide_ceil(rows, M1);
const auto row_lens = make_tuple(M0, number<M1>{});
@@ -233,52 +241,64 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
template <index_t MNPerBlock, index_t warp_groups_>
CK_TILE_DEVICE static constexpr auto MakeABLdsBlockDescriptor_()
{
constexpr index_t M4 = warp_size / static_cast<index_t>(WGAccess) / K1; // 4
constexpr index_t M3 = static_cast<index_t>(WGAccess); // 2
constexpr index_t M2 = WarpTileM / M4 / M3; // 2
constexpr index_t M4 = SwizzleFactor;
constexpr index_t M3 = warp_size / (M4 * K1);
constexpr index_t M2 = WarpTileM / M4 / M3;
constexpr index_t M1 = (warp_num / warp_groups_) / M2;
constexpr index_t M0 = MNPerBlock / M1 / M2 / M3 / M4;
static_assert(M1 * M0 * M2 * M3 * M4 == MNPerBlock, "wrong!");
constexpr index_t PadSize = 16;
constexpr index_t PadSize = SwizzleFactor * K2;
// Padding is needed between waves writing to LDS for a single mfma tile
// Example: instruction 16x128 8 bit
// 2 waves read elements from gmem to lds but they are consumed by a single
// wave when reading from LDS, so we need padding there but not between
// consecutive mfma tiles in LDS
constexpr auto desc_0 = make_naive_tensor_descriptor( //
number_tuple<M2, KWarps, M1, M0, K0, M3, M4, K1, K2>{},
number_tuple<KWarps * M1 * M0 * K0 * M3 * M4 * K1 * K2 + PadSize,
M1 * M0 * K0 * M3 * M4 * K1 * K2,
M0 * K0 * M3 * M4 * K1 * K2,
K0 * M3 * M4 * K1 * K2,
M3 * M4 * K1 * K2,
M4 * K1 * K2,
K1 * K2,
K2,
1>{},
make_tuple(number<M0>{},
number<K0>{},
number<M1>{},
number<M2>{},
number<M3>{},
number<M4>{},
number<K1>{},
number<K2>{}),
make_tuple(number<K0 * M1*(M2 * (M3 * M4 * K1 * K2) + (M2 - 1) * PadSize)>{},
number<M1*(M2 * (M3 * M4 * K1 * K2) + (M2 - 1) * PadSize)>{},
number<M2*(M3 * M4 * K1 * K2) + (M2 - 1) * PadSize>{},
number<M3 * M4 * K1 * K2 + PadSize>{},
number<M4 * K1 * K2>{},
number<K1 * K2>{},
number<K2>{},
number<1>{}),
number<K2>{},
number<1>{});
constexpr auto desc_1 = transform_tensor_descriptor(
desc_0,
make_tuple(make_pass_through_transform(number<M2>{}),
make_pass_through_transform(number<KWarps>{}),
make_pass_through_transform(number<M1>{}),
make_pass_through_transform(number<M0>{}),
make_tuple(make_pass_through_transform(number<M0>{}),
make_pass_through_transform(number<K0>{}),
make_pass_through_transform(number<M1>{}),
make_pass_through_transform(number<M2>{}),
make_pass_through_transform(number<M3>{}),
make_xor_transform(make_tuple(number<M4>{}, number<K1>{})),
make_pass_through_transform(number<K2>{})),
container_concat(generate_tuple([](auto i) { return sequence<i>{}; }, number<6>{}),
make_tuple(sequence<6, 7>{}),
make_tuple(sequence<8>{})),
container_concat(generate_tuple([](auto i) { return sequence<i>{}; }, number<6>{}),
make_tuple(sequence<6, 7>{}),
make_tuple(sequence<8>{})));
constexpr auto desc_2 = transform_tensor_descriptor( //
container_concat(generate_tuple([](auto i) { return sequence<i>{}; }, number<5>{}),
make_tuple(sequence<5, 6>{}),
make_tuple(sequence<7>{})),
container_concat(generate_tuple([](auto i) { return sequence<i>{}; }, number<5>{}),
make_tuple(sequence<5, 6>{}),
make_tuple(sequence<7>{})));
constexpr auto desc_2 = transform_tensor_descriptor(
desc_1,
make_tuple(make_merge_transform_v3_division_mod(number_tuple<M0, M1, M2, M3, M4>{}),
make_merge_transform_v3_division_mod(number_tuple<KWarps, K0, K1, K2>{})),
make_tuple(sequence<3, 2, 0, 5, 6>{}, sequence<1, 4, 7, 8>{}),
make_merge_transform_v3_division_mod(number_tuple<K0, K1, K2>{})),
make_tuple(sequence<0, 2, 3, 4, 5>{}, sequence<1, 6, 7>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return desc_2;
}
CK_TILE_DEVICE static constexpr auto MakeALdsBlockDescriptor()
@@ -317,8 +337,8 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
else
{
constexpr index_t K1_ = warp_size / WarpTileN; // 4
constexpr index_t K0_ = KPerWarp / K1_ / K2; // 2
static_assert(K0_ * K1_ * K2 == KPerWarp, "wrong!");
constexpr index_t K0_ = KPerBlock / K1_ / K2; // 2
static_assert(K0_ * K1_ * K2 == KPerBlock, "wrong!");
constexpr index_t N2 = warp_size / K1_; // 16
constexpr index_t N1 = warp_num / NWarps / K0_; // 2
@@ -342,15 +362,17 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
CK_TILE_DEVICE static constexpr index_t GetSmemSizeA()
{
constexpr index_t sizeofType =
std::is_same_v<AComputeDataType, pk_fp6x16_t> ? 16 : sizeof(AComputeDataType);
constexpr index_t desc_size = MakeALdsBlockDescriptor().get_element_space_size();
return integer_least_multiple(sizeof(typename Problem::ADataType) * desc_size / PackedSize,
16);
return integer_least_multiple(sizeofType * desc_size / PackedSize, 16);
}
CK_TILE_DEVICE static constexpr index_t GetSmemSizeB()
{
constexpr index_t sizeofType =
std::is_same_v<BComputeDataType, pk_fp6x16_t> ? 16 : sizeof(BComputeDataType);
constexpr index_t desc_size = MakeBLdsBlockDescriptor().get_element_space_size();
return integer_least_multiple(sizeof(typename Problem::BDataType) * desc_size / PackedSize,
16);
return integer_least_multiple(sizeofType * desc_size / PackedSize, 16);
}
CK_TILE_DEVICE static constexpr index_t GetSmemSize()
@@ -371,8 +393,8 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
// TODO: Fix for transpose
constexpr auto wg_attr_num_access = WGAccess;
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
using WarpGemm = WarpGemmDispatcher<typename Problem::AComputeDataType,
typename Problem::BComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
@@ -382,11 +404,12 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
false,
wg_attr_num_access>;
using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
BlockWarps,
WarpGemm>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::AComputeDataType,
typename Problem::BComputeDataType,
typename Problem::CDataType,
BlockWarps,
WarpGemm>;
return BlockGemmARegBRegCRegEightWavesV1<Problem, BlockGemmPolicy>{};
}

View File

@@ -63,8 +63,11 @@ struct GemmPipelineAgBgCrEightWavesImplBase : public GemmPipelineAgBgCrImplBase<
DstBlockTile& dst_block_tile,
SrcTileWindow& lds_tile_window) const
{
// swizzle factor limitation
using static_move_ys =
std::conditional_t<std::is_same_v<DataType, pk_fp6x16_t>, false_type, true_type>;
lds_tile_window.set_bottom_tensor_view_data_ptr(smem);
lds_tile_window.load(dst_block_tile, number<-1>{}, true_type{}, true_type{});
lds_tile_window.load(dst_block_tile, number<-1>{}, true_type{}, static_move_ys{});
}
template <typename DataType, typename DstBlockTile, typename SrcTileWindow>
@@ -72,6 +75,9 @@ struct GemmPipelineAgBgCrEightWavesImplBase : public GemmPipelineAgBgCrImplBase<
DstBlockTile& dst_block_tile,
SrcTileWindow& lds_tile_window) const
{
// swizzle factor limitation
using static_move_ys =
std::conditional_t<std::is_same_v<DataType, pk_fp6x16_t>, false_type, true_type>;
lds_tile_window.set_bottom_tensor_view_data_ptr(smem);
static_for_product<number<NIterPerWarp>, number<KIterPerWarp>>{}(
[&](auto nIter, auto kIter) {
@@ -80,7 +86,7 @@ struct GemmPipelineAgBgCrEightWavesImplBase : public GemmPipelineAgBgCrImplBase<
dst_block_tile[nIter][kIter],
number<-1>{},
true_type{},
true_type{});
static_move_ys{});
});
}

View File

@@ -29,8 +29,8 @@ struct MXGemmPipelineAgBgCrCompAsyncEightWavesPolicy
using ComputeDataType = AComputeDataType;
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>, "Wrong!");
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::ColumnMajor>, "Wrong!");
static_assert(is_any_of<AComputeDataType, fp8_t, bf8_t, pk_fp4_t>::value);
static_assert(is_any_of<BComputeDataType, fp8_t, bf8_t, pk_fp4_t>::value);
static_assert(is_any_of<AComputeDataType, fp8_t, bf8_t, pk_fp4_t, pk_fp6x16_t>::value);
static_assert(is_any_of<BComputeDataType, fp8_t, bf8_t, pk_fp4_t, pk_fp6x16_t>::value);
static_assert(std::is_same_v<AComputeDataType, BComputeDataType>);
static_assert(std::is_same_v<CDataType, float>);
@@ -79,7 +79,6 @@ struct MXGemmPipelineAgBgCrCompAsyncEightWavesPolicy
static constexpr index_t KPerWarp = KPerBlock / KWarps;
static constexpr index_t NPerWarp = NPerBlock / NWarps;
static_assert(NWarps == 2, "NWarps == 2 for ping-pong!");
static_assert(KWarpTiles == KWarps, "Wrong!");
static constexpr index_t warp_size = get_warp_size();
static constexpr index_t warp_num = BlockSize / warp_size;
@@ -92,7 +91,6 @@ struct MXGemmPipelineAgBgCrCompAsyncEightWavesPolicy
static constexpr index_t K1 = WarpTile::at(I2) / K2; // 8
static constexpr index_t K0 = KPerWarp / (K1 * K2);
static_assert(K0 * K1 * K2 == KPerWarp, "Wrong!");
static_assert(K0 == 1, "Wrong!");
CK_TILE_HOST_DEVICE static constexpr auto GetKStepAQ() { return KPerBlockScale; }
CK_TILE_HOST_DEVICE static constexpr auto GetKStepBQ() { return KPerBlockScale; }

View File

@@ -311,20 +311,52 @@ using CompAsyncConfig16x16x128 = std::tuple<ALayout,
CompAsync>;
template <typename ALayout, typename BLayout, typename CLayout, typename InputType>
using CompAsyncEightWavesConfig = std::tuple<ALayout,
BLayout,
CLayout,
InputType, // AType
InputType, // BType
F32, // AccType
F16, // OutputType
I192, // MBlockTileSize
I256, // NBlockTileSize
I128, // KBlockTileSize
I16, // MWarpTileSize
I16, // NWarpTileSize
Intrawave,
CompAsyncEightWaves>;
using CompAsyncEightWavesConfig4Bit = std::tuple<ALayout,
BLayout,
CLayout,
InputType, // AType
InputType, // BType
F32, // AccType
F16, // OutputType
I128, // MBlockTileSize
I256, // NBlockTileSize
I256, // KBlockTileSize
I16, // MWarpTileSize
I16, // NWarpTileSize
Intrawave,
CompAsyncEightWaves>;
template <typename ALayout, typename BLayout, typename CLayout, typename InputType>
using CompAsyncEightWavesConfig8Bit = std::tuple<ALayout,
BLayout,
CLayout,
InputType, // AType
InputType, // BType
F32, // AccType
F16, // OutputType
I128, // MBlockTileSize
I256, // NBlockTileSize
I128, // KBlockTileSize
I16, // MWarpTileSize
I16, // NWarpTileSize
Intrawave,
CompAsyncEightWaves>;
template <typename ALayout, typename BLayout, typename CLayout, typename InputType>
using CompAsyncEightWavesConfig16Bit = std::tuple<ALayout,
BLayout,
CLayout,
InputType, // AType
InputType, // BType
F32, // AccType
F16, // OutputType
I192, // MBlockTileSize
I256, // NBlockTileSize
I64, // KBlockTileSize
I16, // MWarpTileSize
I16, // NWarpTileSize
Intrawave,
CompAsyncEightWaves>;
using KernelTypesCompAsync = ::testing::Types<CompAsyncConfig<Row, Row, Row, F16>,
CompAsyncConfig<Row, Col, Row, F16>,
@@ -339,7 +371,11 @@ using KernelTypesCompAsync16x16x128 = ::testing::Types<CompAsyncConfig16x16x128<
CompAsyncConfig16x16x128<Row, Col, Row, F8>>;
using KernelTypesCompAsyncEightWaves =
::testing::Types<CompAsyncEightWavesConfig<Row, Col, Row, F8>>;
::testing::Types<CompAsyncEightWavesConfig8Bit<Row, Col, Row, F8>,
CompAsyncEightWavesConfig8Bit<Row, Col, Row, BF8>,
CompAsyncEightWavesConfig4Bit<Row, Col, Row, F4>,
CompAsyncEightWavesConfig16Bit<Row, Col, Row, F16>,
CompAsyncEightWavesConfig16Bit<Row, Col, Row, BF16>>;
// clang-format off
using KernelTypesCompV6 = ::testing::Types<

View File

@@ -63,8 +63,11 @@ constexpr ck_tile::index_t get_k_warp_tile()
return 16;
#endif
#else
if constexpr(PipelineType == GemmPipelineType::CompAsyncEightWaves)
if constexpr(PipelineType == GemmPipelineType::CompAsyncEightWaves && sizeof(PrecType) == 1)
return 128;
else if constexpr(PipelineType == GemmPipelineType::CompAsyncEightWaves &&
sizeof(PrecType) == 2)
return 32;
// CompAsyncConfig16x16x128
else if constexpr(PipelineType == GemmPipelineType::CompAsync && M_Warp_Tile == 16)
return 128;