Conv:TF32: add more instances - 1 (#2867)

* conv:tf32:add more instances
* add instances of device_grouped_conv_fwd_xdl_f32_comp_instances
* add instances of device_grouped_conv_fwd_xdl_f32_tf32_mem_instances
* add instances of device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances
* remove gnhwc/ngchw/ngcdhw instances
This commit is contained in:
yinglu
2025-09-25 09:27:18 +08:00
committed by GitHub
parent f076f207ce
commit df97a286d5
92 changed files with 4273 additions and 443 deletions

View File

@@ -107,8 +107,11 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
using BComputeDataType =
conditional_t<is_same_v<BComputeDataType_, ck::half_t>, ck::bhalf_t, BComputeDataType_>;
#else
using AComputeDataType = AComputeDataType_;
using BComputeDataType = BComputeDataType_;
// Element data type is used in LDS and registers. ComputeDataType_ is inside mfma, eg tf32.
using AElementDataType =
conditional_t<is_same_v<AComputeDataType_, ck::tf32_t>, float, AComputeDataType_>;
using BElementDataType =
conditional_t<is_same_v<BComputeDataType_, ck::tf32_t>, float, BComputeDataType_>;
#endif
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
@@ -199,8 +202,8 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max(a_block_space_size_aligned * sizeof(AComputeDataType) +
b_block_space_size_aligned * sizeof(BComputeDataType),
return math::max(a_block_space_size_aligned * sizeof(AElementDataType) +
b_block_space_size_aligned * sizeof(BElementDataType),
c_block_size * sizeof(CShuffleDataType));
}
@@ -621,7 +624,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
ThisThreadBlock,
AsDataType,
Tuple<AComputeDataType>,
Tuple<AElementDataType>,
decltype(as_grid_desc_ak0_m_ak1),
decltype(tie(a_block_desc_ak0_m_ak1)),
AElementwiseOperation,
@@ -649,7 +652,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
ThisThreadBlock,
BsDataType,
Tuple<BComputeDataType>,
Tuple<BElementDataType>,
decltype(bs_grid_desc_bk0_n_bk1),
decltype(tie(b_block_desc_bk0_n_bk1)),
BElementwiseOperation,
@@ -679,27 +682,28 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
// sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
(((is_same<AComputeDataType, half_t>::value ||
is_same<AComputeDataType, bhalf_t>::value) &&
(((is_same<AComputeDataType_, half_t>::value ||
is_same<AComputeDataType_, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<AComputeDataType, int8_t>::value && lcm_AK1_BK1 <= 8) ||
((is_same<AComputeDataType, f8_t>::value || is_same<AComputeDataType, bf8_t>::value) &&
(is_same<AComputeDataType_, int8_t>::value && lcm_AK1_BK1 <= 8) ||
((is_same<AComputeDataType_, f8_t>::value ||
is_same<AComputeDataType_, bf8_t>::value) &&
lcm_AK1_BK1 < 32))
? true
: false;
static constexpr auto is_scale_mfma = false;
constexpr index_t KPack = math::max(lcm_AK1_BK1,
MfmaSelector<AComputeDataType,
MfmaSelector<AComputeDataType_,
MPerXdl,
NPerXdl,
BComputeDataType,
BComputeDataType_,
is_single_rate_mfma,
is_scale_mfma>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
AComputeDataType,
BComputeDataType,
AElementDataType,
BElementDataType,
AccDataType,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
@@ -709,8 +713,8 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
NXdlPerWave,
KPack,
LoopSched,
AComputeDataType,
BComputeDataType>();
AComputeDataType_,
BComputeDataType_>();
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
@@ -719,10 +723,10 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<AComputeDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
static_cast<AElementDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<BComputeDataType*>(p_shared) + a_block_space_size_aligned,
static_cast<BElementDataType*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);

View File

@@ -164,6 +164,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using ElementDataTypeAB = conditional_t<is_same_v<FloatAB, ck::tf32_t>, float, FloatAB>;
__host__ static auto CalculateGridSize(index_t M, index_t N)
{
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1);
@@ -236,8 +238,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// Argument
struct Argument : public Problem, public tensor_operation::device::BaseArgument
{
__host__ Argument(const FloatAB* p_a_grid_,
const FloatAB* p_b_grid_,
__host__ Argument(const ElementDataTypeAB* p_a_grid_,
const ElementDataTypeAB* p_b_grid_,
FloatC* p_c_grid_,
index_t M_,
index_t N_,
@@ -252,8 +254,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
{
}
const FloatAB* p_a_grid;
const FloatAB* p_b_grid;
const ElementDataTypeAB* p_a_grid;
const ElementDataTypeAB* p_b_grid;
FloatC* p_c_grid;
};
@@ -329,7 +331,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr auto b_block_space_size_aligned =
math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB);
return (a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(ElementDataTypeAB);
}
template <
@@ -450,8 +453,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
using BlockwiseGemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
FloatABAdjusted,
FloatABAdjusted,
ElementDataTypeAB,
ElementDataTypeAB,
FloatAcc,
decltype(a_block_desc_k0_m_k1),
decltype(b_block_desc_k0_n_k1),
@@ -459,7 +462,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
NPerXdl,
MXdlPerWave,
NXdlPerWave,
K1>;
K1,
FloatABAdjusted,
FloatABAdjusted>;
return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n);
}
@@ -471,8 +476,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename CGridDesc_M_N>
__device__ static void Run(const FloatAB* p_a_grid,
const FloatAB* p_b_grid,
__device__ static void Run(const ElementDataTypeAB* p_a_grid,
const ElementDataTypeAB* p_b_grid,
FloatC* p_c_grid,
void* __restrict__ p_shared,
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
@@ -533,8 +538,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
Sequence<K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatABAdjusted,
ElementDataTypeAB,
ElementDataTypeAB,
decltype(a_grid_desc_k0_m_k1),
decltype(a_block_desc_k0_m_k1),
ABlockTransferSrcAccessOrder,
@@ -564,8 +569,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
Sequence<K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatABAdjusted,
ElementDataTypeAB,
ElementDataTypeAB,
decltype(b_grid_desc_k0_n_k1),
decltype(b_block_desc_k0_n_k1),
BBlockTransferSrcAccessOrder,
@@ -595,8 +600,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// sanity check
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
FloatABAdjusted,
FloatABAdjusted,
ElementDataTypeAB,
ElementDataTypeAB,
FloatAcc,
decltype(a_block_desc_k0_m_k1),
decltype(b_block_desc_k0_n_k1),
@@ -605,7 +610,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
MXdlPerWave,
NXdlPerWave,
K1,
LoopSched>();
LoopSched,
FloatABAdjusted,
FloatABAdjusted>();
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
@@ -614,10 +621,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatABAdjusted*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
static_cast<ElementDataTypeAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatABAdjusted*>(p_shared) + a_block_space_size_aligned,
static_cast<ElementDataTypeAB*>(p_shared) + a_block_space_size_aligned,
b_block_desc_k0_n_k1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);