add mixed_prec fp16xfp4

This commit is contained in:
Feng Shijie
2025-08-08 20:19:16 +00:00
parent 3dea10a277
commit f788d3d629
9 changed files with 252 additions and 123 deletions

View File

@@ -6,7 +6,6 @@
#include "ck_tile/core.hpp"
template <typename DataType>
struct A16W4_FlatmmConfig32
{
static constexpr ck_tile::index_t M_Tile = 128;
@@ -37,18 +36,16 @@ struct A16W4_FlatmmConfig32
static constexpr bool TiledMMAPermuteN = false; // disable PermuteN when NWarpTile != 16
};
template <typename DataType>
struct A16W4_FlatmmConfig32_950 : public A16W4_FlatmmConfig32<DataType>
struct A16W4_FlatmmConfig32_950 : A16W4_FlatmmConfig32
{
};
// GEMM config with 16x16 warp tile
template <typename DataType>
struct A16W4_FlatmmConfig16
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t M_Tile = 64;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
@@ -73,17 +70,16 @@ struct A16W4_FlatmmConfig16
static constexpr bool DoubleSmemBuffer = false;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
static constexpr bool TiledMMAPermuteN = false;
};
template <typename DataType>
struct A16W4_FlatmmConfig16_950 : public A16W4_FlatmmConfig16<DataType>
struct A16W4_FlatmmConfig16_950 : public A16W4_FlatmmConfig16
{
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 128;
static constexpr int kBlockPerCu = 1;
static constexpr int N_Repeat = N_Tile / A16W4_FlatmmConfig16<DataType>::N_Warp_Tile /
A16W4_FlatmmConfig16<DataType>::N_Warp;
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
static constexpr int N_Repeat =
N_Tile / A16W4_FlatmmConfig16::N_Warp_Tile / A16W4_FlatmmConfig16::N_Warp;
static constexpr bool TiledMMAPermuteN = false;
};

View File

@@ -107,7 +107,7 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>&
tail_number_v>;
using CodegenFlatmmPipeline =
ck_tile::FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
ck_tile::MixedPrecFlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ComputeDataType,
@@ -160,10 +160,8 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>&
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
static constexpr ck_tile::index_t APackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
static constexpr ck_tile::index_t BPackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
constexpr ck_tile::index_t APackedSize = ck_tile::numeric_traits<ADataType>::PackedSize;
constexpr ck_tile::index_t BPackedSize = ck_tile::numeric_traits<BDataType>::PackedSize;
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
@@ -329,24 +327,41 @@ auto create_args(int argc, char* argv[])
return std::make_tuple(result, arg_parser);
}
template <typename FlatmmConfig, typename T>
auto shuffle_subbyte_b(const ck_tile::HostTensor<T>& t)
template <class IterSrc, class IterDst>
void preShuffleWeight(const IterSrc src, IterDst dst, int N, int K, int NXdl)
{
constexpr int PackSize = 2;
int KPack = 16;
int NLane = NXdl;
int KLane = 64 / NLane;
int K_pk = K / 2;
int K0 = K_pk / (KLane * KPack);
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K -> N0 K0 KLane NLane KPack
int tempk;
for(int n = 0; n < N; ++n)
{
for(int k = 0; k < K_pk; ++k)
{
int n0 = n / NLane;
int n1 = n % NLane;
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0] / 2;
constexpr int divisor = FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4;
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
k_ / FlatmmConfig::K_Warp_Tile,
divisor,
FlatmmConfig::K_Warp_Tile / divisor / 2});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
int k0 = k / (KLane * KPack);
tempk = k % (KLane * KPack);
int k1 = tempk / KPack;
int k2 = tempk % KPack;
int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
k1 * KPack * NLane + n1 * KPack + k2;
dst[outputIndex] = src[n * K_pk + k];
}
}
}
template <class IterSrc, class IterDst>
void preShuffleScale(const IterSrc src, IterDst dst, int N, int K, int NXdl);
#include "run_mixed_prec_flatmm.inc"
template <template <typename PrecType> typename FlatmmConfig>

View File

@@ -51,13 +51,13 @@ int run_mixed_prec_flatmm_with_layouts(int argc,
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
ck_tile::HostTensor<AccDataType> weight_dequant_scale(ck_tile::HostTensorDescriptor(
{N / DequantGranularityN, K / DequantGranularityK}, {1, N / DequantGranularityN}));
{K / DequantGranularityK, N / DequantGranularityN}, {N / DequantGranularityN, 1}));
if(init_method == 0)
{
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-4.f, 4.f}(b_origin_host);
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(weight_dequant_scale);
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(weight_dequant_scale);
}
else if(init_method == 1)
{
@@ -66,7 +66,10 @@ int run_mixed_prec_flatmm_with_layouts(int argc,
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(weight_dequant_scale);
}
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_subbyte_b<FlatmmConfig>(b_origin_host);
ck_tile::HostTensor<BDataType> b_shuffle_host(
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
preShuffleWeight(
b_origin_host.begin(), b_shuffle_host.begin(), N, K, FlatmmConfig::N_Warp_Tile);
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes());
@@ -154,9 +157,6 @@ int run_mixed_prec_flatmm_with_layouts(int argc,
c_gpu_ref_dev_buf.FromDevice(c_gpu_ref_host.data());
const float max_accumulated_value =
*std::max_element(c_gpu_ref_host.mData.begin(), c_gpu_ref_host.mData.end());
const float rtol = 1e-3;
const float atol = 1e-3;

View File

@@ -1450,7 +1450,9 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
(std::is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, pk_int4_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)),
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32) ||
(std::is_same<T, pk_fp4_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16))),
"wrong! not implemented");
using rtn_type = thread_buffer<T, N>;

View File

@@ -88,7 +88,12 @@ template <typename T, typename>
struct vector_traits
{
using scalar_type =
std::conditional_t<std::is_same_v<remove_cvref_t<T>, pk_int4_t>, int8_t, remove_cvref_t<T>>;
std::conditional_t<std::is_same_v<remove_cvref_t<T>, pk_int4_t>,
int8_t,
std::conditional_t<std::is_same_v<remove_cvref_t<T>, pk_fp4_t>,
uint8_t,
remove_cvref_t<T>>>;
static constexpr index_t vector_size = 1;
};
@@ -96,7 +101,11 @@ struct vector_traits
template <typename T, index_t N>
struct vector_traits<T __attribute__((ext_vector_type(N)))>
{
using scalar_type = std::conditional_t<std::is_same_v<T, pk_int4_t>, int8_t, T>;
using scalar_type =
std::conditional_t<std::is_same_v<T, pk_int4_t>,
int8_t,
std::conditional_t<std::is_same_v<T, pk_fp4_t>, uint8_t, T>>;
static constexpr index_t vector_size = N;
};
@@ -237,4 +246,13 @@ using pk_int4x4_t = int8_t __attribute((ext_vector_type(4)));
using pk_int4x8_t = int8_t __attribute((ext_vector_type(8)));
using pk_int4x16_t = int8_t __attribute((ext_vector_type(16)));
using pk_int4x32_t = int8_t __attribute((ext_vector_type(32)));
// pk_fp4_t
// using pk_fp4_t
using pk_fp4x2_t = uint8_t __attribute((ext_vector_type(2)));
using pk_fp4x4_t = uint8_t __attribute((ext_vector_type(4)));
using pk_fp4x8_t = uint8_t __attribute((ext_vector_type(8)));
using pk_fp4x16_t = uint8_t __attribute((ext_vector_type(16)));
using pk_fp4x32_t = uint8_t __attribute((ext_vector_type(32)));
} // namespace ck_tile

View File

@@ -169,6 +169,14 @@ __global__ void naive_gemm_kernel(ADataType* A,
else
v_a = fp32_val.lo;
}
else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
{
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a]);
if(k % 2 == 1)
v_a = fp32_val.hi;
else
v_a = fp32_val.lo;
}
else
{
v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
@@ -181,6 +189,14 @@ __global__ void naive_gemm_kernel(ADataType* A,
else
v_b = fp32_val.lo;
}
else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
{
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b]);
if(k % 2 == 1)
v_b = fp32_val.hi;
else
v_b = fp32_val.lo;
}
else
{
v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
@@ -265,10 +281,19 @@ __global__ void blockwise_gemm_kernel(ADataType* A,
else
v_a = fp32_val.lo;
}
else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
{
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a]);
if(k % 2 == 1)
v_a = fp32_val.hi;
else
v_a = fp32_val.lo;
}
else
{
v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
@@ -277,6 +302,14 @@ __global__ void blockwise_gemm_kernel(ADataType* A,
else
v_b = fp32_val.lo;
}
else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
{
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b]);
if(k % 2 == 1)
v_b = fp32_val.hi;
else
v_b = fp32_val.lo;
}
else
{
v_b = ck_tile::type_convert<AccDataType>(B[b_index]);

View File

@@ -39,6 +39,8 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
// Below type is actually accumulation data type - the output of block GEMM.
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
static constexpr int QuantPackedSize = numeric_traits<BDataType>::PackedSize;
static constexpr index_t NumDTensor = DsDataType::size();
static constexpr auto I0 = number<0>();
@@ -89,16 +91,15 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
}
}();
index_t kFlatK =
FlatmmPipeline::flatKPerWarp * (kargs.K / BlockGemmShape::WarpTile::at(I2));
index_t kFlatN = kargs.N * kargs.K / kFlatK;
index_t kFlatK = kargs.K * BlockGemmShape::WarpTile::at(I1);
index_t kFlatN = kargs.N * kargs.K / kFlatK;
const auto& b_flat_tensor_view = [&]() {
return make_naive_tensor_view<address_space_enum::global>(
b_flat_ptr,
make_tuple(kFlatN, kFlatK),
make_tuple(kFlatK, 1),
number<FlatmmPipeline::GetVectorSizeB()>{},
number<1>{});
return make_naive_tensor_view<address_space_enum::global>(b_flat_ptr,
make_tuple(kFlatN, kFlatK),
make_tuple(kFlatK, 1),
number<32>{},
number<1>{});
}();
const auto& ds_tensor_view = generate_tuple(
@@ -307,7 +308,8 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
a_block_window, b_flat_block_window, num_loop, smem_ptr_ping, smem_ptr_pong);
// Run Epilogue Pipeline
if constexpr(ScaleM::GranularityMN != -1 || ScaleN::GranularityMN != -1)
if constexpr(false && (ScaleM::GranularityMN != -1 && ScaleM::GranularityK == 0) ||
(ScaleN::GranularityMN != -1 && ScaleN::GranularityK == 0))
{
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}.template
@@ -346,8 +348,8 @@ struct MixedPrecFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, E
// options
const ADataType* a_ptr =
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
const BDataType* b_flat_ptr =
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
const BDataType* b_flat_ptr = static_cast<const BDataType*>(kargs.b_ptr) +
splitk_batch_offset.b_k_split_offset / QuantPackedSize;
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
// allocate LDS

View File

@@ -371,8 +371,39 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
sequence<1>>{});
}
template <typename Problem, int PackSize = 1>
CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution(number<PackSize> = {})
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeF16xF4_ADramDistribution()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
// constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = 16 / sizeof(ADataType);
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
constexpr index_t M1 = BlockSize / get_warp_size();
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
// constexpr index_t M0 = MPerBlock / (M2 * M1);
// static_assert(M0 * M1 * M2 == MPerBlock,
// "Incorrect M0, M2, M1 configuration! "
// "M0, M1, M2 must cover whole MPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<4>,
tuple<sequence<16>, sequence<4, 4, 8>>,
tuple<sequence<0>, sequence<2, 1>>,
tuple<sequence<0>, sequence<0, 0>>,
sequence<2>,
sequence<2>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
@@ -380,7 +411,7 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNum = BlockSize / WaveSize;
constexpr index_t KBPerLoad = GetKBPerLoad<Problem>() / PackSize;
constexpr index_t KBPerLoad = GetKBPerLoad<Problem>();
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
constexpr index_t KWavePerBlk = 1;
constexpr index_t KRepeat = 1;
@@ -407,6 +438,42 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
sequence<0, 3, 0, 3>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeFp4BFlatDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNum = BlockSize / WaveSize;
constexpr index_t KBPerLoad = 32;
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
constexpr index_t KWavePerBlk = 1;
constexpr index_t KRepeat = 1;
// static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
constexpr index_t NBPerLoad = 1;
constexpr index_t NThdPerWave = 1;
constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp
constexpr index_t NRepeat = 1;
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<WaveRepeat>, // ?
tuple<sequence<NRepeat, NWavePerBlk, NThdPerWave, NBPerLoad>, // second direction
sequence<KRepeat, KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
// wave in blk, // thd in wave
// <M, K> // <M, K>
tuple<sequence<0, 1, 2>, sequence<1, 2>>, // which direction
tuple<sequence<0, 1, 1>, sequence<2, 2>>, // which index
// <repeat, vec_load>
sequence<1, 1, 2, 2>,
sequence<0, 3, 0, 3>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDistribution()
{

View File

@@ -29,7 +29,12 @@ struct MixedPrecFlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
TailNum_,
ComputeDataType_>
{
using BlockGemmShape = BlockGemmShape_;
using QuantType = BDataType_;
static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp;
static constexpr index_t flatKPerWarp = 128;
};
template <typename Problem, typename PipelinePolicy = UniversalFlatmmPipelineAgBgCrPolicy>
@@ -68,8 +73,8 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp;
static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp;
static constexpr index_t flatKPerWarp = Problem::flatKPerWarp;
static constexpr index_t flatNPerWarp = Problem::flatNPerWarp;
static constexpr index_t GetVectorSizeA() { return Problem::VectorSizeA; }
static constexpr index_t GetVectorSizeB() { return Problem::VectorSizeB; }
@@ -168,15 +173,10 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
index_t round_data_inst = (sum_data_inst + mfma_perM_perK - 1) / mfma_perM_perK;
index_t inst_order[NIterPerWarp * 10];
#pragma unroll
for(int idx = 0; idx < NIterPerWarp * 10; idx++)
{
inst_order[idx] = 0;
}
_Pragma("unroll") for(int idx = 0; idx < NIterPerWarp * 10; idx++) { inst_order[idx] = 0; }
index_t index = 0;
#pragma unroll
for(int j = 0; j < max_data_inst; j++)
_Pragma("unroll") for(int j = 0; j < max_data_inst; j++)
{
if(dswrite_perM > j)
{
@@ -195,9 +195,8 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
}
}
// Schedule IGLP
#pragma unroll
for(int j = 0; j < mfma_perM_perK; j++)
// Schedule IGLP
_Pragma("unroll") for(int j = 0; j < mfma_perM_perK; j++)
{
index_t inst_idx = 0;
if(j == 0)
@@ -211,8 +210,7 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
#pragma unroll
for(int r = 0; r < round_data_inst; r++)
_Pragma("unroll") for(int r = 0; r < round_data_inst; r++)
{
if(r % 2 == 0)
{
@@ -325,11 +323,9 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
// 0 M7N2: 63 - - 8 -
// 0 M7N3: 64 4 - - -
#pragma unroll
for(int kIter = 0; kIter < KIterPerWarp; kIter++)
_Pragma("unroll") for(int kIter = 0; kIter < KIterPerWarp; kIter++)
{
#pragma unroll
for(int mIter = 0; mIter < MIterPerWarp; mIter++)
_Pragma("unroll") for(int mIter = 0; mIter < MIterPerWarp; mIter++)
{
index_t dsread_perM = 0;
index_t dswrite_perM = 0;
@@ -390,11 +386,9 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
CK_TILE_HOST_DEVICE static constexpr auto Last2ndHotLoopScheduler()
{
#pragma unroll
for(int kIter = 0; kIter < KIterPerWarp; kIter++)
_Pragma("unroll") for(int kIter = 0; kIter < KIterPerWarp; kIter++)
{
#pragma unroll
for(int mIter = 0; mIter < MIterPerWarp; mIter++)
_Pragma("unroll") for(int mIter = 0; mIter < MIterPerWarp; mIter++)
{
index_t dsread_perM = 0;
index_t dswrite_perM = 0;
@@ -444,11 +438,9 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
CK_TILE_HOST_DEVICE static constexpr auto LastHotLoopScheduler()
{
#pragma unroll
for(int kIter = 0; kIter < KIterPerWarp; kIter++)
_Pragma("unroll") for(int kIter = 0; kIter < KIterPerWarp; kIter++)
{
#pragma unroll
for(int mIter = 0; mIter < MIterPerWarp; mIter++)
_Pragma("unroll") for(int mIter = 0; mIter < MIterPerWarp; mIter++)
{
index_t dsread_perM = 0;
index_t dswrite_perM = 0;
@@ -524,18 +516,19 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
{0, 0},
PipelinePolicy::template MakeADramTileDistribution<Problem>());
auto A_Warp_Dist = PipelinePolicy::template MakeF16xF4_ADramDistribution<Problem>();
// ping-pong window for A LDS
auto a_warp_window_ping_tmp =
make_tile_window(a_lds_block_ping,
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
{iMWarp * WG::kM, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
A_Warp_Dist);
auto a_warp_window_pong_tmp =
make_tile_window(a_lds_block_pong,
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
{iMWarp * WG::kM, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
A_Warp_Dist);
statically_indexed_array<
statically_indexed_array<decltype(a_warp_window_ping_tmp), KIterPerWarp>,
@@ -547,12 +540,14 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
MIterPerWarp>
a_warp_windows_pong;
constexpr int KStridePerIter = 8;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
move_tile_window(a_warp_windows_ping(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
{mIter * MPerBlockPerIter, kIter * KStridePerIter});
});
});
@@ -561,7 +556,7 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
move_tile_window(a_warp_windows_pong(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
{mIter * MPerBlockPerIter, kIter * KStridePerIter});
});
});
@@ -570,9 +565,12 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
// Acc register tile
auto c_block_tile = block_flatmm.MakeCBlockTile();
constexpr int XDLPerLoadK = 4;
constexpr int QuantKPerWarp = KIterPerWarp / XDLPerLoadK;
// B flat DRAM window for load
auto b_flat_distribution =
PipelinePolicy::template MakeBFlatDramTileDistribution<Problem>(number<2>{});
PipelinePolicy::template MakeFp4BFlatDramTileDistribution<Problem>();
auto b_flat_dram_window = // tile_window_with_static_distribution
make_tile_window(
b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
@@ -582,17 +580,17 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
// pingpong buffer for B
statically_indexed_array<
statically_indexed_array<decltype(b_flat_dram_window), KIterPerWarp>,
statically_indexed_array<decltype(b_flat_dram_window), QuantKPerWarp>,
NIterPerWarp>
b_flat_dram_windows;
statically_indexed_array<
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), QuantKPerWarp>,
NIterPerWarp>
b_warp_tensor_ping;
statically_indexed_array<
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), QuantKPerWarp>,
NIterPerWarp>
b_warp_tensor_pong;
@@ -604,7 +602,7 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
// prefetch B
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, QuantKPerWarp, 1>{}([&](auto kIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
@@ -616,20 +614,6 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
// Prefill A0
// if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>)
// {
// auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
// PipelinePolicy::template MakeShuffledARegBlockDistribution<Problem>());
// shuffle_tile(a_shuffle_tmp, a_block_tile);
// const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp);
// store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
// }
// else
// {
// store_tile(a_copy_lds_window_ping, tile_elementwise_in(a_element_func,
// a_block_tile));
// }
auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
__builtin_amdgcn_sched_barrier(0);
@@ -657,12 +641,23 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
});
__builtin_amdgcn_sched_barrier(0);
auto dequant_B = typename WG::BWarpTensor{};
auto deq_fn = [&](auto& quant_weight_tensor, auto sub_idx) {
constexpr int ScalarCnt = WG::BWarpTensor::get_thread_buffer_size();
static_for<0, ScalarCnt / 2, 1>{}([&](auto i) {
dequant_B.get_thread_buffer().template set_as<fp16x2_t>(
number<i>{},
fp16x2_t(quant_weight_tensor.get_thread_buffer()[sub_idx * ScalarCnt / 2 + i]));
});
};
// MAIN LOOP
index_t iCounter = 0; // (num_loop - 1) / 2;
index_t iCounter = (num_loop - 1) / 2;
while(iCounter > 0)
{
// prefetch B(2i+1)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, QuantKPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
@@ -694,10 +689,11 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
deq_fn(b_warp_tensor_ping(nIter)(kIter / number<XDLPerLoadK>{}),
kIter % number<XDLPerLoadK>{});
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
cast_tile<ADataType>(b_warp_tensor_ping(nIter)(kIter)));
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
@@ -737,7 +733,7 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
// Next K
// prefetch B(2i+2)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, QuantKPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
@@ -768,10 +764,10 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
deq_fn(b_warp_tensor_pong(nIter)(kIter / number<XDLPerLoadK>{}),
kIter % number<XDLPerLoadK>{});
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
cast_tile<ADataType>(b_warp_tensor_pong(nIter)(kIter)));
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
@@ -815,7 +811,7 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
if constexpr(TailNum == TailNumber::Even)
{
// prefetch B(loopK)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, QuantKPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
@@ -842,10 +838,10 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
deq_fn(b_warp_tensor_ping(nIter)(kIter / number<XDLPerLoadK>{}),
kIter % number<XDLPerLoadK>{});
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
cast_tile<ADataType>(b_warp_tensor_ping(nIter)(kIter)));
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
@@ -892,10 +888,10 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
deq_fn(b_warp_tensor_pong(nIter)(kIter / number<XDLPerLoadK>{}),
kIter % number<XDLPerLoadK>{});
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
cast_tile<ADataType>(b_warp_tensor_pong(nIter)(kIter)));
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
@@ -934,10 +930,10 @@ struct MixedPrecFlatmmPipelineAGmemBGmemCRegV1
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
deq_fn(b_warp_tensor_ping(nIter)(kIter / number<XDLPerLoadK>{}),
kIter % number<XDLPerLoadK>{});
// warp GEMM
WG{}(c_warp_tensor,
a_warp_tensor(number<AwarpIter>{}),
cast_tile<ADataType>(b_warp_tensor_ping(nIter)(kIter)));
WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B);
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(