mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
Use packed_size_v for A/BPackedSize
This commit is contained in:
@@ -35,10 +35,8 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
|
||||
using ComputeTypeB = BDataType;
|
||||
using AccType = float; // for now only support V_MFMA_SCALE_F32
|
||||
|
||||
static constexpr index_t APackedSize =
|
||||
is_same_v<remove_cvref_t<ComputeTypeA>, f4x2_pk_t> ? 2 : 1;
|
||||
static constexpr index_t BPackedSize =
|
||||
is_same_v<remove_cvref_t<ComputeTypeB>, f4x2_pk_t> ? 2 : 1;
|
||||
static constexpr index_t APackedSize = packed_size_v<ComputeTypeA>;
|
||||
static constexpr index_t BPackedSize = packed_size_v<ComputeTypeB>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -85,7 +83,7 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
|
||||
static constexpr index_t NXdlPack = 2;
|
||||
static constexpr index_t KXdlPack = 2;
|
||||
|
||||
using HotLoopInstList = ck::BlockwiseGemmXdlops_pipeline_hotloop_inst<
|
||||
using HotLoopInstList = ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< //
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
@@ -101,8 +99,7 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
xdlops_gemm.KPerXdlops,
|
||||
(is_same_v<remove_cvref_t<ComputeTypeA>, f4x2_pk_t> ||
|
||||
is_same_v<remove_cvref_t<ComputeTypeB>, f4x2_pk_t>)>;
|
||||
(packed_size_v<ComputeTypeA> > 1 || packed_size_v<ComputeTypeB> > 1)>;
|
||||
|
||||
static_assert(KPerThread % KPack == 0,
|
||||
"Wrong KPack setting; try increasing KPerThread or decreasing KPack");
|
||||
|
||||
@@ -151,21 +151,8 @@ struct DeviceMoeGemmMX : public DeviceMoEGemmMXBPreShuffle<ALayout,
|
||||
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
|
||||
static constexpr index_t APackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> ||
|
||||
is_same_v<remove_cvref_t<ADataType>, f4x2_pk_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
static constexpr index_t BPackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t> ||
|
||||
is_same_v<remove_cvref_t<BDataType>, f4x2_pk_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
static constexpr index_t APackedSize = packed_size_v<ADataType>;
|
||||
static constexpr index_t BPackedSize = packed_size_v<BDataType>;
|
||||
|
||||
int GetPreShuffleParameters() override { return NPerXDL; }
|
||||
|
||||
|
||||
@@ -151,21 +151,8 @@ struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle<ALayout,
|
||||
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
|
||||
static constexpr index_t APackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> ||
|
||||
is_same_v<remove_cvref_t<ADataType>, f4x2_pk_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
static constexpr index_t BPackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t> ||
|
||||
is_same_v<remove_cvref_t<BDataType>, f4x2_pk_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
static constexpr index_t APackedSize = packed_size_v<ADataType>;
|
||||
static constexpr index_t BPackedSize = packed_size_v<BDataType>;
|
||||
|
||||
int GetPreShuffleParameters() override { return NPerXDL; }
|
||||
|
||||
|
||||
@@ -182,21 +182,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
// TODO: Move this to blockwise pipeline base
|
||||
// KPack in packed data types for pk A/B
|
||||
|
||||
static constexpr index_t APackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> ||
|
||||
is_same_v<remove_cvref_t<ADataType>, f4x2_pk_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
static constexpr index_t BPackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t> ||
|
||||
is_same_v<remove_cvref_t<BDataType>, f4x2_pk_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
static constexpr index_t APackedSize = packed_size_v<ADataType>;
|
||||
static constexpr index_t BPackedSize = packed_size_v<BDataType>;
|
||||
|
||||
static constexpr index_t KPack =
|
||||
math::max(lcm_AK1_BK1,
|
||||
|
||||
@@ -182,21 +182,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle
|
||||
// TODO: Move this to blockwise pipeline base
|
||||
// KPack in packed data types for pk A/B
|
||||
|
||||
static constexpr index_t APackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> ||
|
||||
is_same_v<remove_cvref_t<ADataType>, f4x2_pk_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
static constexpr index_t BPackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t> ||
|
||||
is_same_v<remove_cvref_t<BDataType>, f4x2_pk_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
static constexpr index_t APackedSize = packed_size_v<ADataType>;
|
||||
static constexpr index_t BPackedSize = packed_size_v<BDataType>;
|
||||
|
||||
static constexpr index_t KPack =
|
||||
math::max(lcm_AK1_BK1,
|
||||
|
||||
@@ -194,21 +194,8 @@ struct GridwiseMoeGemmMX
|
||||
static constexpr auto NXdlPack = 2;
|
||||
static constexpr auto KXdlPack = 2;
|
||||
|
||||
static constexpr index_t APackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> ||
|
||||
is_same_v<remove_cvref_t<ADataType>, f4x2_pk_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
static constexpr index_t BPackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t> ||
|
||||
is_same_v<remove_cvref_t<BDataType>, f4x2_pk_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
static constexpr index_t APackedSize = packed_size_v<ADataType>;
|
||||
static constexpr index_t BPackedSize = packed_size_v<BDataType>;
|
||||
|
||||
static constexpr bool is_single_rate_mfma = false;
|
||||
static constexpr auto is_scale_mfma = true;
|
||||
|
||||
@@ -199,21 +199,8 @@ struct GridwiseMoeGemmMXBNS
|
||||
static constexpr auto NXdlPack = 2;
|
||||
static constexpr auto KXdlPack = 2;
|
||||
|
||||
static constexpr index_t APackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> ||
|
||||
is_same_v<remove_cvref_t<ADataType>, f4x2_pk_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
static constexpr index_t BPackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t> ||
|
||||
is_same_v<remove_cvref_t<BDataType>, f4x2_pk_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
static constexpr index_t APackedSize = packed_size_v<ADataType>;
|
||||
static constexpr index_t BPackedSize = packed_size_v<BDataType>;
|
||||
|
||||
static constexpr bool is_single_rate_mfma = false;
|
||||
static constexpr auto is_scale_mfma = true;
|
||||
|
||||
@@ -90,8 +90,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
src_element_op_(src_element_op),
|
||||
dst_element_op_(dst_element_op)
|
||||
{
|
||||
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t> ||
|
||||
is_same_v<remove_cvref_t<SrcData>, f4x2_pk_t>)
|
||||
if constexpr((packed_size_v<SrcData>) > 1)
|
||||
{
|
||||
static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>,
|
||||
"SrcData != DstData");
|
||||
@@ -100,7 +99,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0,
|
||||
"SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type");
|
||||
|
||||
static_assert(SrcVectorDim == DstVectorDim, "pk_i4_t does not support transpose");
|
||||
static_assert(SrcVectorDim == DstVectorDim,
|
||||
"Packed data type does not support transpose");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -96,8 +96,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
dst_element_op_(dst_element_op),
|
||||
gather_offsets_(gather_offsets)
|
||||
{
|
||||
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t> ||
|
||||
is_same_v<remove_cvref_t<SrcData>, f4x2_pk_t>)
|
||||
if constexpr((packed_size_v<SrcData>) > 1)
|
||||
{
|
||||
static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>,
|
||||
"SrcData != DstData");
|
||||
@@ -107,7 +106,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
"SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type");
|
||||
|
||||
static_assert(SrcVectorDim == DstVectorDim,
|
||||
"pk_i4_t or f4x2_pk_t does not support transpose");
|
||||
"Packed data type does not support transpose");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user