mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] Stream-K Gemm Example for fp8 and bf8 (#3041)
* Addition of streamk fp8 example for CK Tile * Adding in bf8 streamk example in CK Tile * Refactoring fp8/bf8 unit tests Refactored the unit tests for fp8/bf8 to utilize the test harness. Implemented smoke tests with layouts: CCR, CRR, RCR, RRR for fp8/bf8. The tests are using 128x128x32 for the tile configuration, as other configurations revealed implementation gaps that are currently being documented.
This commit is contained in:
@@ -1,5 +1,10 @@
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_executable(tile_example_streamk_gemm_basic EXCLUDE_FROM_ALL streamk_gemm_basic.cpp)
|
||||
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
target_compile_options(tile_example_streamk_gemm_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
else()
|
||||
message(DEBUG "Skipping ck_tile streamk gemm tests for current target")
|
||||
endif()
|
||||
|
||||
@@ -28,10 +28,10 @@ args:
|
||||
-stride_b tensor B stride (default:0)
|
||||
-stride_c tensor C stride (default:0)
|
||||
-v validation strategy. 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:1)
|
||||
-prec data type. fp16/bf16 (default:fp16)
|
||||
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
|
||||
-warmup number of iterations before benchmarking the kernel (default:50)
|
||||
-repeat number of iterations to benchmark the kernel (default:100)
|
||||
-timer timing mode. gpu:gpu timer, cpu:cpu timer (default:gpu)
|
||||
-init data initialization strategy. 0:random, 1:linear, 2:constant(1) (default:0)
|
||||
-flush_cache flush the cache before running the kernel (default:true)
|
||||
```
|
||||
```
|
||||
|
||||
@@ -75,6 +75,18 @@ struct DataTypeTraits<ck_tile::bf16_t>
|
||||
static constexpr const char* name = "bf16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::fp8_t>
|
||||
{
|
||||
static constexpr const char* name = "fp8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf8_t>
|
||||
{
|
||||
static constexpr const char* name = "bf8";
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
@@ -94,7 +106,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
.insert("stride_c", "0", "Tensor C stride")
|
||||
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
|
||||
.insert("prec", "fp16", "data type. fp16/bf16")
|
||||
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert("warmup", "50", "number of iterations before benchmarking the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
|
||||
@@ -56,7 +56,7 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
|
||||
GemmUniversalTraits,
|
||||
GemmConfig::Scheduler>;
|
||||
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<UniversalGemmProblem>;
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
@@ -187,6 +187,18 @@ int run_gemm_example(int argc, char* argv[])
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, TypeConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
using TypeConfig = StreamKGemmTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
using TypeConfig = StreamKGemmTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, TypeConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
|
||||
@@ -28,3 +28,4 @@ add_subdirectory(38_block_scale_gemm)
|
||||
add_subdirectory(39_copy)
|
||||
add_subdirectory(40_streamk_gemm)
|
||||
add_subdirectory(41_batched_contraction)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user