Use packed_size_v for A/BPackedSize

This commit is contained in:
Ding, Yi
2025-06-03 06:10:07 +00:00
parent 331ccb8ca2
commit 0cbc5e2bdb
9 changed files with 22 additions and 104 deletions

View File

@@ -35,10 +35,8 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
using ComputeTypeB = BDataType;
using AccType = float; // for now only support V_MFMA_SCALE_F32
static constexpr index_t APackedSize =
is_same_v<remove_cvref_t<ComputeTypeA>, f4x2_pk_t> ? 2 : 1;
static constexpr index_t BPackedSize =
is_same_v<remove_cvref_t<ComputeTypeB>, f4x2_pk_t> ? 2 : 1;
static constexpr index_t APackedSize = packed_size_v<ComputeTypeA>;
static constexpr index_t BPackedSize = packed_size_v<ComputeTypeB>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
@@ -85,7 +83,7 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
static constexpr index_t NXdlPack = 2;
static constexpr index_t KXdlPack = 2;
using HotLoopInstList = ck::BlockwiseGemmXdlops_pipeline_hotloop_inst<
using HotLoopInstList = ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< //
BlockSize,
MPerBlock,
NPerBlock,
@@ -101,8 +99,7 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
MPerXDL,
NPerXDL,
xdlops_gemm.KPerXdlops,
(is_same_v<remove_cvref_t<ComputeTypeA>, f4x2_pk_t> ||
is_same_v<remove_cvref_t<ComputeTypeB>, f4x2_pk_t>)>;
(packed_size_v<ComputeTypeA> > 1 || packed_size_v<ComputeTypeB> > 1)>;
static_assert(KPerThread % KPack == 0,
"Wrong KPack setting; try increasing KPerThread or decreasing KPack");

View File

@@ -151,21 +151,8 @@ struct DeviceMoeGemmMX : public DeviceMoEGemmMXBPreShuffle<ALayout,
using Argument = typename GridwiseGemm::Argument;
static constexpr index_t APackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> ||
is_same_v<remove_cvref_t<ADataType>, f4x2_pk_t>)
return 2;
else
return 1;
}();
static constexpr index_t BPackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t> ||
is_same_v<remove_cvref_t<BDataType>, f4x2_pk_t>)
return 2;
else
return 1;
}();
static constexpr index_t APackedSize = packed_size_v<ADataType>;
static constexpr index_t BPackedSize = packed_size_v<BDataType>;
int GetPreShuffleParameters() override { return NPerXDL; }

View File

@@ -151,21 +151,8 @@ struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle<ALayout,
using Argument = typename GridwiseGemm::Argument;
static constexpr index_t APackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> ||
is_same_v<remove_cvref_t<ADataType>, f4x2_pk_t>)
return 2;
else
return 1;
}();
static constexpr index_t BPackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t> ||
is_same_v<remove_cvref_t<BDataType>, f4x2_pk_t>)
return 2;
else
return 1;
}();
static constexpr index_t APackedSize = packed_size_v<ADataType>;
static constexpr index_t BPackedSize = packed_size_v<BDataType>;
int GetPreShuffleParameters() override { return NPerXDL; }

View File

@@ -182,21 +182,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
// TODO: Move this to blockwise pipeline base
// KPack in packed data types for pk A/B
static constexpr index_t APackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> ||
is_same_v<remove_cvref_t<ADataType>, f4x2_pk_t>)
return 2;
else
return 1;
}();
static constexpr index_t BPackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t> ||
is_same_v<remove_cvref_t<BDataType>, f4x2_pk_t>)
return 2;
else
return 1;
}();
static constexpr index_t APackedSize = packed_size_v<ADataType>;
static constexpr index_t BPackedSize = packed_size_v<BDataType>;
static constexpr index_t KPack =
math::max(lcm_AK1_BK1,

View File

@@ -182,21 +182,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle
// TODO: Move this to blockwise pipeline base
// KPack in packed data types for pk A/B
static constexpr index_t APackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> ||
is_same_v<remove_cvref_t<ADataType>, f4x2_pk_t>)
return 2;
else
return 1;
}();
static constexpr index_t BPackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t> ||
is_same_v<remove_cvref_t<BDataType>, f4x2_pk_t>)
return 2;
else
return 1;
}();
static constexpr index_t APackedSize = packed_size_v<ADataType>;
static constexpr index_t BPackedSize = packed_size_v<BDataType>;
static constexpr index_t KPack =
math::max(lcm_AK1_BK1,

View File

@@ -194,21 +194,8 @@ struct GridwiseMoeGemmMX
static constexpr auto NXdlPack = 2;
static constexpr auto KXdlPack = 2;
static constexpr index_t APackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> ||
is_same_v<remove_cvref_t<ADataType>, f4x2_pk_t>)
return 2;
else
return 1;
}();
static constexpr index_t BPackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t> ||
is_same_v<remove_cvref_t<BDataType>, f4x2_pk_t>)
return 2;
else
return 1;
}();
static constexpr index_t APackedSize = packed_size_v<ADataType>;
static constexpr index_t BPackedSize = packed_size_v<BDataType>;
static constexpr bool is_single_rate_mfma = false;
static constexpr auto is_scale_mfma = true;

View File

@@ -199,21 +199,8 @@ struct GridwiseMoeGemmMXBNS
static constexpr auto NXdlPack = 2;
static constexpr auto KXdlPack = 2;
static constexpr index_t APackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> ||
is_same_v<remove_cvref_t<ADataType>, f4x2_pk_t>)
return 2;
else
return 1;
}();
static constexpr index_t BPackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t> ||
is_same_v<remove_cvref_t<BDataType>, f4x2_pk_t>)
return 2;
else
return 1;
}();
static constexpr index_t APackedSize = packed_size_v<ADataType>;
static constexpr index_t BPackedSize = packed_size_v<BDataType>;
static constexpr bool is_single_rate_mfma = false;
static constexpr auto is_scale_mfma = true;

View File

@@ -90,8 +90,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
src_element_op_(src_element_op),
dst_element_op_(dst_element_op)
{
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t> ||
is_same_v<remove_cvref_t<SrcData>, f4x2_pk_t>)
if constexpr((packed_size_v<SrcData>) > 1)
{
static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>,
"SrcData != DstData");
@@ -100,7 +99,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0,
"SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type");
static_assert(SrcVectorDim == DstVectorDim, "pk_i4_t does not support transpose");
static_assert(SrcVectorDim == DstVectorDim,
"Packed data type does not support transpose");
}
}

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -96,8 +96,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
dst_element_op_(dst_element_op),
gather_offsets_(gather_offsets)
{
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t> ||
is_same_v<remove_cvref_t<SrcData>, f4x2_pk_t>)
if constexpr((packed_size_v<SrcData>) > 1)
{
static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>,
"SrcData != DstData");
@@ -107,7 +106,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
"SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type");
static_assert(SrcVectorDim == DstVectorDim,
"pk_i4_t or f4x2_pk_t does not support transpose");
"Packed data type does not support transpose");
}
}