mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 13:17:36 +00:00
Merge branch 'develop' into tianxing/unified-attention
This commit is contained in:
@@ -158,7 +158,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("stride_c", "0", "Tensor C stride")
|
||||
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
|
||||
.insert(
|
||||
"mx_prec", "fp4xfp4", "data type for activation and weight, support: fp6xfp6, fp8xfp8")
|
||||
"mx_prec", "fp4xfp4", "data type for activation and weight, support: fp4xfp4, fp8xfp8")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
|
||||
@@ -75,7 +75,7 @@ float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
HasHotLoop,
|
||||
TailNum>;
|
||||
|
||||
using MXFlatmmPipeline = ck_tile::MXF4FlatmmPipelineAGmemBGmemCRegV1<MXPipelineProblem>;
|
||||
using MXFlatmmPipeline = ck_tile::MXFlatmmPipelineAGmemBGmemCRegV1<MXPipelineProblem>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<FlatmmShape,
|
||||
|
||||
@@ -74,7 +74,7 @@ User need to select correct mapping of config for each quant mode:
|
||||
|:--------|:-----:|:-----:|-------|
|
||||
| For selecting AQuant | aquant | gemm_aquant_quantgrouped.cpp| GemmConfigQuantDecode |
|
||||
| For selecting AQuant with Preshuffle quant | aquant | gemm_aquant_quantgrouped_preshufflequant.cpp | GemmConfigPreshuffleQuantDecode |
|
||||
| For selecting BQuant | bquant | gemm_bquant_quantgrouped_<prec_type>.cpp| GemmConfigQuantDecode (or) GemmConfigBQuantPrefill |
|
||||
| For selecting BQuant | bquant | gemm_bquant_quantgrouped_<prec_type>.cpp| GemmConfigQuantDecode (or) GemmConfigQuantPrefill |
|
||||
| For selecting BQuant with Preshuffle quant | bquant | gemm_bquant_quantgrouped_preshufflequant.cpp| GemmConfigPreshuffleQuantDecode (or) GemmConfigPreshuffleBQuantPrefill |
|
||||
| For selecting PreShuffle B with BQuant | bquant | gemm_bquant_quantgrouped_preshuffleb.cpp| GemmConfigPreshuffleB_BQuant_Decode (or) GemmConfigPreshuffleB_BQuant_Prefill
|
||||
| For selecting PreShuffle B with preshuffle BQuant | bquant | gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp |GemmConfigPreshuffleB_PreshuffleBQuant_Decode (or) GemmConfigPreshuffleB_PreshuffleBQuant_Prefill
|
||||
|
||||
@@ -6,6 +6,10 @@
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigQuantDecode<T>;
|
||||
|
||||
// GemmConfigQuantPrefill is also supported for aquant grouped quantization
|
||||
// template <typename T>
|
||||
// using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
|
||||
void aquant_quantgrouped_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigBQuantPrefill<T>;
|
||||
using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigBQuantPrefill<T>;
|
||||
using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigBQuantPrefill<T>;
|
||||
using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigBQuantPrefill<T>;
|
||||
using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \
|
||||
|
||||
@@ -221,7 +221,7 @@ struct GemmConfigPreshuffleB_PreshuffleBQuant_Prefill
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigBQuantPrefill : public GemmConfigBase
|
||||
struct GemmConfigQuantPrefill : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
@@ -237,13 +237,13 @@ struct GemmConfigBQuantPrefill : public GemmConfigBase
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigBQuantPrefill<PrecType>
|
||||
struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigQuantPrefill<PrecType>
|
||||
{
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigBQuantPrefill_Wmma : public GemmConfigBQuantPrefill<PrecType>
|
||||
struct GemmConfigBQuantPrefill_Wmma : public GemmConfigQuantPrefill<PrecType>
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
|
||||
Reference in New Issue
Block a user