mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
remove some unnecessary hacky; enable 256x256x256 tilesize
This commit is contained in:
@@ -154,10 +154,10 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
static constexpr auto I7 = Number<7>{};
|
||||
|
||||
// K1 should be Number<...>
|
||||
static constexpr auto AK0Number = Number<KPerBlock / AK1Value / 2>{};
|
||||
static constexpr auto BK0Number = Number<KPerBlock / BK1Value / 2>{};
|
||||
static constexpr auto AK1Number = Number<AK1Value * 2>{};
|
||||
static constexpr auto BK1Number = Number<BK1Value * 2>{};
|
||||
static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
|
||||
static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
|
||||
static constexpr auto AK1Number = Number<AK1Value>{};
|
||||
static constexpr auto BK1Number = Number<BK1Value>{};
|
||||
|
||||
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
|
||||
static constexpr bool is_single_rate_mfma = false;
|
||||
@@ -175,8 +175,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
NPerXdl,
|
||||
ComputeTypeB,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>::selected_mfma.k_per_blk /
|
||||
2);
|
||||
is_scale_mfma>::selected_mfma.k_per_blk);
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
@@ -295,7 +294,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0/2, AK1Value*2)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
|
||||
make_pass_through_transform(MPad)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
@@ -308,7 +307,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
// pad M, but not K
|
||||
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
|
||||
a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0/2, AK1Value*2)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
|
||||
make_right_pad_transform(M, MPad - M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
@@ -327,7 +326,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0/2, AK1Value*2)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
@@ -339,7 +338,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
// not pad M or K
|
||||
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
|
||||
a_grid_desc_mraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0/2, AK1Value*2)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
@@ -384,7 +383,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0/2, BK1Value*2)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
|
||||
make_pass_through_transform(NPad)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
@@ -397,7 +396,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
// pad N, but not K
|
||||
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
|
||||
b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0/2, BK1Value*2)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
|
||||
make_right_pad_transform(N, NPad - N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
@@ -416,7 +415,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0/2, BK1Value*2)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
@@ -430,7 +429,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
// not pad N or K
|
||||
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
|
||||
b_grid_desc_nraw_kraw,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0/2, BK1Value*2)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
@@ -765,7 +764,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
// in some cases.
|
||||
else if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
{
|
||||
constexpr index_t LdsSize = 32 * 4 / (KPerBlock / APackedSize) / sizeof(ADataType);
|
||||
constexpr index_t LdsSize = 32 * 4 / (KPerBlock * sizeof(ADataType) / APackedSize);
|
||||
constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize;
|
||||
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(
|
||||
@@ -901,7 +900,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
{
|
||||
// NLdsLayer * K0 as logical Bank
|
||||
constexpr index_t LdsSize = 32 * 4 / (KPerBlock / BPackedSize) / sizeof(BDataType);
|
||||
constexpr index_t LdsSize = 32 * 4 / (KPerBlock * sizeof(BDataType)/ BPackedSize) ;
|
||||
constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize;
|
||||
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(
|
||||
@@ -1416,8 +1415,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
Sequence<0, 1, 2>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector * 2,
|
||||
ABlockTransferDstScalarPerVector_AK1 * 2,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
@@ -1447,8 +1446,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
Sequence<0, 1, 2>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector * 2,
|
||||
BBlockTransferDstScalarPerVector_BK1 * 2,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
@@ -1467,14 +1466,14 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
|
||||
// Cast after lds
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()/APackedSize);
|
||||
CK_PRINT<ck::Number<a_block_desc_ak0_m_ak1.GetElementSpaceSize()>>();
|
||||
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) + a_block_space_size_aligned *
|
||||
sizeof(ADataType) /
|
||||
APackedSize),
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize()/BPackedSize);
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
|
||||
@@ -1912,8 +1911,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
Sequence<0, 1, 2>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector * 2,
|
||||
ABlockTransferDstScalarPerVector_AK1 * 2,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
@@ -1943,8 +1942,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
Sequence<0, 1, 2>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector * 2,
|
||||
BBlockTransferDstScalarPerVector_BK1 * 2,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
|
||||
Reference in New Issue
Block a user