mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 12:17:00 +00:00
Merge branch 'develop' into vpietila/ckb-fwd-instance-test-improvements
This commit is contained in:
35
README.md
35
README.md
@@ -93,13 +93,44 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa
|
||||
want to build the library for a list of different architectures,
|
||||
you should use the `GPU_ARCHS` build argument, for example `GPU_ARCHS=gfx908;gfx1030;gfx1100;gfx942`.
|
||||
|
||||
4. Build the entire CK library:
|
||||
**Convenience script for development builds:**
|
||||
|
||||
Alternatively, you can use the provided convenience script `script/cmake-ck-dev.sh` which automatically
|
||||
configures CK for development with sensible defaults. In the build directory:
|
||||
|
||||
```bash
|
||||
../script/cmake-ck-dev.sh
|
||||
```
|
||||
|
||||
This script:
|
||||
* Cleans CMake cache files before configuring
|
||||
* Sets `BUILD_DEV=ON` for development mode
|
||||
* Defaults to GPU targets: `gfx908;gfx90a;gfx942`
|
||||
* Enables verbose makefile output
|
||||
* Sets additional compiler flags for better error messages
|
||||
|
||||
By default, it considers the parent directory to be the project source directory.
|
||||
|
||||
You can specify the source directory as the first argument.
|
||||
You can specify custom GPU targets (semicolon-separated) as the second argument:
|
||||
|
||||
```bash
|
||||
../script/cmake-ck-dev.sh .. gfx1100
|
||||
```
|
||||
|
||||
Or pass additional cmake arguments:
|
||||
|
||||
```bash
|
||||
../script/cmake-ck-dev.sh .. gfx90a -DCMAKE_BUILD_TYPE=Release
|
||||
```
|
||||
|
||||
5. Build the entire CK library:
|
||||
|
||||
```bash
|
||||
make -j"$(nproc)"
|
||||
```
|
||||
|
||||
5. Install CK:
|
||||
6. Install CK:
|
||||
|
||||
```bash
|
||||
make -j install
|
||||
|
||||
@@ -68,7 +68,7 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
|
||||
else if(data_type == "pk_int4_t")
|
||||
{
|
||||
// TODO: Add support for bhalf_t ADataType
|
||||
if constexpr(GemmConfig::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
if constexpr(GemmConfig::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig,
|
||||
Invoker,
|
||||
|
||||
@@ -962,7 +962,7 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
|
||||
else if(data_type == "pk_int4_t")
|
||||
{
|
||||
// TODO: Add support for bhalf_t ADataType
|
||||
if constexpr(GemmConfig<ck_tile::half_t>::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
if constexpr(GemmConfig<ck_tile::half_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>,
|
||||
ck_tile::half_t,
|
||||
|
||||
@@ -12,13 +12,6 @@
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V4 3
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V5 4
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V6 5
|
||||
#define CK_TILE_PIPELINE_PRESHUFFLE_V2 6
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
@@ -69,7 +62,7 @@ struct GemmConfigBase
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool Preshuffle = false;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
@@ -91,9 +84,9 @@ struct GemmConfigMemoryInterwave : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -111,8 +104,8 @@ struct GemmConfigMemoryIntrawave : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -131,8 +124,8 @@ struct GemmConfigComputeV3 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -150,8 +143,8 @@ struct GemmConfigComputeV3_1 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -169,8 +162,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
@@ -190,8 +183,8 @@ struct GemmConfigComputeV3_WMMA : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
@@ -213,8 +206,8 @@ struct GemmConfigComputeV4 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -232,8 +225,8 @@ struct GemmConfigComputeV4_1 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -252,7 +245,7 @@ struct GemmConfigComputeV5 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 2;
|
||||
};
|
||||
|
||||
@@ -272,7 +265,7 @@ struct GemmConfigComputeV6 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V6;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V6;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
};
|
||||
|
||||
@@ -291,13 +284,13 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -315,13 +308,13 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -465,11 +458,11 @@ struct DataTypeTraits<ck_tile::int8_t>
|
||||
static constexpr const char* name = "int8";
|
||||
};
|
||||
|
||||
template <ck_tile::index_t PipelineId>
|
||||
template <ck_tile::GemmPipeline PipelineId>
|
||||
struct PipelineTypeTraits;
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::MEMORY>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
@@ -478,7 +471,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V3>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
@@ -487,7 +480,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V4>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
@@ -496,7 +489,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V5>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V5>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5<PipelineProblem>;
|
||||
@@ -505,7 +498,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V5>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V6>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V6>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV6<PipelineProblem>;
|
||||
@@ -514,7 +507,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V6>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE_V2>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::PRESHUFFLE_V2>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2<PipelineProblem>;
|
||||
|
||||
@@ -58,7 +58,7 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
|
||||
else if(data_type == "fp16i4")
|
||||
{
|
||||
// TODO: Add support for bhalf_t ADataType
|
||||
if constexpr(GemmConfig<ck_tile::half_t>::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
if constexpr(GemmConfig<ck_tile::half_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>,
|
||||
Invoker,
|
||||
@@ -73,7 +73,7 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
else if(data_type == "fp8i4")
|
||||
{
|
||||
if constexpr(GemmConfig<ck_tile::fp8_t>::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
if constexpr(GemmConfig<ck_tile::fp8_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
Invoker,
|
||||
@@ -88,7 +88,7 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
else if(data_type == "bf8i4")
|
||||
{
|
||||
if constexpr(GemmConfig<ck_tile::bf8_t>::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
if constexpr(GemmConfig<ck_tile::bf8_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
Invoker,
|
||||
|
||||
@@ -11,10 +11,6 @@
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V4 3
|
||||
|
||||
struct GemmConfigMemory
|
||||
{
|
||||
// Memory friendly for Interwave scheduler
|
||||
@@ -30,9 +26,9 @@ struct GemmConfigMemory
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 8;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
||||
};
|
||||
|
||||
struct GemmConfigV3
|
||||
@@ -50,9 +46,9 @@ struct GemmConfigV3
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
};
|
||||
|
||||
struct GemmConfigV4
|
||||
@@ -71,9 +67,9 @@ struct GemmConfigV4
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
};
|
||||
|
||||
struct GemmConfigV3_Wmma
|
||||
@@ -91,16 +87,16 @@ struct GemmConfigV3_Wmma
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
};
|
||||
|
||||
template <ck_tile::index_t PipelineId>
|
||||
template <ck_tile::GemmPipeline PipelineId>
|
||||
struct PipelineTypeTraits;
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::MEMORY>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
@@ -109,7 +105,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V3>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
@@ -118,7 +114,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V4>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
|
||||
@@ -11,11 +11,6 @@
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V4 3
|
||||
#define CK_TILE_PIPELINE_PRESHUFFLE_V2 4
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
@@ -87,7 +82,7 @@ struct GemmConfigBase
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool Preshuffle = false;
|
||||
static constexpr bool Persistent = true;
|
||||
@@ -109,8 +104,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
};
|
||||
@@ -132,8 +127,8 @@ struct GemmConfigComputeV4 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
@@ -155,8 +150,8 @@ struct GemmConfigComputeV4_V2 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
@@ -178,12 +173,12 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
|
||||
|
||||
static constexpr bool kPadK = true;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool Persistent = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool Persistent = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -201,12 +196,12 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr bool kPadK = true;
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr bool kPadK = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -226,8 +221,8 @@ struct GemmConfigComputeV4_Wmma : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
@@ -249,18 +244,18 @@ struct GemmConfigPreshuffleDecode_Wmma : public GemmConfigBase
|
||||
|
||||
static constexpr bool kPadK = true;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
};
|
||||
|
||||
template <ck_tile::index_t PipelineId>
|
||||
template <ck_tile::GemmPipeline PipelineId>
|
||||
struct PipelineTypeTraits;
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::MEMORY>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
@@ -269,7 +264,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V3>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
@@ -278,7 +273,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V4>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
@@ -287,7 +282,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE_V2>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::PRESHUFFLE_V2>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2<PipelineProblem>;
|
||||
|
||||
@@ -11,10 +11,6 @@
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V4 3
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
@@ -44,8 +40,8 @@ struct GemmConfigBase
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
static constexpr bool Preshuffle = false; // currently preshuffle == true is not supported yet
|
||||
static constexpr bool Persistent = false; // currently persistent == true is not supported yet
|
||||
static constexpr bool DoubleSmemBuffer =
|
||||
@@ -67,10 +63,10 @@ struct GemmConfigMemory : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 8;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr bool Persistent = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr bool Persistent = true;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
||||
};
|
||||
|
||||
struct GemmConfigV3 : public GemmConfigBase
|
||||
@@ -88,10 +84,10 @@ struct GemmConfigV3 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool Persistent = true;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr bool Persistent = true;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
};
|
||||
struct GemmConfigV4 : public GemmConfigBase
|
||||
{
|
||||
@@ -109,10 +105,10 @@ struct GemmConfigV4 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool Persistent = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr bool Persistent = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
};
|
||||
|
||||
struct GemmConfigV3_Wmma : public GemmConfigBase
|
||||
@@ -130,16 +126,16 @@ struct GemmConfigV3_Wmma : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
};
|
||||
|
||||
template <ck_tile::index_t PipelineId>
|
||||
template <ck_tile::GemmPipeline PipelineId>
|
||||
struct PipelineTypeTraits;
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::MEMORY>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
@@ -148,7 +144,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V3>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
@@ -157,7 +153,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V4>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
|
||||
@@ -7,12 +7,9 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V4 3
|
||||
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using D0DataType = ck_tile::half_t;
|
||||
@@ -36,9 +33,9 @@ struct GemmConfigMemory
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 8;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
||||
};
|
||||
|
||||
struct GemmConfigV3
|
||||
@@ -56,9 +53,9 @@ struct GemmConfigV3
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
};
|
||||
|
||||
struct GemmConfigV4
|
||||
@@ -77,9 +74,9 @@ struct GemmConfigV4
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
};
|
||||
|
||||
struct GemmConfigV3_Wmma
|
||||
@@ -97,16 +94,16 @@ struct GemmConfigV3_Wmma
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
};
|
||||
|
||||
template <ck_tile::index_t PipelineId>
|
||||
template <ck_tile::GemmPipeline PipelineId>
|
||||
struct PipelineTypeTraits;
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::MEMORY>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
@@ -115,7 +112,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V3>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
@@ -124,7 +121,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V4>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
|
||||
@@ -12,11 +12,6 @@
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V4 3
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V5 4
|
||||
|
||||
struct ConvConfigBase
|
||||
{
|
||||
static constexpr bool kPadM = true;
|
||||
@@ -37,7 +32,7 @@ struct ConvConfigBase
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool Preshuffle = false;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
@@ -61,9 +56,9 @@ struct ConvConfigMemoryInterwave : public ConvConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -81,8 +76,8 @@ struct ConvConfigMemoryIntrawave : public ConvConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -101,8 +96,8 @@ struct ConvConfigComputeV3 : public ConvConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 32;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -120,8 +115,8 @@ struct ConvConfigComputeV3_1 : public ConvConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -139,8 +134,8 @@ struct ConvConfigComputeV3_2 : public ConvConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 32;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
@@ -160,8 +155,8 @@ struct ConvConfigComputeV3_WMMA : public ConvConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
@@ -183,8 +178,8 @@ struct ConvConfigComputeV4 : public ConvConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -202,8 +197,8 @@ struct ConvConfigComputeV4_1 : public ConvConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -222,7 +217,7 @@ struct ConvConfigComputeV5 : public ConvConfigBase
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5;
|
||||
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
|
||||
};
|
||||
|
||||
@@ -245,8 +240,8 @@ struct ConvConfigComputeV3_merged_groups : public ConvConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 32;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
|
||||
static constexpr ck_tile::index_t NumGroupsToMerge = 2;
|
||||
};
|
||||
@@ -294,11 +289,11 @@ struct DataTypeTraits<ck_tile::bf16_t>
|
||||
static constexpr const char* name = "bf16";
|
||||
};
|
||||
|
||||
template <ck_tile::index_t PipelineId>
|
||||
template <ck_tile::GemmPipeline PipelineId>
|
||||
struct PipelineTypeTraits;
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::MEMORY>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
@@ -307,7 +302,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V3>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
@@ -316,7 +311,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V4>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
@@ -325,7 +320,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V5>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V5>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5<PipelineProblem>;
|
||||
|
||||
@@ -112,89 +112,87 @@ struct GroupedConvolutionForwardInvoker
|
||||
// =====================================================================
|
||||
// Regular Convolution: Simple, no split-image
|
||||
// =====================================================================
|
||||
const auto Run = [&]<bool EnableSplitImage>(const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
const auto Run =
|
||||
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
OutDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
OutDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
CDElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
GemmConfig::TransposeC,
|
||||
memory_operation,
|
||||
1,
|
||||
true,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
CDElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
GemmConfig::TransposeC,
|
||||
memory_operation,
|
||||
1,
|
||||
true,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
|
||||
using Kernel = ck_tile::GroupedConvolutionForwardKernel<EnableSplitImage,
|
||||
GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
using Kernel = ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
}
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << '\n'
|
||||
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << '\n'
|
||||
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
// =====================================================================
|
||||
// Split-K lambda
|
||||
@@ -202,11 +200,11 @@ struct GroupedConvolutionForwardInvoker
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run.template operator()<false>(has_hot_loop_, tail_number_, MemoryOpSet{});
|
||||
Run.template operator()(has_hot_loop_, tail_number_, MemoryOpSet{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run.template operator()<false>(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
|
||||
Run.template operator()(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -53,7 +53,10 @@ struct GroupedConvolutionForwardInvoker
|
||||
OutLayout,
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeC>;
|
||||
VectorSizeC,
|
||||
1, /*NumGroupsToMerge*/
|
||||
ck_tile::element_wise::PassThrough,
|
||||
true /*EnableSplitImage*/>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
|
||||
GemmConfig::kPadM,
|
||||
@@ -238,68 +241,64 @@ struct GroupedConvolutionForwardInvoker
|
||||
// =====================================================================
|
||||
// Kernel launch lambda: Uses EnableSplitImage based on layout support
|
||||
// =====================================================================
|
||||
const auto Run = [&]<bool EnableSplitImage>(const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
const auto Run =
|
||||
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
OutDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
OutDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
GemmConfig::TransposeC,
|
||||
memory_operation,
|
||||
1,
|
||||
true,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
GemmConfig::TransposeC,
|
||||
memory_operation,
|
||||
1,
|
||||
true,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
|
||||
// Use split-image kernel if layout supports it, otherwise use regular kernel
|
||||
using Kernel = ck_tile::GroupedConvolutionForwardKernel<EnableSplitImage,
|
||||
GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
// Use split-image kernel if layout supports it, otherwise use regular kernel
|
||||
using Kernel = ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
|
||||
// Create kargs
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
// Create kargs
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
// Populate split-image metadata ONLY if using split-image kernel
|
||||
if constexpr(EnableSplitImage)
|
||||
{
|
||||
// Populate split-image metadata ONLY if using split-image kernel
|
||||
kargs.num_spatial_pieces = total_pieces;
|
||||
kargs.split_image.total_d = total_d;
|
||||
kargs.split_image.total_h = total_h;
|
||||
@@ -320,41 +319,35 @@ struct GroupedConvolutionForwardInvoker
|
||||
temp_pieces[i].h_size,
|
||||
temp_pieces[i].w_size};
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate grid: use total_blocks for split-image, or normal GridSize for regular
|
||||
const dim3 grids = [&]() {
|
||||
if constexpr(EnableSplitImage)
|
||||
return dim3(total_blocks, kargs.GemmBatch, kargs.n_splits);
|
||||
else
|
||||
return Kernel::GridSize(kargs);
|
||||
}();
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
// Calculate grid: use total_blocks for split-image, or normal GridSize for regular
|
||||
const dim3 grids = dim3(total_blocks, kargs.GemmBatch, kargs.n_splits);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
}
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << '\n'
|
||||
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << '\n'
|
||||
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
// =====================================================================
|
||||
// Step 4: Dispatch kernel (split-image or regular based on decision)
|
||||
|
||||
@@ -11,7 +11,9 @@
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
#include "conv_configs.hpp"
|
||||
|
||||
using MemoryOpSet =
|
||||
std::integral_constant<ck_tile::memory_operation_enum, ck_tile::memory_operation_enum::set>;
|
||||
using MemoryOpAtomicAdd = std::integral_constant<ck_tile::memory_operation_enum,
|
||||
|
||||
@@ -7,16 +7,9 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V4 3
|
||||
|
||||
#ifndef CK_TILE_PIPELINE_DEFAULT
|
||||
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3
|
||||
#endif
|
||||
|
||||
using A0DataType = ck_tile::half_t;
|
||||
using A1DataType = ck_tile::half_t;
|
||||
|
||||
@@ -49,9 +42,9 @@ struct GemmConfigMemory
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 8;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
||||
};
|
||||
|
||||
struct GemmConfigV3
|
||||
@@ -69,9 +62,9 @@ struct GemmConfigV3
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
};
|
||||
|
||||
struct GemmConfigV4
|
||||
@@ -90,9 +83,9 @@ struct GemmConfigV4
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
};
|
||||
|
||||
struct GemmConfigV3_Wmma
|
||||
@@ -110,16 +103,16 @@ struct GemmConfigV3_Wmma
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
};
|
||||
|
||||
template <ck_tile::index_t PipelineId>
|
||||
template <ck_tile::GemmPipeline PipelineId>
|
||||
struct PipelineTypeTraits;
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::MEMORY>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
@@ -128,7 +121,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V3>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
@@ -137,7 +130,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V4>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
|
||||
@@ -1 +1 @@
|
||||
# Empty placeholder until we add library code.
|
||||
#Empty placeholder until we add library code.
|
||||
|
||||
@@ -425,6 +425,11 @@ struct DeviceGemm_Xdl_CShuffleV3_BPreshuffle : public DeviceGemmV2BPreshuffle<AL
|
||||
return false;
|
||||
}
|
||||
|
||||
if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(get_warp_size() == 64)
|
||||
{
|
||||
if constexpr(NXdlPerWave64 > 0)
|
||||
|
||||
@@ -40,14 +40,22 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
{
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
// Full K needed for matrix B
|
||||
const index_t Kt = karg.K;
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
|
||||
|
||||
const index_t num_k_per_block = GridwiseGemm::CalculateBK0Shuffled(karg.K);
|
||||
const index_t k_id = blockIdx.z * num_k_per_block;
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_b_grid,
|
||||
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
|
||||
p_shared,
|
||||
karg);
|
||||
karg,
|
||||
k_id,
|
||||
Kt);
|
||||
}
|
||||
#else
|
||||
ignore = karg;
|
||||
@@ -74,15 +82,23 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
// Full K needed for matrix B
|
||||
const index_t Kt = karg.K;
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
|
||||
|
||||
const index_t num_k_per_block = GridwiseGemm::CalculateBK0Shuffled(karg.K);
|
||||
const index_t k_id = blockIdx.z * num_k_per_block;
|
||||
|
||||
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_b_grid,
|
||||
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
|
||||
p_shared_0,
|
||||
p_shared_1,
|
||||
karg);
|
||||
karg,
|
||||
k_id,
|
||||
Kt);
|
||||
}
|
||||
#else
|
||||
ignore = karg;
|
||||
@@ -658,25 +674,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
|
||||
a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA;
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB;
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
if constexpr(!PermuteB)
|
||||
{
|
||||
// b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize;
|
||||
|
||||
b_k_split_offset = blockIdx.z * karg.KRead * NLane / BPackedSize;
|
||||
}
|
||||
else
|
||||
{
|
||||
const int k0_offset = karg.KRead * karg.N;
|
||||
b_k_split_offset = blockIdx.z * k0_offset / BPackedSize;
|
||||
}
|
||||
}
|
||||
|
||||
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
|
||||
{
|
||||
karg.K = karg.KRead;
|
||||
@@ -697,7 +694,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
|
||||
}
|
||||
|
||||
index_t a_k_split_offset;
|
||||
index_t b_k_split_offset;
|
||||
index_t c_reduce_offset;
|
||||
};
|
||||
|
||||
@@ -900,6 +896,11 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
|
||||
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
|
||||
"Invalid tuning param!");
|
||||
|
||||
if constexpr(NXdlPerWave % CShuffleNXdlPerWavePerShuffle != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
|
||||
@@ -1134,7 +1135,8 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
|
||||
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BPreshuffled& b_grid_desc_bpreshuffled,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock)
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const index_t k_id)
|
||||
{
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
@@ -1226,7 +1228,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
|
||||
true>(b_grid_desc_bpreshuffled,
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
k_id,
|
||||
KPack * (get_thread_local_1d_id() % WarpSize)));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
@@ -1465,10 +1467,12 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
|
||||
const BDataType* p_b_grid,
|
||||
CDataType* p_c_grid,
|
||||
void* p_shared,
|
||||
const Problem& problem)
|
||||
const Problem& problem,
|
||||
const index_t k_id,
|
||||
const index_t Kt)
|
||||
{
|
||||
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
|
||||
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
|
||||
index_t BK0Shuffled = CalculateBK0Shuffled(Kt);
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
|
||||
const auto b_grid_desc_bpreshuffled =
|
||||
@@ -1491,7 +1495,8 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
|
||||
problem,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bpreshuffled,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock);
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_id);
|
||||
}
|
||||
|
||||
template <typename AGridDesc_AK0_M_K1,
|
||||
@@ -1509,7 +1514,8 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
|
||||
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BPreshuffled& b_grid_desc_bpreshuffled,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock)
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const index_t k_id)
|
||||
{
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
@@ -1606,7 +1612,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
|
||||
true>(b_grid_desc_bpreshuffled,
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
k_id,
|
||||
KPack * (get_thread_local_1d_id() % WarpSize)));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
@@ -1849,10 +1855,12 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
|
||||
CDataType* p_c_grid,
|
||||
void* p_shared_0,
|
||||
void* p_shared_1,
|
||||
const Problem& problem)
|
||||
const Problem& problem,
|
||||
const index_t k_id,
|
||||
const index_t Kt)
|
||||
{
|
||||
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
|
||||
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
|
||||
index_t BK0Shuffled = CalculateBK0Shuffled(Kt);
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
|
||||
const auto b_grid_desc_bpreshuffled =
|
||||
@@ -1877,7 +1885,8 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
|
||||
problem,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bpreshuffled,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock);
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
k_id);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -43,18 +43,26 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
{
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
// Full K needed for matrix B
|
||||
const index_t Kt = karg.K;
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
const index_t num_k_per_block = GridwiseGemm::CalculateBK0Shuffled(karg.K);
|
||||
const index_t k_id = blockIdx.z * num_k_per_block;
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_b_grid,
|
||||
karg.p_ds_grid,
|
||||
karg.p_c_grid,
|
||||
p_shared,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.c_element_op);
|
||||
karg.c_element_op,
|
||||
k_id,
|
||||
Kt);
|
||||
}
|
||||
#else
|
||||
ignore = karg;
|
||||
@@ -79,11 +87,17 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
__shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
// Full K needed for matrix B
|
||||
const index_t Kt = karg.K;
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
const index_t num_k_per_block = GridwiseGemm::CalculateBK0Shuffled(karg.K);
|
||||
const index_t k_id = blockIdx.z * num_k_per_block;
|
||||
|
||||
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_b_grid,
|
||||
karg.p_ds_grid,
|
||||
karg.p_c_grid,
|
||||
p_shared,
|
||||
@@ -91,7 +105,9 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.c_element_op);
|
||||
karg.c_element_op,
|
||||
k_id,
|
||||
Kt);
|
||||
}
|
||||
#else
|
||||
ignore = karg;
|
||||
@@ -691,16 +707,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
a_k_split_offset = k_id * karg.KRead * karg.StrideA;
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = k_id * karg.KRead * karg.StrideB;
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
// KPack * NLane * KLane * K0 * N0
|
||||
b_k_split_offset = k_id * karg.KRead * NLane;
|
||||
}
|
||||
|
||||
if(k_id < karg.KBatch - 1)
|
||||
{
|
||||
karg.K = karg.KRead;
|
||||
@@ -712,7 +718,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
}
|
||||
|
||||
index_t a_k_split_offset;
|
||||
index_t b_k_split_offset;
|
||||
};
|
||||
|
||||
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
@@ -1163,7 +1168,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
const Problem& problem,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
CElementwiseOperation c_element_op,
|
||||
const index_t k_id,
|
||||
const index_t Kt)
|
||||
{
|
||||
const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4};
|
||||
Run<Block2CTileMapDefault, HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
@@ -1176,7 +1183,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
block_2_ctile_map);
|
||||
block_2_ctile_map,
|
||||
k_id,
|
||||
Kt);
|
||||
}
|
||||
|
||||
template <typename Block2CTileMap,
|
||||
@@ -1192,11 +1201,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
const Block2CTileMap& block_2_ctile_map,
|
||||
const index_t k_id,
|
||||
const index_t Kt)
|
||||
{
|
||||
ignore = b_element_op;
|
||||
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
|
||||
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
|
||||
index_t BK0Shuffled = CalculateBK0Shuffled(Kt);
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
|
||||
@@ -1293,7 +1304,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
true>(b_grid_desc_bpreshuffled,
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
k_id,
|
||||
KPackPerGroup * (get_thread_local_1d_id() % WarpSize)));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
@@ -1597,7 +1608,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
const Problem& problem,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
CElementwiseOperation c_element_op,
|
||||
const index_t k_id,
|
||||
const index_t Kt)
|
||||
{
|
||||
const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4};
|
||||
Run_2Lds<Block2CTileMapDefault, HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
@@ -1611,7 +1624,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
block_2_ctile_map);
|
||||
block_2_ctile_map,
|
||||
k_id,
|
||||
Kt);
|
||||
}
|
||||
|
||||
template <typename Block2CTileMap,
|
||||
@@ -1628,11 +1643,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
const Block2CTileMap& block_2_ctile_map,
|
||||
const index_t k_id,
|
||||
const index_t Kt)
|
||||
{
|
||||
ignore = b_element_op;
|
||||
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
|
||||
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
|
||||
index_t BK0Shuffled = CalculateBK0Shuffled(Kt);
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
|
||||
|
||||
@@ -1731,7 +1748,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
true>(b_grid_desc_bpreshuffled,
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
k_id,
|
||||
KPackPerGroup * (get_thread_local_1d_id() % WarpSize)));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
|
||||
@@ -10,6 +10,26 @@
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
#if !defined(CK_TILE_HAS_ROW_NEWBCAST)
|
||||
// row_newbcast (DPP modifier 0x157) support by architecture:
|
||||
// - Not supported: gfx908 (MI100) and older
|
||||
// - Supported: gfx90a (MI200), gfx94x (MI300), and all RDNA architectures
|
||||
|
||||
#if defined(__HIP_DEVICE_COMPILE__) && defined(__HIP_PLATFORM_AMD__)
|
||||
#if defined(__gfx908__) || defined(__gfx906__) || defined(__gfx900__)
|
||||
// Explicitly disable for known unsupported architectures
|
||||
#define CK_TILE_HAS_ROW_NEWBCAST 0
|
||||
#else
|
||||
// Assume support for gfx90a and newer (including all gfx94x and RDNA)
|
||||
// This is safer as new architectures typically maintain backward compatibility
|
||||
#define CK_TILE_HAS_ROW_NEWBCAST 1
|
||||
#endif
|
||||
#else
|
||||
// Conservative default for non-AMD or host compilation
|
||||
#define CK_TILE_HAS_ROW_NEWBCAST 0
|
||||
#endif
|
||||
#endif
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
#define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \
|
||||
@@ -380,18 +400,7 @@ struct MoeSortingKernel
|
||||
row_mask,
|
||||
bank_mask,
|
||||
bound_ctrl))); // row_shr:8
|
||||
#if 0
|
||||
constexpr int bank_mask_0_7 = 0b1100;
|
||||
auto reduce_op_r = [&](auto x_, auto y_) { return x_ - y_; };
|
||||
thread_data = reduce_op_r(thread_data, __builtin_bit_cast(data_t,
|
||||
__builtin_amdgcn_update_dpp(0, /* old value */
|
||||
__builtin_bit_cast(int, thread_data),
|
||||
0x157,
|
||||
row_mask,
|
||||
bank_mask_0_7,
|
||||
bound_ctrl))// row_newbcast:7
|
||||
);
|
||||
#else
|
||||
#if CK_TILE_HAS_ROW_NEWBCAST
|
||||
data_t xxx =__builtin_bit_cast(data_t,
|
||||
__builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
|
||||
0x157,
|
||||
@@ -401,6 +410,17 @@ struct MoeSortingKernel
|
||||
|
||||
data_t yyy = (__lane_id() / 8) % 2 == 0 ? 0 : xxx;
|
||||
thread_data = thread_data - yyy;
|
||||
#else
|
||||
// portable fallback for gfx908 and older: emulate row_newbcast:7 via ds_bpermute
|
||||
// For wave_size == 8 context, we need to broadcast from lane 7 of the 16-lane group
|
||||
int broadcast_src_lane = (__lane_id() & ~15) + 7; // Lane 7 of the 16-lane group
|
||||
int broadcast_addr = broadcast_src_lane << 2; // Convert to byte address
|
||||
int bcast7 = __builtin_amdgcn_ds_bpermute(broadcast_addr, __builtin_bit_cast(int, thread_data));
|
||||
|
||||
// Apply subtraction only to odd 8-lane groups (lanes 8-15 of each 16-lane unit)
|
||||
if ((__lane_id() / 8) % 2 != 0) { // Note: != 0, not == 0
|
||||
thread_data = thread_data - __builtin_bit_cast(data_t, bcast7);
|
||||
}
|
||||
#endif
|
||||
|
||||
}
|
||||
@@ -1267,18 +1287,7 @@ CK_TILE_DEVICE void moe_sorting_wave_cumsum(data_t& thread_data)
|
||||
row_mask,
|
||||
bank_mask,
|
||||
bound_ctrl))); // row_shr:8
|
||||
#if 0
|
||||
constexpr int bank_mask_0_7 = 0b1100;
|
||||
auto reduce_op_r = [&](auto x_, auto y_) { return x_ - y_; };
|
||||
thread_data = reduce_op_r(thread_data, __builtin_bit_cast(data_t,
|
||||
__builtin_amdgcn_update_dpp(0, /* old value */
|
||||
__builtin_bit_cast(int, thread_data),
|
||||
0x157,
|
||||
row_mask,
|
||||
bank_mask_0_7,
|
||||
bound_ctrl))// row_newbcast:7
|
||||
);
|
||||
#else
|
||||
#if CK_TILE_HAS_ROW_NEWBCAST
|
||||
data_t xxx =
|
||||
__builtin_bit_cast(data_t,
|
||||
__builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
|
||||
@@ -1289,6 +1298,19 @@ CK_TILE_DEVICE void moe_sorting_wave_cumsum(data_t& thread_data)
|
||||
|
||||
data_t yyy = (__lane_id() / 8) % 2 == 0 ? 0 : xxx;
|
||||
thread_data = thread_data - yyy;
|
||||
#else
|
||||
// portable fallback for gfx908 and older: emulate row_newbcast:7 via ds_bpermute
|
||||
// For wave_size == 8 context, we need to broadcast from lane 7 of the 16-lane group
|
||||
int broadcast_src_lane = (__lane_id() & ~15) + 7; // Lane 7 of the 16-lane group
|
||||
int broadcast_addr = broadcast_src_lane << 2; // Convert to byte address
|
||||
int bcast7 =
|
||||
__builtin_amdgcn_ds_bpermute(broadcast_addr, __builtin_bit_cast(int, thread_data));
|
||||
|
||||
// Apply subtraction only to odd 8-lane groups (lanes 8-15 of each 16-lane unit)
|
||||
if((__lane_id() / 8) % 2 != 0)
|
||||
{ // Note: != 0, not == 0
|
||||
thread_data = thread_data - __builtin_bit_cast(data_t, bcast7);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
if constexpr(wave_size > 8)
|
||||
|
||||
@@ -55,6 +55,7 @@
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipelines.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
|
||||
|
||||
@@ -25,6 +25,10 @@ struct BaseGemmPipelineAgBgCrCompAsync
|
||||
|
||||
CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
if(num_loop == 1)
|
||||
{
|
||||
return TailNumber::One;
|
||||
}
|
||||
if(num_loop % PrefetchStages == 1)
|
||||
{
|
||||
return TailNumber::Three;
|
||||
@@ -65,6 +69,11 @@ struct BaseGemmPipelineAgBgCrCompAsync
|
||||
return run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::Two>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return (run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::One>{}));
|
||||
}
|
||||
}
|
||||
// If execution reaches here, it's an invalid tail_number because it wasn't handled above.
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
@@ -485,7 +494,7 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
}
|
||||
}
|
||||
else
|
||||
else if(TailNum == TailNumber::Two)
|
||||
// 2 block gemms remaining
|
||||
{
|
||||
{
|
||||
@@ -500,6 +509,12 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Prob
|
||||
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
|
||||
}
|
||||
}
|
||||
else if(TailNum == TailNumber::One)
|
||||
{
|
||||
block_sync_lds();
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
return c_block_tile;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -27,6 +27,10 @@ struct BaseGemmPipelineAgBgCrCompV4
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
if(num_loop == 1)
|
||||
{
|
||||
return TailNumber::One;
|
||||
}
|
||||
if(num_loop % PrefetchStages == 1)
|
||||
{
|
||||
return TailNumber::Three;
|
||||
@@ -67,6 +71,11 @@ struct BaseGemmPipelineAgBgCrCompV4
|
||||
return run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::Two>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return (run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::One>{}));
|
||||
}
|
||||
}
|
||||
// If execution reaches here, it's an invalid tail_number because it wasn't handled above.
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
@@ -621,7 +630,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
}
|
||||
else
|
||||
else if(TailNum == TailNumber::Two)
|
||||
{
|
||||
// 2
|
||||
{
|
||||
@@ -641,6 +650,12 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
}
|
||||
else if(TailNum == TailNumber::One)
|
||||
{
|
||||
block_sync_lds();
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
return c_block_tile;
|
||||
}
|
||||
};
|
||||
|
||||
21
include/ck_tile/ops/gemm/pipeline/gemm_pipelines.hpp
Normal file
21
include/ck_tile/ops/gemm/pipeline/gemm_pipelines.hpp
Normal file
@@ -0,0 +1,21 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
enum struct GemmPipeline
|
||||
{
|
||||
COMPUTE_ASYNC,
|
||||
COMPUTE_V3,
|
||||
COMPUTE_V4,
|
||||
COMPUTE_V5,
|
||||
COMPUTE_V6,
|
||||
MEMORY,
|
||||
BASIC_V1,
|
||||
BASIC_V2,
|
||||
PRESHUFFLE_V2
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -434,14 +434,13 @@ struct GroupedConvFwdKernelArgs
|
||||
/// multiplication implementation. It is responsible for storing
|
||||
/// results calculated by @ref GemmPipeline_ "GemmPipeline" to
|
||||
/// the output C tensor in global memory.
|
||||
template <bool EnableSplitImage_,
|
||||
typename GroupedConvTraitsType_,
|
||||
template <typename GroupedConvTraitsType_,
|
||||
typename TilePartitioner_,
|
||||
typename GemmPipeline_,
|
||||
typename EpiloguePipeline_>
|
||||
struct GroupedConvolutionForwardKernel
|
||||
{
|
||||
static constexpr bool EnableSplitImage = EnableSplitImage_;
|
||||
static constexpr bool EnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
|
||||
static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial;
|
||||
static constexpr ConvolutionSpecialization ConvSpecialization =
|
||||
GroupedConvTraitsType_::ConvSpecialization;
|
||||
|
||||
@@ -63,7 +63,8 @@ template <index_t NDimSpatial_,
|
||||
index_t VectorSizeB_ = 1,
|
||||
index_t VectorSizeC_ = 1,
|
||||
index_t NumGroupsToMerge_ = 1,
|
||||
typename CDElementwise_ = PassThrough>
|
||||
typename CDElementwise_ = PassThrough,
|
||||
bool EnableSplitImage_ = false>
|
||||
struct GroupedConvTraits
|
||||
{
|
||||
private:
|
||||
@@ -74,6 +75,7 @@ struct GroupedConvTraits
|
||||
}
|
||||
|
||||
public:
|
||||
static constexpr bool EnableSplitImage = EnableSplitImage_;
|
||||
static constexpr index_t NumGroupsToMerge = NumGroupsToMerge_;
|
||||
static constexpr index_t NDimSpatial = NDimSpatial_;
|
||||
static constexpr ConvolutionSpecialization ConvSpecialization = ConvSpecialization_;
|
||||
|
||||
103
profiler/include/profiler/common.hpp
Normal file
103
profiler/include/profiler/common.hpp
Normal file
@@ -0,0 +1,103 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <type_traits>
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
template <typename DataType, typename ComputeDataType = DataType>
|
||||
inline __host__ __device__ constexpr double get_rtol()
|
||||
{
|
||||
if constexpr(std::is_same_v<DataType, float> && std::is_same_v<ComputeDataType, ck::tf32_t>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, float>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, double>)
|
||||
{
|
||||
return 1e-6;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::half_t>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
|
||||
{
|
||||
return 5e-2;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int32_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int8_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
|
||||
{
|
||||
return 1e-1; // 240 and 224 are acceptable
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
|
||||
{
|
||||
return 1.5e-1; // 57344 and 49152 are acceptable
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataType, typename ComputeDataType = DataType>
|
||||
inline __host__ __device__ constexpr double get_atol()
|
||||
{
|
||||
if constexpr(std::is_same_v<DataType, float> && std::is_same_v<ComputeDataType, ck::tf32_t>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, float>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, double>)
|
||||
{
|
||||
return 1e-6;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::half_t>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
|
||||
{
|
||||
return 5e-2;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int32_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int8_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
|
||||
{
|
||||
return 16.1; // 240 and 224 are acceptable
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
|
||||
{
|
||||
return 8192.1; // 57344 and 49152 are acceptable
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace ck
|
||||
@@ -69,19 +69,19 @@ template <typename A0DataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout>
|
||||
bool profile_gemm_blockscale_weighpreshuffle_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int StrideA,
|
||||
int StrideB,
|
||||
int StrideE,
|
||||
int n_warmup,
|
||||
int n_iter,
|
||||
uint64_t rotating = 0)
|
||||
bool profile_gemm_blockscale_weightpreshuffle_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int StrideA,
|
||||
int StrideB,
|
||||
int StrideE,
|
||||
int n_warmup,
|
||||
int n_iter,
|
||||
uint64_t rotating = 0)
|
||||
{
|
||||
bool pass = true;
|
||||
|
||||
@@ -126,6 +126,26 @@ bool profile_gemm_blockscale_weighpreshuffle_impl(int do_verification,
|
||||
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
|
||||
// Update strides based on tensor properties if they are <= 0
|
||||
auto get_stride = [](auto& tensor, auto layout, ck::index_t current_stride) -> ck::index_t {
|
||||
if(current_stride <= 0)
|
||||
{
|
||||
if constexpr(std::is_same_v<decltype(layout), tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return tensor.GetStrides()[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
return tensor.GetStrides()[1];
|
||||
}
|
||||
}
|
||||
return current_stride;
|
||||
};
|
||||
|
||||
StrideA = get_stride(a0_m_k, ALayout{}, StrideA);
|
||||
StrideB = get_stride(b0_k_n, BLayout{}, StrideB);
|
||||
StrideE = get_stride(e_m_n_host_result, ELayout{}, StrideE);
|
||||
|
||||
int total_gemm_needed =
|
||||
a0_m_k.GetElementSpaceSizeInBytes() + b0_k_n.GetElementSpaceSizeInBytes() +
|
||||
a1_m_k.GetElementSpaceSizeInBytes() + b1_k_n.GetElementSpaceSizeInBytes();
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "profiler/common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
@@ -112,6 +113,28 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
|
||||
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
|
||||
// Update strides based on tensor properties if they are <= 0
|
||||
auto get_stride = [](auto& tensor, auto layout, ck::index_t current_stride) -> ck::index_t {
|
||||
if(current_stride <= 0)
|
||||
{
|
||||
if constexpr(std::is_same_v<decltype(layout), tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return tensor.GetStrides()[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
return tensor.GetStrides()[1];
|
||||
}
|
||||
}
|
||||
return current_stride;
|
||||
};
|
||||
|
||||
StrideA = get_stride(a_m_k, ALayout{}, StrideA);
|
||||
StrideB = get_stride(b_k_n, BLayout{}, StrideB);
|
||||
StrideD0 = get_stride(d0_m_n, D0Layout{}, StrideD0);
|
||||
StrideD1 = get_stride(d1_m_n, D1Layout{}, StrideD1);
|
||||
StrideE = get_stride(e_m_n_host_result, ELayout{}, StrideE);
|
||||
|
||||
int total_gemm_needed =
|
||||
a_m_k.GetElementSpaceSizeInBytes() + b_k_n.GetElementSpaceSizeInBytes() +
|
||||
d0_m_n.GetElementSpaceSizeInBytes() + d1_m_n.GetElementSpaceSizeInBytes();
|
||||
@@ -133,7 +156,7 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-1, 2});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-1, 2});
|
||||
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-5, 5});
|
||||
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-1, 1});
|
||||
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-1, 1});
|
||||
break;
|
||||
default:
|
||||
@@ -282,8 +305,8 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
|
||||
is_same_v<EDataType, int8_t>))
|
||||
{
|
||||
std::string msg = "Error: Incorrect results!";
|
||||
double rtol = 1e-3;
|
||||
double atol = 5e-2;
|
||||
double rtol = get_rtol<EDataType>();
|
||||
double atol = get_atol<EDataType>();
|
||||
pass = pass & ck::utils::check_err(
|
||||
e_m_n_device_result, e_m_n_host_result, msg, rtol, atol);
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "profiler/common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
@@ -99,6 +100,26 @@ bool profile_gemm_universal_preshuffle_impl(int do_verification,
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
// Update strides based on tensor properties if they are <= 0
|
||||
auto get_stride = [](auto& tensor, auto layout, ck::index_t current_stride) -> ck::index_t {
|
||||
if(current_stride <= 0)
|
||||
{
|
||||
if constexpr(std::is_same_v<decltype(layout), tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return tensor.GetStrides()[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
return tensor.GetStrides()[1];
|
||||
}
|
||||
}
|
||||
return current_stride;
|
||||
};
|
||||
|
||||
StrideA = get_stride(a_m_k, ALayout{}, StrideA);
|
||||
StrideB = get_stride(b_k_n, BLayout{}, StrideB);
|
||||
StrideC = get_stride(c_m_n_host_result, CLayout{}, StrideC);
|
||||
|
||||
std::size_t total_gemm_needed =
|
||||
a_m_k.GetElementSpaceSizeInBytes() + b_k_n.GetElementSpaceSizeInBytes();
|
||||
int rotating_count = std::max(
|
||||
@@ -317,8 +338,8 @@ bool profile_gemm_universal_preshuffle_impl(int do_verification,
|
||||
is_same_v<CDataType, f8_t>)
|
||||
{
|
||||
std::string msg = "Error: Incorrect results!";
|
||||
double rtol = 1e-1;
|
||||
double atol = 1e-1;
|
||||
double rtol = get_rtol<CDataType>();
|
||||
double atol = get_atol<CDataType>();
|
||||
pass = pass & ck::utils::check_err(
|
||||
c_m_n_device_result, c_m_n_host_result, msg, rtol, atol);
|
||||
}
|
||||
|
||||
@@ -5,92 +5,11 @@
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "profiler/common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
template <typename DataType>
|
||||
inline constexpr double get_rtol()
|
||||
{
|
||||
if constexpr(std::is_same_v<DataType, float>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, double>)
|
||||
{
|
||||
return 1e-6;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::half_t>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
|
||||
{
|
||||
return 5e-2;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int32_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int8_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
|
||||
{
|
||||
return 1e-1; // 240 and 224 are acceptable
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
|
||||
{
|
||||
return 1.5e-1; // 57344 and 49152 are acceptable
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
inline constexpr double get_atol()
|
||||
{
|
||||
if constexpr(std::is_same_v<DataType, float>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, double>)
|
||||
{
|
||||
return 1e-6;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::half_t>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
|
||||
{
|
||||
return 5e-2;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int32_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, int8_t>)
|
||||
{
|
||||
return 1e-1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
|
||||
{
|
||||
return 16.1; // 240 and 224 are acceptable
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
|
||||
{
|
||||
return 8192.1; // 57344 and 49152 are acceptable
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
}
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
|
||||
@@ -126,19 +126,19 @@ int profile_gemm_blockscale_weighpreshuffle(int argc, char* argv[])
|
||||
const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
|
||||
const int DefaultStrideE = ck::is_same_v<ELayout, Row> ? N : M;
|
||||
|
||||
bool pass = ck::profiler::profile_gemm_blockscale_weighpreshuffle_impl<A0DataType,
|
||||
A1DataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
scale_block_m,
|
||||
scale_block_n,
|
||||
scale_block_k,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout>(
|
||||
bool pass = ck::profiler::profile_gemm_blockscale_weightpreshuffle_impl<A0DataType,
|
||||
A1DataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
scale_block_m,
|
||||
scale_block_n,
|
||||
scale_block_k,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout>(
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
|
||||
@@ -245,10 +245,13 @@ add_subdirectory(conv_util)
|
||||
add_subdirectory(reference_conv_fwd)
|
||||
add_subdirectory(gemm)
|
||||
add_subdirectory(gemm_add)
|
||||
add_subdirectory(gemm_blockscale_wp)
|
||||
add_subdirectory(gemm_layernorm)
|
||||
add_subdirectory(gemm_multi_abd)
|
||||
add_subdirectory(gemm_multiply_multiply_wp)
|
||||
add_subdirectory(gemm_split_k)
|
||||
add_subdirectory(gemm_universal)
|
||||
add_subdirectory(gemm_universal_preshuffle)
|
||||
add_subdirectory(gemm_b_scale)
|
||||
add_subdirectory(gemm_universal_streamk)
|
||||
add_subdirectory(gemm_reduce)
|
||||
|
||||
@@ -10,11 +10,6 @@
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V4 3
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V5 4
|
||||
|
||||
class ArgumentsNotSupportedException : public std::logic_error
|
||||
{
|
||||
public:
|
||||
@@ -56,7 +51,7 @@ struct GemmConfigBase
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
};
|
||||
|
||||
@@ -76,9 +71,9 @@ struct GemmConfigMemoryInterwave : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -96,8 +91,8 @@ struct GemmConfigMemoryIntrawave : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -116,8 +111,8 @@ struct GemmConfigComputeV3 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -135,8 +130,8 @@ struct GemmConfigComputeV3_1 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -154,8 +149,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
@@ -177,8 +172,8 @@ struct GemmConfigComputeV4 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -196,8 +191,8 @@ struct GemmConfigComputeV4_1 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -216,7 +211,7 @@ struct GemmConfigComputeV5 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5;
|
||||
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
|
||||
};
|
||||
|
||||
@@ -235,8 +230,8 @@ struct GemmConfigComputeV3_WMMA : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
@@ -401,11 +396,11 @@ struct DataTypeTraits<ck_tile::int8_t>
|
||||
static constexpr const char* name = "int8";
|
||||
};
|
||||
|
||||
template <ck_tile::index_t PipelineId>
|
||||
template <ck_tile::GemmPipeline PipelineId>
|
||||
struct PipelineTypeTraits;
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::MEMORY>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
@@ -414,7 +409,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V3>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
@@ -423,7 +418,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V4>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
@@ -432,7 +427,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V5>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V5>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5<PipelineProblem>;
|
||||
|
||||
@@ -32,43 +32,35 @@ function(create_individual_gemm_test_target datatype layout config_name trait ti
|
||||
set(target_name "test_gemm_tile_engine_${datatype}_${layout}_${config_name}_${trait}_${tile_config}")
|
||||
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}/${config_name}")
|
||||
|
||||
# Generated header path for this specific kernel configuration
|
||||
# Generated header path (already created during cmake configuration)
|
||||
set(test_header "${working_path}/gemm_single_${datatype}_${layout}_${trait}_${tile_config}.hpp")
|
||||
set(test_params_header "${working_path}/test_params.hpp")
|
||||
|
||||
# Generate kernel header using tile_engine's Python script
|
||||
add_custom_command(
|
||||
OUTPUT ${test_header}
|
||||
COMMAND ${Python3_EXECUTABLE} ${TILE_ENGINE_GEMM_DIR}/gemm_instance_builder.py
|
||||
--working_path ${working_path}
|
||||
--gpu_target "${GEMM_TEST_GPU_TARGETS}"
|
||||
--datatype ${datatype}
|
||||
--layout ${layout}
|
||||
--config_json ${config_json}
|
||||
--gen_single
|
||||
--kernel_name "test_gemm_${datatype}_${layout}_${trait}_${tile_config}"
|
||||
--tile_config "${tile_config}"
|
||||
--trait_combo "${trait}"
|
||||
DEPENDS ${TILE_ENGINE_GEMM_DIR}/gemm_instance_builder.py ${config_json}
|
||||
COMMENT "Generating test header ${test_header}"
|
||||
VERBATIM
|
||||
)
|
||||
# Verify header exists (should have been generated during cmake configuration)
|
||||
if(NOT EXISTS ${test_header})
|
||||
message(WARNING "Generated header not found: ${test_header}")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# Verify test parameters header exists
|
||||
if(NOT EXISTS ${test_params_header})
|
||||
message(WARNING "Test parameters header not found: ${test_params_header}")
|
||||
return()
|
||||
endif()
|
||||
|
||||
|
||||
# Create GTest executable for this kernel configuration
|
||||
add_gtest_executable(${target_name}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/test_gemm_simple.cpp
|
||||
)
|
||||
|
||||
# Ensure header is generated before compilation
|
||||
set(header_target "${target_name}_header")
|
||||
add_custom_target(${header_target} DEPENDS ${test_header})
|
||||
add_dependencies(${target_name} ${header_target})
|
||||
|
||||
# Configure GPU architectures for HIP compilation
|
||||
set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GEMM_TEST_GPU_TARGETS})
|
||||
|
||||
# Define preprocessor macros for generated header location
|
||||
# Define preprocessor macros for generated header location and test parameters
|
||||
target_compile_definitions(${target_name} PRIVATE
|
||||
GEMM_SINGLE_INSTANCE_HPP="${test_header}"
|
||||
GEMM_TEST_PARAMS_HPP="${test_params_header}"
|
||||
)
|
||||
|
||||
# Include directories for headers and dependencies
|
||||
@@ -87,6 +79,11 @@ function(create_individual_gemm_test_target datatype layout config_name trait ti
|
||||
-include ${test_header} # Auto-include generated header
|
||||
)
|
||||
|
||||
# Add FP8 format definitions for proper data type interpretation
|
||||
if(CK_USE_OCP_FP8)
|
||||
target_compile_options(${target_name} PRIVATE -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
message(STATUS " Created test target: ${target_name}")
|
||||
endfunction()
|
||||
|
||||
@@ -107,7 +104,6 @@ function(build_gemm_test_targets datatype layout config_name)
|
||||
# Locate and validate configuration file
|
||||
set(config_filename "${config_name}.json")
|
||||
set(json_blob "${CMAKE_CURRENT_SOURCE_DIR}/configs/${config_filename}")
|
||||
message(STATUS " Using test config: ${config_filename}")
|
||||
|
||||
if(NOT EXISTS ${json_blob})
|
||||
message(WARNING "Test config file not found: ${json_blob}")
|
||||
@@ -118,7 +114,6 @@ function(build_gemm_test_targets datatype layout config_name)
|
||||
file(MAKE_DIRECTORY ${working_path})
|
||||
|
||||
# STEP 1: Discovery phase - list all valid kernel configurations
|
||||
message(STATUS " Listing kernel configurations for ${datatype}_${layout}...")
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} -u ${TILE_ENGINE_GEMM_DIR}/gemm_instance_builder.py
|
||||
--working_path ${working_path}
|
||||
@@ -134,32 +129,90 @@ function(build_gemm_test_targets datatype layout config_name)
|
||||
)
|
||||
|
||||
if(NOT ret EQUAL 0)
|
||||
message(WARNING "Failed to list kernels for ${datatype}_${layout}: ${list_error}")
|
||||
message(WARNING "Failed to list kernels for ${datatype}_${layout}_${config_name}: ${list_error}")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# Validate kernel discovery results
|
||||
if(EXISTS ${working_path}/gemm_kernel_count.txt)
|
||||
file(READ ${working_path}/gemm_kernel_count.txt kernel_count)
|
||||
string(STRIP "${kernel_count}" kernel_count)
|
||||
message(STATUS " Found ${kernel_count} test configurations for ${datatype}_${layout}")
|
||||
else()
|
||||
message(WARNING "Kernel count file not found for ${datatype}_${layout}")
|
||||
# Verify kernel list file was generated
|
||||
if(NOT EXISTS ${working_path}/gemm_kernel_list.txt)
|
||||
message(STATUS "No kernels found for ${datatype}_${layout}_${config_name} (validation filtered out all combinations)")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# STEP 2: Generation phase - create test targets for each discovered kernel
|
||||
if(EXISTS ${working_path}/gemm_kernel_list.txt)
|
||||
file(STRINGS ${working_path}/gemm_kernel_list.txt kernel_lines)
|
||||
set(test_count 0)
|
||||
foreach(line IN LISTS kernel_lines)
|
||||
# Parse kernel specification format: kernel_name|tile_config|trait_combo
|
||||
string(REPLACE "|" ";" parts "${line}")
|
||||
list(LENGTH parts parts_len)
|
||||
if(parts_len EQUAL 3)
|
||||
list(GET parts 0 kernel_name)
|
||||
list(GET parts 1 tile_config)
|
||||
list(GET parts 2 trait_combo)
|
||||
message(STATUS "Building tests for ${datatype}_${layout}_${config_name}")
|
||||
|
||||
# STEP 2a: Extract test parameters from config
|
||||
set(test_params_file "${working_path}/test_params.hpp")
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_SOURCE_DIR}/extract_test_params.py
|
||||
--config_file ${json_blob}
|
||||
--output_file ${test_params_file}
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
RESULT_VARIABLE extract_ret
|
||||
OUTPUT_VARIABLE extract_output
|
||||
ERROR_VARIABLE extract_error
|
||||
)
|
||||
|
||||
if(NOT extract_ret EQUAL 0)
|
||||
message(WARNING "Failed to extract test parameters for ${datatype}_${layout}: ${extract_error}")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# STEP 2b: Header generation phase - generate headers using --gen_single
|
||||
message(STATUS " Generating headers using --gen_single...")
|
||||
|
||||
file(STRINGS ${working_path}/gemm_kernel_list.txt kernel_lines)
|
||||
set(gen_count 0)
|
||||
|
||||
foreach(line IN LISTS kernel_lines)
|
||||
# Parse kernel specification format: kernel_name|tile_config|trait_combo
|
||||
string(REPLACE "|" ";" parts "${line}")
|
||||
list(LENGTH parts parts_len)
|
||||
if(parts_len EQUAL 3)
|
||||
list(GET parts 0 kernel_name)
|
||||
list(GET parts 1 tile_config)
|
||||
list(GET parts 2 trait_combo)
|
||||
|
||||
# Generate header using --gen_single
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} -u ${TILE_ENGINE_GEMM_DIR}/gemm_instance_builder.py
|
||||
--working_path ${working_path}
|
||||
--gpu_target "${GEMM_TEST_GPU_TARGETS}"
|
||||
--datatype ${datatype}
|
||||
--layout ${layout}
|
||||
--config_json ${json_blob}
|
||||
--gen_single
|
||||
--kernel_name "${kernel_name}"
|
||||
--tile_config "${tile_config}"
|
||||
--trait_combo "${trait_combo}"
|
||||
WORKING_DIRECTORY ${TILE_ENGINE_GEMM_DIR}
|
||||
RESULT_VARIABLE gen_ret
|
||||
OUTPUT_VARIABLE gen_output
|
||||
ERROR_VARIABLE gen_error
|
||||
)
|
||||
|
||||
if(NOT gen_ret EQUAL 0)
|
||||
message(WARNING "Failed to generate header for ${kernel_name}: ${gen_error}")
|
||||
else()
|
||||
math(EXPR gen_count "${gen_count} + 1")
|
||||
endif()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
message(STATUS " Generated ${gen_count} headers for ${datatype}_${layout}")
|
||||
|
||||
# STEP 3: Target creation phase - create test targets
|
||||
message(STATUS " Creating test targets...")
|
||||
file(STRINGS ${working_path}/gemm_kernel_list.txt kernel_lines)
|
||||
set(test_count 0)
|
||||
foreach(line IN LISTS kernel_lines)
|
||||
# Parse kernel specification format: kernel_name|tile_config|trait_combo
|
||||
string(REPLACE "|" ";" parts "${line}")
|
||||
list(LENGTH parts parts_len)
|
||||
if(parts_len EQUAL 3)
|
||||
list(GET parts 0 kernel_name)
|
||||
list(GET parts 1 tile_config)
|
||||
list(GET parts 2 trait_combo)
|
||||
|
||||
# Generate test target for this kernel configuration
|
||||
create_individual_gemm_test_target("${datatype}" "${layout}" "${config_name}" "${trait_combo}" "${tile_config}" "${json_blob}")
|
||||
@@ -167,12 +220,7 @@ function(build_gemm_test_targets datatype layout config_name)
|
||||
endif()
|
||||
endforeach()
|
||||
message(STATUS " Created ${test_count} test targets for ${datatype}_${layout}")
|
||||
else()
|
||||
message(WARNING "Kernel list file not found for ${datatype}_${layout}")
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
# ============================================================================
|
||||
endfunction()# ============================================================================
|
||||
# MAIN EXECUTION - Test Target Generation
|
||||
# ============================================================================
|
||||
|
||||
@@ -198,42 +246,100 @@ endif()
|
||||
|
||||
message(STATUS "Building GEMM tile engine tests for GPU targets: ${GEMM_TEST_GPU_TARGETS}")
|
||||
|
||||
# ============================================================================
|
||||
# Test Configuration Matrix
|
||||
# ============================================================================
|
||||
# Enable parallel compilation optimizations
|
||||
# Set up job pools for better parallel compilation control
|
||||
set_property(GLOBAL PROPERTY JOB_POOLS
|
||||
compile_heavy=4 # Limit heavy compilations to prevent OOM
|
||||
compile_normal=16 # Allow more parallel normal compilations
|
||||
)
|
||||
|
||||
# Available test configurations (minimal set for fast CI/testing)
|
||||
set(TEST_CONFIGS
|
||||
"simple_test_config"
|
||||
# "medium_tiles_config" # Uncomment for broader testing
|
||||
)
|
||||
|
||||
# Data types for testing (core precision types)
|
||||
set(TEST_DATATYPES "fp16" "bf16")
|
||||
# Extended data type options:
|
||||
# set(TEST_DATATYPES "fp16" "bf16" "fp32" "fp64" "int8")
|
||||
|
||||
# Matrix layouts for testing (row-column-row is most common)
|
||||
set(TEST_LAYOUTS "rcr")
|
||||
# Extended layout options:
|
||||
# set(TEST_LAYOUTS "rcr" "rrr" "ccr" "crr")
|
||||
# Enable compiler cache if available and explicitly requested
|
||||
# Disabled by default due to permission issues in CI environments
|
||||
option(ENABLE_CCACHE_TESTS "Enable ccache for test compilation" OFF)
|
||||
if(ENABLE_CCACHE_TESTS)
|
||||
find_program(CCACHE_PROGRAM ccache)
|
||||
if(CCACHE_PROGRAM)
|
||||
set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM})
|
||||
message(STATUS "Using ccache for faster test compilation")
|
||||
else()
|
||||
message(WARNING "ccache requested but not found")
|
||||
endif()
|
||||
else()
|
||||
message(STATUS "ccache disabled for tests (use -DENABLE_CCACHE_TESTS=ON to enable)")
|
||||
endif()
|
||||
|
||||
# ============================================================================
|
||||
# Test Target Generation Loop
|
||||
# Test Configuration Matrix - Clean Focused Design
|
||||
# ============================================================================
|
||||
|
||||
foreach(datatype IN LISTS TEST_DATATYPES)
|
||||
foreach(layout IN LISTS TEST_LAYOUTS)
|
||||
foreach(config IN LISTS TEST_CONFIGS)
|
||||
set(CONFIG_FILE "${CMAKE_CURRENT_SOURCE_DIR}/configs/${config}.json")
|
||||
if(EXISTS ${CONFIG_FILE})
|
||||
message(STATUS "Building tests for ${datatype}_${layout}_${config}")
|
||||
build_gemm_test_targets("${datatype}" "${layout}" "${config}")
|
||||
else()
|
||||
message(WARNING "Config file not found: ${CONFIG_FILE}")
|
||||
endif()
|
||||
# All supported data types and layouts for comprehensive testing
|
||||
# Note: fp64 not included (no MFMA hardware support)
|
||||
set(TEST_DATATYPES "fp16;fp8;bf16;fp32")
|
||||
set(TEST_LAYOUTS "rcr;rrr;ccr;crr")
|
||||
|
||||
# ============================================================================
|
||||
# Test Target Generation - Datatype-Specific Categories
|
||||
# ============================================================================
|
||||
|
||||
# 1. SMALL DATATYPES: Test optimized config for small data types (fp8, fp16, bf16)
|
||||
# These data types can use larger warp tiles due to smaller memory footprint
|
||||
set(SMALL_DATATYPE_CONFIG "small_datatype_config")
|
||||
set(SMALL_DATATYPE_CONFIG_FILE "${CMAKE_CURRENT_SOURCE_DIR}/configs/${SMALL_DATATYPE_CONFIG}.json")
|
||||
set(SMALL_DATATYPES "fp8;fp16;bf16")
|
||||
|
||||
if(EXISTS ${SMALL_DATATYPE_CONFIG_FILE})
|
||||
message(STATUS "Processing small datatype config: ${SMALL_DATATYPE_CONFIG} (fp8, fp16, bf16)")
|
||||
foreach(datatype IN LISTS SMALL_DATATYPES)
|
||||
# fp8, fp16, bf16: testing all layouts (rcr, rrr, ccr, crr)
|
||||
foreach(layout IN LISTS TEST_LAYOUTS)
|
||||
build_gemm_test_targets("${datatype}" "${layout}" "${SMALL_DATATYPE_CONFIG}")
|
||||
endforeach()
|
||||
endforeach()
|
||||
endforeach()
|
||||
else()
|
||||
message(WARNING "Small datatype config file not found: ${SMALL_DATATYPE_CONFIG_FILE}")
|
||||
endif()
|
||||
|
||||
message(STATUS "GEMM tile engine tests configured for ${TEST_DATATYPES} with ${TEST_LAYOUTS} layouts using ${TEST_CONFIGS} configurations")
|
||||
# 2. PADDING COVERAGE: Test padding combinations with fixed fp16/rcr configuration
|
||||
# This focuses on padding behavior (pad_m, pad_n, pad_k)
|
||||
set(PADDING_CONFIG "padding_coverage_config")
|
||||
set(PADDING_CONFIG_FILE "${CMAKE_CURRENT_SOURCE_DIR}/configs/${PADDING_CONFIG}.json")
|
||||
|
||||
if(EXISTS ${PADDING_CONFIG_FILE})
|
||||
message(STATUS "Processing padding config: ${PADDING_CONFIG} (fp16/rcr only)")
|
||||
build_gemm_test_targets("fp16" "rcr" "${PADDING_CONFIG}")
|
||||
else()
|
||||
message(WARNING "Padding config file not found: ${PADDING_CONFIG_FILE}")
|
||||
endif()
|
||||
|
||||
# 3. COVERAGE LEVEL: Quick or comprehensive testing
|
||||
# Quick: ~144 kernels with multiple tile sizes and trait combinations
|
||||
# Comprehensive: Several thousand kernels with extensive tile sizes, warp configurations, and all trait combinations
|
||||
set(COVERAGE_LEVEL "quick" CACHE STRING "Coverage level: quick or comprehensive")
|
||||
set_property(CACHE COVERAGE_LEVEL PROPERTY STRINGS "quick" "comprehensive")
|
||||
|
||||
if(COVERAGE_LEVEL STREQUAL "quick")
|
||||
set(COVERAGE_CONFIG "quick_coverage_config")
|
||||
set(COVERAGE_DESC "Quick - approximately 144 kernels with trait combinations")
|
||||
elseif(COVERAGE_LEVEL STREQUAL "comprehensive")
|
||||
set(COVERAGE_CONFIG "comprehensive_coverage_config")
|
||||
set(COVERAGE_DESC "Comprehensive - several thousand kernels with extensive tile and trait coverage")
|
||||
else()
|
||||
message(FATAL_ERROR "Invalid COVERAGE_LEVEL: ${COVERAGE_LEVEL}. Must be 'quick' or 'comprehensive'")
|
||||
endif()
|
||||
|
||||
set(COVERAGE_CONFIG_FILE "${CMAKE_CURRENT_SOURCE_DIR}/configs/${COVERAGE_CONFIG}.json")
|
||||
|
||||
if(EXISTS ${COVERAGE_CONFIG_FILE})
|
||||
message(STATUS "Processing coverage config: ${COVERAGE_LEVEL} - ${COVERAGE_DESC}")
|
||||
build_gemm_test_targets("fp16" "rcr" "${COVERAGE_CONFIG}")
|
||||
else()
|
||||
message(WARNING "Coverage config file not found: ${COVERAGE_CONFIG_FILE}")
|
||||
endif()
|
||||
# ============================================================================
|
||||
|
||||
|
||||
message(STATUS "GEMM tile engine tests configured with datatype-specific design:")
|
||||
message(STATUS " - Small datatypes: fp8/fp16/bf16 (all layouts)")
|
||||
message(STATUS " - Padding coverage with fp16/rcr")
|
||||
message(STATUS " - Coverage level: ${COVERAGE_LEVEL} (~144 kernels quick, several thousand comprehensive)")
|
||||
message(STATUS " Use -DCOVERAGE_LEVEL=comprehensive for extensive testing")
|
||||
|
||||
@@ -17,11 +17,69 @@ JSON Config → tile_engine Python scripts → Generated Headers → Test Execut
|
||||
```
|
||||
|
||||
- **`--list_kernels`**: Get available kernel configurations from JSON
|
||||
- **`--gen_individual`**: Generate all kernel headers in parallel during CMake configuration
|
||||
- **`--gen_single`**: Generate individual kernel header for each configuration
|
||||
- **Same verification**: Uses tile_engine's adaptive error thresholds and reference calculations
|
||||
- **Same patterns**: Follows tile_engine's tensor initialization, stride calculation, and kernel launching
|
||||
|
||||
### Config-Specific Test Parameters
|
||||
|
||||
Each test configuration can specify optimized problem sizes in its JSON file:
|
||||
- **`test_params.problem_sizes`**: Array of `{m, n, k, split_k}` configurations
|
||||
- **CMake extraction**: `extract_test_params.py` generates config-specific test parameter files
|
||||
- **Build integration**: Each test target uses parameters appropriate for its kernel configuration
|
||||
- **Optimized testing**: Different configs test different problem sizes that showcase their strengths
|
||||
|
||||
|
||||
The key idea: **Unit tests that use tile_engine's exact kernel generation and verification methodology** instead of creating separate test infrastructure.
|
||||
|
||||
## Test Configurations
|
||||
|
||||
### 1. **Simple Test** (`simple_test_config.json`)
|
||||
- **Purpose**: Basic functionality validation
|
||||
- **Config**: 128x128x64, warp 2x2x1, warp_tile 16x16x16
|
||||
- **Traits**: compv3 + compv4 pipelines
|
||||
- **Coverage**: ~2 kernels per datatype/layout
|
||||
|
||||
### 2. **Small Datatype** (`small_datatype_config.json`)
|
||||
- **Purpose**: Optimized for fp8/fp16/bf16 data types
|
||||
- **Config**: 128x128x32, warp 2x2x1, warp_tile 32x32x16
|
||||
- **Traits**: compv3 pipeline only
|
||||
- **Coverage**: All 4 layouts (rcr, rrr, ccr, crr) for fp8, fp16, bf16
|
||||
|
||||
### 3. **Padding Coverage** (`padding_coverage_config.json`)
|
||||
- **Purpose**: Test padding behavior with all padding flags enabled
|
||||
- **Config**: Fixed 64x64x32, warp 2x2x1, warp_tile 32x32x16
|
||||
- **Padding**: All enabled (pad_m=true, pad_n=true, pad_k=true)
|
||||
- **Problem sizes**: Vector-aligned but not tile-aligned (104×104×56, 200×152×80, 152×200×64)
|
||||
- **Coverage**: 1 kernel configuration testing padding with irregular sizes
|
||||
|
||||
### 4. **Coverage Testing** (Quick or Comprehensive)
|
||||
- **Purpose**: Comprehensive testing across tile sizes, warp configurations, and trait combinations
|
||||
- **Quick** (`quick_coverage_config.json`): Approximately 144 kernels
|
||||
- tile_m/n: [32, 64, 256], tile_k: [16, 32]
|
||||
- warp config: 2×2×1, warp_tile 16×16×16
|
||||
- Traits: 3 pipelines × 2 epilogues × 2 schedulers (persistent=false only)
|
||||
- Focused set testing trait combinations with multiple tile sizes
|
||||
- **Comprehensive** (`comprehensive_coverage_config.json`): Several thousand kernels
|
||||
- tile_m/n: [16-256 step 16]
|
||||
- tile_k: [16, 32, 64]
|
||||
- warp_m/n: [1, 2, 4], warp_tile_m/n: [16, 32], warp_tile_k: [16, 32]
|
||||
- Traits: 3 pipelines × 2 epilogues × 2 schedulers × 2 persistent
|
||||
- Extensive coverage across all tile sizes, warp configurations, and trait combinations
|
||||
- Exact count varies based on validation filtering
|
||||
- **Note**: Use CMake option `-DCOVERAGE_LEVEL=comprehensive` to enable comprehensive testing (default is quick)
|
||||
|
||||
## Data Type Support
|
||||
- ✅ **fp8, fp16, bf16**: Fully supported - all layouts (rcr, rrr, ccr, crr)
|
||||
- ❌ **fp64**: Not supported (hardware MFMA limitation)
|
||||
- ⏳ **fp32, bf8, pk-int4-t**: Not yet supported by gemm_instance_builder (will be added later)
|
||||
|
||||
## Test Result Behavior
|
||||
|
||||
Tests automatically handle unsupported configurations through runtime validation:
|
||||
- **PASSED**: Kernel executed correctly with results within error thresholds ✅
|
||||
- **SKIPPED**: Kernel validation returned "Arguments not supported" (expected for certain problem sizes/configurations) ⚠️
|
||||
- **FAILED**: Actual error or incorrect computation results ❌
|
||||
|
||||
When a kernel's `IsSupportedArgument()` check fails (e.g., due to vector alignment requirements, dimension constraints, or padding limitations), the test is automatically skipped rather than failed. This allows comprehensive testing across various problem sizes while gracefully handling configurations that don't meet specific kernel requirements.
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
{
|
||||
"problem": {
|
||||
"description": "Comprehensive coverage testing - extensive tile size coverage (16-256, step 16) with multiple warp configurations and all trait combinations. Several thousand kernels."
|
||||
},
|
||||
"test_params": {
|
||||
"problem_sizes": [
|
||||
{"m": 512, "n": 512, "k": 256, "split_k": 1},
|
||||
{"m": 1024, "n": 512, "k": 512, "split_k": 1},
|
||||
{"m": 512, "n": 1024, "k": 512, "split_k": 1},
|
||||
{"m": 1024, "n": 1024, "k": 256, "split_k": 1},
|
||||
{"m": 1024, "n": 1024, "k": 256, "split_k": 2},
|
||||
{"m": 1024, "n": 1024, "k": 256, "split_k": 4}
|
||||
]
|
||||
},
|
||||
"tile_config": {
|
||||
"tile_m": {"values": [16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256]},
|
||||
"tile_n": {"values": [16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256]},
|
||||
"tile_k": {"values": [16, 32, 64]},
|
||||
"warp_m": {"values": [1, 2, 4]},
|
||||
"warp_n": {"values": [1, 2, 4]},
|
||||
"warp_k": {"values": [1]},
|
||||
"warp_tile_m": {"values": [16, 32]},
|
||||
"warp_tile_n": {"values": [16, 32]},
|
||||
"warp_tile_k": {"values": [8, 16, 32, 64, 128]}
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {"values": ["mem", "compv3", "compv4"]},
|
||||
"epilogue": {"values": ["default", "cshuffle"]},
|
||||
"scheduler": {"values": ["intrawave", "interwave"]},
|
||||
"pad_m": {"values": [false]},
|
||||
"pad_n": {"values": [false]},
|
||||
"pad_k": {"values": [false]},
|
||||
"persistent": {"values": [true, false]}
|
||||
},
|
||||
"k_block_per_cu": 1,
|
||||
"permute_n": false
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
{
|
||||
"problem": {
|
||||
"description": "Configuration optimized for large data types (fp32) with smaller warp tiles due to memory constraints"
|
||||
},
|
||||
"test_params": {
|
||||
"problem_sizes": [
|
||||
{"m": 512, "n": 512, "k": 128, "split_k": 1},
|
||||
{"m": 512, "n": 256, "k": 192, "split_k": 1},
|
||||
{"m": 256, "n": 384, "k": 192, "split_k": 1}
|
||||
]
|
||||
},
|
||||
"tile_config": {
|
||||
"tile_m": {"values": [256]},
|
||||
"tile_n": {"values": [128]},
|
||||
"tile_k": {"values": [32]},
|
||||
"warp_m": {"values": [2]},
|
||||
"warp_n": {"values": [2]},
|
||||
"warp_k": {"values": [1]},
|
||||
"warp_tile_m": {"values": [16]},
|
||||
"warp_tile_n": {"values": [16]},
|
||||
"warp_tile_k": {"values": [16]}
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {"values": ["compv3"]},
|
||||
"epilogue": {"values": ["default"]},
|
||||
"scheduler": {"values": ["intrawave"]},
|
||||
"pad_m": {"values": [false]},
|
||||
"pad_n": {"values": [false]},
|
||||
"pad_k": {"values": [false]},
|
||||
"persistent": {"values": [false]}
|
||||
},
|
||||
"k_block_per_cu": 1,
|
||||
"permute_n": false
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
{
|
||||
"problem": {
|
||||
"description": "Padding coverage testing - fixed config with fp16/rcr, varying only padding combinations"
|
||||
},
|
||||
"test_params": {
|
||||
"problem_sizes": [
|
||||
{"m": 104, "n": 104, "k": 56, "split_k": 1},
|
||||
{"m": 200, "n": 152, "k": 80, "split_k": 1},
|
||||
{"m": 152, "n": 200, "k": 64, "split_k": 1}
|
||||
]
|
||||
},
|
||||
"tile_config": {
|
||||
"tile_m": {"values": [64]},
|
||||
"tile_n": {"values": [64]},
|
||||
"tile_k": {"values": [32]},
|
||||
"warp_m": {"values": [2]},
|
||||
"warp_n": {"values": [2]},
|
||||
"warp_k": {"values": [1]},
|
||||
"warp_tile_m": {"values": [32]},
|
||||
"warp_tile_n": {"values": [32]},
|
||||
"warp_tile_k": {"values": [16]}
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {"values": ["compv3"]},
|
||||
"epilogue": {"values": ["default"]},
|
||||
"scheduler": {"values": ["intrawave"]},
|
||||
"pad_m": {"values": [true]},
|
||||
"pad_n": {"values": [true]},
|
||||
"pad_k": {"values": [true]},
|
||||
"persistent": {"values": [false]}
|
||||
},
|
||||
"k_block_per_cu": 1,
|
||||
"permute_n": false
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
{
|
||||
"problem": {
|
||||
"description": "Quick coverage testing - tests multiple tile sizes with all trait combinations (pipelines, epilogues, schedulers). Approximately 144 kernels."
|
||||
},
|
||||
"test_params": {
|
||||
"problem_sizes": [
|
||||
{"m": 512, "n": 1024, "k": 512, "split_k": 1},
|
||||
{"m": 1024, "n": 1024, "k": 256, "split_k": 2},
|
||||
{"m": 1024, "n": 1024, "k": 256, "split_k": 4}
|
||||
]
|
||||
},
|
||||
"tile_config": {
|
||||
"tile_m": {"values": [32, 64, 256]},
|
||||
"tile_n": {"values": [32, 64, 256]},
|
||||
"tile_k": {"values": [16, 32]},
|
||||
"warp_m": {"values": [2]},
|
||||
"warp_n": {"values": [2]},
|
||||
"warp_k": {"values": [1]},
|
||||
"warp_tile_m": {"values": [16]},
|
||||
"warp_tile_n": {"values": [16]},
|
||||
"warp_tile_k": {"values": [16]}
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {"values": ["mem", "compv3", "compv4"]},
|
||||
"epilogue": {"values": ["default", "cshuffle"]},
|
||||
"scheduler": {"values": ["intrawave", "interwave"]},
|
||||
"pad_m": {"values": [false]},
|
||||
"pad_n": {"values": [false]},
|
||||
"pad_k": {"values": [false]},
|
||||
"persistent": {"values": [false]}
|
||||
},
|
||||
"k_block_per_cu": 1,
|
||||
"permute_n": false
|
||||
}
|
||||
@@ -1,88 +1,33 @@
|
||||
{
|
||||
"problem": {
|
||||
"description": "Basic functionality validation with moderate problem sizes"
|
||||
},
|
||||
"test_params": {
|
||||
"problem_sizes": [
|
||||
{"m": 256, "n": 256, "k": 128, "split_k": 1},
|
||||
{"m": 512, "n": 256, "k": 256, "split_k": 1},
|
||||
{"m": 256, "n": 512, "k": 256, "split_k": 1}
|
||||
]
|
||||
},
|
||||
"tile_config": {
|
||||
"tile_m": {
|
||||
"values": [
|
||||
128
|
||||
]
|
||||
},
|
||||
"tile_n": {
|
||||
"values": [
|
||||
128
|
||||
]
|
||||
},
|
||||
"tile_k": {
|
||||
"values": [
|
||||
64
|
||||
]
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [
|
||||
2
|
||||
]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [
|
||||
2
|
||||
]
|
||||
},
|
||||
"warp_k": {
|
||||
"values": [
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
}
|
||||
"tile_m": {"values": [128]},
|
||||
"tile_n": {"values": [128]},
|
||||
"tile_k": {"values": [64]},
|
||||
"warp_m": {"values": [2]},
|
||||
"warp_n": {"values": [2]},
|
||||
"warp_k": {"values": [1]},
|
||||
"warp_tile_m": {"values": [16]},
|
||||
"warp_tile_n": {"values": [16]},
|
||||
"warp_tile_k": {"values": [16]}
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {
|
||||
"values": [
|
||||
"compv3",
|
||||
"compv4"
|
||||
]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": [
|
||||
"intrawave"
|
||||
]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": [
|
||||
"default"
|
||||
]
|
||||
},
|
||||
"pad_m": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_n": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_k": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"persistent": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
}
|
||||
"pipeline": {"values": ["compv3", "compv4"]},
|
||||
"epilogue": {"values": ["default"]},
|
||||
"scheduler": {"values": ["intrawave"]},
|
||||
"pad_m": {"values": [false]},
|
||||
"pad_n": {"values": [false]},
|
||||
"pad_k": {"values": [false]},
|
||||
"persistent": {"values": [false]}
|
||||
},
|
||||
"k_block_per_cu": 1,
|
||||
"permute_n": false
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
{
|
||||
"problem": {
|
||||
"description": "Configuration optimized for small data types (fp8, fp16, bf16) with larger warp tiles"
|
||||
},
|
||||
"test_params": {
|
||||
"problem_sizes": [
|
||||
{"m": 512, "n": 512, "k": 256, "split_k": 1},
|
||||
{"m": 1024, "n": 512, "k": 512, "split_k": 1},
|
||||
{"m": 512, "n": 1024, "k": 512, "split_k": 1},
|
||||
{"m": 1024, "n": 1024, "k": 256, "split_k": 1}
|
||||
]
|
||||
},
|
||||
"tile_config": {
|
||||
"tile_m": {"values": [128]},
|
||||
"tile_n": {"values": [128]},
|
||||
"tile_k": {"values": [32]},
|
||||
"warp_m": {"values": [2]},
|
||||
"warp_n": {"values": [2]},
|
||||
"warp_k": {"values": [1]},
|
||||
"warp_tile_m": {"values": [32]},
|
||||
"warp_tile_n": {"values": [32]},
|
||||
"warp_tile_k": {"values": [16]}
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {"values": ["compv3"]},
|
||||
"epilogue": {"values": ["default"]},
|
||||
"scheduler": {"values": ["intrawave"]},
|
||||
"pad_m": {"values": [false]},
|
||||
"pad_n": {"values": [false]},
|
||||
"pad_k": {"values": [false]},
|
||||
"persistent": {"values": [false]}
|
||||
},
|
||||
"k_block_per_cu": 1,
|
||||
"permute_n": false
|
||||
}
|
||||
71
test/ck_tile/gemm_tile_engine/extract_test_params.py
Normal file
71
test/ck_tile/gemm_tile_engine/extract_test_params.py
Normal file
@@ -0,0 +1,71 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import json
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def extract_test_params(config_file, output_file):
|
||||
"""Extract test parameters from config JSON and write to output file"""
|
||||
|
||||
# Read config file
|
||||
with open(config_file, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
# Extract test parameters
|
||||
test_params = []
|
||||
if "test_params" in config and "problem_sizes" in config["test_params"]:
|
||||
test_params = config["test_params"]["problem_sizes"]
|
||||
else:
|
||||
# Default test parameters if none specified
|
||||
test_params = [
|
||||
{"m": 256, "n": 256, "k": 128, "split_k": 1},
|
||||
{"m": 256, "n": 256, "k": 1024, "split_k": 1},
|
||||
{"m": 256, "n": 512, "k": 512, "split_k": 1},
|
||||
{"m": 512, "n": 256, "k": 512, "split_k": 1},
|
||||
]
|
||||
|
||||
# Write to output file in C++ format
|
||||
output_dir = Path(output_file).parent
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_file, "w") as f:
|
||||
f.write("// Generated test parameters for this configuration\n")
|
||||
f.write("// This file is auto-generated during CMake configuration\n\n")
|
||||
f.write("static const std::vector<GemmTestParams> CONFIG_TEST_PARAMS = {\n")
|
||||
|
||||
for i, params in enumerate(test_params):
|
||||
comma = "," if i < len(test_params) - 1 else ""
|
||||
f.write(
|
||||
f" {{{params['m']}, {params['n']}, {params['k']}, {params['split_k']}}}{comma}\n"
|
||||
)
|
||||
|
||||
f.write("};\n")
|
||||
|
||||
print(
|
||||
f"Extracted {len(test_params)} test parameters from {config_file} -> {output_file}"
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Extract test parameters from config JSON"
|
||||
)
|
||||
parser.add_argument("--config_file", required=True, help="Input config JSON file")
|
||||
parser.add_argument(
|
||||
"--output_file", required=True, help="Output test parameters file"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(args.config_file):
|
||||
print(f"Error: Config file not found: {args.config_file}")
|
||||
return 1
|
||||
|
||||
extract_test_params(args.config_file, args.output_file)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
@@ -1,8 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// Unit tests for tile_engine generated GEMM kernels
|
||||
// Tests kernel correctness using tile_engine's verification methodology
|
||||
/**
|
||||
* @file test_gemm_simple.cpp
|
||||
* @brief Unit tests for GEMM kernels generated by gemm_instance_builder
|
||||
*
|
||||
* This test includes kernels generated during CMake configuration by
|
||||
* gemm_instance_builder.py and tests them with problem sizes extracted
|
||||
* from the corresponding JSON configuration files.
|
||||
*/
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <iostream>
|
||||
@@ -68,6 +74,11 @@ struct GemmTestParams
|
||||
int m, n, k, split_k;
|
||||
};
|
||||
|
||||
// Include config-specific test parameters (after GemmTestParams struct is defined)
|
||||
#ifdef GEMM_TEST_PARAMS_HPP
|
||||
#include GEMM_TEST_PARAMS_HPP
|
||||
#endif
|
||||
|
||||
class GemmTileEngineTest : public ::testing::TestWithParam<GemmTestParams>
|
||||
{
|
||||
protected:
|
||||
@@ -185,7 +196,16 @@ TEST_P(GemmTileEngineTest, BasicFunctionality)
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
FAIL() << "Kernel launch failed: " << e.what();
|
||||
std::string error_msg(e.what());
|
||||
// If arguments not supported, skip the test (configuration validation failure, not a bug)
|
||||
if(error_msg.find("Arguments not supported") != std::string::npos)
|
||||
{
|
||||
GTEST_SKIP() << "Configuration not supported: " << e.what();
|
||||
}
|
||||
else
|
||||
{
|
||||
FAIL() << "Kernel launch failed: " << e.what();
|
||||
}
|
||||
}
|
||||
|
||||
// Copy result back from device
|
||||
@@ -208,13 +228,11 @@ TEST_P(GemmTileEngineTest, KernelInfo)
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// Define test parameters for GEMM verification
|
||||
// Use config-specific test parameters (included via compile flags)
|
||||
// CONFIG_TEST_PARAMS is defined in the auto-generated test_params.hpp file
|
||||
INSTANTIATE_TEST_SUITE_P(GemmVerification,
|
||||
GemmTileEngineTest,
|
||||
::testing::Values(GemmTestParams{256, 256, 128, 1},
|
||||
GemmTestParams{256, 256, 1024, 1},
|
||||
GemmTestParams{256, 512, 512, 1},
|
||||
GemmTestParams{512, 256, 512, 1}),
|
||||
::testing::ValuesIn(CONFIG_TEST_PARAMS),
|
||||
[](const ::testing::TestParamInfo<GemmTestParams>& param_info) {
|
||||
return std::to_string(param_info.param.m) + "x" +
|
||||
std::to_string(param_info.param.n) + "x" +
|
||||
|
||||
6
test/gemm_blockscale_wp/CMakeLists.txt
Normal file
6
test/gemm_blockscale_wp/CMakeLists.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
if(GPU_TARGETS MATCHES "gfx9[45]|gfx12")
|
||||
add_gtest_executable(test_gemm_blockscale_wp_xdl_fp8 test_gemm_blockscale_wp_xdl_fp8.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_blockscale_wp_xdl_fp8 PRIVATE utility device_gemm_blockscale_wp_instance)
|
||||
endif()
|
||||
endif()
|
||||
64
test/gemm_blockscale_wp/test_gemm_blockscale_wp_xdl_fp8.cpp
Normal file
64
test/gemm_blockscale_wp/test_gemm_blockscale_wp_xdl_fp8.cpp
Normal file
@@ -0,0 +1,64 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "test_gemm_common.hpp"
|
||||
|
||||
using F8 = ck::f8_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename X, typename Y>
|
||||
struct tuple_concat;
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
struct tuple_concat<std::tuple<Xs...>, std::tuple<Ys...>>
|
||||
{
|
||||
using type = std::tuple<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmBlockScaleWP_FP8_MK_NK : public ck::test::TestGemmBlockscaleWPCommon<
|
||||
typename tuple_concat<std::tuple<Row, Col>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes_MK_NK = ::testing::Types<
|
||||
#if defined(CK_ENABLE_FP8)
|
||||
std::tuple< F8, F32, F8, F32, F8, BF16>
|
||||
#endif
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmBlockScaleWP_FP8_MK_NK, KernelTypes_MK_NK);
|
||||
|
||||
TYPED_TEST(TestGemmBlockScaleWP_FP8_MK_NK, Regular0)
|
||||
{
|
||||
std::vector<int> Ms{128, 256, 512};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 2048;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmBlockScaleWP_FP8_MK_NK, Regular1)
|
||||
{
|
||||
std::vector<int> Ms{128, 256, 512};
|
||||
constexpr int N = 1024;
|
||||
constexpr int K = 4096;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
77
test/gemm_blockscale_wp/test_gemm_common.hpp
Normal file
77
test/gemm_blockscale_wp/test_gemm_common.hpp
Normal file
@@ -0,0 +1,77 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/ck.hpp"
|
||||
#include "profiler/profile_gemm_blockscale_wp_impl.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace test {
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using F32 = float;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmBlockscaleWPCommon : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using CLayout = Row;
|
||||
using A0DataType = std::tuple_element_t<2, Tuple>;
|
||||
using A1DataType = std::tuple_element_t<3, Tuple>;
|
||||
using B0DataType = std::tuple_element_t<4, Tuple>;
|
||||
using B1DataType = std::tuple_element_t<5, Tuple>;
|
||||
using ComputeDataType = std::tuple_element_t<6, Tuple>;
|
||||
using CDataType = std::tuple_element_t<7, Tuple>;
|
||||
|
||||
public:
|
||||
static constexpr bool verify_ = true;
|
||||
static constexpr int init_method_ = 1;
|
||||
static constexpr bool log_ = false;
|
||||
static constexpr bool bench_ = false;
|
||||
static constexpr index_t ScaleBlockM = 1;
|
||||
static constexpr index_t ScaleBlockN = 128;
|
||||
static constexpr index_t ScaleBlockK = 128;
|
||||
|
||||
void Run(const int M, const int N, const int K, int n_warmup = 1, int n_iter = 10)
|
||||
{
|
||||
bool all_success = true;
|
||||
|
||||
int StrideA = std::is_same_v<ALayout, Row> ? K : M;
|
||||
int StrideB = std::is_same_v<BLayout, Row> ? N : K;
|
||||
int StrideC = std::is_same_v<CLayout, Row> ? N : M;
|
||||
|
||||
all_success =
|
||||
all_success &
|
||||
ck::profiler::profile_gemm_blockscale_weightpreshuffle_impl<A0DataType,
|
||||
A1DataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
ComputeDataType,
|
||||
F32,
|
||||
CDataType,
|
||||
ScaleBlockM,
|
||||
ScaleBlockN,
|
||||
ScaleBlockK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(verify_,
|
||||
init_method_,
|
||||
log_,
|
||||
bench_,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
n_warmup,
|
||||
n_iter);
|
||||
|
||||
EXPECT_TRUE(all_success);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace test
|
||||
} // namespace ck
|
||||
6
test/gemm_multiply_multiply_wp/CMakeLists.txt
Normal file
6
test/gemm_multiply_multiply_wp/CMakeLists.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
if(GPU_TARGETS MATCHES "gfx9[45]|gfx12")
|
||||
add_gtest_executable(test_gemm_multiply_multiply_wp_xdl_fp8 test_gemm_multiply_multiply_wp_xdl_fp8.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_multiply_multiply_wp_xdl_fp8 PRIVATE utility device_gemm_multiply_multiply_wp_instance)
|
||||
endif()
|
||||
endif()
|
||||
93
test/gemm_multiply_multiply_wp/test_gemm_common.hpp
Normal file
93
test/gemm_multiply_multiply_wp/test_gemm_common.hpp
Normal file
@@ -0,0 +1,93 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/ck.hpp"
|
||||
#include "profiler/profile_gemm_multiply_multiply_wp_impl.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace test {
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using F32 = float;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmMultiplyMultiplyWPCommon : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using D0Layout = std::tuple_element_t<2, Tuple>;
|
||||
using D1Layout = std::tuple_element_t<3, Tuple>;
|
||||
using ELayout = Row;
|
||||
using ADataType = std::tuple_element_t<4, Tuple>;
|
||||
using BDataType = std::tuple_element_t<5, Tuple>;
|
||||
using ComputeDataType = std::tuple_element_t<6, Tuple>;
|
||||
using D0DataType = std::tuple_element_t<7, Tuple>;
|
||||
using D1DataType = std::tuple_element_t<8, Tuple>;
|
||||
using EDataType = std::tuple_element_t<9, Tuple>;
|
||||
|
||||
public:
|
||||
static constexpr bool verify_ = true;
|
||||
static constexpr int init_method_ = 1; // decimal value initialization
|
||||
static constexpr bool log_ = false;
|
||||
static constexpr bool bench_ = false; // measure kernel performance
|
||||
std::vector<int> k_batches_;
|
||||
|
||||
void SetUp() override { k_batches_ = {1, 2, 4}; }
|
||||
|
||||
void Run(const int M, const int N, const int K)
|
||||
{
|
||||
for(size_t i = 0; i < k_batches_.size(); i++)
|
||||
{
|
||||
RunSingle(M, N, K, k_batches_[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void RunSingle(
|
||||
const int M, const int N, const int K, int kbatch = 1, int n_warmup = 1, int n_iter = 10)
|
||||
{
|
||||
bool all_success = true;
|
||||
|
||||
int StrideA = std::is_same_v<remove_cvref_t<ALayout>, Row> ? K : M;
|
||||
int StrideB = std::is_same_v<remove_cvref_t<BLayout>, Row> ? N : K;
|
||||
int StrideD0 = std::is_same_v<remove_cvref_t<D0Layout>, Row> ? N : M;
|
||||
int StrideD1 = std::is_same_v<remove_cvref_t<D1Layout>, Row> ? N : M;
|
||||
int StrideE = std::is_same_v<ELayout, Row> ? N : M;
|
||||
|
||||
all_success =
|
||||
all_success &
|
||||
ck::profiler::profile_gemm_multiply_multiply_weight_preshuffle_impl<ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
F32,
|
||||
D0DataType,
|
||||
D1DataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
D0Layout,
|
||||
D1Layout,
|
||||
ELayout>(
|
||||
verify_,
|
||||
init_method_,
|
||||
log_,
|
||||
bench_,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideD0,
|
||||
StrideD1,
|
||||
StrideE,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_iter);
|
||||
|
||||
EXPECT_TRUE(all_success);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace test
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,77 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "test_gemm_common.hpp"
|
||||
|
||||
using F8 = ck::f8_t;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename X, typename Y>
|
||||
struct tuple_concat;
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
struct tuple_concat<std::tuple<Xs...>, std::tuple<Ys...>>
|
||||
{
|
||||
using type = std::tuple<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmMultiplyMultiplyWP_FP8_MK_NK
|
||||
: public ck::test::TestGemmMultiplyMultiplyWPCommon<
|
||||
typename tuple_concat<std::tuple<Row, Col, Row, Col>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes_MK_NK = ::testing::Types<
|
||||
#if defined(CK_ENABLE_FP8)
|
||||
std::tuple< F8, F8, F8, F32, F32, F16>,
|
||||
std::tuple< F8, F8, F8, F32, F32, BF16>
|
||||
#endif
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmMultiplyMultiplyWP_FP8_MK_NK, KernelTypes_MK_NK);
|
||||
|
||||
TYPED_TEST(TestGemmMultiplyMultiplyWP_FP8_MK_NK, Regular0)
|
||||
{
|
||||
std::vector<int> Ms{128, 224, 256, 448, 512};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 2048;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmMultiplyMultiplyWP_FP8_MK_NK, Regular1)
|
||||
{
|
||||
std::vector<int> Ms{128, 224, 256, 448, 512};
|
||||
constexpr int N = 1024;
|
||||
constexpr int K = 4096;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmMultiplyMultiplyWP_FP8_MK_NK, Regular2)
|
||||
{
|
||||
std::vector<int> Ms{128, 256, 512};
|
||||
constexpr int N = 448;
|
||||
constexpr int K = 2048;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
6
test/gemm_universal_preshuffle/CMakeLists.txt
Normal file
6
test/gemm_universal_preshuffle/CMakeLists.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
if(GPU_TARGETS MATCHES "gfx9[45]|gfx12")
|
||||
add_gtest_executable(test_gemm_universal_preshuffle_xdl_fp8 test_gemm_universal_preshuffle_xdl_fp8.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_universal_preshuffle_xdl_fp8 PRIVATE utility device_gemm_universal_preshuffle_instance)
|
||||
endif()
|
||||
endif()
|
||||
79
test/gemm_universal_preshuffle/test_gemm_common.hpp
Normal file
79
test/gemm_universal_preshuffle/test_gemm_common.hpp
Normal file
@@ -0,0 +1,79 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/ck.hpp"
|
||||
#include "profiler/profile_gemm_universal_preshuffle_impl.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace test {
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using F32 = float;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversalPreshuffleCommon : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using CLayout = Row;
|
||||
using ADataType = std::tuple_element_t<2, Tuple>;
|
||||
using BDataType = std::tuple_element_t<3, Tuple>;
|
||||
using ComputeDataType = std::tuple_element_t<4, Tuple>;
|
||||
using CDataType = std::tuple_element_t<5, Tuple>;
|
||||
|
||||
public:
|
||||
static constexpr bool verify_ = true;
|
||||
static constexpr int init_method_ = 1;
|
||||
static constexpr bool log_ = false;
|
||||
static constexpr bool bench_ = false;
|
||||
std::vector<int> k_batches_;
|
||||
|
||||
void SetUp() override { k_batches_ = {1, 2, 4}; }
|
||||
|
||||
void Run(const int M, const int N, const int K)
|
||||
{
|
||||
for(size_t i = 0; i < k_batches_.size(); i++)
|
||||
{
|
||||
RunSingle(M, N, K, k_batches_[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void RunSingle(
|
||||
const int M, const int N, const int K, int kbatch = 1, int n_warmup = 1, int n_iter = 10)
|
||||
{
|
||||
bool all_success = true;
|
||||
|
||||
int StrideA = std::is_same_v<ALayout, Row> ? K : M;
|
||||
int StrideB = std::is_same_v<BLayout, Row> ? N : K;
|
||||
int StrideC = std::is_same_v<CLayout, Row> ? N : M;
|
||||
|
||||
all_success = all_success &
|
||||
ck::profiler::profile_gemm_universal_preshuffle_impl<ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
F32,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(verify_,
|
||||
init_method_,
|
||||
log_,
|
||||
bench_,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_iter);
|
||||
|
||||
EXPECT_TRUE(all_success);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace test
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,77 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "test_gemm_common.hpp"
|
||||
|
||||
using F8 = ck::f8_t;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename X, typename Y>
|
||||
struct tuple_concat;
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
struct tuple_concat<std::tuple<Xs...>, std::tuple<Ys...>>
|
||||
{
|
||||
using type = std::tuple<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmUniversalPreshuffle_FP8_MK_NK
|
||||
: public ck::test::TestGemmUniversalPreshuffleCommon<
|
||||
typename tuple_concat<std::tuple<Row, Col>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes_MK_NK = ::testing::Types<
|
||||
#if defined(CK_ENABLE_FP8)
|
||||
std::tuple< F8, F8, F8, F16>,
|
||||
std::tuple< F8, F8, F8, BF16>
|
||||
#endif
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmUniversalPreshuffle_FP8_MK_NK, KernelTypes_MK_NK);
|
||||
|
||||
TYPED_TEST(TestGemmUniversalPreshuffle_FP8_MK_NK, Regular0)
|
||||
{
|
||||
std::vector<int> Ms{128, 224, 256, 448, 512};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 2048;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversalPreshuffle_FP8_MK_NK, Regular1)
|
||||
{
|
||||
std::vector<int> Ms{128, 224, 256, 448, 512};
|
||||
constexpr int N = 1024;
|
||||
constexpr int K = 4096;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmUniversalPreshuffle_FP8_MK_NK, Regular2)
|
||||
{
|
||||
std::vector<int> Ms{128, 256, 512};
|
||||
constexpr int N = 448;
|
||||
constexpr int K = 2048;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
Reference in New Issue
Block a user