remove some unnecessary hacky; enable 256x256x256 tilesize

This commit is contained in:
aska-0096
2025-05-09 07:54:28 +00:00
parent b2efb06315
commit bb043a3202
2 changed files with 42 additions and 40 deletions

View File

@@ -154,10 +154,10 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
static constexpr auto I7 = Number<7>{};
// K1 should be Number<...>
static constexpr auto AK0Number = Number<KPerBlock / AK1Value / 2>{};
static constexpr auto BK0Number = Number<KPerBlock / BK1Value / 2>{};
static constexpr auto AK1Number = Number<AK1Value * 2>{};
static constexpr auto BK1Number = Number<BK1Value * 2>{};
static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1Number = Number<AK1Value>{};
static constexpr auto BK1Number = Number<BK1Value>{};
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
static constexpr bool is_single_rate_mfma = false;
@@ -175,8 +175,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
NPerXdl,
ComputeTypeB,
is_single_rate_mfma,
is_scale_mfma>::selected_mfma.k_per_blk /
2);
is_scale_mfma>::selected_mfma.k_per_blk);
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
@@ -295,7 +294,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0/2, AK1Value*2)),
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_pass_through_transform(MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -308,7 +307,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
// pad M, but not K
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0/2, AK1Value*2)),
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_right_pad_transform(M, MPad - M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -327,7 +326,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0/2, AK1Value*2)),
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -339,7 +338,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
// not pad M or K
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0/2, AK1Value*2)),
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -384,7 +383,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0/2, BK1Value*2)),
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_pass_through_transform(NPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -397,7 +396,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
// pad N, but not K
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0/2, BK1Value*2)),
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -416,7 +415,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0/2, BK1Value*2)),
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -430,7 +429,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
// not pad N or K
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0/2, BK1Value*2)),
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
@@ -765,7 +764,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
// in some cases.
else if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
constexpr index_t LdsSize = 32 * 4 / (KPerBlock / APackedSize) / sizeof(ADataType);
constexpr index_t LdsSize = 32 * 4 / (KPerBlock * sizeof(ADataType) / APackedSize);
constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize;
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
make_tuple(
@@ -901,7 +900,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
// NLdsLayer * K0 as logical Bank
constexpr index_t LdsSize = 32 * 4 / (KPerBlock / BPackedSize) / sizeof(BDataType);
constexpr index_t LdsSize = 32 * 4 / (KPerBlock * sizeof(BDataType)/ BPackedSize) ;
constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize;
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor(
make_tuple(
@@ -1416,8 +1415,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector * 2,
ABlockTransferDstScalarPerVector_AK1 * 2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
@@ -1447,8 +1446,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector * 2,
BBlockTransferDstScalarPerVector_BK1 * 2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
@@ -1467,14 +1466,14 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
// Cast after lds
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()/APackedSize);
CK_PRINT<ck::Number<a_block_desc_ak0_m_ak1.GetElementSpaceSize()>>();
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) + a_block_space_size_aligned *
sizeof(ADataType) /
APackedSize),
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
b_block_desc_bk0_n_bk1.GetElementSpaceSize()/BPackedSize);
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
@@ -1912,8 +1911,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector * 2,
ABlockTransferDstScalarPerVector_AK1 * 2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
@@ -1943,8 +1942,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector * 2,
BBlockTransferDstScalarPerVector_BK1 * 2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,