mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
add mixed_prec fp16xfp4
This commit is contained in:
@@ -39,6 +39,8 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
|
||||
// Below type is actually accumulation data type - the output of block GEMM.
|
||||
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
static constexpr int QuantPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
@@ -89,16 +91,15 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
|
||||
}
|
||||
}();
|
||||
|
||||
index_t kFlatK =
|
||||
FlatmmPipeline::flatKPerWarp * (kargs.K / BlockGemmShape::WarpTile::at(I2));
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
index_t kFlatK = kargs.K * BlockGemmShape::WarpTile::at(I1);
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
|
||||
const auto& b_flat_tensor_view = [&]() {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_flat_ptr,
|
||||
make_tuple(kFlatN, kFlatK),
|
||||
make_tuple(kFlatK, 1),
|
||||
number<FlatmmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
return make_naive_tensor_view<address_space_enum::global>(b_flat_ptr,
|
||||
make_tuple(kFlatN, kFlatK),
|
||||
make_tuple(kFlatK, 1),
|
||||
number<32>{},
|
||||
number<1>{});
|
||||
}();
|
||||
|
||||
const auto& ds_tensor_view = generate_tuple(
|
||||
@@ -307,7 +308,8 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
|
||||
a_block_window, b_flat_block_window, num_loop, smem_ptr_ping, smem_ptr_pong);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
if constexpr(ScaleM::GranularityMN != -1 || ScaleN::GranularityMN != -1)
|
||||
if constexpr(false && (ScaleM::GranularityMN != -1 && ScaleM::GranularityK == 0) ||
|
||||
(ScaleN::GranularityMN != -1 && ScaleN::GranularityK == 0))
|
||||
{
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
EpiloguePipeline{}.template
|
||||
@@ -346,8 +348,8 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
|
||||
// options
|
||||
const ADataType* a_ptr =
|
||||
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
|
||||
const BDataType* b_flat_ptr =
|
||||
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
|
||||
const BDataType* b_flat_ptr = static_cast<const BDataType*>(kargs.b_ptr) +
|
||||
splitk_batch_offset.b_k_split_offset / QuantPackedSize;
|
||||
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
|
||||
|
||||
// allocate LDS
|
||||
|
||||
@@ -371,8 +371,39 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
sequence<1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem, int PackSize = 1>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution(number<PackSize> = {})
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeF16xF4_ADramDistribution()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
// constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(ADataType);
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
constexpr index_t M1 = BlockSize / get_warp_size();
|
||||
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
// constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
// static_assert(M0 * M1 * M2 == MPerBlock,
|
||||
// "Incorrect M0, M2, M1 configuration! "
|
||||
// "M0, M1, M2 must cover whole MPerBlock!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<4>,
|
||||
tuple<sequence<16>, sequence<4, 4, 8>>,
|
||||
tuple<sequence<0>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<2>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
|
||||
|
||||
@@ -380,7 +411,7 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
|
||||
constexpr index_t KBPerLoad = GetKBPerLoad<Problem>() / PackSize;
|
||||
constexpr index_t KBPerLoad = GetKBPerLoad<Problem>();
|
||||
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
|
||||
constexpr index_t KWavePerBlk = 1;
|
||||
constexpr index_t KRepeat = 1;
|
||||
@@ -407,6 +438,42 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
sequence<0, 3, 0, 3>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeFp4BFlatDramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
|
||||
constexpr index_t KBPerLoad = 32;
|
||||
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
|
||||
constexpr index_t KWavePerBlk = 1;
|
||||
constexpr index_t KRepeat = 1;
|
||||
// static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
|
||||
|
||||
constexpr index_t NBPerLoad = 1;
|
||||
constexpr index_t NThdPerWave = 1;
|
||||
constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<WaveRepeat>, // ?
|
||||
tuple<sequence<NRepeat, NWavePerBlk, NThdPerWave, NBPerLoad>, // second direction
|
||||
sequence<KRepeat, KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
|
||||
// wave in blk, // thd in wave
|
||||
// <M, K> // <M, K>
|
||||
tuple<sequence<0, 1, 2>, sequence<1, 2>>, // which direction
|
||||
tuple<sequence<0, 1, 1>, sequence<2, 2>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<1, 1, 2, 2>,
|
||||
sequence<0, 3, 0, 3>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDistribution()
|
||||
{
|
||||
|
||||
@@ -29,7 +29,12 @@ struct MixedPrecFlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
|
||||
TailNum_,
|
||||
ComputeDataType_>
|
||||
{
|
||||
using BlockGemmShape = BlockGemmShape_;
|
||||
|
||||
using QuantType = BDataType_;
|
||||
|
||||
static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp;
|
||||
static constexpr index_t flatKPerWarp = 128;
|
||||
};
|
||||
|
||||
template <typename Problem, typename PipelinePolicy = UniversalFlatmmPipelineAgBgCrPolicy>
|
||||
@@ -68,8 +73,8 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp;
|
||||
static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp;
|
||||
static constexpr index_t flatKPerWarp = Problem::flatKPerWarp;
|
||||
static constexpr index_t flatNPerWarp = Problem::flatNPerWarp;
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Problem::VectorSizeA; }
|
||||
static constexpr index_t GetVectorSizeB() { return Problem::VectorSizeB; }
|
||||
@@ -168,15 +173,10 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
index_t round_data_inst = (sum_data_inst + mfma_perM_perK - 1) / mfma_perM_perK;
|
||||
|
||||
index_t inst_order[NIterPerWarp * 10];
|
||||
#pragma unroll
|
||||
for(int idx = 0; idx < NIterPerWarp * 10; idx++)
|
||||
{
|
||||
inst_order[idx] = 0;
|
||||
}
|
||||
_Pragma("unroll") for(int idx = 0; idx < NIterPerWarp * 10; idx++) { inst_order[idx] = 0; }
|
||||
|
||||
index_t index = 0;
|
||||
#pragma unroll
|
||||
for(int j = 0; j < max_data_inst; j++)
|
||||
_Pragma("unroll") for(int j = 0; j < max_data_inst; j++)
|
||||
{
|
||||
if(dswrite_perM > j)
|
||||
{
|
||||
@@ -195,9 +195,8 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
}
|
||||
}
|
||||
|
||||
// Schedule IGLP
|
||||
#pragma unroll
|
||||
for(int j = 0; j < mfma_perM_perK; j++)
|
||||
// Schedule IGLP
|
||||
_Pragma("unroll") for(int j = 0; j < mfma_perM_perK; j++)
|
||||
{
|
||||
index_t inst_idx = 0;
|
||||
if(j == 0)
|
||||
@@ -211,8 +210,7 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
|
||||
#pragma unroll
|
||||
for(int r = 0; r < round_data_inst; r++)
|
||||
_Pragma("unroll") for(int r = 0; r < round_data_inst; r++)
|
||||
{
|
||||
if(r % 2 == 0)
|
||||
{
|
||||
@@ -325,11 +323,9 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
// 0 M7N2: 63 - - 8 -
|
||||
// 0 M7N3: 64 4 - - -
|
||||
|
||||
#pragma unroll
|
||||
for(int kIter = 0; kIter < KIterPerWarp; kIter++)
|
||||
_Pragma("unroll") for(int kIter = 0; kIter < KIterPerWarp; kIter++)
|
||||
{
|
||||
#pragma unroll
|
||||
for(int mIter = 0; mIter < MIterPerWarp; mIter++)
|
||||
_Pragma("unroll") for(int mIter = 0; mIter < MIterPerWarp; mIter++)
|
||||
{
|
||||
index_t dsread_perM = 0;
|
||||
index_t dswrite_perM = 0;
|
||||
@@ -390,11 +386,9 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto Last2ndHotLoopScheduler()
|
||||
{
|
||||
#pragma unroll
|
||||
for(int kIter = 0; kIter < KIterPerWarp; kIter++)
|
||||
_Pragma("unroll") for(int kIter = 0; kIter < KIterPerWarp; kIter++)
|
||||
{
|
||||
#pragma unroll
|
||||
for(int mIter = 0; mIter < MIterPerWarp; mIter++)
|
||||
_Pragma("unroll") for(int mIter = 0; mIter < MIterPerWarp; mIter++)
|
||||
{
|
||||
index_t dsread_perM = 0;
|
||||
index_t dswrite_perM = 0;
|
||||
@@ -444,11 +438,9 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto LastHotLoopScheduler()
|
||||
{
|
||||
#pragma unroll
|
||||
for(int kIter = 0; kIter < KIterPerWarp; kIter++)
|
||||
_Pragma("unroll") for(int kIter = 0; kIter < KIterPerWarp; kIter++)
|
||||
{
|
||||
#pragma unroll
|
||||
for(int mIter = 0; mIter < MIterPerWarp; mIter++)
|
||||
_Pragma("unroll") for(int mIter = 0; mIter < MIterPerWarp; mIter++)
|
||||
{
|
||||
index_t dsread_perM = 0;
|
||||
index_t dswrite_perM = 0;
|
||||
@@ -524,18 +516,19 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
{0, 0},
|
||||
PipelinePolicy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
auto A_Warp_Dist = PipelinePolicy::template MakeF16xF4_ADramDistribution<Problem>();
|
||||
// ping-pong window for A LDS
|
||||
auto a_warp_window_ping_tmp =
|
||||
make_tile_window(a_lds_block_ping,
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
|
||||
{iMWarp * WG::kM, 0},
|
||||
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
|
||||
A_Warp_Dist);
|
||||
|
||||
auto a_warp_window_pong_tmp =
|
||||
make_tile_window(a_lds_block_pong,
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
|
||||
{iMWarp * WG::kM, 0},
|
||||
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
|
||||
A_Warp_Dist);
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(a_warp_window_ping_tmp), KIterPerWarp>,
|
||||
@@ -547,12 +540,14 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
MIterPerWarp>
|
||||
a_warp_windows_pong;
|
||||
|
||||
constexpr int KStridePerIter = 8;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
|
||||
|
||||
move_tile_window(a_warp_windows_ping(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
{mIter * MPerBlockPerIter, kIter * KStridePerIter});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -561,7 +556,7 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
|
||||
|
||||
move_tile_window(a_warp_windows_pong(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
{mIter * MPerBlockPerIter, kIter * KStridePerIter});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -570,9 +565,12 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
// Acc register tile
|
||||
auto c_block_tile = block_flatmm.MakeCBlockTile();
|
||||
|
||||
constexpr int XDLPerLoadK = 4;
|
||||
constexpr int QuantKPerWarp = KIterPerWarp / XDLPerLoadK;
|
||||
|
||||
// B flat DRAM window for load
|
||||
auto b_flat_distribution =
|
||||
PipelinePolicy::template MakeBFlatDramTileDistribution<Problem>(number<2>{});
|
||||
PipelinePolicy::template MakeFp4BFlatDramTileDistribution<Problem>();
|
||||
auto b_flat_dram_window = // tile_window_with_static_distribution
|
||||
make_tile_window(
|
||||
b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
|
||||
@@ -582,17 +580,17 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
// pingpong buffer for B
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_flat_dram_window), KIterPerWarp>,
|
||||
statically_indexed_array<decltype(b_flat_dram_window), QuantKPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_flat_dram_windows;
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
|
||||
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), QuantKPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_tensor_ping;
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
|
||||
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), QuantKPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_tensor_pong;
|
||||
|
||||
@@ -604,7 +602,7 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
// prefetch B
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, QuantKPerWarp, 1>{}([&](auto kIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
@@ -616,20 +614,6 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
// Prefill A0
|
||||
// if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>)
|
||||
// {
|
||||
// auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
// PipelinePolicy::template MakeShuffledARegBlockDistribution<Problem>());
|
||||
// shuffle_tile(a_shuffle_tmp, a_block_tile);
|
||||
// const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp);
|
||||
// store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// store_tile(a_copy_lds_window_ping, tile_elementwise_in(a_element_func,
|
||||
// a_block_tile));
|
||||
// }
|
||||
auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
@@ -657,12 +641,23 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
auto dequant_B = typename WG::BWarpTensor{};
|
||||
|
||||
auto deq_fn = [&](auto& quant_weight_tensor, auto sub_idx) {
|
||||
constexpr int ScalarCnt = WG::BWarpTensor::get_thread_buffer_size();
|
||||
static_for<0, ScalarCnt / 2, 1>{}([&](auto i) {
|
||||
dequant_B.get_thread_buffer().template set_as<fp16x2_t>(
|
||||
number<i>{},
|
||||
fp16x2_t(quant_weight_tensor.get_thread_buffer()[sub_idx * ScalarCnt / 2 + i]));
|
||||
});
|
||||
};
|
||||
|
||||
// MAIN LOOP
|
||||
index_t iCounter = 0; // (num_loop - 1) / 2;
|
||||
index_t iCounter = (num_loop - 1) / 2;
|
||||
while(iCounter > 0)
|
||||
{
|
||||
// prefetch B(2i+1)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, QuantKPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
@@ -694,10 +689,11 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
deq_fn(b_warp_tensor_ping(nIter)(kIter / number<XDLPerLoadK>{}),
|
||||
kIter % number<XDLPerLoadK>{});
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
cast_tile<ADataType>(b_warp_tensor_ping(nIter)(kIter)));
|
||||
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
@@ -737,7 +733,7 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
// Next K
|
||||
|
||||
// prefetch B(2i+2)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, QuantKPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
@@ -768,10 +764,10 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
deq_fn(b_warp_tensor_pong(nIter)(kIter / number<XDLPerLoadK>{}),
|
||||
kIter % number<XDLPerLoadK>{});
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
cast_tile<ADataType>(b_warp_tensor_pong(nIter)(kIter)));
|
||||
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
@@ -815,7 +811,7 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
// prefetch B(loopK)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, QuantKPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
@@ -842,10 +838,10 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
deq_fn(b_warp_tensor_ping(nIter)(kIter / number<XDLPerLoadK>{}),
|
||||
kIter % number<XDLPerLoadK>{});
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
cast_tile<ADataType>(b_warp_tensor_ping(nIter)(kIter)));
|
||||
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
@@ -892,10 +888,10 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
deq_fn(b_warp_tensor_pong(nIter)(kIter / number<XDLPerLoadK>{}),
|
||||
kIter % number<XDLPerLoadK>{});
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
cast_tile<ADataType>(b_warp_tensor_pong(nIter)(kIter)));
|
||||
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
@@ -934,10 +930,10 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
deq_fn(b_warp_tensor_ping(nIter)(kIter / number<XDLPerLoadK>{}),
|
||||
kIter % number<XDLPerLoadK>{});
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
cast_tile<ADataType>(b_warp_tensor_ping(nIter)(kIter)));
|
||||
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
|
||||
Reference in New Issue
Block a user