mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-28 18:56:59 +00:00
[rocm-libraries] ROCm/rocm-libraries#6089 (commit c876d18)
[CK Tile] Extend type support EightWave pipeline ## Motivation EightWave pipeline was designed for 8 bit types. This PR extend support for any FP type ## Technical Details - Generalize policy to support any FP type - Change LDS layout to fix bank conflicts. This removes all bank conflicts in the pipeline (checked for all supported types). Remaining bank conflicts are related to Cshuffle epilogue. ## Test Plan Added GEMM tests with new supported types. Note that FP6 is also supported for MX GEMM but the PR was reverted so no tests were added for it. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
054436ca4a
commit
1b4fbd95fd
@@ -62,7 +62,7 @@ struct GemmPipelineAgBgCrCompAsyncEightWaves : public BaseGemmPipelineAgBgCrComp
|
||||
static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(I1);
|
||||
static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(I2);
|
||||
|
||||
static constexpr index_t kflatKPerBlock = BlockGemmShape::flatKPerBlock;
|
||||
static constexpr index_t kflatKPerWarp = BlockGemmShape::flatKPerWarp;
|
||||
|
||||
static constexpr index_t MIterPerWarp = MPerBlock / (MWarps * WarpGemm::kM);
|
||||
static constexpr index_t NIterPerWarp = NPerBlock / (NWarps * WarpGemm::kN);
|
||||
@@ -170,9 +170,9 @@ struct GemmPipelineAgBgCrCompAsyncEightWaves : public BaseGemmPipelineAgBgCrComp
|
||||
static_assert((MPerBlock == AsDramBlockWindowTmp{}.get_window_lengths()[I0] &&
|
||||
KPerBlock == AsDramBlockWindowTmp{}.get_window_lengths()[I1]),
|
||||
"A block window has incorrect lengths for defined ALayout!");
|
||||
static_assert(Preshuffle //
|
||||
static_assert(Preshuffle
|
||||
? (NWarps == BsDramBlockWindowTmp{}.get_window_lengths()[I0] &&
|
||||
kflatKPerBlock == BsDramBlockWindowTmp{}.get_window_lengths()[I1])
|
||||
kflatKPerWarp == BsDramBlockWindowTmp{}.get_window_lengths()[I1])
|
||||
: (NPerBlock == BsDramBlockWindowTmp{}.get_window_lengths()[I0] &&
|
||||
KPerBlock == BsDramBlockWindowTmp{}.get_window_lengths()[I1]),
|
||||
"B block window has incorrect lengths for defined BLayout!");
|
||||
|
||||
@@ -21,6 +21,8 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
|
||||
static constexpr index_t kDramLoadPackBytes = 128;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
@@ -33,14 +35,17 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
|
||||
"ALayout must be RowMajor!");
|
||||
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::ColumnMajor>,
|
||||
"BLayout must be ColumnMajor!");
|
||||
static_assert(is_any_of<AComputeDataType, fp8_t, bf8_t, pk_fp4_t>::value);
|
||||
static_assert(is_any_of<BComputeDataType, fp8_t, bf8_t, pk_fp4_t>::value);
|
||||
static_assert(
|
||||
is_any_of<AComputeDataType, fp8_t, bf8_t, pk_fp4_t, pk_fp6x16_t, fp16_t, bf16_t>::value);
|
||||
static_assert(
|
||||
is_any_of<BComputeDataType, fp8_t, bf8_t, pk_fp4_t, pk_fp6x16_t, fp16_t, bf16_t>::value);
|
||||
static_assert(std::is_same_v<AComputeDataType, BComputeDataType>);
|
||||
static_assert(std::is_same_v<CDataType, float>);
|
||||
|
||||
static constexpr auto WGAccess = std::is_same_v<ComputeDataType, fp8_t>
|
||||
? WGAttrNumAccessEnum::Double
|
||||
: WGAttrNumAccessEnum::Single;
|
||||
static constexpr auto WGAccess =
|
||||
std::is_same_v<ComputeDataType, fp8_t> || std::is_same_v<ComputeDataType, bf8_t>
|
||||
? WGAttrNumAccessEnum::Double
|
||||
: WGAttrNumAccessEnum::Single;
|
||||
static constexpr auto PackedSize = numeric_traits<ComputeDataType>::PackedSize;
|
||||
|
||||
using BlockGemmShape = typename Problem::BlockGemmShape;
|
||||
@@ -92,10 +97,8 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
|
||||
static constexpr index_t KWarps = BlockWarps::at(I2);
|
||||
static constexpr index_t MIterPerWarp = MWarpTiles / MWarps;
|
||||
static constexpr index_t NIterPerWarp = NWarpTiles / NWarps;
|
||||
static constexpr index_t KPerWarp = KPerBlock / KWarps;
|
||||
static constexpr index_t NPerWarp = NPerBlock / NWarps;
|
||||
static_assert(NWarps == 2, "NWarps == 2 for ping-pong!");
|
||||
static_assert(KWarpTiles == KWarps, "Wrong!");
|
||||
|
||||
static constexpr index_t warp_size = get_warp_size();
|
||||
static constexpr index_t warp_num = BlockSize / warp_size;
|
||||
@@ -103,27 +106,35 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
|
||||
static_assert(warp_num * warp_size == BlockSize, "Wrong!");
|
||||
|
||||
static_assert(sizeof(ADataType) == sizeof(BDataType), "Wrong!");
|
||||
static constexpr index_t ElementSize = sizeof(ADataType);
|
||||
static constexpr index_t K2 = Problem::VectorLoadSize / ElementSize * PackedSize; // 16
|
||||
static constexpr index_t K1 = WarpTile::at(I2) / K2; // 8
|
||||
static constexpr index_t K0 = KPerWarp / (K1 * K2);
|
||||
static_assert(K0 * K1 * K2 == KPerWarp, "Wrong!");
|
||||
static_assert(K0 == 1, "Wrong!");
|
||||
static_assert(sizeof(ADataType) == sizeof(ComputeDataType));
|
||||
static constexpr index_t ElementSize = sizeof(ComputeDataType);
|
||||
static constexpr index_t K2 = Problem::VectorLoadSize / ElementSize * PackedSize;
|
||||
// We define kDramLoadPackElems as 128 for fp6 because K2 == 16, so in this way we have correct
|
||||
// values for K1 (number of contiguous lanes in a row)
|
||||
static constexpr index_t kDramLoadPackElems =
|
||||
std::is_same_v<ComputeDataType, pk_fp6x16_t>
|
||||
? kDramLoadPackBytes
|
||||
: kDramLoadPackBytes / sizeof(ComputeDataType) * PackedSize;
|
||||
static constexpr index_t PacksPerLdsRow = std::min(kDramLoadPackElems, KPerBlock) / K2;
|
||||
static constexpr index_t K1 = PacksPerLdsRow;
|
||||
static constexpr index_t K0 = KPerBlock / (K1 * K2);
|
||||
static_assert(K0 * K1 * K2 == KPerBlock, "Wrong!");
|
||||
|
||||
static constexpr index_t SwizzleFactor = WarpTileK / static_cast<index_t>(WGAccess) / K2;
|
||||
|
||||
CK_TILE_DEVICE static constexpr bool IsPreshuffle() { return Preshuffle; }
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
constexpr index_t M2 = warp_size / K1; // 8
|
||||
constexpr index_t M1 = warp_num; // 8
|
||||
constexpr index_t M2 = warp_size / K1;
|
||||
constexpr index_t M1 = warp_num;
|
||||
constexpr index_t M0 = MPerBlock / M1 / M2;
|
||||
static_assert(M0 * M1 * M2 == MPerBlock, "wrong!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
ck_tile::tile_distribution_encoding<
|
||||
ck_tile::sequence<>,
|
||||
ck_tile::tuple<ck_tile::sequence<M0, M1, M2>, // [123] 8 8
|
||||
ck_tile::sequence<K0, K1, K2>>, // 1 8 16
|
||||
ck_tile::tuple<ck_tile::sequence<M0, M1, M2>, ck_tile::sequence<K0, K1, K2>>,
|
||||
ck_tile::tuple<ck_tile::sequence<1>, ck_tile::sequence<1, 2>>, // M0 M2,K1
|
||||
ck_tile::tuple<ck_tile::sequence<1>, ck_tile::sequence<2, 1>>,
|
||||
ck_tile::sequence<1, 2, 2>, // M0,K0,K2
|
||||
@@ -134,36 +145,33 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
|
||||
{
|
||||
if constexpr(Preshuffle)
|
||||
{
|
||||
constexpr index_t K1_ = warp_size; // 64
|
||||
constexpr index_t K0_ = KPerBlock * WarpTileN / K1_ / K2; // 2
|
||||
constexpr index_t K1_ = warp_size;
|
||||
constexpr index_t K0_ = KPerBlock * WarpTileN / K1_ / K2;
|
||||
static_assert(K0_ * K1_ * K2 == KPerBlock * WarpTileN, "wrong!");
|
||||
|
||||
constexpr index_t N1 = warp_num / NWarps / K0_; // 2
|
||||
constexpr index_t N0 = NPerBlock / WarpTileN / N1 / NWarps; // 4
|
||||
constexpr index_t N1 = warp_num / NWarps / K0_;
|
||||
constexpr index_t N0 = NPerBlock / WarpTileN / N1 / NWarps;
|
||||
static_assert(NWarps * N0 * N1 == NPerBlock / WarpTileN, "wrong!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding< //
|
||||
sequence<>,
|
||||
tuple<sequence<NWarps, N0, N1>, // 2 [4] 2
|
||||
sequence<K0_, K1_, K2>>, // 2 64 16
|
||||
tuple<sequence<1, 1, 2>, sequence<2>>, // NWarps,N1,K0 K1
|
||||
tuple<sequence<0, 2, 0>, sequence<1>>,
|
||||
sequence<1, 2>, // N0,K2
|
||||
sequence<1, 2>>{});
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<NWarps, N0, N1>, sequence<K0_, K1_, K2>>,
|
||||
tuple<sequence<1, 1, 2>, sequence<2>>, // NWarps,N1,K0 K1
|
||||
tuple<sequence<0, 2, 0>, sequence<1>>,
|
||||
sequence<1, 2>, // N0,K2
|
||||
sequence<1, 2>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t N2 = warp_size / K1; // 8
|
||||
constexpr index_t N1 = warp_num / NWarps; // 4
|
||||
constexpr index_t N0 = NPerBlock / N1 / N2 / NWarps; // 4
|
||||
constexpr index_t N2 = warp_size / K1;
|
||||
constexpr index_t N1 = warp_num / NWarps;
|
||||
constexpr index_t N0 = NPerBlock / N1 / N2 / NWarps;
|
||||
static_assert(NWarps * N0 * N1 * N2 == NPerBlock, "wrong!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding< //
|
||||
tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<NWarps, N0, N1, N2>, // 2 [4] 4 8
|
||||
sequence<K0, K1, K2>>, // 1 8 16
|
||||
tuple<sequence<NWarps, N0, N1, N2>, sequence<K0, K1, K2>>,
|
||||
tuple<sequence<1, 1>, sequence<1, 2>>, // NWarps,N1 N2,K1
|
||||
tuple<sequence<0, 2>, sequence<3, 1>>,
|
||||
sequence<1, 2, 2>, // N0,K0,K2
|
||||
@@ -182,7 +190,7 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
|
||||
const index_t k_tiles = cols / (KWarps * K1 * K2);
|
||||
const auto col_lens = make_tuple(k_tiles, number<KWarps>{}, number<K1>{}, number<K2>{});
|
||||
|
||||
constexpr index_t M1 = warp_size / static_cast<index_t>(WGAccess) / K1; // 4
|
||||
constexpr index_t M1 = SwizzleFactor;
|
||||
const index_t M0 = integer_divide_ceil(rows, M1);
|
||||
const auto row_lens = make_tuple(M0, number<M1>{});
|
||||
|
||||
@@ -233,52 +241,64 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
|
||||
template <index_t MNPerBlock, index_t warp_groups_>
|
||||
CK_TILE_DEVICE static constexpr auto MakeABLdsBlockDescriptor_()
|
||||
{
|
||||
constexpr index_t M4 = warp_size / static_cast<index_t>(WGAccess) / K1; // 4
|
||||
constexpr index_t M3 = static_cast<index_t>(WGAccess); // 2
|
||||
constexpr index_t M2 = WarpTileM / M4 / M3; // 2
|
||||
constexpr index_t M4 = SwizzleFactor;
|
||||
constexpr index_t M3 = warp_size / (M4 * K1);
|
||||
constexpr index_t M2 = WarpTileM / M4 / M3;
|
||||
constexpr index_t M1 = (warp_num / warp_groups_) / M2;
|
||||
constexpr index_t M0 = MNPerBlock / M1 / M2 / M3 / M4;
|
||||
|
||||
static_assert(M1 * M0 * M2 * M3 * M4 == MNPerBlock, "wrong!");
|
||||
|
||||
constexpr index_t PadSize = 16;
|
||||
constexpr index_t PadSize = SwizzleFactor * K2;
|
||||
|
||||
// Padding is needed between waves writing to LDS for a single mfma tile
|
||||
// Example: instruction 16x128 8 bit
|
||||
// 2 waves read elements from gmem to lds but they are consumed by a single
|
||||
// wave when reading from LDS, so we need padding there but not between
|
||||
// consecutive mfma tiles in LDS
|
||||
constexpr auto desc_0 = make_naive_tensor_descriptor( //
|
||||
number_tuple<M2, KWarps, M1, M0, K0, M3, M4, K1, K2>{},
|
||||
number_tuple<KWarps * M1 * M0 * K0 * M3 * M4 * K1 * K2 + PadSize,
|
||||
M1 * M0 * K0 * M3 * M4 * K1 * K2,
|
||||
M0 * K0 * M3 * M4 * K1 * K2,
|
||||
K0 * M3 * M4 * K1 * K2,
|
||||
M3 * M4 * K1 * K2,
|
||||
M4 * K1 * K2,
|
||||
K1 * K2,
|
||||
K2,
|
||||
1>{},
|
||||
make_tuple(number<M0>{},
|
||||
number<K0>{},
|
||||
number<M1>{},
|
||||
number<M2>{},
|
||||
number<M3>{},
|
||||
number<M4>{},
|
||||
number<K1>{},
|
||||
number<K2>{}),
|
||||
make_tuple(number<K0 * M1*(M2 * (M3 * M4 * K1 * K2) + (M2 - 1) * PadSize)>{},
|
||||
number<M1*(M2 * (M3 * M4 * K1 * K2) + (M2 - 1) * PadSize)>{},
|
||||
number<M2*(M3 * M4 * K1 * K2) + (M2 - 1) * PadSize>{},
|
||||
number<M3 * M4 * K1 * K2 + PadSize>{},
|
||||
number<M4 * K1 * K2>{},
|
||||
number<K1 * K2>{},
|
||||
number<K2>{},
|
||||
number<1>{}),
|
||||
number<K2>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto desc_1 = transform_tensor_descriptor(
|
||||
desc_0,
|
||||
make_tuple(make_pass_through_transform(number<M2>{}),
|
||||
make_pass_through_transform(number<KWarps>{}),
|
||||
make_pass_through_transform(number<M1>{}),
|
||||
make_pass_through_transform(number<M0>{}),
|
||||
make_tuple(make_pass_through_transform(number<M0>{}),
|
||||
make_pass_through_transform(number<K0>{}),
|
||||
make_pass_through_transform(number<M1>{}),
|
||||
make_pass_through_transform(number<M2>{}),
|
||||
make_pass_through_transform(number<M3>{}),
|
||||
make_xor_transform(make_tuple(number<M4>{}, number<K1>{})),
|
||||
make_pass_through_transform(number<K2>{})),
|
||||
container_concat(generate_tuple([](auto i) { return sequence<i>{}; }, number<6>{}),
|
||||
make_tuple(sequence<6, 7>{}),
|
||||
make_tuple(sequence<8>{})),
|
||||
container_concat(generate_tuple([](auto i) { return sequence<i>{}; }, number<6>{}),
|
||||
make_tuple(sequence<6, 7>{}),
|
||||
make_tuple(sequence<8>{})));
|
||||
constexpr auto desc_2 = transform_tensor_descriptor( //
|
||||
container_concat(generate_tuple([](auto i) { return sequence<i>{}; }, number<5>{}),
|
||||
make_tuple(sequence<5, 6>{}),
|
||||
make_tuple(sequence<7>{})),
|
||||
container_concat(generate_tuple([](auto i) { return sequence<i>{}; }, number<5>{}),
|
||||
make_tuple(sequence<5, 6>{}),
|
||||
make_tuple(sequence<7>{})));
|
||||
|
||||
constexpr auto desc_2 = transform_tensor_descriptor(
|
||||
desc_1,
|
||||
make_tuple(make_merge_transform_v3_division_mod(number_tuple<M0, M1, M2, M3, M4>{}),
|
||||
make_merge_transform_v3_division_mod(number_tuple<KWarps, K0, K1, K2>{})),
|
||||
make_tuple(sequence<3, 2, 0, 5, 6>{}, sequence<1, 4, 7, 8>{}),
|
||||
make_merge_transform_v3_division_mod(number_tuple<K0, K1, K2>{})),
|
||||
make_tuple(sequence<0, 2, 3, 4, 5>{}, sequence<1, 6, 7>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return desc_2;
|
||||
}
|
||||
CK_TILE_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
@@ -317,8 +337,8 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
|
||||
else
|
||||
{
|
||||
constexpr index_t K1_ = warp_size / WarpTileN; // 4
|
||||
constexpr index_t K0_ = KPerWarp / K1_ / K2; // 2
|
||||
static_assert(K0_ * K1_ * K2 == KPerWarp, "wrong!");
|
||||
constexpr index_t K0_ = KPerBlock / K1_ / K2; // 2
|
||||
static_assert(K0_ * K1_ * K2 == KPerBlock, "wrong!");
|
||||
|
||||
constexpr index_t N2 = warp_size / K1_; // 16
|
||||
constexpr index_t N1 = warp_num / NWarps / K0_; // 2
|
||||
@@ -342,15 +362,17 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
|
||||
|
||||
CK_TILE_DEVICE static constexpr index_t GetSmemSizeA()
|
||||
{
|
||||
constexpr index_t sizeofType =
|
||||
std::is_same_v<AComputeDataType, pk_fp6x16_t> ? 16 : sizeof(AComputeDataType);
|
||||
constexpr index_t desc_size = MakeALdsBlockDescriptor().get_element_space_size();
|
||||
return integer_least_multiple(sizeof(typename Problem::ADataType) * desc_size / PackedSize,
|
||||
16);
|
||||
return integer_least_multiple(sizeofType * desc_size / PackedSize, 16);
|
||||
}
|
||||
CK_TILE_DEVICE static constexpr index_t GetSmemSizeB()
|
||||
{
|
||||
constexpr index_t sizeofType =
|
||||
std::is_same_v<BComputeDataType, pk_fp6x16_t> ? 16 : sizeof(BComputeDataType);
|
||||
constexpr index_t desc_size = MakeBLdsBlockDescriptor().get_element_space_size();
|
||||
return integer_least_multiple(sizeof(typename Problem::BDataType) * desc_size / PackedSize,
|
||||
16);
|
||||
return integer_least_multiple(sizeofType * desc_size / PackedSize, 16);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr index_t GetSmemSize()
|
||||
@@ -371,8 +393,8 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
|
||||
// TODO: Fix for transpose
|
||||
constexpr auto wg_attr_num_access = WGAccess;
|
||||
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::AComputeDataType,
|
||||
typename Problem::BComputeDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
@@ -382,11 +404,12 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
|
||||
false,
|
||||
wg_attr_num_access>;
|
||||
|
||||
using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
BlockWarps,
|
||||
WarpGemm>;
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::AComputeDataType,
|
||||
typename Problem::BComputeDataType,
|
||||
typename Problem::CDataType,
|
||||
BlockWarps,
|
||||
WarpGemm>;
|
||||
|
||||
return BlockGemmARegBRegCRegEightWavesV1<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
@@ -63,8 +63,11 @@ struct GemmPipelineAgBgCrEightWavesImplBase : public GemmPipelineAgBgCrImplBase<
|
||||
DstBlockTile& dst_block_tile,
|
||||
SrcTileWindow& lds_tile_window) const
|
||||
{
|
||||
// swizzle factor limitation
|
||||
using static_move_ys =
|
||||
std::conditional_t<std::is_same_v<DataType, pk_fp6x16_t>, false_type, true_type>;
|
||||
lds_tile_window.set_bottom_tensor_view_data_ptr(smem);
|
||||
lds_tile_window.load(dst_block_tile, number<-1>{}, true_type{}, true_type{});
|
||||
lds_tile_window.load(dst_block_tile, number<-1>{}, true_type{}, static_move_ys{});
|
||||
}
|
||||
|
||||
template <typename DataType, typename DstBlockTile, typename SrcTileWindow>
|
||||
@@ -72,6 +75,9 @@ struct GemmPipelineAgBgCrEightWavesImplBase : public GemmPipelineAgBgCrImplBase<
|
||||
DstBlockTile& dst_block_tile,
|
||||
SrcTileWindow& lds_tile_window) const
|
||||
{
|
||||
// swizzle factor limitation
|
||||
using static_move_ys =
|
||||
std::conditional_t<std::is_same_v<DataType, pk_fp6x16_t>, false_type, true_type>;
|
||||
lds_tile_window.set_bottom_tensor_view_data_ptr(smem);
|
||||
static_for_product<number<NIterPerWarp>, number<KIterPerWarp>>{}(
|
||||
[&](auto nIter, auto kIter) {
|
||||
@@ -80,7 +86,7 @@ struct GemmPipelineAgBgCrEightWavesImplBase : public GemmPipelineAgBgCrImplBase<
|
||||
dst_block_tile[nIter][kIter],
|
||||
number<-1>{},
|
||||
true_type{},
|
||||
true_type{});
|
||||
static_move_ys{});
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -29,8 +29,8 @@ struct MXGemmPipelineAgBgCrCompAsyncEightWavesPolicy
|
||||
using ComputeDataType = AComputeDataType;
|
||||
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>, "Wrong!");
|
||||
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::ColumnMajor>, "Wrong!");
|
||||
static_assert(is_any_of<AComputeDataType, fp8_t, bf8_t, pk_fp4_t>::value);
|
||||
static_assert(is_any_of<BComputeDataType, fp8_t, bf8_t, pk_fp4_t>::value);
|
||||
static_assert(is_any_of<AComputeDataType, fp8_t, bf8_t, pk_fp4_t, pk_fp6x16_t>::value);
|
||||
static_assert(is_any_of<BComputeDataType, fp8_t, bf8_t, pk_fp4_t, pk_fp6x16_t>::value);
|
||||
static_assert(std::is_same_v<AComputeDataType, BComputeDataType>);
|
||||
static_assert(std::is_same_v<CDataType, float>);
|
||||
|
||||
@@ -79,7 +79,6 @@ struct MXGemmPipelineAgBgCrCompAsyncEightWavesPolicy
|
||||
static constexpr index_t KPerWarp = KPerBlock / KWarps;
|
||||
static constexpr index_t NPerWarp = NPerBlock / NWarps;
|
||||
static_assert(NWarps == 2, "NWarps == 2 for ping-pong!");
|
||||
static_assert(KWarpTiles == KWarps, "Wrong!");
|
||||
|
||||
static constexpr index_t warp_size = get_warp_size();
|
||||
static constexpr index_t warp_num = BlockSize / warp_size;
|
||||
@@ -92,7 +91,6 @@ struct MXGemmPipelineAgBgCrCompAsyncEightWavesPolicy
|
||||
static constexpr index_t K1 = WarpTile::at(I2) / K2; // 8
|
||||
static constexpr index_t K0 = KPerWarp / (K1 * K2);
|
||||
static_assert(K0 * K1 * K2 == KPerWarp, "Wrong!");
|
||||
static_assert(K0 == 1, "Wrong!");
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKStepAQ() { return KPerBlockScale; }
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKStepBQ() { return KPerBlockScale; }
|
||||
|
||||
@@ -311,20 +311,52 @@ using CompAsyncConfig16x16x128 = std::tuple<ALayout,
|
||||
CompAsync>;
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout, typename InputType>
|
||||
using CompAsyncEightWavesConfig = std::tuple<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
InputType, // AType
|
||||
InputType, // BType
|
||||
F32, // AccType
|
||||
F16, // OutputType
|
||||
I192, // MBlockTileSize
|
||||
I256, // NBlockTileSize
|
||||
I128, // KBlockTileSize
|
||||
I16, // MWarpTileSize
|
||||
I16, // NWarpTileSize
|
||||
Intrawave,
|
||||
CompAsyncEightWaves>;
|
||||
using CompAsyncEightWavesConfig4Bit = std::tuple<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
InputType, // AType
|
||||
InputType, // BType
|
||||
F32, // AccType
|
||||
F16, // OutputType
|
||||
I128, // MBlockTileSize
|
||||
I256, // NBlockTileSize
|
||||
I256, // KBlockTileSize
|
||||
I16, // MWarpTileSize
|
||||
I16, // NWarpTileSize
|
||||
Intrawave,
|
||||
CompAsyncEightWaves>;
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout, typename InputType>
|
||||
using CompAsyncEightWavesConfig8Bit = std::tuple<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
InputType, // AType
|
||||
InputType, // BType
|
||||
F32, // AccType
|
||||
F16, // OutputType
|
||||
I128, // MBlockTileSize
|
||||
I256, // NBlockTileSize
|
||||
I128, // KBlockTileSize
|
||||
I16, // MWarpTileSize
|
||||
I16, // NWarpTileSize
|
||||
Intrawave,
|
||||
CompAsyncEightWaves>;
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout, typename InputType>
|
||||
using CompAsyncEightWavesConfig16Bit = std::tuple<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
InputType, // AType
|
||||
InputType, // BType
|
||||
F32, // AccType
|
||||
F16, // OutputType
|
||||
I192, // MBlockTileSize
|
||||
I256, // NBlockTileSize
|
||||
I64, // KBlockTileSize
|
||||
I16, // MWarpTileSize
|
||||
I16, // NWarpTileSize
|
||||
Intrawave,
|
||||
CompAsyncEightWaves>;
|
||||
|
||||
using KernelTypesCompAsync = ::testing::Types<CompAsyncConfig<Row, Row, Row, F16>,
|
||||
CompAsyncConfig<Row, Col, Row, F16>,
|
||||
@@ -339,7 +371,11 @@ using KernelTypesCompAsync16x16x128 = ::testing::Types<CompAsyncConfig16x16x128<
|
||||
CompAsyncConfig16x16x128<Row, Col, Row, F8>>;
|
||||
|
||||
using KernelTypesCompAsyncEightWaves =
|
||||
::testing::Types<CompAsyncEightWavesConfig<Row, Col, Row, F8>>;
|
||||
::testing::Types<CompAsyncEightWavesConfig8Bit<Row, Col, Row, F8>,
|
||||
CompAsyncEightWavesConfig8Bit<Row, Col, Row, BF8>,
|
||||
CompAsyncEightWavesConfig4Bit<Row, Col, Row, F4>,
|
||||
CompAsyncEightWavesConfig16Bit<Row, Col, Row, F16>,
|
||||
CompAsyncEightWavesConfig16Bit<Row, Col, Row, BF16>>;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypesCompV6 = ::testing::Types<
|
||||
|
||||
@@ -63,8 +63,11 @@ constexpr ck_tile::index_t get_k_warp_tile()
|
||||
return 16;
|
||||
#endif
|
||||
#else
|
||||
if constexpr(PipelineType == GemmPipelineType::CompAsyncEightWaves)
|
||||
if constexpr(PipelineType == GemmPipelineType::CompAsyncEightWaves && sizeof(PrecType) == 1)
|
||||
return 128;
|
||||
else if constexpr(PipelineType == GemmPipelineType::CompAsyncEightWaves &&
|
||||
sizeof(PrecType) == 2)
|
||||
return 32;
|
||||
// CompAsyncConfig16x16x128
|
||||
else if constexpr(PipelineType == GemmPipelineType::CompAsync && M_Warp_Tile == 16)
|
||||
return 128;
|
||||
|
||||
Reference in New Issue
Block a user