remove some unnecessary hacky; enable 256x256x256 tilesize

This commit is contained in:
aska-0096
2025-05-09 07:54:28 +00:00
parent b2efb06315
commit bb043a3202
2 changed files with 42 additions and 40 deletions

View File

@@ -23,12 +23,15 @@ using BElementOp = PassThrough; // elementwise transformation for B matrix
using CElementOp = PassThrough; // elementwise transformation for C matrix
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 128;
constexpr ck::index_t KPerBlock = 256;
constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave;
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1;
// AB DataType: f4x2_pk_t
// Mathmatically, all numbers are represented as f4.
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3<
ALayout, // ALayout
BLayout, // BLayout
@@ -45,29 +48,29 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle
CElementOp, // CElementwiseOperation
GemmSpec, // GemmSpec
ScaleBlockSize, // ScaleBlockSize: Scaling block size
256, // BlockSize: Thread block size
128, // MPerBlock
128, // NPerBlock
256, // BlockSize: Thread block size
256, // MPerBlock
256, // NPerBlock
KPerBlock, // KPerBlock
16, // AK1
16, // BK1
32, // AK1
32, // BK1
16, // MPerXDL
16, // NPerXDL
4, // MXdlPerWave
4, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
8, // MXdlPerWave
8, // NXdlPerWave
S<8, 32, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
16, // ABlockTransferSrcScalarPerVector
16, // ABlockTransferDstScalarPerVector_AK1
32, // ABlockTransferSrcScalarPerVector
32, // ABlockTransferDstScalarPerVector_AK1
false, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<8, 32, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
16, // BBlockTransferSrcScalarPerVector
16, // BBlockTransferDstScalarPerVector_BK1
32, // BBlockTransferSrcScalarPerVector
32, // BBlockTransferDstScalarPerVector_BK1
false, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle

View File

@@ -154,10 +154,10 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
static constexpr auto I7 = Number<7>{};
// K1 should be Number<...>
static constexpr auto AK0Number = Number<KPerBlock / AK1Value / 2>{};
static constexpr auto BK0Number = Number<KPerBlock / BK1Value / 2>{};
static constexpr auto AK1Number = Number<AK1Value * 2>{};
static constexpr auto BK1Number = Number<BK1Value * 2>{};
static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1Number = Number<AK1Value>{};
static constexpr auto BK1Number = Number<BK1Value>{};
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
static constexpr bool is_single_rate_mfma = false;
@@ -175,8 +175,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
NPerXdl,
ComputeTypeB,
is_single_rate_mfma,
is_scale_mfma>::selected_mfma.k_per_blk /
2);
is_scale_mfma>::selected_mfma.k_per_blk);
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
@@ -295,7 +294,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0/2, AK1Value*2)),
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_pass_through_transform(MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -308,7 +307,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
// pad M, but not K
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0/2, AK1Value*2)),
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_right_pad_transform(M, MPad - M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -327,7 +326,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0/2, AK1Value*2)),
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -339,7 +338,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
// not pad M or K
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0/2, AK1Value*2)),
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -384,7 +383,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0/2, BK1Value*2)),
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_pass_through_transform(NPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -397,7 +396,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
// pad N, but not K
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0/2, BK1Value*2)),
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -416,7 +415,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0/2, BK1Value*2)),
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -430,7 +429,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
// not pad N or K
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0/2, BK1Value*2)),
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -765,7 +764,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
// in some cases.
else if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
constexpr index_t LdsSize = 32 * 4 / (KPerBlock / APackedSize) / sizeof(ADataType);
constexpr index_t LdsSize = 32 * 4 / (KPerBlock * sizeof(ADataType) / APackedSize);
constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize;
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
make_tuple(
@@ -901,7 +900,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
// NLdsLayer * K0 as logical Bank
constexpr index_t LdsSize = 32 * 4 / (KPerBlock / BPackedSize) / sizeof(BDataType);
constexpr index_t LdsSize = 32 * 4 / (KPerBlock * sizeof(BDataType)/ BPackedSize) ;
constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize;
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor(
make_tuple(
@@ -1416,8 +1415,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector * 2,
ABlockTransferDstScalarPerVector_AK1 * 2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
@@ -1447,8 +1446,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector * 2,
BBlockTransferDstScalarPerVector_BK1 * 2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
@@ -1467,14 +1466,14 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
// Cast after lds
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()/APackedSize);
CK_PRINT<ck::Number<a_block_desc_ak0_m_ak1.GetElementSpaceSize()>>();
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) + a_block_space_size_aligned *
sizeof(ADataType) /
APackedSize),
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
b_block_desc_bk0_n_bk1.GetElementSpaceSize()/BPackedSize);
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
@@ -1912,8 +1911,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector * 2,
ABlockTransferDstScalarPerVector_AK1 * 2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
@@ -1943,8 +1942,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector * 2,
BBlockTransferDstScalarPerVector_BK1 * 2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,