add mixed_prec fp16xfp4

This commit is contained in:
Feng Shijie
2025-08-08 20:19:16 +00:00
parent 3dea10a277
commit f788d3d629
9 changed files with 252 additions and 123 deletions

View File

@@ -6,7 +6,6 @@
#include "ck_tile/core.hpp"
template <typename DataType>
struct A16W4_FlatmmConfig32
{
static constexpr ck_tile::index_t M_Tile = 128;
@@ -37,18 +36,16 @@ struct A16W4_FlatmmConfig32
static constexpr bool TiledMMAPermuteN = false; // disable PermuteN when NWarpTile != 16
};
template <typename DataType>
struct A16W4_FlatmmConfig32_950 : public A16W4_FlatmmConfig32<DataType>
struct A16W4_FlatmmConfig32_950 : A16W4_FlatmmConfig32
{
};
// GEMM config with 16x16 warp tile
template <typename DataType>
struct A16W4_FlatmmConfig16
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t M_Tile = 64;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
@@ -73,17 +70,16 @@ struct A16W4_FlatmmConfig16
static constexpr bool DoubleSmemBuffer = false;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
static constexpr bool TiledMMAPermuteN = false;
};
template <typename DataType>
struct A16W4_FlatmmConfig16_950 : public A16W4_FlatmmConfig16<DataType>
struct A16W4_FlatmmConfig16_950 : public A16W4_FlatmmConfig16
{
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 128;
static constexpr int kBlockPerCu = 1;
static constexpr int N_Repeat = N_Tile / A16W4_FlatmmConfig16<DataType>::N_Warp_Tile /
A16W4_FlatmmConfig16<DataType>::N_Warp;
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
static constexpr int N_Repeat =
N_Tile / A16W4_FlatmmConfig16::N_Warp_Tile / A16W4_FlatmmConfig16::N_Warp;
static constexpr bool TiledMMAPermuteN = false;
};

View File

@@ -107,7 +107,7 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>&
tail_number_v>;
using CodegenFlatmmPipeline =
ck_tile::FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
ck_tile::MixedPrecFlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ComputeDataType,
@@ -160,10 +160,8 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>&
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
static constexpr ck_tile::index_t APackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
static constexpr ck_tile::index_t BPackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
constexpr ck_tile::index_t APackedSize = ck_tile::numeric_traits<ADataType>::PackedSize;
constexpr ck_tile::index_t BPackedSize = ck_tile::numeric_traits<BDataType>::PackedSize;
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
@@ -329,24 +327,41 @@ auto create_args(int argc, char* argv[])
return std::make_tuple(result, arg_parser);
}
template <typename FlatmmConfig, typename T>
auto shuffle_subbyte_b(const ck_tile::HostTensor<T>& t)
template <class IterSrc, class IterDst>
void preShuffleWeight(const IterSrc src, IterDst dst, int N, int K, int NXdl)
{
constexpr int PackSize = 2;
int KPack = 16;
int NLane = NXdl;
int KLane = 64 / NLane;
int K_pk = K / 2;
int K0 = K_pk / (KLane * KPack);
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K -> N0 K0 KLane NLane KPack
int tempk;
for(int n = 0; n < N; ++n)
{
for(int k = 0; k < K_pk; ++k)
{
int n0 = n / NLane;
int n1 = n % NLane;
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0] / 2;
constexpr int divisor = FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4;
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
k_ / FlatmmConfig::K_Warp_Tile,
divisor,
FlatmmConfig::K_Warp_Tile / divisor / 2});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
int k0 = k / (KLane * KPack);
tempk = k % (KLane * KPack);
int k1 = tempk / KPack;
int k2 = tempk % KPack;
int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
k1 * KPack * NLane + n1 * KPack + k2;
dst[outputIndex] = src[n * K_pk + k];
}
}
}
template <class IterSrc, class IterDst>
void preShuffleScale(const IterSrc src, IterDst dst, int N, int K, int NXdl);
#include "run_mixed_prec_flatmm.inc"
template <template <typename PrecType> typename FlatmmConfig>

View File

@@ -51,13 +51,13 @@ int run_mixed_prec_flatmm_with_layouts(int argc,
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
ck_tile::HostTensor<AccDataType> weight_dequant_scale(ck_tile::HostTensorDescriptor(
{N / DequantGranularityN, K / DequantGranularityK}, {1, N / DequantGranularityN}));
{K / DequantGranularityK, N / DequantGranularityN}, {N / DequantGranularityN, 1}));
if(init_method == 0)
{
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-4.f, 4.f}(b_origin_host);
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(weight_dequant_scale);
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(weight_dequant_scale);
}
else if(init_method == 1)
{
@@ -66,7 +66,10 @@ int run_mixed_prec_flatmm_with_layouts(int argc,
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(weight_dequant_scale);
}
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_subbyte_b<FlatmmConfig>(b_origin_host);
ck_tile::HostTensor<BDataType> b_shuffle_host(
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
preShuffleWeight(
b_origin_host.begin(), b_shuffle_host.begin(), N, K, FlatmmConfig::N_Warp_Tile);
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes());
@@ -154,9 +157,6 @@ int run_mixed_prec_flatmm_with_layouts(int argc,
c_gpu_ref_dev_buf.FromDevice(c_gpu_ref_host.data());
const float max_accumulated_value =
*std::max_element(c_gpu_ref_host.mData.begin(), c_gpu_ref_host.mData.end());
const float rtol = 1e-3;
const float atol = 1e-3;