enable fp4 for universal gemm - without any scaling

This commit is contained in:
Sami Remes
2026-02-03 03:10:35 -05:00
parent 4d241289c9
commit b47853d3fe
8 changed files with 205 additions and 113 deletions

View File

@@ -7,6 +7,7 @@
#include <variant>
#include "ck_tile/core.hpp"
#include "ck_tile/core/numeric/pk_fp4.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
@@ -137,6 +138,27 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
static constexpr int kBlockPerCu = 2;
};
template <typename PrecType>
struct GemmConfigComputeV3_3 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 64;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 128;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
static constexpr int kBlockPerCu = 2;
};
template <typename PrecType>
struct GemmConfigComputeV3_WMMA : public GemmConfigBase
{
@@ -241,6 +263,28 @@ struct GemmConfigComputeV6 : public GemmConfigBase
static constexpr ck_tile::index_t NumWaveGroups = 1;
};
template <typename PrecType>
struct GemmConfigComputeAsync : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 64;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 128;
// static constexpr ck_tile::index_t K_Warp_Tile = 32;
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_ASYNC;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool UseStructuredSparsity = false;
};
template <typename PrecType>
struct GemmConfigPreshuffleDecode : public GemmConfigBase
{
@@ -375,6 +419,15 @@ struct GemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, int32_t>
using CDataType = int32_t;
};
template <>
struct GemmTypeConfig<ck_tile::pk_fp4_t, ck_tile::pk_fp4_t, ck_tile::half_t>
{
using ADataType = ck_tile::pk_fp4_t;
using BDataType = ck_tile::pk_fp4_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <ck_tile::GemmPipeline PipelineId>
struct PipelineTypeTraits;
@@ -423,6 +476,16 @@ struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V6>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV6<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_ASYNC>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompAsync<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompAsync<PipelineProblem>;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
};
template <>
struct PipelineTypeTraits<ck_tile::GemmPipeline::PRESHUFFLE_V2>
{