mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Addressing (Post Merge) code review comments for PR 1845 (#1883)
* Addressing code review comments.
* Addressing code review comments.
* Reorganized code for better readability.
* add ck_tile gemms for new types in CI
* fix jenkins syntax
* fix script syntax
* Add the test cases back
* Address the review comments
* Address review comments
* clang format
* Solve the merging issues
* Addressed the comments
* clang format
---------
Co-authored-by: illsilin <Illia.Silin@amd.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>
Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
[ROCm/composable_kernel commit: 66c5f5b0b6]
This commit is contained in:
10
Jenkinsfile
vendored
10
Jenkinsfile
vendored
@@ -351,12 +351,12 @@ def cmake_build(Map conf=[:]){
|
||||
}
|
||||
if (params.RUN_CK_TILE_GEMM_TESTS){
|
||||
try{
|
||||
archiveArtifacts "perf_tile_gemm_*.log"
|
||||
archiveArtifacts "perf_tile_gemm_**.log"
|
||||
if (arch_type == 1){
|
||||
stash includes: "perf_tile_gemm_**_fp16_gfx90a.log", name: "perf_tile_gemm_log_gfx90a"
|
||||
stash includes: "perf_tile_gemm_**_gfx90a.log", name: "perf_tile_gemm_log_gfx90a"
|
||||
}
|
||||
else if (arch_type == 2){
|
||||
stash includes: "perf_tile_gemm_**_fp16_gfx942.log", name: "perf_tile_gemm_log_gfx942"
|
||||
stash includes: "perf_tile_gemm_**_gfx942.log", name: "perf_tile_gemm_log_gfx942"
|
||||
}
|
||||
}
|
||||
catch(Exception err){
|
||||
@@ -799,8 +799,8 @@ pipeline {
|
||||
description: "Run the ck_tile FMHA tests (default: OFF)")
|
||||
booleanParam(
|
||||
name: "RUN_CK_TILE_GEMM_TESTS",
|
||||
defaultValue: true,
|
||||
description: "Run the ck_tile GEMM tests (default: ON)")
|
||||
defaultValue: false,
|
||||
description: "Run the ck_tile GEMM tests (default: OFF)")
|
||||
booleanParam(
|
||||
name: "BUILD_INSTANCES_ONLY",
|
||||
defaultValue: false,
|
||||
|
||||
@@ -29,8 +29,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
// This part comes from the Codegen
|
||||
constexpr ck_tile::index_t M_Tile = 128;
|
||||
constexpr ck_tile::index_t N_Tile = 128;
|
||||
constexpr ck_tile::index_t M_Tile = 256;
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
@@ -54,7 +54,9 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
|
||||
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<AccDataType,
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
CodegenPipelineProblem::kBlockSize,
|
||||
@@ -99,45 +101,99 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
template <typename APrecType, typename BPrecType = APrecType, typename CPrecType = APrecType>
|
||||
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
if constexpr(std::is_same_v<BPrecType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout for the input matrices when "
|
||||
"BPrecType is ck_tile::pk_int4_t!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "R")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "R")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout for the input matrices!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int run_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::half_t>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::fp8_t>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data_type!");
|
||||
}
|
||||
return run_gemm_example_prec_type<ck_tile::half_t>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_gemm_example_prec_type<ck_tile::bf16_t>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run_gemm_example_prec_type<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
return run_gemm_example_prec_type<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
else if(data_type == "pk_int4_t")
|
||||
{
|
||||
// TODO: Add support for bhalf_t ADataType
|
||||
return run_gemm_example_prec_type<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
#endif
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -114,7 +114,7 @@ struct GemmTypeConfig<ck_tile::half_t>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::bf16_t>
|
||||
struct GemmTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf16_t;
|
||||
using BDataType = ck_tile::bf16_t;
|
||||
@@ -123,7 +123,7 @@ struct GemmTypeConfig<ck_tile::bf16_t>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::fp8_t>
|
||||
struct GemmTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::fp8_t;
|
||||
using BDataType = ck_tile::fp8_t;
|
||||
@@ -132,7 +132,7 @@ struct GemmTypeConfig<ck_tile::fp8_t>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::bf8_t>
|
||||
struct GemmTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf8_t;
|
||||
using BDataType = ck_tile::bf8_t;
|
||||
|
||||
0
example/ck_tile/03_gemm/script/benchmark_basic_bf16.sh
Normal file → Executable file
0
example/ck_tile/03_gemm/script/benchmark_basic_bf16.sh
Normal file → Executable file
0
example/ck_tile/03_gemm/script/benchmark_basic_bf8.sh
Normal file → Executable file
0
example/ck_tile/03_gemm/script/benchmark_basic_bf8.sh
Normal file → Executable file
0
example/ck_tile/03_gemm/script/benchmark_basic_fp8.sh
Normal file → Executable file
0
example/ck_tile/03_gemm/script/benchmark_basic_fp8.sh
Normal file → Executable file
0
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf16.sh
Normal file → Executable file
0
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf16.sh
Normal file → Executable file
0
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf8.sh
Normal file → Executable file
0
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf8.sh
Normal file → Executable file
0
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_fp8.sh
Normal file → Executable file
0
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_fp8.sh
Normal file → Executable file
@@ -32,14 +32,11 @@ function print_log_header(){
|
||||
}
|
||||
|
||||
# run verification tests
|
||||
example/ck_tile/03_gemm/script/smoke_test_basic.sh
|
||||
example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh
|
||||
|
||||
# run performance benchmarks
|
||||
export gemm_basic_log="perf_tile_gemm_basic_fp16_$GPU_arch.log"
|
||||
print_log_header $gemm_basic_log $env_type $branch $host_name
|
||||
example/ck_tile/03_gemm/script/benchmark_basic.sh 2>&1 | tee -a $gemm_basic_log
|
||||
|
||||
export gemm_mem_pipeline_log="perf_tile_gemm_mem_pipeline_fp16_$GPU_arch.log"
|
||||
print_log_header $gemm_mem_pipeline_log $env_type $branch $host_name
|
||||
example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh 2>&1 | tee -a $gemm_mem_pipeline_log
|
||||
for dtype in fp16 bf16 fp8 bf8; do
|
||||
export gemm_log="perf_tile_gemm_mem_pipeline_${dtype}_${GPU_arch}.log"
|
||||
print_log_header $gemm_log $env_type $branch $host_name
|
||||
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_$dtype.sh 2>&1 | tee -a $gemm_log
|
||||
done
|
||||
|
||||
@@ -76,7 +76,9 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
|
||||
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<AccDataType,
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
@@ -121,6 +123,16 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Even)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::ostringstream err;
|
||||
@@ -205,11 +217,29 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
}
|
||||
else
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Num K loop must be larger than number of prefetech stages."
|
||||
<< "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages << "\n File: " << __FILE__
|
||||
<< ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
Run(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
Run(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Even)
|
||||
{
|
||||
Run(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Num K loop must be larger than number of prefetech stages."
|
||||
<< "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages
|
||||
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
@@ -217,133 +247,113 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
template <typename APrecType, typename BPrecType = APrecType, typename CPrecType = APrecType>
|
||||
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
if constexpr(std::is_same_v<BPrecType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout for the input matrices when "
|
||||
"BPrecType is ck_tile::pk_int4_t!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "R")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "R")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout for the input matrices!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int run_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
if(a_layout == "R" && b_layout == "R")
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::half_t>(argc, argv, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::fp8_t>(argc, argv, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Row{}, Row{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data_type!");
|
||||
}
|
||||
return run_gemm_example_prec_type<ck_tile::half_t>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(a_layout == "R" && b_layout == "C")
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::half_t>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::fp8_t>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
return run_gemm_example_prec_type<ck_tile::bf16_t>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run_gemm_example_prec_type<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
return run_gemm_example_prec_type<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
else if(data_type == "pk_int4_t")
|
||||
{
|
||||
// TODO: Add support for bhalf_t ADataType
|
||||
return run_gemm_example_with_layouts<ck_tile::half_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
#endif
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data_type!");
|
||||
}
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
else if(data_type == "pk_int4_t")
|
||||
{
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::half_t>(argc, argv, Col{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Col{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::fp8_t>(argc, argv, Col{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Col{}, Col{}, Row{});
|
||||
}
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
else if(data_type == "pk_int4_t")
|
||||
{
|
||||
// TODO: Add support for bhalf_t ADataType
|
||||
return run_gemm_example_with_layouts<ck_tile::half_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t>(argc, argv, Col{}, Col{}, Row{});
|
||||
}
|
||||
// TODO: Add support for bhalf_t ADataType
|
||||
return run_gemm_example_prec_type<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
#endif
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data_type!");
|
||||
}
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "R")
|
||||
{
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::half_t>(argc, argv, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::fp8_t>(argc, argv, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
return run_gemm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Col{}, Row{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data_type!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
try
|
||||
{
|
||||
run_gemm_example(argc, argv);
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Caught runtime error: " << e.what() << '\n';
|
||||
// Return a non-zero code to indicate failure
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
|
||||
@@ -361,7 +361,7 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
|
||||
{
|
||||
if constexpr(N == 2)
|
||||
{
|
||||
atomic_add(c_style_pointer_cast<bf16x2_t*>(p_dst), bit_cast<bf16x2_t>(x));
|
||||
atomic_add(c_style_pointer_cast<bf16x2_t*>(p_dst), x.template get_as<bf16x2_t>()[I0]);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
|
||||
@@ -523,7 +523,7 @@ CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x)
|
||||
int exponent = (x & 0x7F) >> SrcT_mant;
|
||||
if constexpr(is_fnuz)
|
||||
{
|
||||
if(x == 0x80)
|
||||
if((x & 0xff) == 0x80)
|
||||
{
|
||||
return fNaN;
|
||||
}
|
||||
|
||||
@@ -9,7 +9,9 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename AccDataType_,
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename AccDataType_,
|
||||
typename ODataType_,
|
||||
typename CLayout_,
|
||||
index_t kBlockSize_,
|
||||
@@ -23,6 +25,8 @@ template <typename AccDataType_,
|
||||
bool isCTransposed_>
|
||||
struct CShuffleEpilogueProblem
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using CLayout = remove_cvref_t<CLayout_>;
|
||||
@@ -40,9 +44,13 @@ struct CShuffleEpilogueProblem
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct CShuffleEpilogue
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using BTypeToUse =
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ODataType, BDataType>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t kMPerBlock = Problem::kMPerBlock;
|
||||
@@ -56,8 +64,8 @@ struct CShuffleEpilogue
|
||||
static constexpr index_t kMPerIteration = kMPerXdl * kMWave;
|
||||
static constexpr index_t kNPerIteration = kNPerXdl * kNWave;
|
||||
|
||||
using WG = WarpGemmMfmaDispatcher<ODataType,
|
||||
ODataType,
|
||||
using WG = WarpGemmMfmaDispatcher<ADataType,
|
||||
BTypeToUse,
|
||||
AccDataType,
|
||||
kMPerXdl,
|
||||
kNPerXdl,
|
||||
@@ -77,7 +85,6 @@ struct CShuffleEpilogue
|
||||
*
|
||||
* @return The vector store size for C tensor.
|
||||
*/
|
||||
template <typename ODataType>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
|
||||
{
|
||||
constexpr index_t MaxVectorStoreSize = 16;
|
||||
@@ -143,7 +150,7 @@ struct CShuffleEpilogue
|
||||
TileDistributionEncodingPattern2D<kBlockSize,
|
||||
kMPerIteration,
|
||||
kNPerIteration,
|
||||
GetVectorSizeC<ODataType>(),
|
||||
GetVectorSizeC(),
|
||||
tile_distribution_pattern::thread_raked>;
|
||||
constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
|
||||
|
||||
@@ -167,7 +167,7 @@ struct GemmKernel
|
||||
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs)
|
||||
{
|
||||
if constexpr(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
|
||||
if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<CDataType, fp16_t, bf16_t>::value)
|
||||
{
|
||||
if(kargs.k_batch != 1)
|
||||
@@ -275,7 +275,7 @@ struct GemmKernel
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if(kargs.N % EpiloguePipeline::template GetVectorSizeC<CDataType>() != 0)
|
||||
if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
@@ -295,7 +295,7 @@ struct GemmKernel
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if(kargs.M % EpiloguePipeline::template GetVectorSizeC<CDataType>() != 0)
|
||||
if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
@@ -407,7 +407,7 @@ struct GemmKernel
|
||||
c_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.stride_C, 1),
|
||||
number<EpiloguePipeline::template GetVectorSizeC<CDataType>()>{},
|
||||
number<EpiloguePipeline::GetVectorSizeC()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
@@ -671,7 +671,7 @@ struct GemmKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(!(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
|
||||
if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<CDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm2LDS<memory_operation_enum::atomic_add>(a_ptr,
|
||||
@@ -694,7 +694,7 @@ struct GemmKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(!(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
|
||||
if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<CDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm<memory_operation_enum::atomic_add>(
|
||||
|
||||
@@ -33,8 +33,21 @@ struct BaseGemmPipelineAgBgCrCompV3
|
||||
|
||||
CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
ignore = num_loop;
|
||||
return TailNumber::Full;
|
||||
if(BlockHasHotloop(num_loop))
|
||||
{
|
||||
return TailNumber::Full;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(num_loop == 1)
|
||||
{
|
||||
return TailNumber::Odd;
|
||||
}
|
||||
else
|
||||
{
|
||||
return TailNumber::Even;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -470,6 +483,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
@@ -478,12 +492,43 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Full)
|
||||
if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd))
|
||||
{
|
||||
// Leak last MFMA block to epilogue region, cover the potential lds-shuffle
|
||||
// latency
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_sync_lds();
|
||||
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
|
||||
}
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
|
||||
}
|
||||
block_sync_lds();
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
}
|
||||
// Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
|
||||
// latency
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
return c_block_tile;
|
||||
}
|
||||
|
||||
@@ -143,7 +143,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
constexpr index_t A_LDS_Read_Inst_Num =
|
||||
WaveNumN * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
|
||||
constexpr index_t B_LDS_Read_Inst_Num =
|
||||
WaveNumM * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
|
||||
WaveNumM * NPerBlock * KPerBlock / (BlockSize * KPerXDL);
|
||||
|
||||
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
|
||||
(BlockSize / WaveSize) /
|
||||
@@ -442,6 +442,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
Base::LocalPrefill(
|
||||
b_copy_lds_window1, b_global_load_tile, b_element_func);
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
Base::GlobalPrefetch(
|
||||
a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
@@ -26,6 +26,10 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
using I2 = number<2>;
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
@@ -81,11 +85,21 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
std::is_same_v<BDataType, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
constexpr bool is_a_col_major = std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
static_assert(is_a_col_major
|
||||
? (kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
: (kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"A block window has incorrect lengths for defined ALayout!");
|
||||
static_assert(is_b_row_major
|
||||
? (kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
: (kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"B block window has incorrect lengths for defined BLayout!");
|
||||
// A tile in LDS
|
||||
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
|
||||
|
||||
|
||||
@@ -344,6 +344,30 @@ def main():
|
||||
for i in range(1, len(results)+1):
|
||||
testlist.append("Test%i"%i)
|
||||
table_name="ck_tile_gemm_mem_pipeline_fp16_tflops"
|
||||
if 'gemm_basic_bf16' in filename:
|
||||
for i in range(1, len(results)+1):
|
||||
testlist.append("Test%i"%i)
|
||||
table_name="ck_tile_gemm_basic_bf16_tflops"
|
||||
if 'gemm_mem_pipeline_bf16' in filename:
|
||||
for i in range(1, len(results)+1):
|
||||
testlist.append("Test%i"%i)
|
||||
table_name="ck_tile_gemm_mem_pipeline_bf16_tflops"
|
||||
if 'gemm_basic_fp8' in filename:
|
||||
for i in range(1, len(results)+1):
|
||||
testlist.append("Test%i"%i)
|
||||
table_name="ck_tile_gemm_basic_fp8_tflops"
|
||||
if 'gemm_mem_pipeline_fp8' in filename:
|
||||
for i in range(1, len(results)+1):
|
||||
testlist.append("Test%i"%i)
|
||||
table_name="ck_tile_gemm_mem_pipeline_fp8_tflops"
|
||||
if 'gemm_basic_bf8' in filename:
|
||||
for i in range(1, len(results)+1):
|
||||
testlist.append("Test%i"%i)
|
||||
table_name="ck_tile_gemm_basic_bf8_tflops"
|
||||
if 'gemm_mem_pipeline_bf8' in filename:
|
||||
for i in range(1, len(results)+1):
|
||||
testlist.append("Test%i"%i)
|
||||
table_name="ck_tile_gemm_mem_pipeline_bf8_tflops"
|
||||
|
||||
tflops_base = get_baseline(table_name,conn)
|
||||
store_new_test_result(table_name, results, testlist, branch_name, node_id, gpu_arch, compute_units, rocm_vers, hip_vers, environment, sqlEngine)
|
||||
|
||||
@@ -43,19 +43,12 @@ file=./perf_fmha_bwd_gfx90a.log
|
||||
if [ -e "$file" ]; then
|
||||
python3 process_perf_data.py perf_fmha_bwd_gfx90a.log
|
||||
fi
|
||||
file=./perf_tile_gemm_basic_fp16_gfx942.log
|
||||
if [ -e "$file" ]; then
|
||||
python3 process_perf_data.py perf_tile_gemm_basic_fp16_gfx942.log
|
||||
fi
|
||||
file=./perf_tile_gemm_basic_fp16_gfx90a.log
|
||||
if [ -e "$file" ]; then
|
||||
python3 process_perf_data.py perf_tile_gemm_basic_fp16_gfx90a.log
|
||||
fi
|
||||
file=./perf_tile_gemm_mem_pipeline_fp16_gfx942.log
|
||||
if [ -e "$file" ]; then
|
||||
python3 process_perf_data.py perf_tile_gemm_mem_pipeline_fp16_gfx942.log
|
||||
fi
|
||||
file=./perf_tile_gemm_mem_pipeline_fp16_gfx90a.log
|
||||
if [ -e "$file" ]; then
|
||||
python3 process_perf_data.py perf_tile_gemm_mem_pipeline_fp16_gfx90a.log
|
||||
fi
|
||||
|
||||
for gpu in "gfx90a" "gfx942"; do
|
||||
for dtype in "fp16" "bf16" "fp8" "bf8"; do
|
||||
file=./perf_tile_gemm_mem_pipeline_${dtype}_${gpu}.log
|
||||
if [ -e "$file" ]; then
|
||||
python3 process_perf_data.py perf_tile_gemm_mem_pipeline_${dtype}_${gpu}.log
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
||||
@@ -52,19 +52,12 @@ file=./perf_fmha_bwd_gfx90a.log
|
||||
if [ -e "$file" ]; then
|
||||
python3 process_perf_data.py perf_fmha_bwd_gfx90a.log
|
||||
fi
|
||||
file=./perf_gemm_basic_gfx942.log
|
||||
if [ -e "$file" ]; then
|
||||
python3 process_perf_data.py perf_gemm_basic_gfx942.log
|
||||
fi
|
||||
file=./perf_gemm_basic_gfx90a.log
|
||||
if [ -e "$file" ]; then
|
||||
python3 process_perf_data.py perf_gemm_basic_gfx90a.log
|
||||
fi
|
||||
file=./perf_gemm_mem_pipeline_gfx942.log
|
||||
if [ -e "$file" ]; then
|
||||
python3 process_perf_data.py perf_gemm_mem_pipeline_gfx942.log
|
||||
fi
|
||||
file=./perf_gemm_mem_pipeline_gfx90a.log
|
||||
if [ -e "$file" ]; then
|
||||
python3 process_perf_data.py perf_gemm_mem_pipeline_gfx90a.log
|
||||
fi
|
||||
|
||||
for gpu in "gfx90a" "gfx942"; do
|
||||
for dtype in "fp16" "bf16" "fp8" "bf8"; do
|
||||
file=./perf_tile_gemm_mem_pipeline_${dtype}_${gpu}.log
|
||||
if [ -e "$file" ]; then
|
||||
python3 process_perf_data.py perf_tile_gemm_mem_pipeline_${dtype}_${gpu}.log
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
||||
@@ -67,7 +67,9 @@ class TestCkTileBatchedGemm : public ::testing::Test
|
||||
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<AccDataType,
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
CodegenGemmPipeline::BlockSize,
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
# Currently ck_tile is only built on gfx9
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline test_gemm_pipeline.cpp)
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_mem test_gemm_pipeline_mem.cpp)
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_compv3 test_gemm_pipeline_compv3.cpp)
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_compv4 test_gemm_pipeline_compv4.cpp)
|
||||
endif()
|
||||
|
||||
16
test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp
Normal file
16
test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp
Normal file
@@ -0,0 +1,16 @@
|
||||
#include "test_gemm_pipeline_kernel_types.hpp"
|
||||
#include "test_gemm_pipeline_util.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename T>
|
||||
class TestCkTileGemmPipelineCompV3 : public TestCkTileGemmPipeline<T>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileGemmPipelineCompV3
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGemmPipelineCompV3, KernelTypesMem);
|
||||
|
||||
#include "test_gemm_pipeline_ut_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
16
test/ck_tile/gemm/test_gemm_pipeline_compv4.cpp
Normal file
16
test/ck_tile/gemm/test_gemm_pipeline_compv4.cpp
Normal file
@@ -0,0 +1,16 @@
|
||||
#include "test_gemm_pipeline_kernel_types.hpp"
|
||||
#include "test_gemm_pipeline_util.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename T>
|
||||
class TestCkTileGemmPipelineCompV4 : public TestCkTileGemmPipeline<T>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileGemmPipelineCompV4
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGemmPipelineCompV4, KernelTypesMem);
|
||||
|
||||
#include "test_gemm_pipeline_ut_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -10,6 +10,7 @@
|
||||
|
||||
using F16 = ck_tile::half_t;
|
||||
using F32 = float;
|
||||
using F8 = ck_tile::fp8_t;
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
|
||||
@@ -21,27 +22,41 @@ using CompV3 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType:
|
||||
using CompV4 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV4>;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType
|
||||
using KernelTypesMem = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, CompV4>,
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, Interwave, Mem>,
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV4>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, Interwave, Mem>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, CompV4>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, Interwave, Mem>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, Intrawave, Mem>,
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, Interwave, Mem>
|
||||
>;
|
||||
|
||||
using KernelTypesCompV3 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Row, Row, F8, F8, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Row, Col, Row, F8, F8, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Row, Row, F8, F8, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV3>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV4>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem>
|
||||
>;
|
||||
std::tuple< Col, Col, Row, F8, F8, F32, F16, Intrawave, CompV3>
|
||||
>;
|
||||
|
||||
using KernelTypesCompV4 = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, CompV4>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, CompV4>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, CompV4>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, CompV4>
|
||||
>;
|
||||
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGemmPipeline, KernelTypes);
|
||||
|
||||
#include "test_gemm_pipeline_ut_cases.inc"
|
||||
16
test/ck_tile/gemm/test_gemm_pipeline_mem.cpp
Normal file
16
test/ck_tile/gemm/test_gemm_pipeline_mem.cpp
Normal file
@@ -0,0 +1,16 @@
|
||||
#include "test_gemm_pipeline_kernel_types.hpp"
|
||||
#include "test_gemm_pipeline_util.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename T>
|
||||
class TestCkTileGemmPipelineMem : public TestCkTileGemmPipeline<T>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileGemmPipelineMem
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGemmPipelineMem, KernelTypesMem);
|
||||
|
||||
#include "test_gemm_pipeline_ut_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -3,7 +3,10 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TestCkTileGemmPipeline, SmallM)
|
||||
#ifndef TEST_GEMM_PIPELINE_UT_CASES_INC
|
||||
#define TEST_GEMM_PIPELINE_UT_CASES_INC
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, SmallM)
|
||||
{
|
||||
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
|
||||
constexpr int N = 1024;
|
||||
@@ -13,18 +16,25 @@ TYPED_TEST(TestCkTileGemmPipeline, SmallM)
|
||||
{
|
||||
if constexpr(std::is_same_v<typename TestFixture::ALayout,
|
||||
ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
EXPECT_THROW((this->Run(M, N, K)), std::runtime_error);
|
||||
}
|
||||
else
|
||||
{
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmPipeline, MidLargeM)
|
||||
TYPED_TEST(TEST_SUITE_NAME, MidLargeM)
|
||||
{
|
||||
std::vector<int> Ms{127, 255, 312, 799, 1573};
|
||||
constexpr int N = 1024;
|
||||
constexpr int K = 320;
|
||||
constexpr int VecLoadSize = 8;
|
||||
constexpr int VecLoadSize = (std::is_same_v<typename TestFixture::ADataType, ck_tile::fp8_t> ||
|
||||
std::is_same_v<typename TestFixture::ADataType, ck_tile::bf8_t>)
|
||||
? 16
|
||||
: 8;
|
||||
|
||||
for(int M : Ms)
|
||||
{
|
||||
@@ -33,9 +43,13 @@ TYPED_TEST(TestCkTileGemmPipeline, MidLargeM)
|
||||
{
|
||||
// TODO: Can we anyhow deduce used vector load size?
|
||||
if(M % VecLoadSize == 0)
|
||||
{
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
else
|
||||
{
|
||||
EXPECT_THROW((this->Run(M, N, K)), std::runtime_error);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -44,7 +58,7 @@ TYPED_TEST(TestCkTileGemmPipeline, MidLargeM)
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmPipeline, PaddK)
|
||||
TYPED_TEST(TEST_SUITE_NAME, PaddK)
|
||||
{
|
||||
std::vector<int> Ms{128};
|
||||
constexpr int N = 1024;
|
||||
@@ -54,7 +68,7 @@ TYPED_TEST(TestCkTileGemmPipeline, PaddK)
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmPipeline, Regular)
|
||||
TYPED_TEST(TEST_SUITE_NAME, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 1024;
|
||||
@@ -64,7 +78,16 @@ TYPED_TEST(TestCkTileGemmPipeline, Regular)
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGemmPipeline, NotSupportedArgument)
|
||||
TYPED_TEST(TEST_SUITE_NAME, LargeMatrix)
|
||||
{
|
||||
constexpr int M = 2048;
|
||||
constexpr int N = 2048;
|
||||
constexpr int K = 2048;
|
||||
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, NotSupportedArgument)
|
||||
{
|
||||
constexpr int M = 512;
|
||||
constexpr int N = 1025;
|
||||
@@ -76,3 +99,5 @@ TYPED_TEST(TestCkTileGemmPipeline, NotSupportedArgument)
|
||||
|
||||
EXPECT_THROW((this->template Run<PadM, PadN, PadK>(M, N, K)), std::runtime_error);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -11,6 +11,27 @@
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
enum struct GemmPipelineType
|
||||
{
|
||||
Mem,
|
||||
@@ -63,7 +84,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
// TODO: This should be parameterized in tests
|
||||
constexpr ck_tile::index_t M_Tile = 256;
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Tile = (PipelineType == GemmPipelineType::CompV4) ? 32 : 64;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
@@ -71,8 +92,6 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
// TODO: Restore to 8. At now after changes in block_universal_gemm_as_bs_cr it return wrong
|
||||
// values.
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
constexpr bool kPadM = PadM;
|
||||
@@ -136,7 +155,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
typename GemmPipelineTypeSelector<PipelineType, UniversalGemmProblem>::pipeline;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<AccDataType,
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
GemmPipeline::BlockSize,
|
||||
@@ -420,7 +441,18 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k, b_k_n, c_m_n_host_ref);
|
||||
|
||||
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -79,6 +79,8 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
|
||||
Reference in New Issue
Block a user