mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#4368 (commit 17f7dfc)
[CK_TILE][FMHA] Support microscaling (mxfp8 and mxfp4) on gfx950 (#4368) ## Motivation Microscaling types (mxfp8 and mxfp4) for fwd qr pipeline ## Technical Details The microscaling is used when quant scale mode is `BlockAttentionQuantScaleEnum::MX` and `Q/K/P/VDataType` are fp8/bf8/fp4. Supported features: * only "qr" pipeline is implemented * hdim 128 and 256 (smaller hdim are not possible due to restrictions of "qr" pipeline, but they can be computed using instances with padding) * both 32x32x64 and 16x16x128 scale MFMAs are supported * Q and K scales are applied in hdim, V scales - in seqlen dimension * column-major V only * batch and group mode * bias, Alibi (tested but no instances by default, just like fp8) * masking etc. Aiter PR with new API args: https://github.com/ROCm/aiter/pull/2008 ## Test Plan ``` ninja test_ck_tile_fmha_fwd_mxfp8 && bin/test_ck_tile_fmha_fwd_mxfp8 ninja test_ck_tile_fmha_fwd_mxfp4 && bin/test_ck_tile_fmha_fwd_mxfp4 ``` ## Test Result The tests must pass. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
c85c272c39
commit
2312eef6c3
@@ -0,0 +1,374 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A is block distributed tensor
|
||||
// A scale is block distributed tensor
|
||||
// B is block window on shared memory
|
||||
// B scale is block distributed tensor
|
||||
// C is block distributed tensor
|
||||
// It supports only warp gemms with transposed C.
|
||||
// TargetCMPerLane_ controls how many consecutive elements of matrix C are calculated by each lane.
|
||||
template <typename Problem_, typename Policy_, index_t TargetCMPerLane_ = -1>
|
||||
struct BlockGemmMxARegBSmemCRegV1
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
static constexpr index_t MWarp = config.template at<1>();
|
||||
static constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
|
||||
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
static constexpr index_t CMPerLane = WarpGemm::WarpGemmAttribute::Impl::kCM0PerLane *
|
||||
WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane;
|
||||
static constexpr index_t TargetCMPerLane = max(CMPerLane, TargetCMPerLane_);
|
||||
|
||||
static_assert(TargetCMPerLane % CMPerLane == 0);
|
||||
static constexpr index_t NIterPack = TargetCMPerLane / CMPerLane;
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor,
|
||||
typename ABlockTensorTmp,
|
||||
typename AScaleBlockTensorTmp,
|
||||
typename BBlockWindowTmp,
|
||||
typename BScaleBlockTensorTmp>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const AScaleBlockTensorTmp& a_scale_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp,
|
||||
const BScaleBlockTensorTmp& b_scale_block_tensor_tmp) const
|
||||
{
|
||||
static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>);
|
||||
|
||||
static_assert(MPerBlock == ABlockTensorTmp{}.get_lengths()[number<0>{}] &&
|
||||
NPerBlock == BBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
KPerBlock == ABlockTensorTmp{}.get_lengths()[number<1>{}]);
|
||||
|
||||
const index_t iNWarp = get_warp_id() % NWarp;
|
||||
|
||||
// construct A-block-tensor from A-Block-tensor-tmp
|
||||
auto a_block_tensor = make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(
|
||||
MakeABlockTileDistribution());
|
||||
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
|
||||
|
||||
auto a_scale_block_tensor =
|
||||
make_static_distributed_tensor<remove_cv_t<typename AScaleBlockTensorTmp::DataType>>(
|
||||
MakeAScaleBlockTileDistribution());
|
||||
a_scale_block_tensor.get_thread_buffer() = a_scale_block_tensor_tmp.get_thread_buffer();
|
||||
|
||||
auto b_scale_block_tensor =
|
||||
make_static_distributed_tensor<remove_cv_t<typename BScaleBlockTensorTmp::DataType>>(
|
||||
MakeBScaleBlockTileDistribution());
|
||||
b_scale_block_tensor.get_thread_buffer() = b_scale_block_tensor_tmp.get_thread_buffer();
|
||||
|
||||
// Construct B-warp-window
|
||||
// Matrix B is shuffled in such a way that each lane calculates TargetCMPerLane consecutive
|
||||
// elements of matrix C. See MakeBScaleBlockTileDistribution and MakeCBlockTile that shuffle
|
||||
// B scale and C in the same way.
|
||||
auto b_warp_window_tmp = [&] {
|
||||
using Impl = typename WarpGemm::WarpGemmAttribute::Impl;
|
||||
|
||||
constexpr index_t N3 = Impl::kCM1PerLane;
|
||||
constexpr index_t N2 = TargetCMPerLane / N3;
|
||||
constexpr index_t N1 = Impl::kCMLane;
|
||||
constexpr index_t N0 = NPerBlock / (N1 * N2 * N3);
|
||||
|
||||
const auto b_lds_unmerged = transform_tensor_view(
|
||||
b_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(make_unmerge_transform(
|
||||
make_tuple(number<N0>{}, number<N1>{}, number<N2>{}, number<N3>{})),
|
||||
make_pass_through_transform(number<KPerBlock>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0, 2, 1, 3>{}, sequence<4>{}));
|
||||
|
||||
const auto b_lds_merged = transform_tensor_view(
|
||||
b_lds_unmerged,
|
||||
make_tuple(make_merge_transform(
|
||||
make_tuple(number<N0>{}, number<N2>{}, number<N1>{}, number<N3>{})),
|
||||
make_pass_through_transform(number<KPerBlock>{})),
|
||||
make_tuple(sequence<0, 1, 2, 3>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return make_tile_window(
|
||||
b_lds_merged,
|
||||
make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
|
||||
b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WarpGemm::kN, 0},
|
||||
make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{}));
|
||||
}();
|
||||
|
||||
// check C-block-distribution
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(MakeCBlockTile()
|
||||
.get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>,
|
||||
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>);
|
||||
|
||||
using AWarpDstr = typename WarpGemm::AWarpDstr;
|
||||
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WarpGemm::AWarpTensor;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
using AScaleWarpDstr =
|
||||
remove_cvref_t<decltype(make_static_tile_distribution(MakeAScaleWarpDstrEncoding()))>;
|
||||
using AScaleWarpTensor =
|
||||
static_distributed_tensor<remove_cv_t<typename AScaleBlockTensorTmp::DataType>,
|
||||
AScaleWarpDstr>;
|
||||
|
||||
using BScaleWarpDstr =
|
||||
remove_cvref_t<decltype(make_static_tile_distribution(MakeBScaleWarpDstrEncoding()))>;
|
||||
using BScaleWarpTensor =
|
||||
static_distributed_tensor<remove_cv_t<typename BScaleBlockTensorTmp::DataType>,
|
||||
BScaleWarpDstr>;
|
||||
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
constexpr auto a_scale_warp_y_lengths =
|
||||
to_sequence(AScaleWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto b_scale_warp_y_lengths =
|
||||
to_sequence(BScaleWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
constexpr auto a_scale_warp_y_index_zeros =
|
||||
uniform_sequence_gen_t<AScaleWarpDstr::NDimY, 0>{};
|
||||
constexpr auto b_scale_warp_y_index_zeros =
|
||||
uniform_sequence_gen_t<BScaleWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
auto b_warp_window = b_warp_window_tmp;
|
||||
move_tile_window(
|
||||
b_warp_window,
|
||||
{nIter * (NPerBlock / NIterPerWarp), kIter * (KPerBlock / KIterPerWarp)});
|
||||
// read B warp tensor from B Block window
|
||||
const auto b_warp_tensor = load_tile(b_warp_window);
|
||||
|
||||
BScaleWarpTensor b_scale_warp_tensor;
|
||||
|
||||
b_scale_warp_tensor.get_thread_buffer() =
|
||||
b_scale_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter / NIterPack, nIter % NIterPack, kIter>{},
|
||||
b_scale_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1, 1>{}, b_scale_warp_y_lengths));
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
AScaleWarpTensor a_scale_warp_tensor;
|
||||
|
||||
a_scale_warp_tensor.get_thread_buffer() =
|
||||
a_scale_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_scale_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_scale_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter / NIterPack, nIter % NIterPack>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WarpGemm{}.template operator()<0, 0>(
|
||||
c_warp_tensor,
|
||||
a_warp_tensor,
|
||||
b_warp_tensor,
|
||||
int32_t(a_scale_warp_tensor.get_thread_buffer()[0]),
|
||||
int32_t(b_scale_warp_tensor.get_thread_buffer()[0]));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter / NIterPack, nIter % NIterPack>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t MPerBlock_ = MPerBlock, index_t KPerBlock_ = KPerBlock>
|
||||
CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution()
|
||||
{
|
||||
constexpr index_t MIterPerWarp_ = MPerBlock_ / (MWarp * WarpGemm::kM);
|
||||
constexpr index_t KIterPerWarp_ = KPerBlock_ / WarpGemm::kK;
|
||||
|
||||
constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp_, MWarp>, sequence<KIterPerWarp_>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
|
||||
return make_static_tile_distribution(a_block_dstr_encode);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeAScaleWarpDstrEncoding()
|
||||
{
|
||||
using Impl = typename WarpGemm::WarpGemmAttribute::Impl;
|
||||
|
||||
constexpr index_t AScaleMLane = Impl::kAMLane;
|
||||
constexpr index_t ABScaleKLane = Impl::kABKLane;
|
||||
constexpr index_t ABScaleKPerLane = Impl::kABKPerLane / Impl::kScaleGranularity;
|
||||
|
||||
return ck_tile::tile_distribution_encoding<
|
||||
ck_tile::sequence<>,
|
||||
ck_tile::tuple<ck_tile::sequence<AScaleMLane>,
|
||||
ck_tile::sequence<ABScaleKLane, ABScaleKPerLane>>,
|
||||
ck_tile::tuple<ck_tile::sequence<2, 1>>,
|
||||
ck_tile::tuple<ck_tile::sequence<0, 0>>,
|
||||
ck_tile::sequence<2>,
|
||||
ck_tile::sequence<1>>{};
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeBScaleWarpDstrEncoding()
|
||||
{
|
||||
using Impl = typename WarpGemm::WarpGemmAttribute::Impl;
|
||||
|
||||
constexpr index_t BScaleNLane = Impl::kBNLane;
|
||||
constexpr index_t ABScaleKLane = Impl::kABKLane;
|
||||
constexpr index_t ABScaleKPerLane = Impl::kABKPerLane / Impl::kScaleGranularity;
|
||||
|
||||
return ck_tile::tile_distribution_encoding<
|
||||
ck_tile::sequence<>,
|
||||
ck_tile::tuple<ck_tile::sequence<BScaleNLane>,
|
||||
ck_tile::sequence<ABScaleKLane, ABScaleKPerLane>>,
|
||||
ck_tile::tuple<ck_tile::sequence<2, 1>>,
|
||||
ck_tile::tuple<ck_tile::sequence<0, 0>>,
|
||||
ck_tile::sequence<2>,
|
||||
ck_tile::sequence<1>>{};
|
||||
}
|
||||
|
||||
template <index_t MPerBlock_ = MPerBlock, index_t KPerBlock_ = KPerBlock>
|
||||
CK_TILE_DEVICE static constexpr auto MakeAScaleBlockTileDistribution()
|
||||
{
|
||||
constexpr index_t MIterPerWarp_ = MPerBlock_ / (MWarp * WarpGemm::kM);
|
||||
constexpr index_t KIterPerWarp_ = KPerBlock_ / WarpGemm::kK;
|
||||
|
||||
constexpr auto a_scale_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp_, MWarp>, sequence<KIterPerWarp_>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_scale_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_scale_block_outer_dstr_encoding, MakeAScaleWarpDstrEncoding());
|
||||
|
||||
return make_static_tile_distribution(a_scale_block_dstr_encode);
|
||||
}
|
||||
|
||||
template <index_t NPerBlock_ = NPerBlock, index_t KPerBlock_ = KPerBlock>
|
||||
CK_TILE_DEVICE static constexpr auto MakeBScaleBlockTileDistribution()
|
||||
{
|
||||
constexpr index_t NIterPerWarp_ = NPerBlock_ / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp_ = KPerBlock_ / WarpGemm::kK;
|
||||
|
||||
using Impl = typename WarpGemm::WarpGemmAttribute::Impl;
|
||||
|
||||
constexpr index_t ABScaleKLane = Impl::kABKLane;
|
||||
constexpr index_t ABScaleKPerLane = Impl::kABKPerLane / Impl::kScaleGranularity;
|
||||
|
||||
constexpr auto b_scale_block_dstr_encode = ck_tile::tile_distribution_encoding<
|
||||
ck_tile::sequence<MWarp>,
|
||||
ck_tile::tuple<ck_tile::sequence<NIterPerWarp_ / NIterPack,
|
||||
NWarp,
|
||||
Impl::kCMLane,
|
||||
NIterPack,
|
||||
Impl::kCM0PerLane,
|
||||
Impl::kCM1PerLane>,
|
||||
ck_tile::sequence<KIterPerWarp_, ABScaleKLane, ABScaleKPerLane>>,
|
||||
ck_tile::tuple<ck_tile::sequence<0, 1>, ck_tile::sequence<2, 1, 1, 1>>,
|
||||
ck_tile::tuple<ck_tile::sequence<0, 1>, ck_tile::sequence<1, 4, 2, 5>>,
|
||||
ck_tile::sequence<1, 1, 2, 2>,
|
||||
ck_tile::sequence<0, 3, 0, 2>>{};
|
||||
|
||||
return make_static_tile_distribution(b_scale_block_dstr_encode);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
using Impl = typename WarpGemm::WarpGemmAttribute::Impl;
|
||||
|
||||
constexpr auto c_block_dstr_encode = ck_tile::tile_distribution_encoding<
|
||||
ck_tile::sequence<>,
|
||||
ck_tile::tuple<ck_tile::sequence<MIterPerWarp, MWarp, Impl::kCNLane>,
|
||||
ck_tile::sequence<NIterPerWarp / NIterPack,
|
||||
NWarp,
|
||||
Impl::kCMLane,
|
||||
NIterPack,
|
||||
Impl::kCM0PerLane,
|
||||
Impl::kCM1PerLane>>,
|
||||
ck_tile::tuple<ck_tile::sequence<1, 2>, ck_tile::sequence<2, 1>>,
|
||||
ck_tile::tuple<ck_tile::sequence<1, 1>, ck_tile::sequence<2, 2>>,
|
||||
ck_tile::sequence<1, 2, 2, 2, 2>,
|
||||
ck_tile::sequence<0, 0, 3, 4, 5>>{};
|
||||
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
template <typename ABlockTensorTmp,
|
||||
typename AScaleBlockTensorTmp,
|
||||
typename BBlockWindowTmp,
|
||||
typename BScaleBlockTensorTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const AScaleBlockTensorTmp& a_scale_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp,
|
||||
const BScaleBlockTensorTmp& b_scale_block_tensor_tmp) const
|
||||
{
|
||||
auto c_block_tensor = MakeCBlockTile();
|
||||
operator()(c_block_tensor,
|
||||
a_block_tensor_tmp,
|
||||
a_scale_block_tensor_tmp,
|
||||
b_block_window_tmp,
|
||||
b_scale_block_tensor_tmp);
|
||||
return c_block_tensor;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,36 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename AType_,
|
||||
typename BType_,
|
||||
typename CType_,
|
||||
typename BlockWarps_,
|
||||
typename WarpGemm_>
|
||||
struct BlockGemmMxARegBSmemCRegV1CustomPolicy
|
||||
{
|
||||
using AType = remove_cvref_t<AType_>;
|
||||
using BType = remove_cvref_t<BType_>;
|
||||
using CType = remove_cvref_t<CType_>;
|
||||
|
||||
using BlockWarps = remove_cvref_t<BlockWarps_>;
|
||||
|
||||
static constexpr index_t kMWarps = BlockWarps::at(number<0>{});
|
||||
static constexpr index_t kNWarps = BlockWarps::at(number<1>{});
|
||||
static constexpr index_t kKWarps = BlockWarps::at(number<2>{});
|
||||
|
||||
using WarpGemm = remove_cvref_t<WarpGemm_>;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
return make_tuple(WarpGemm{}, kMWarps, kNWarps);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -407,6 +407,12 @@ using WarpGemmMfma_f32_16x16x128_bf8_bf8_CTransposed =
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<bf8_t, bf8_t>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_16x16x128_fp4_fp4_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<pk_fp4_t, pk_fp4_t>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_32x32x64_fp8_fp8 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_fp8<WGAttrCtlEnum::Default_>,
|
||||
@@ -427,6 +433,36 @@ using WarpGemmMfma_f32_32x32x64_bf8_bf8 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_bf8<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_32x32x64_fp8_fp8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_fp8<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_32x32x64_fp8_bf8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_bf8<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_32x32x64_bf8_fp8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_fp8<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_32x32x64_bf8_bf8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_bf8<WGAttrCtlEnum::Default_>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_32x32x64_fp4_fp4_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x64_f8f6f4<pk_fp4_t, pk_fp4_t>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
@@ -446,6 +446,19 @@ struct WarpGemmAttributeMfmaTransposedCDistribution
|
||||
Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
|
||||
}
|
||||
|
||||
template <index_t opselA, index_t opselB, bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const int32_t& a_scale,
|
||||
const BVecType& b_vec,
|
||||
const int32_t& b_scale,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
// swap A and B
|
||||
Impl{}.template operator()<opselB, opselA>(
|
||||
c_vec, b_vec, b_scale, a_vec, a_scale, bool_constant<post_nop_>{});
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
@@ -540,6 +553,19 @@ struct WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB
|
||||
Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
|
||||
}
|
||||
|
||||
template <index_t opselA, index_t opselB, bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const int32_t& a_scale,
|
||||
const BVecType& b_vec,
|
||||
const int32_t& b_scale,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
// swap A and B
|
||||
Impl{}.template operator()<opselB, opselA>(
|
||||
c_vec, b_vec, b_scale, a_vec, a_scale, bool_constant<post_nop_>{});
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
|
||||
@@ -1599,6 +1599,8 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
static constexpr index_t kScaleGranularity = 32;
|
||||
|
||||
// To get unity scale: 2^(kDefaultScale - 127) = 1.0
|
||||
static constexpr index_t kDefaultScale = 0x7F7F7F7F;
|
||||
|
||||
@@ -1683,15 +1685,15 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4
|
||||
};
|
||||
|
||||
template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base
|
||||
struct WarpGemmAttributeMfmaImpl_f32_32x32x64_f8f6f4
|
||||
{
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = AType_;
|
||||
using BDataType = BType_;
|
||||
using CDataType = float;
|
||||
|
||||
using AVecType = ext_vector_t<ADataType, 32>;
|
||||
using BVecType = ext_vector_t<BDataType, 32>;
|
||||
using AVecType = ext_vector_t<ADataType, 32 / numeric_traits<ADataType>::PackedSize>;
|
||||
using BVecType = ext_vector_t<BDataType, 32 / numeric_traits<BDataType>::PackedSize>;
|
||||
using CVecType = ext_vector_t<CDataType, 16>;
|
||||
|
||||
static constexpr index_t kM = 32;
|
||||
@@ -1711,6 +1713,71 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base
|
||||
static constexpr index_t kCM0PerLane = 4;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
static constexpr index_t kScaleGranularity = 32;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
template <index_t opselA, index_t opselB, bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const int32_t& a_scale,
|
||||
const BVecType& b_vec,
|
||||
const int32_t& b_scale,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
auto dtype2conf = [](auto dtype) {
|
||||
if constexpr(std::is_same_v<decltype(dtype), fp8_t>)
|
||||
return make_tuple(number<0>{}, int32x8_t{});
|
||||
else if constexpr(std::is_same_v<decltype(dtype), bf8_t>)
|
||||
return make_tuple(number<1>{}, int32x8_t{});
|
||||
else if constexpr(std::is_same_v<decltype(dtype), pk_fp6x16_t>)
|
||||
return make_tuple(number<2>{}, pk_fp6x32_t{});
|
||||
// else if e3m2 => make_tuple(number<3>{}, int32x6_t{})
|
||||
else if constexpr(std::is_same_v<decltype(dtype), pk_fp4_t>)
|
||||
return make_tuple(number<4>{}, int32x4_t{});
|
||||
else
|
||||
static_assert(false, "Unsupported data type for mfma scale");
|
||||
};
|
||||
auto dtype2code = [&](auto dtype) { return dtype2conf(dtype)(number<0>{}); };
|
||||
auto dtype2vec = [&](auto dtype) { return dtype2conf(dtype)(number<1>{}); };
|
||||
auto arg256 = [&](auto x) {
|
||||
if constexpr(sizeof(x) == 16)
|
||||
return int32x8_t{x[0], x[1], x[2], x[3], 0, 0, 0, 0};
|
||||
else if constexpr(sizeof(x) == 24)
|
||||
return int32x8_t{x[0], x[1], x[2], x[3], x[4], x[5], 0, 0};
|
||||
else if constexpr(sizeof(x) == 32)
|
||||
return x;
|
||||
else
|
||||
static_assert(false, "Unexpected vector size for mfma scale");
|
||||
};
|
||||
|
||||
auto arg_a = bit_cast<decltype(dtype2vec(ADataType{}))>(a_vec);
|
||||
auto arg_b = bit_cast<decltype(dtype2vec(BDataType{}))>(b_vec);
|
||||
constexpr int cbsz = decltype(dtype2code(ADataType{}))::value;
|
||||
constexpr int blgp = decltype(dtype2code(BDataType{}))::value;
|
||||
c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
arg256(arg_a), arg256(arg_b), c_vec, cbsz, blgp, opselA, a_scale, opselB, b_scale);
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
ck_tile::ignore = a_scale;
|
||||
ck_tile::ignore = b_scale;
|
||||
#endif
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
template <index_t opselA, index_t opselB>
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec,
|
||||
const int32_t& a_scale,
|
||||
const BVecType& b_vec,
|
||||
const int32_t& b_scale) const
|
||||
{
|
||||
CVecType c_vec{0.f};
|
||||
operator()<opselA, opselB>(c_vec, a_vec, a_scale, b_vec, b_scale);
|
||||
return c_vec;
|
||||
}
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
@@ -1718,67 +1785,31 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
//__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(a, b, c, cbsz, blgp, opsel, scale_a,
|
||||
// opsel, scale_b)
|
||||
#if defined(__gfx950__)
|
||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
a_vec, b_vec, c_vec, 0, 0, 0, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
a_vec, b_vec, c_vec, 0, 1, 0, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
a_vec, b_vec, c_vec, 1, 0, 0, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
a_vec, b_vec, c_vec, 1, 1, 0, 0, 0, 0);
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
#endif
|
||||
operator()<0, 0>(c_vec, a_vec, 0, b_vec, 0);
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
a_vec, b_vec, CVecType{0.f}, 0, 0, 0, 0, 0, 0));
|
||||
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
a_vec, b_vec, CVecType{0.f}, 0, 1, 0, 0, 0, 0));
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
a_vec, b_vec, CVecType{0.f}, 1, 0, 0, 0, 0, 0));
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
a_vec, b_vec, CVecType{0.f}, 1, 1, 0, 0, 0, 0));
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
return CVecType{0.f};
|
||||
#endif
|
||||
return operator()<0, 0>(a_vec, 0, b_vec, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_fp8 =
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base<fp8_t, fp8_t, Ctrl_>;
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x64_f8f6f4<fp8_t, fp8_t, Ctrl_>;
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_bf8 =
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base<fp8_t, bf8_t, Ctrl_>;
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x64_f8f6f4<fp8_t, bf8_t, Ctrl_>;
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_fp8 =
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base<bf8_t, fp8_t, Ctrl_>;
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x64_f8f6f4<bf8_t, fp8_t, Ctrl_>;
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_bf8 =
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base<bf8_t, bf8_t, Ctrl_>;
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x64_f8f6f4<bf8_t, bf8_t, Ctrl_>;
|
||||
|
||||
// int8
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
|
||||
@@ -130,6 +130,8 @@ template<WGAttrNumAccessEnum I> struct Dispatcher<fp8_t, bf8_t, float, 16, 16, 1
|
||||
template<WGAttrNumAccessEnum I> struct Dispatcher<bf8_t, fp8_t, float, 16, 16, 128, true, false, false, I> { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8_CTransposed<I>; };
|
||||
template<WGAttrNumAccessEnum I> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 128, true, false, false, I> { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8_CTransposed<I>; };
|
||||
|
||||
template<WGAttrNumAccessEnum I> struct Dispatcher<pk_fp4_t, pk_fp4_t, float, 16, 16, 128, true, false, false, I> { using Type = WarpGemmMfma_f32_16x16x128_fp4_fp4_CTransposed<I>; };
|
||||
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<>; };
|
||||
template<> struct Dispatcher<fp8_t, bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<>; };
|
||||
template<> struct Dispatcher<bf8_t, fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8<>; };
|
||||
@@ -143,6 +145,13 @@ template<> struct Dispatcher<fp8_t, bf8_t, float, 32, 32, 64, false, false, fal
|
||||
template<> struct Dispatcher<bf8_t, fp8_t, float, 32, 32, 64, false, false, false, EQuad> { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8<EQuad>; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 64, false, false, false, EQuad> { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8<EQuad>; };
|
||||
|
||||
template<WGAttrNumAccessEnum I> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 64, true, false, false, I> { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8_CTransposed<I>; };
|
||||
template<WGAttrNumAccessEnum I> struct Dispatcher<fp8_t, bf8_t, float, 32, 32, 64, true, false, false, I> { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8_CTransposed<I>; };
|
||||
template<WGAttrNumAccessEnum I> struct Dispatcher<bf8_t, fp8_t, float, 32, 32, 64, true, false, false, I> { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8_CTransposed<I>; };
|
||||
template<WGAttrNumAccessEnum I> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 64, true, false, false, I> { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8_CTransposed<I>; };
|
||||
|
||||
template<WGAttrNumAccessEnum I> struct Dispatcher<pk_fp4_t, pk_fp4_t, float, 32, 32, 64, true, false, false, I> { using Type = WarpGemmMfma_f32_32x32x64_fp4_fp4_CTransposed<I>; };
|
||||
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8<>; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 32, false, false, false, EDouble> { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8<EDouble>; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_bf8_bf8<>; };
|
||||
@@ -152,7 +161,6 @@ template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 64, true> { using Ty
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 64, false> { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8<>; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 64, false, false, false, EDouble> { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8<EDouble>; };
|
||||
|
||||
|
||||
//WMMA cases
|
||||
template<bool TransposeC> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_f32_16x16x16_f8_f8<TransposeC>; };
|
||||
template<bool TransposeC> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_f32_16x16x16_bf8_bf8<TransposeC>; };
|
||||
|
||||
Reference in New Issue
Block a user