mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Mx fp6 flatmm (#3601)
* add fp6 data-type and support sync/async dwordx3 load/store * clang-format * pre-commit * 1st commit * default mnk pass ut * fix a distrubution * fix * fix bdram distr * update * pass ut * improve perf * update * clean code * resolve copilot comment * reslove comment * clang-format --------- Co-authored-by: ZheWang <zhewan@amd.com>
This commit is contained in:
@@ -118,8 +118,9 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
static constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN);
|
||||
static constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
|
||||
|
||||
static constexpr index_t KFlatBytesPerBlockPerIter = flatKPerWarp / BPackedSize;
|
||||
static constexpr index_t NFlatPerBlockPerIter = flatNPerWarp;
|
||||
static constexpr index_t KFlatBytesPerBlockPerIter =
|
||||
flatKPerWarp * sizeof(BDataType) / BPackedSize;
|
||||
static constexpr index_t NFlatPerBlockPerIter = flatNPerWarp;
|
||||
|
||||
static constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp;
|
||||
static constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp;
|
||||
@@ -132,8 +133,12 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
static constexpr index_t KXdlPack = Problem::KXdlPack;
|
||||
static constexpr index_t ScaleGranularityK = Problem::ScaleGranularityK;
|
||||
|
||||
static constexpr index_t AK1 = 16 /*dwordx4*/ * APackedSize / sizeof(ADataType);
|
||||
static constexpr index_t BK1 = 16 /*dwordx4*/ * BPackedSize / sizeof(BDataType);
|
||||
static constexpr index_t AK1 = std::is_same_v<ADataType, pk_fp6x16_t>
|
||||
? 16
|
||||
: 16 /*dwordx4*/ * APackedSize / sizeof(ADataType);
|
||||
static constexpr index_t BK1 = std::is_same_v<BDataType, pk_fp6x16_t>
|
||||
? 16
|
||||
: 16 /*dwordx4*/ * BPackedSize / sizeof(BDataType);
|
||||
|
||||
static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload)
|
||||
? DsReadPreload
|
||||
@@ -537,24 +542,26 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
|
||||
auto a_store_lds_window_ping = make_tile_window( //
|
||||
a_lds_block_ping,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock / APackedSize>{}),
|
||||
make_tuple(number<kMPerBlock>{},
|
||||
number<kKPerBlock / APackedSize * sizeof(ADataType)>{}),
|
||||
{0, 0});
|
||||
auto a_store_lds_window_pong = make_tile_window( //
|
||||
a_lds_block_pong,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock / APackedSize>{}),
|
||||
make_tuple(number<kMPerBlock>{},
|
||||
number<kKPerBlock / APackedSize * sizeof(ADataType)>{}),
|
||||
{0, 0});
|
||||
|
||||
// ping-pong window for A LDS
|
||||
auto a_warp_window_ping =
|
||||
make_tile_window(a_lds_block_ping,
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK / APackedSize>{}),
|
||||
{0, 0},
|
||||
PipelinePolicy::template MakeMX_ALDSBytes_TileDistribution<Problem>());
|
||||
auto a_warp_window_pong =
|
||||
make_tile_window(a_lds_block_pong,
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK / APackedSize>{}),
|
||||
{0, 0},
|
||||
PipelinePolicy::template MakeMX_ALDSBytes_TileDistribution<Problem>());
|
||||
auto a_warp_window_ping = make_tile_window(
|
||||
a_lds_block_ping,
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK / APackedSize * sizeof(ADataType)>{}),
|
||||
{0, 0},
|
||||
PipelinePolicy::template MakeMX_ALDSBytes_TileDistribution<Problem>());
|
||||
auto a_warp_window_pong = make_tile_window(
|
||||
a_lds_block_pong,
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK / APackedSize * sizeof(ADataType)>{}),
|
||||
{0, 0},
|
||||
PipelinePolicy::template MakeMX_ALDSBytes_TileDistribution<Problem>());
|
||||
|
||||
// B flat DRAM window for load
|
||||
|
||||
@@ -621,7 +628,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
// HEAD
|
||||
// Prefetch A0
|
||||
async_load_tile_(a_store_lds_window_ping, a_dram_window);
|
||||
move_tile_window(a_dram_window, {0, kKPerBlock / APackedSize});
|
||||
move_tile_window(a_dram_window, {0, kKPerBlock * sizeof(ADataType) / APackedSize});
|
||||
|
||||
// prefetch B
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
@@ -663,7 +670,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
if constexpr(HasHotLoop || TailNum == TailNumber::Even)
|
||||
{
|
||||
async_load_tile_(a_store_lds_window_pong, a_dram_window);
|
||||
move_tile_window(a_dram_window, {0, kKPerBlock / APackedSize});
|
||||
move_tile_window(a_dram_window, {0, sizeof(ADataType) * kKPerBlock / APackedSize});
|
||||
}
|
||||
// initialize C
|
||||
statically_indexed_array<statically_indexed_array<CWarpTensor, NIterPerWarp>, MIterPerWarp>
|
||||
@@ -683,7 +690,8 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
|
||||
a_warp_tensor(loadIter) = load_tile_with_offset(
|
||||
a_warp_window_ping,
|
||||
tuple<number<mIter * WG::kM>, number<kIter * WG::kK / APackedSize>>{});
|
||||
tuple<number<mIter * WG::kM>,
|
||||
number<kIter * WG::kK * sizeof(ADataType) / APackedSize>>{});
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
@@ -750,7 +758,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
a_warp_tensor(number<APackIter>{}) = load_tile_with_offset( //
|
||||
a_warp_window_ping,
|
||||
tuple<number<AmIter * WG::kM>,
|
||||
number<AkIter * WG::kK / APackedSize>>{});
|
||||
number<sizeof(ADataType) * AkIter * WG::kK / APackedSize>>{});
|
||||
}
|
||||
});
|
||||
// barrier as ds_load A(2i) and buffer_load_lds A(2i + 1) finished
|
||||
@@ -760,7 +768,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
|
||||
// Prefetch A(2i+2)
|
||||
async_load_tile_(a_store_lds_window_ping, a_dram_window);
|
||||
move_tile_window(a_dram_window, {0, kKPerBlock / APackedSize});
|
||||
move_tile_window(a_dram_window, {0, kKPerBlock * sizeof(ADataType) / APackedSize});
|
||||
|
||||
// move B window to next flat K
|
||||
move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
@@ -772,7 +780,8 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
constexpr auto kIter = loadIter / MXdlPack;
|
||||
a_warp_tensor(loadIter) = load_tile_with_offset(
|
||||
a_warp_window_pong,
|
||||
tuple<number<mIter * WG::kM>, number<kIter * WG::kK / APackedSize>>{});
|
||||
tuple<number<mIter * WG::kM>,
|
||||
number<kIter * WG::kK * sizeof(ADataType) / APackedSize>>{});
|
||||
});
|
||||
HotLoopScheduler();
|
||||
|
||||
@@ -839,7 +848,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
a_warp_tensor(number<APackIter>{}) = load_tile_with_offset( //
|
||||
a_warp_window_pong,
|
||||
tuple<number<AmIter * WG::kM>,
|
||||
number<AkIter * WG::kK / APackedSize>>{});
|
||||
number<sizeof(ADataType) * AkIter * WG::kK / APackedSize>>{});
|
||||
}
|
||||
});
|
||||
// barrier as ds_load A(2i + 1) and buffer_load_lds A(2i + 2) finished
|
||||
@@ -849,7 +858,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
|
||||
// Prefetch A(2i+3)
|
||||
async_load_tile_(a_store_lds_window_pong, a_dram_window);
|
||||
move_tile_window(a_dram_window, {0, kKPerBlock / APackedSize});
|
||||
move_tile_window(a_dram_window, {0, sizeof(ADataType) * kKPerBlock / APackedSize});
|
||||
// move B window to next flat K
|
||||
move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
@@ -860,7 +869,8 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
constexpr auto kIter = loadIter / MXdlPack;
|
||||
a_warp_tensor(loadIter) = load_tile_with_offset(
|
||||
a_warp_window_ping,
|
||||
tuple<number<mIter * WG::kM>, number<kIter * WG::kK / APackedSize>>{});
|
||||
tuple<number<mIter * WG::kM>,
|
||||
number<kIter * WG::kK * sizeof(ADataType) / APackedSize>>{});
|
||||
});
|
||||
HotLoopScheduler();
|
||||
};
|
||||
@@ -874,7 +884,6 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
iCounter--;
|
||||
} while(iCounter > 0);
|
||||
}
|
||||
|
||||
// TAIL
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
@@ -933,7 +942,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
a_warp_tensor(number<APackIter>{}) = load_tile_with_offset( //
|
||||
a_warp_window_ping,
|
||||
tuple<number<AmIter * WG::kM>,
|
||||
number<AkIter * WG::kK / APackedSize>>{});
|
||||
number<sizeof(ADataType) * AkIter * WG::kK / APackedSize>>{});
|
||||
}
|
||||
});
|
||||
// barrier as ds_load A(2i) and buffer_load_lds A(2i + 1) finished
|
||||
@@ -947,7 +956,8 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
constexpr auto kIter = loadIter / MXdlPack;
|
||||
a_warp_tensor(loadIter) = load_tile_with_offset(
|
||||
a_warp_window_pong,
|
||||
tuple<number<mIter * WG::kM>, number<kIter * WG::kK / APackedSize>>{});
|
||||
tuple<number<mIter * WG::kM>,
|
||||
number<kIter * WG::kK * sizeof(ADataType) / APackedSize>>{});
|
||||
});
|
||||
|
||||
Last2ndHotLoopScheduler();
|
||||
@@ -977,12 +987,12 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
|
||||
(n_iter == NIterPerWarp - 1))
|
||||
{
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
a_warp_tensor(number<APackIter>{}) =
|
||||
load_tile_with_offset(a_warp_window_pong,
|
||||
tuple<number<AmIter * WG::kM>,
|
||||
number<AkIter * WG::kK / APackedSize>>{});
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
a_warp_tensor(number<APackIter>{}) = load_tile_with_offset(
|
||||
a_warp_window_pong,
|
||||
tuple<number<AmIter * WG::kM>,
|
||||
number<sizeof(ADataType) * AkIter * WG::kK / APackedSize>>{});
|
||||
}
|
||||
});
|
||||
LastHotLoopScheduler();
|
||||
@@ -1014,12 +1024,12 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
|
||||
(n_iter == NIterPerWarp - 1))
|
||||
{
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
a_warp_tensor(number<APackIter>{}) =
|
||||
load_tile_with_offset(a_warp_window_ping,
|
||||
tuple<number<AmIter * WG::kM>,
|
||||
number<AkIter * WG::kK / APackedSize>>{});
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
a_warp_tensor(number<APackIter>{}) = load_tile_with_offset(
|
||||
a_warp_window_ping,
|
||||
tuple<number<AmIter * WG::kM>,
|
||||
number<sizeof(ADataType) * AkIter * WG::kK / APackedSize>>{});
|
||||
}
|
||||
});
|
||||
LastHotLoopScheduler();
|
||||
|
||||
@@ -17,6 +17,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
|
||||
static constexpr index_t kDramLoadPackBytes = 128;
|
||||
static constexpr index_t DWORDx4 = 16;
|
||||
static constexpr index_t DWORDx3 = 12;
|
||||
|
||||
static constexpr int MXdlPack = 2;
|
||||
static constexpr int NXdlPack = 2;
|
||||
@@ -77,15 +78,16 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeMX_ABytesDramTileDistribution()
|
||||
{
|
||||
constexpr index_t K2 = DWORDx4; // 16 bytes
|
||||
constexpr index_t K1 = kDramLoadPackBytes / K2; // 8
|
||||
constexpr index_t K0 = KPerBlock / (K1 * K2 * APackedSize); // KPerBlock/256/packsize
|
||||
constexpr index_t K2 = std::is_same_v<ADataType, pk_fp6x16_t> ? DWORDx3 : DWORDx4;
|
||||
constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // fp8/fp6/fp4 K1 equal to 8
|
||||
constexpr index_t K0 =
|
||||
KPerBlock / APackedSize * sizeof(ADataType) / (K1 * K2); // KPerBlock/256/packsize
|
||||
|
||||
constexpr index_t M2 = WaveSize / K1; // 8
|
||||
constexpr index_t M1 = BlockSize / WaveSize; // 4
|
||||
constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
static_assert(M0 * M1 * M2 == MPerBlock, "M0, M1, M2 must cover whole MPerBlock!");
|
||||
static_assert(K0 * K1 * K2 * APackedSize == KPerBlock,
|
||||
static_assert(K0 * K1 * K2 == KPerBlock / APackedSize * sizeof(ADataType),
|
||||
"K0, K1, K2 must cover whole KPerBlock!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
@@ -107,9 +109,9 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
auto&& tensor_view_tmp = window_tmp.get_bottom_tensor_view();
|
||||
const auto [rows, cols] = tensor_view_tmp.get_tensor_descriptor().get_lengths();
|
||||
|
||||
constexpr index_t K2 = DWORDx4; // 16 bytes
|
||||
constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // 8
|
||||
const index_t K0 = cols / (K1 * K2 * APackedSize);
|
||||
constexpr index_t K2 = std::is_same_v<ADataType, pk_fp6x16_t> ? DWORDx3 : DWORDx4;
|
||||
constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // fp8/fp6/fp4 K1 equal to 8
|
||||
const index_t K0 = cols / (K1 * K2 / sizeof(ADataType) * APackedSize);
|
||||
const auto col_lens = make_tuple(K0, number<K1>{}, number<K2>{});
|
||||
|
||||
constexpr index_t M1 = 4; // so that we can use imm offset to load lds
|
||||
@@ -138,19 +140,23 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
auto&& byte_ptr = reinterpret_cast<const uint8_t*>(&(tensor_view_tmp.get_buffer_view()(0)));
|
||||
auto&& byte_tensor_view = make_tensor_view<address_space_enum::global>(byte_ptr, desc);
|
||||
|
||||
auto&& origin_tmp = window_tmp.get_window_origin();
|
||||
auto&& origin_tmp = window_tmp.get_window_origin();
|
||||
constexpr index_t test1 = APackedSize / sizeof(ADataType);
|
||||
return make_tile_window(byte_tensor_view,
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock / APackedSize>{}),
|
||||
{origin_tmp[0], origin_tmp[1] / APackedSize},
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock / test1>{}),
|
||||
{origin_tmp[0], origin_tmp[1] / test1},
|
||||
MakeMX_ABytesDramTileDistribution());
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeMX_ALdsBytesBlockDescriptor()
|
||||
{
|
||||
constexpr index_t K2 = AK1 / APackedSize; // 16
|
||||
constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // 8
|
||||
constexpr index_t K0 = KPerBlock / (K1 * AK1); // KPerBlock/256
|
||||
static_assert(K0 * K1 * K2 * APackedSize == KPerBlock,
|
||||
constexpr index_t K2 = std::is_same_v<ADataType, pk_fp6x16_t> ? DWORDx3 : AK1 / APackedSize;
|
||||
constexpr index_t K2_Pad = 16;
|
||||
constexpr index_t K1 = kDramLoadPackBytes / DWORDx4; // 8
|
||||
constexpr index_t K0 = std::is_same_v<ADataType, pk_fp6x16_t>
|
||||
? KPerBlock / (K1 * K2 / sizeof(ADataType) * APackedSize)
|
||||
: KPerBlock / (K1 * AK1); // KPerBlock/256
|
||||
static_assert(K0 * K1 * K2 / sizeof(ADataType) * APackedSize == KPerBlock,
|
||||
"K0, K1, K2 must cover whole KPerBlock!");
|
||||
|
||||
constexpr index_t M3 = 4; // so that we can use imm offset to load lds
|
||||
@@ -169,12 +175,12 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
number<M3>{},
|
||||
number<K1>{},
|
||||
number<K2>{}),
|
||||
make_tuple(number<K0*(M1 * (M2 * M3 * K1 * K2) + (M1 - 1) * Pad)>{},
|
||||
number<M1*(M2 * M3 * K1 * K2) + (M1 - 1) * Pad>{},
|
||||
number<M2 * M3 * K1 * K2 + Pad>{},
|
||||
number<M3 * K1 * K2>{},
|
||||
number<K1 * K2>{},
|
||||
number<K2>{},
|
||||
make_tuple(number<K0*(M1 * (M2 * M3 * K1 * K2_Pad) + (M1 - 1) * Pad)>{},
|
||||
number<M1*(M2 * M3 * K1 * K2_Pad) + (M1 - 1) * Pad>{},
|
||||
number<M2 * M3 * K1 * K2_Pad + Pad>{},
|
||||
number<M3 * K1 * K2_Pad>{},
|
||||
number<K1 * K2_Pad>{},
|
||||
number<K2_Pad>{},
|
||||
number<1>{}),
|
||||
number<K2>{},
|
||||
number<1>{});
|
||||
@@ -216,7 +222,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
{
|
||||
static_assert(BlockWarps::at(I0) == 1, "requires Wave_M == 1");
|
||||
|
||||
if constexpr(K_Thread == AK1)
|
||||
if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding< //
|
||||
sequence<NWarps>,
|
||||
@@ -225,7 +231,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
tuple<sequence<0, 0>, sequence<0, 2>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{});
|
||||
else
|
||||
else if constexpr(std::is_same_v<ADataType, fp8_t>)
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<NWarps>,
|
||||
@@ -235,6 +241,19 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
tuple<sequence<0, 0>, sequence<1, 2>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>{});
|
||||
else if constexpr(std::is_same_v<ADataType, pk_fp6x16_t>)
|
||||
// K_Lane=4, K_Thread=32
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding< //
|
||||
sequence<NWarps>,
|
||||
tuple<sequence<MWarps, MXdlPack, MPerXdl>,
|
||||
sequence<K_Lane, KPerXdl / (K_Lane * APackedSize), DWORDx3>>,
|
||||
tuple<sequence<1, 0>, sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>, sequence<0, 2>>,
|
||||
sequence<2, 2>,
|
||||
sequence<1, 2>>{});
|
||||
else
|
||||
static_assert(false, "unsupported datatype");
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatBytesDramTileDistribution()
|
||||
@@ -245,17 +264,17 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
|
||||
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
|
||||
|
||||
if constexpr(BK1 == K_Thread)
|
||||
if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding< //
|
||||
sequence<WaveRepeat>,
|
||||
tuple<sequence<NWarps, NXdlPack>, // 4 2
|
||||
sequence<K0, K1, BK1 / BPackedSize>>, // 1 64 32
|
||||
sequence<K0, K1, BK1 / BPackedSize>>, // 1 64 16
|
||||
tuple<sequence<0, 1, 2>, sequence<2>>,
|
||||
tuple<sequence<0, 0, 0>, sequence<1>>,
|
||||
sequence<2>,
|
||||
sequence<2>>{});
|
||||
else
|
||||
else if constexpr(std::is_same_v<BDataType, fp8_t>)
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding< //
|
||||
sequence<WaveRepeat>,
|
||||
@@ -265,6 +284,21 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
tuple<sequence<0, 0, 1>, sequence<2>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 3>>{});
|
||||
else if constexpr(std::is_same_v<ADataType, pk_fp6x16_t>)
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding< //
|
||||
sequence<WaveRepeat>,
|
||||
tuple<sequence<NWarps, NXdlPack>, // 4 2
|
||||
sequence<K0,
|
||||
K1,
|
||||
K_Thread * sizeof(BDataType) / (DWORDx3 * BPackedSize),
|
||||
DWORDx3>>, // 64 1 2 12
|
||||
tuple<sequence<0, 1, 2>, sequence<2>>,
|
||||
tuple<sequence<0, 0, 0>, sequence<1>>,
|
||||
sequence<2, 2>,
|
||||
sequence<2, 3>>{});
|
||||
else
|
||||
static_assert(false, "unsupported datatype");
|
||||
}
|
||||
|
||||
template <typename WindowTmp>
|
||||
@@ -280,21 +314,27 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
const auto [flat_n, flat_k] = tensor_view_tmp.get_tensor_descriptor().get_lengths();
|
||||
constexpr auto flat_k_per_block = KPerBlock * M_Warp_Tile;
|
||||
auto&& byte_tensor_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(
|
||||
flat_n, flat_k / flat_k_per_block, number<flat_k_per_block / BPackedSize>{})),
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(flat_n,
|
||||
flat_k / flat_k_per_block,
|
||||
number<flat_k_per_block / BPackedSize * sizeof(BDataType)>{})),
|
||||
make_tuple(make_pass_through_transform(flat_n),
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
flat_k / flat_k_per_block, number<flat_k_per_block / BPackedSize>{}))),
|
||||
flat_k / flat_k_per_block,
|
||||
number<flat_k_per_block / BPackedSize * sizeof(BDataType)>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
auto&& byte_ptr = reinterpret_cast<const uint8_t*>(&(tensor_view_tmp.get_buffer_view()(0)));
|
||||
auto&& byte_tensor_view =
|
||||
make_tensor_view<address_space_enum::global>(byte_ptr, byte_tensor_desc);
|
||||
auto&& origin_tmp = window_tmp.get_window_origin();
|
||||
auto origin_n = origin_tmp[0];
|
||||
auto origin_k = static_cast<int>(origin_tmp[1] * sizeof(BDataType) / BPackedSize);
|
||||
return make_tile_window(
|
||||
byte_tensor_view,
|
||||
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp / BPackedSize>{}),
|
||||
{origin_tmp[0], origin_tmp[1] / BPackedSize},
|
||||
make_tuple(number<flatNPerWarp>{},
|
||||
number<flatKPerWarp * sizeof(BDataType) / BPackedSize>{}),
|
||||
{origin_n, origin_k},
|
||||
MakeMX_BFlatBytesDramTileDistribution());
|
||||
}
|
||||
|
||||
@@ -372,7 +412,14 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
|
||||
{
|
||||
return sizeof(ADataType) * MakeMX_ALdsBytesBlockDescriptor().get_element_space_size();
|
||||
if constexpr(!std::is_same_v<ADataType, pk_fp6x16_t>)
|
||||
{
|
||||
return sizeof(ADataType) * MakeMX_ALdsBytesBlockDescriptor().get_element_space_size();
|
||||
}
|
||||
else
|
||||
{
|
||||
return MakeMX_ALdsBytesBlockDescriptor().get_element_space_size();
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return GetSmemSizeA(); }
|
||||
|
||||
Reference in New Issue
Block a user