mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
[HotFix] add config and version files to pass on build info (#856)
* experiment with config file * experiment with version.h config * add more info to version.h * minor updates * minor updates * fix case where DTYPE is not used * large amount of files but minor changes * remove white space * minor changes to add more MACROs * fix cmakedefine01 * fix issue with CK internal conflict * fix define and define value * fix clang-format * fix formatting issue * experiment with cmake * clang format v12 to be consistent with miopen * avoid clang-format for config file
This commit is contained in:
@@ -35,8 +35,8 @@ struct BlockwiseSoftmax
|
||||
static constexpr index_t MRepeat = ThreadSliceDesc_M_K{}.GetLength(I0);
|
||||
static constexpr index_t KRepeat = ThreadSliceDesc_M_K{}.GetLength(I1);
|
||||
|
||||
using ThreadSliceDesc_M = decltype(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(ThreadSliceDesc_M_K{}.GetLength(I0))));
|
||||
using ThreadSliceDesc_M = decltype(make_naive_tensor_descriptor_packed(
|
||||
make_tuple(ThreadSliceDesc_M_K{}.GetLength(I0))));
|
||||
|
||||
using ThreadwiseMaxReduce = typename conditional<
|
||||
IgnoreNaN,
|
||||
|
||||
@@ -588,14 +588,18 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
|
||||
LoopSched>;
|
||||
|
||||
// desc for blockwise copy
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
|
||||
using AGridDesc_AK0_M_AK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
|
||||
AGridDesc_M_K{}))>;
|
||||
using BGridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
|
||||
BGridDesc_N_K{}))>;
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
|
||||
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
DsGridDesc_M_N{}))>;
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
EGridDesc_M_N{}))>;
|
||||
|
||||
// block-to-e-tile map
|
||||
using Block2ETileMap =
|
||||
|
||||
@@ -378,13 +378,16 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute<ALayout,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched>;
|
||||
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
|
||||
using AGridDesc_AK0_M_AK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
|
||||
AGridDesc_M_K{}))>;
|
||||
using BGridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
|
||||
BGridDesc_N_K{}))>;
|
||||
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype(
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}));
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
EGridDesc_M_N{}));
|
||||
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
|
||||
|
||||
// Argument
|
||||
|
||||
@@ -368,14 +368,18 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout,
|
||||
LoopSched>;
|
||||
|
||||
// desc for blockwise copy
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
|
||||
using AGridDesc_AK0_M_AK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
|
||||
AGridDesc_M_K{}))>;
|
||||
using BGridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
|
||||
BGridDesc_N_K{}))>;
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
|
||||
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
DsGridDesc_M_N{}))>;
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
EGridDesc_M_N{}))>;
|
||||
|
||||
// block-to-e-tile map
|
||||
using Block2ETileMap =
|
||||
|
||||
@@ -510,12 +510,15 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
|
||||
CDE1ShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched>;
|
||||
|
||||
using A0GridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultA0GridDescriptor_AK0_M_AK1(A0GridDesc_M_K{}))>;
|
||||
using B0GridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultB0GridDescriptor_BK0_N_BK1(B0GridDesc_N_K{}))>;
|
||||
using B1GridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultB1GridDescriptor_BK0_N_BK1(B1GridDesc_N_K{}))>;
|
||||
using A0GridDesc_AK0_M_AK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultA0GridDescriptor_AK0_M_AK1(
|
||||
A0GridDesc_M_K{}))>;
|
||||
using B0GridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultB0GridDescriptor_BK0_N_BK1(
|
||||
B0GridDesc_N_K{}))>;
|
||||
using B1GridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultB1GridDescriptor_BK0_N_BK1(
|
||||
B1GridDesc_N_K{}))>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
|
||||
@@ -355,14 +355,18 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
|
||||
LoopSched>;
|
||||
|
||||
// desc for blockwise copy
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
|
||||
using AGridDesc_AK0_M_AK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
|
||||
AGridDesc_M_K{}))>;
|
||||
using BGridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
|
||||
BGridDesc_N_K{}))>;
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
|
||||
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
DsGridDesc_M_N{}))>;
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
EGridDesc_M_N{}))>;
|
||||
|
||||
// block-to-e-tile map
|
||||
using Block2ETileMap =
|
||||
|
||||
@@ -364,11 +364,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
|
||||
// We have to separate mean var descriptor for gemm and layernorm bacause of different grid
|
||||
// layout(different padding)
|
||||
using GemmMeanVarGridDesc_M_NBlock = decltype(
|
||||
MakeMeanVarDescriptor_M_N<Sequence<true, false>, GemmMPerBlock, GemmNPerBlock>(1, 1));
|
||||
using GemmMeanVarGridDesc_M_NBlock =
|
||||
decltype(MakeMeanVarDescriptor_M_N<Sequence<true, false>, GemmMPerBlock, GemmNPerBlock>(1,
|
||||
1));
|
||||
|
||||
using GemmCountGridDesc_M_NBlock = decltype(
|
||||
MakeCountDescriptor_M_N<Sequence<true, false>, GemmMPerBlock, GemmNPerBlock>(1, 1));
|
||||
using GemmCountGridDesc_M_NBlock =
|
||||
decltype(MakeCountDescriptor_M_N<Sequence<true, false>, GemmMPerBlock, GemmNPerBlock>(1,
|
||||
1));
|
||||
|
||||
using LayernormMeanVarGridDesc_M_NBlock =
|
||||
decltype(MakeMeanVarDescriptor_M_N<Sequence<true, true>,
|
||||
|
||||
@@ -337,10 +337,12 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
|
||||
RThreadTransferDstScalarPerVector_MPerBlock,
|
||||
LoopSched>;
|
||||
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
|
||||
using AGridDesc_AK0_M_AK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
|
||||
AGridDesc_M_K{}))>;
|
||||
using BGridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
|
||||
BGridDesc_N_K{}))>;
|
||||
|
||||
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
|
||||
|
||||
|
||||
@@ -288,14 +288,18 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
PipelineVer>;
|
||||
|
||||
// desc for blockwise copy
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
|
||||
using AGridDesc_AK0_M_AK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
|
||||
AGridDesc_M_K{}))>;
|
||||
using BGridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
|
||||
BGridDesc_N_K{}))>;
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
|
||||
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
DsGridDesc_M_N{}))>;
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
EGridDesc_M_N{}))>;
|
||||
|
||||
// block-to-e-tile map
|
||||
using Block2ETileMap =
|
||||
|
||||
@@ -248,10 +248,12 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock>;
|
||||
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
|
||||
using AGridDesc_AK0_M_AK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
|
||||
AGridDesc_M_K{}))>;
|
||||
using BGridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
|
||||
BGridDesc_N_K{}))>;
|
||||
|
||||
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
|
||||
|
||||
|
||||
@@ -400,14 +400,18 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle
|
||||
LoopSched>;
|
||||
|
||||
// desc for blockwise copy
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
|
||||
using AGridDesc_AK0_M_AK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
|
||||
AGridDesc_M_K{}))>;
|
||||
using BGridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
|
||||
BGridDesc_N_K{}))>;
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
|
||||
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
DsGridDesc_M_N{}))>;
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
EGridDesc_M_N{}))>;
|
||||
|
||||
struct GroupedContractionBlock2ETileMap
|
||||
{
|
||||
|
||||
@@ -422,10 +422,12 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
using AGridDesc_M_K = decltype(transform_k0_m_k1_to_m_k(AGridDesc_AK0_M_AK1{}));
|
||||
using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_BK0_N_BK1{}));
|
||||
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype(
|
||||
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}));
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype(
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}));
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
DsGridDesc_M_N{}));
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
EGridDesc_M_N{}));
|
||||
|
||||
// block-to-e-tile map
|
||||
using Block2ETileMap =
|
||||
|
||||
@@ -381,8 +381,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
}
|
||||
|
||||
// desc for problem definition
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
|
||||
MakeAGridDescriptor_AK0_M_AK1<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>(
|
||||
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
|
||||
using BGridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>({}, {}))>;
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
|
||||
|
||||
@@ -320,8 +320,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
|
||||
}
|
||||
|
||||
// desc for problem definition
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
|
||||
MakeAGridDescriptor_AK0_M_AK1<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>(
|
||||
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
|
||||
using BGridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>({}, {}))>;
|
||||
using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N<CLayout>({}, {}))>;
|
||||
|
||||
@@ -446,8 +446,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
|
||||
return GetPaddedRGridDescriptor(r_grid_desc_mraw, NHoWo);
|
||||
}
|
||||
|
||||
using AGridDesc_M_K = remove_cvref_t<decltype(
|
||||
MakeAGridDescriptor_M_K<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
|
||||
using AGridDesc_M_K = remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(
|
||||
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
|
||||
using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>;
|
||||
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<DELayout>({}, {}))>;
|
||||
using RGridDesc_M = remove_cvref_t<decltype(MakeRGridDescriptor_M<RLayout>({}, {}))>;
|
||||
@@ -507,10 +507,12 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
|
||||
RThreadTransferDstScalarPerVector_MPerBlock,
|
||||
LoopSched>;
|
||||
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
|
||||
using AGridDesc_AK0_M_AK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
|
||||
AGridDesc_M_K{}))>;
|
||||
using BGridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
|
||||
BGridDesc_N_K{}))>;
|
||||
|
||||
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
|
||||
|
||||
|
||||
@@ -245,8 +245,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
}
|
||||
|
||||
// desc for problem definition
|
||||
using AGridDesc_M_K = remove_cvref_t<decltype(
|
||||
MakeAGridDescriptor_M_K<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
|
||||
using AGridDesc_M_K = remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(
|
||||
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
|
||||
using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>;
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
|
||||
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
|
||||
|
||||
@@ -361,8 +361,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
|
||||
}
|
||||
|
||||
// desc for problem definition
|
||||
using AGridDesc_M_K = remove_cvref_t<decltype(
|
||||
MakeAGridDescriptor_M_K<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
|
||||
using AGridDesc_M_K = remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(
|
||||
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
|
||||
using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>;
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
|
||||
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
|
||||
@@ -412,14 +412,18 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
|
||||
LoopSched>;
|
||||
|
||||
// desc for blockwise copy
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
|
||||
using AGridDesc_AK0_M_AK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
|
||||
AGridDesc_M_K{}))>;
|
||||
using BGridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
|
||||
BGridDesc_N_K{}))>;
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
|
||||
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
DsGridDesc_M_N{}))>;
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
EGridDesc_M_N{}))>;
|
||||
|
||||
// block-to-e-tile map
|
||||
using Block2ETileMap =
|
||||
|
||||
@@ -735,12 +735,12 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
}
|
||||
|
||||
// Check vector load/store requirement
|
||||
const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2
|
||||
? device_arg.a_mz_kz_strides_[1]
|
||||
: device_arg.a_mz_kz_strides_[0];
|
||||
const auto b_stride_lowest = BBlockTransferSrcVectorDim == 2
|
||||
? device_arg.b_nz_kz_strides_[1]
|
||||
: device_arg.b_nz_kz_strides_[0];
|
||||
const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2
|
||||
? device_arg.a_mz_kz_strides_[1]
|
||||
: device_arg.a_mz_kz_strides_[0];
|
||||
const auto b_stride_lowest = BBlockTransferSrcVectorDim == 2
|
||||
? device_arg.b_nz_kz_strides_[1]
|
||||
: device_arg.b_nz_kz_strides_[0];
|
||||
const auto b1_stride_lowest = B1BlockTransferSrcVectorDim == 2
|
||||
? device_arg.b1_nz_kz_strides_[1]
|
||||
: device_arg.b1_nz_kz_strides_[0];
|
||||
|
||||
@@ -272,14 +272,18 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched>;
|
||||
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
|
||||
using AGridDesc_AK0_M_AK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
|
||||
AGridDesc_M_K{}))>;
|
||||
using BGridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
|
||||
BGridDesc_N_K{}))>;
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
|
||||
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
DsGridDesc_M_N{}))>;
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
EGridDesc_M_N{}))>;
|
||||
|
||||
struct GroupedGemmBlock2ETileMap
|
||||
{
|
||||
|
||||
@@ -617,10 +617,12 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched>;
|
||||
|
||||
using AGridDesc_AKB_AK0_M_AK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultAGridDescriptor_AKB_AK0_M_AK1(AGridDesc_M_K{}, 1))>;
|
||||
using BGridDesc_BKB_BK0_N_BK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BKB_BK0_N_BK1(BGridDesc_N_K{}, 1))>;
|
||||
using AGridDesc_AKB_AK0_M_AK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AKB_AK0_M_AK1(
|
||||
AGridDesc_M_K{}, 1))>;
|
||||
using BGridDesc_BKB_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BKB_BK0_N_BK1(
|
||||
BGridDesc_N_K{}, 1))>;
|
||||
|
||||
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
|
||||
|
||||
|
||||
@@ -136,8 +136,8 @@ struct GridwiseMultiblockBatchNormForward
|
||||
using ThreadReduceDstDesc_M =
|
||||
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
|
||||
|
||||
using ThreadReduceSrcDesc_M_1 = decltype(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
|
||||
using ThreadReduceSrcDesc_M_1 = decltype(make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
|
||||
|
||||
using ThreadwiseWelford1 =
|
||||
ThreadwiseWelford<AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M>;
|
||||
|
||||
@@ -118,8 +118,8 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadReduceSrcDesc_M_1 = decltype(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
|
||||
using ThreadReduceSrcDesc_M_1 = decltype(make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
|
||||
using ThreadReduceDstDesc_M =
|
||||
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
|
||||
|
||||
|
||||
@@ -121,8 +121,8 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal
|
||||
static constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadReduceSrcDesc_M_1 = decltype(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
|
||||
using ThreadReduceSrcDesc_M_1 = decltype(make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
|
||||
using ThreadReduceDstDesc_M =
|
||||
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
|
||||
|
||||
|
||||
@@ -115,8 +115,8 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
|
||||
|
||||
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
|
||||
using ThreadReduceSrcDesc_M_1 = decltype(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
|
||||
using ThreadReduceSrcDesc_M_1 = decltype(make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MThreadSliceSize>{}, Number<1>{})));
|
||||
using ThreadReduceDstDesc_M =
|
||||
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
|
||||
|
||||
|
||||
@@ -101,8 +101,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using GridwiseGemmPipe = remove_cvref_t<decltype(
|
||||
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
|
||||
using GridwiseGemmPipe = remove_cvref_t<
|
||||
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
{
|
||||
@@ -346,14 +346,18 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
|
||||
remove_cvref_t<decltype(MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
|
||||
using DefaultBGridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
|
||||
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
|
||||
using MeanVarGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype(
|
||||
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(MeanVarGridDesc_M_NBlock{}))>;
|
||||
using CountGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t<decltype(
|
||||
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(CountGridDesc_M_NBlock{}))>;
|
||||
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
|
||||
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
EGridDesc_M_N{}))>;
|
||||
using MeanVarGridDescriptor_MBlock_MPerBlock_NBlock =
|
||||
remove_cvref_t<decltype(MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(
|
||||
MeanVarGridDesc_M_NBlock{}))>;
|
||||
using CountGridDescriptor_MBlock_MPerBlock_NBlock =
|
||||
remove_cvref_t<decltype(MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(
|
||||
CountGridDesc_M_NBlock{}))>;
|
||||
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
DsGridDesc_M_N{}))>;
|
||||
|
||||
using DefaultBlock2ETileMap =
|
||||
remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
|
||||
|
||||
@@ -102,8 +102,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using GridwiseGemmPipe = remove_cvref_t<decltype(
|
||||
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
|
||||
using GridwiseGemmPipe = remove_cvref_t<
|
||||
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
|
||||
|
||||
template <typename ABlockDesc_AK0_M_AK1>
|
||||
__host__ __device__ static constexpr auto
|
||||
@@ -286,8 +286,9 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
|
||||
c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
|
||||
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
CGridDesc_M_N{}))>;
|
||||
|
||||
using DefaultBlock2CTileMap =
|
||||
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
|
||||
|
||||
@@ -446,14 +446,17 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
|
||||
e1_grid_desc_m_n);
|
||||
}
|
||||
|
||||
using E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(E1GridDesc_M_N{}))>;
|
||||
using E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
E1GridDesc_M_N{}))>;
|
||||
|
||||
using D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t<decltype(
|
||||
MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(D0sGridDesc_M_N{}))>;
|
||||
using D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 =
|
||||
remove_cvref_t<decltype(MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
|
||||
D0sGridDesc_M_N{}))>;
|
||||
|
||||
using D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(D1sGridDesc_M_N{}))>;
|
||||
using D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
D1sGridDesc_M_N{}))>;
|
||||
|
||||
using DefaultBlock2E1TileMap =
|
||||
remove_cvref_t<decltype(MakeDefaultBlock2E1TileMap(E1GridDesc_M_N{}))>;
|
||||
|
||||
@@ -114,8 +114,8 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using GridwiseGemmPipe = remove_cvref_t<decltype(
|
||||
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
|
||||
using GridwiseGemmPipe = remove_cvref_t<
|
||||
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
|
||||
|
||||
template <typename ABlockDesc_AK0_M_AK1>
|
||||
__host__ __device__ static constexpr auto
|
||||
@@ -368,12 +368,14 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
|
||||
Number<NumD0Tensor>{});
|
||||
}
|
||||
|
||||
using D0sGridPointer = decltype(MakeD0sGridPointer());
|
||||
using D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t<decltype(
|
||||
MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(D0sGridDesc_M_N{}))>;
|
||||
using D0sGridPointer = decltype(MakeD0sGridPointer());
|
||||
using D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 =
|
||||
remove_cvref_t<decltype(MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
|
||||
D0sGridDesc_M_N{}))>;
|
||||
|
||||
using C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
MakeC1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(C1GridDesc_M_N{}))>;
|
||||
using C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(MakeC1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
C1GridDesc_M_N{}))>;
|
||||
|
||||
using DefaultBlock2CTileMap =
|
||||
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(C1GridDesc_M_N{}))>;
|
||||
|
||||
@@ -113,8 +113,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using GridwiseGemmPipe = remove_cvref_t<decltype(
|
||||
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
|
||||
using GridwiseGemmPipe = remove_cvref_t<
|
||||
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
|
||||
|
||||
template <typename ABlockDesc_AK0_M_AK1>
|
||||
__host__ __device__ static constexpr auto
|
||||
@@ -300,8 +300,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
|
||||
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
CGridDesc_M_N{}))>;
|
||||
|
||||
using DefaultBlock2CTileMap =
|
||||
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
|
||||
|
||||
@@ -191,8 +191,8 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using GridwiseGemmPipe = remove_cvref_t<decltype(
|
||||
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
|
||||
using GridwiseGemmPipe = remove_cvref_t<
|
||||
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
{
|
||||
@@ -346,14 +346,17 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
|
||||
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
CGridDesc_M_N{}))>;
|
||||
|
||||
using C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(C0GridDesc_M_N{}))>;
|
||||
using C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
C0GridDesc_M_N{}))>;
|
||||
|
||||
using C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(C1GridDesc_M_N{}))>;
|
||||
using C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
C1GridDesc_M_N{}))>;
|
||||
|
||||
using ReduceGridDescriptor_MBlock_MPerBlock =
|
||||
remove_cvref_t<decltype(MakeReduceGridDescriptor_MBlock_MPerBlock(ReduceGridDesc_M{}))>;
|
||||
|
||||
@@ -92,8 +92,8 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using GridwiseGemmPipe = remove_cvref_t<decltype(
|
||||
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
|
||||
using GridwiseGemmPipe = remove_cvref_t<
|
||||
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
{
|
||||
@@ -300,8 +300,9 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
remove_cvref_t<decltype(MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
|
||||
using DefaultBGridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
|
||||
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
|
||||
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
EGridDesc_M_N{}))>;
|
||||
|
||||
// Support 2 dimension in the future. Not only M
|
||||
using RGridDescriptor_MBlock_MPerBlock =
|
||||
|
||||
@@ -346,8 +346,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using GridwiseGemmPipe = remove_cvref_t<decltype(
|
||||
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
|
||||
using GridwiseGemmPipe = remove_cvref_t<
|
||||
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
|
||||
{
|
||||
@@ -565,10 +565,12 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
|
||||
e_grid_desc_m_n);
|
||||
}
|
||||
|
||||
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
|
||||
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
|
||||
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
DsGridDesc_M_N{}))>;
|
||||
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
EGridDesc_M_N{}))>;
|
||||
using DefaultBlock2CTileMap =
|
||||
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(EGridDesc_M_N{}, 1, 1))>;
|
||||
using DsGridPointer = decltype(MakeDsGridPointer());
|
||||
|
||||
@@ -89,8 +89,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using GridwiseGemmPipe = remove_cvref_t<decltype(
|
||||
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
|
||||
using GridwiseGemmPipe = remove_cvref_t<
|
||||
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
|
||||
|
||||
// denorm test fix, required to work around fp16 mfma issue
|
||||
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
|
||||
|
||||
@@ -164,8 +164,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using GridwiseGemmPipe = remove_cvref_t<decltype(
|
||||
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
|
||||
using GridwiseGemmPipe = remove_cvref_t<
|
||||
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
{
|
||||
@@ -318,8 +318,9 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
|
||||
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
CGridDesc_M_N{}))>;
|
||||
|
||||
using ReduceGridDescriptor_MBlock_MPerBlock =
|
||||
remove_cvref_t<decltype(MakeReduceGridDescriptor_MBlock_MPerBlock(ReduceGridDesc_M{}))>;
|
||||
|
||||
@@ -375,10 +375,12 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
|
||||
remove_cvref_t<decltype(MakeDefaultAGridDescriptor_AKB_AK0_M_AK1(AGridDesc_M_K{}, 1))>;
|
||||
using DefaultBGridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BKB_BK0_N_BK1(BGridDesc_N_K{}, 1))>;
|
||||
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
|
||||
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
|
||||
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
EGridDesc_M_N{}))>;
|
||||
using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
DsGridDesc_M_N{}))>;
|
||||
|
||||
using DefaultBlock2ETileMap =
|
||||
remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}, 1))>;
|
||||
|
||||
@@ -138,8 +138,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using GridwiseGemmPipe = remove_cvref_t<decltype(
|
||||
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
|
||||
using GridwiseGemmPipe = remove_cvref_t<
|
||||
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
|
||||
{
|
||||
@@ -308,8 +308,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
|
||||
c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
|
||||
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
CGridDesc_M_N{}))>;
|
||||
using DefaultBlock2CTileMap =
|
||||
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
|
||||
|
||||
|
||||
@@ -491,8 +491,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
};
|
||||
|
||||
// FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm
|
||||
using GridwiseGemmPipe = remove_cvref_t<decltype(
|
||||
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
|
||||
using GridwiseGemmPipe = remove_cvref_t<
|
||||
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
|
||||
|
||||
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
{
|
||||
|
||||
@@ -173,8 +173,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using GridwiseGemmPipe = remove_cvref_t<decltype(
|
||||
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
|
||||
using GridwiseGemmPipe = remove_cvref_t<
|
||||
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
{
|
||||
@@ -345,8 +345,9 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
|
||||
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
CGridDesc_M_N{}))>;
|
||||
|
||||
using C0GridDescriptor_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(MakeC0GridDescriptor_NBlock_NPerBlock(C0GridDesc_N{}))>;
|
||||
|
||||
@@ -330,8 +330,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
|
||||
return e_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
}
|
||||
|
||||
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
|
||||
using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
EGridDesc_M_N{}))>;
|
||||
|
||||
using DefaultBlock2ETileMap =
|
||||
remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
|
||||
|
||||
@@ -259,8 +259,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using GridwiseGemmPipe = remove_cvref_t<decltype(
|
||||
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
|
||||
using GridwiseGemmPipe = remove_cvref_t<
|
||||
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
|
||||
|
||||
// denorm test fix, required to work around fp16 mfma issue
|
||||
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
|
||||
|
||||
@@ -247,8 +247,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
FloatC* p_c_grid;
|
||||
};
|
||||
|
||||
using GridwiseGemmPipe = remove_cvref_t<decltype(
|
||||
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
|
||||
using GridwiseGemmPipe = remove_cvref_t<
|
||||
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
|
||||
|
||||
// denorm test fix, required to work around fp16 mfma issue
|
||||
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
|
||||
|
||||
@@ -110,8 +110,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using GridwiseGemmPipe = remove_cvref_t<decltype(
|
||||
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
|
||||
using GridwiseGemmPipe = remove_cvref_t<
|
||||
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
|
||||
|
||||
struct Argument : public ck::tensor_operation::device::BaseArgument
|
||||
{
|
||||
|
||||
@@ -139,8 +139,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using GridwiseGemmPipe = remove_cvref_t<decltype(
|
||||
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
|
||||
using GridwiseGemmPipe = remove_cvref_t<
|
||||
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
{
|
||||
@@ -315,8 +315,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
|
||||
c_grid_desc_m_n);
|
||||
}
|
||||
using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
|
||||
remove_cvref_t<decltype(
|
||||
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
|
||||
remove_cvref_t<
|
||||
decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
|
||||
CGridDesc_M_N{}))>;
|
||||
|
||||
using DefaultBlock2CTileMap =
|
||||
@@ -634,10 +634,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
|
||||
Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder,
|
||||
FloatCShuffle, // typename SrcData,
|
||||
FloatC, // typename DstData,
|
||||
decltype(
|
||||
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
|
||||
decltype(
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
|
||||
decltype(c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
|
||||
decltype(c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
|
||||
Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder,
|
||||
5, // index_t VectorDim,
|
||||
CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector,
|
||||
|
||||
@@ -142,8 +142,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using GridwiseGemmPipe = remove_cvref_t<decltype(
|
||||
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
|
||||
using GridwiseGemmPipe = remove_cvref_t<
|
||||
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
|
||||
{
|
||||
@@ -323,13 +323,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
|
||||
}
|
||||
|
||||
using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
|
||||
remove_cvref_t<decltype(
|
||||
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
|
||||
remove_cvref_t<
|
||||
decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
|
||||
CGridDesc_M_N{}))>;
|
||||
|
||||
using C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
|
||||
remove_cvref_t<decltype(
|
||||
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
|
||||
remove_cvref_t<
|
||||
decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
|
||||
C0GridDesc_M_N{}))>;
|
||||
|
||||
using DefaultBlock2CTileMap =
|
||||
@@ -654,12 +654,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
|
||||
FloatC, // typename Src0Data,
|
||||
FloatC, // typename Src1Data,
|
||||
FloatC, // typename DstData,
|
||||
decltype(
|
||||
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
|
||||
decltype(
|
||||
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
|
||||
decltype(
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
|
||||
decltype(c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
|
||||
decltype(c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
|
||||
decltype(c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
|
||||
Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder,
|
||||
5, // index_t VectorDim,
|
||||
CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector,
|
||||
|
||||
@@ -151,8 +151,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
using GridwiseGemmPipe = remove_cvref_t<decltype(
|
||||
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
|
||||
using GridwiseGemmPipe = remove_cvref_t<
|
||||
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
|
||||
{
|
||||
@@ -331,18 +331,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
|
||||
c_grid_desc_m_n);
|
||||
}
|
||||
using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
|
||||
remove_cvref_t<decltype(
|
||||
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
|
||||
remove_cvref_t<
|
||||
decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
|
||||
CGridDesc_M_N{}))>;
|
||||
|
||||
using C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
|
||||
remove_cvref_t<decltype(
|
||||
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
|
||||
remove_cvref_t<
|
||||
decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
|
||||
C0GridDesc_M_N{}))>;
|
||||
|
||||
using C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
|
||||
remove_cvref_t<decltype(
|
||||
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
|
||||
remove_cvref_t<
|
||||
decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
|
||||
C1GridDesc_M_N{}))>;
|
||||
|
||||
using DefaultBlock2CTileMap =
|
||||
@@ -674,14 +674,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
|
||||
FloatC, // typename Src1Data,
|
||||
FloatC, // typename Src2Data,
|
||||
FloatC, // typename DstData,
|
||||
decltype(
|
||||
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
|
||||
decltype(
|
||||
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
|
||||
decltype(
|
||||
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
|
||||
decltype(
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
|
||||
decltype(c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
|
||||
decltype(c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
|
||||
decltype(c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
|
||||
decltype(c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
|
||||
Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder,
|
||||
5, // index_t VectorDim,
|
||||
CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector,
|
||||
|
||||
@@ -78,8 +78,8 @@ struct GridwiseSparseEmbeddingsForwardLayernorm
|
||||
using ThreadwiseWolfordDesc2D = decltype(make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<DimSubBlocks * DimThreadSize>{}, Number<RowSubBlocks * RowVectorSize>{})));
|
||||
|
||||
using ThreadwiseWolfordDescReduce = decltype(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<DimSubBlocks * DimThreadSize>{})));
|
||||
using ThreadwiseWolfordDescReduce = decltype(make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<DimSubBlocks * DimThreadSize>{})));
|
||||
|
||||
using ThreadwiseWelford =
|
||||
ThreadwiseWelford<AccDataType, ThreadwiseWolfordDesc2D, ThreadwiseWolfordDescReduce>;
|
||||
|
||||
@@ -87,9 +87,9 @@ struct GridwiseNormalizationSplitK1st
|
||||
int left_kPerBlock = math::integer_divide_ceil(k, kGridSize);
|
||||
int kRightmostBlock = kRaw - left_kPerBlock * (kGridSize - 1);
|
||||
int kPerThread = kRightmostBlock < K_BlockTileSize
|
||||
? 0
|
||||
: KThreadSliceSize * (kRightmostBlock / K_BlockTileSize);
|
||||
int kPerBlockTail = kRightmostBlock - kPerThread * KThreadClusterSize;
|
||||
? 0
|
||||
: KThreadSliceSize * (kRightmostBlock / K_BlockTileSize);
|
||||
int kPerBlockTail = kRightmostBlock - kPerThread * KThreadClusterSize;
|
||||
|
||||
if(kPerBlockTail > 0)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user