Files
composable_kernel/include/ck_tile/utility/json_dump.hpp
Anton Gorenko ec006bb8e0 [CK_TILE] Add gtests for FMHA (#2744)
* Improve random number generation

* use different seed for each input (Q, K, V...);
* use deterministic generation of:
  * seqstart_q/k (for group mode);
  * block_table (for paged-kvcahe);
  * cache_batch_idx (for kvcache);

* Extract arg_parser-related code from run functions to use them as tests

* Split examples into main programs and fmha runners, build instances separately

* Add dummy tests that use instances and runners

* Fix a missed corner case of f32->f8 conversion

When value if < min f8 denormal but > min f8 denormal / 2, it must be
rounded to min f8 denormal (i.e. 0b1), not to 0.

* Fix incorrect fp8 scales for P and O in validation code

DataTypeConfig was incorrectly compared with fp8_t.

* Add host generation of dropout random values and use it for validation

Previously host validation (reference_batched_dropout) used random
numbers generated by BlockDropout of the kernel, meaning that incorrect
generation on device (bad distribution, repeated numbers, too many zeros,
etc.) would not trigger any validation errors.

* Implement tests from smoke_test_bwd.sh

* Return result as enum to distinguish failure and missing instance

* Add tests for bwd features: bias, alibi, dropout

* Implement tests from smoke_test_fwd.sh

* Pass seqlen_q/k as vectors to fwd and bwd runners

* Add tests for fwd features: bias, alibi, dropout

* Add tests for pagedkv and splitkv

* Fix conditions when to use splitkv and pagedkv kernels

splitkv was executed only when use_kvcache which == (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size).
In the SplitKV tests: the regular fwd kernel was executed if use_cache_batch_idx was not requested even when num_splitkv > 1.
In the AppendKV tests: the pagedkv kernel was executed but it often failed to find an instance.

* Add tests for appendkv

* Use is_v_rowmajor = true because there are no instances with column layout anymore

* Split public and private compile options for instances

Tests and examples need to know only about CK_TILE_FMHA_FWD_*_API.

* Improve parsing validation in bias and mask

* Pass bias as string for consistency with mask

* Catch parsing and other exceptions

* Add bwd test for deterministic flag

* Initialize fp8 tensors (-init=ufq) similarly to uf

* Fix splitkv/pagedkv invocation: use padded sk when seqlen_k_ptr is not null

seqlen_k cannot be used to determine padding when seqlen_k_ptr is
provided. The actual seqlen_k is taken from seqlen_k_ptr[b].
Even seqlen_k values (% bn0 == 0) use padded seqlen_k while seqlen_k_ptr
may contain arbitrary values.
In the example or tests this produces incorrect results with appendkv
(for example, -d=32 -s=1 -s_k=64 -s_knew=7 -vlayout=c -b=8).

* Fix use_pagedkv value when kvcache = true but page_block_size = 0

In this case block_table_ptr is nullptr which is accessed in the kernel.

* Clean up bwd tests

* Unify fwd tests for f16/bf16 and fp8

* Use better explicit instantiation declaration for fmha_bwd<2>

* Use the same seed for all tests, allow to override it with env variable

* Undo clang-format of one irrelevant file

For some reason my local clang-format-18 and the one in CI work differently.

* Do not build instances and tests on unsupported archs

* Build instance libraries as OBJECT library

* CI: Enable sccache for HIP

There are source files with LANGUAGE HIP, they need
-DCMAKE_HIP_COMPILER_LAUNCHER=sccache

* Add tests to REGRESSION_TESTS

* Fix OOB accesses in deterministic bwd due to incorrectly assumed kN0

The runner assumes kN0 = (hdim_q <= 128) ? 128 : 64 but there are
smaller tiles (for tr_load or fp32). This can create too small dq_acc_buf.

* Pass CK_TILE_FMHA_FWD_*_API as INTERFACE compile options

The instances don't actually depend on them, only examples and tests do.
Passing these definitions as INTERFACE allows to change FMHA_FWD_ENABLE_APIS
without recompiling instances that are already in ccache.

* Fix formatting and names
2025-09-10 08:06:14 +05:00

701 lines
28 KiB
C++

#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wzero-as-null-pointer-constant"
#include "rapidjson/writer.h"
#include "rapidjson/stringbuffer.h"
#include "rapidjson/document.h"
#include "rapidjson/rapidjson.h"
// #include <fstream>
#pragma GCC diagnostic pop
#define START_JSON_DUMP_FILE(file_name) \
std::string file_str(file_name); \
std::ofstream file(file_str); \
if(!file.is_open()) \
{ \
throw std::runtime_error("Could not open file: " + std::string(file_name)); \
} \
rapidjson::StringBuffer s; \
rapidjson::Writer<rapidjson::StringBuffer> writer(s); \
writer.StartObject();
#define END_JSON_DUMP_FILE() \
writer.EndObject(); \
file << s.GetString(); \
file.close(); \
std::cout << "Results written to " << file_str << " successfully" << std::endl;
#define ADD_KEY_VALUE(key, value) add_key_value_pair(writer, key, value);
#define ADD_PERF_TO_JSON(_time, tflops, gbytes) add_perf_to_json(writer, _time, tflops, gbytes);
template <typename T>
void add_key_value_pair(rapidjson::Writer<rapidjson::StringBuffer>& writer,
const char* key,
T value)
{
writer.Key(key);
if constexpr(std::is_same<T, const char*>::value)
{
writer.String(value, static_cast<rapidjson::SizeType>(std::strlen(value)));
}
else if constexpr(std::is_same<T, std::string>::value)
{
writer.String(value.c_str(), static_cast<rapidjson::SizeType>(value.length()));
}
else if constexpr(std::is_floating_point<T>::value)
{
writer.Double(static_cast<double>(value));
}
else if constexpr(std::is_integral<T>::value)
{
writer.Int64(static_cast<int64_t>(value));
}
else
{
static_assert(std::is_same<T, const char*>::value || std::is_floating_point<T>::value ||
std::is_integral<T>::value,
"Unsupported type for JSON serialization");
}
}
static void add_perf_to_json(rapidjson::Writer<rapidjson::StringBuffer>& writer,
float time,
float tflops,
float gbytes)
{
std::string roster("perf");
writer.String(roster.c_str(), static_cast<rapidjson::SizeType>(roster.length()));
writer.StartArray();
writer.StartObject();
add_key_value_pair(writer, "time", time);
add_key_value_pair(writer, "tflops", tflops);
add_key_value_pair(writer, "gbytes", gbytes);
writer.EndObject();
writer.EndArray();
}
// Helper traits to check for static member existence
template <typename T, typename = void>
struct has_warp_tile_members : std::false_type
{
};
template <typename T>
struct has_warp_tile_members<
T,
std::void_t<decltype(T::M_Warp_Tile), decltype(T::N_Warp_Tile), decltype(T::K_Warp_Tile)>>
: std::true_type
{
};
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename GemmConfig,
template <typename>
typename DTypeTraits>
void dump_gemm_json_results(const std::string& json_filename,
int M,
int N,
int K,
int stride_A,
int stride_B,
int stride_C,
bool persistent,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "gemm_basic")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("M", M);
ADD_KEY_VALUE("N", N);
ADD_KEY_VALUE("K", K);
ADD_KEY_VALUE("stride_A", stride_A);
ADD_KEY_VALUE("stride_B", stride_B);
ADD_KEY_VALUE("stride_C", stride_C);
ADD_KEY_VALUE("A_layout", ALayout::name);
ADD_KEY_VALUE("B_layout", BLayout::name);
ADD_KEY_VALUE("C_layout", CLayout::name);
using TraitsADataType = DTypeTraits<ADataType>;
using TraitsBDataType = DTypeTraits<BDataType>;
using TraitsCDataType = DTypeTraits<CDataType>;
ADD_KEY_VALUE("A_type", TraitsADataType::name);
ADD_KEY_VALUE("B_type", TraitsBDataType::name);
ADD_KEY_VALUE("C_type", TraitsCDataType::name);
ADD_KEY_VALUE("structured_sparsity", GemmConfig::UseStructuredSparsity ? "on" : "off");
if constexpr(has_warp_tile_members<GemmConfig>::value)
{
ADD_KEY_VALUE("warp_tile",
std::to_string(GemmConfig::M_Warp_Tile) + "x" +
std::to_string(GemmConfig::N_Warp_Tile) + "x" +
std::to_string(GemmConfig::K_Warp_Tile));
}
ADD_KEY_VALUE("persistent", persistent ? "on" : "off");
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec);
END_JSON_DUMP_FILE();
}
void dump_batched_gemm_json_results(const std::string& json_filename,
const std::string& op_name,
int M,
int N,
int K,
int stride_A,
int stride_B,
int stride_C,
int batch_stride_A,
int batch_stride_B,
int batch_stride_C,
int batch_count,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "batched_gemm_basic")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("op_name", op_name);
ADD_KEY_VALUE("M", M);
ADD_KEY_VALUE("N", N);
ADD_KEY_VALUE("K", K);
ADD_KEY_VALUE("stride_A", stride_A);
ADD_KEY_VALUE("stride_B", stride_B);
ADD_KEY_VALUE("stride_C", stride_C);
ADD_KEY_VALUE("batch_stride_A", batch_stride_A);
ADD_KEY_VALUE("batch_stride_B", batch_stride_B);
ADD_KEY_VALUE("batch_stride_C", batch_stride_C);
ADD_KEY_VALUE("batch_count", batch_count);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
END_JSON_DUMP_FILE();
}
template <typename ALayout, typename BLayout, typename CLayout>
void dump_grouped_gemm_json_results(const std::string& json_filename,
const std::string& op_name,
int group_count,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "grouped_gemm")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("op_name", op_name);
ADD_KEY_VALUE("group_count", group_count);
ADD_KEY_VALUE("A_layout", ALayout::name);
ADD_KEY_VALUE("B_layout", BLayout::name);
ADD_KEY_VALUE("C_layout", CLayout::name);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
END_JSON_DUMP_FILE();
}
void dump_flatmm_json_results(const std::string& json_filename,
const std::string& datatype,
int M,
int N,
int K,
int stride_A,
int stride_B,
int stride_C,
int kbatch,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "flatmm_basic")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("DataType", datatype);
ADD_KEY_VALUE("M", M);
ADD_KEY_VALUE("N", N);
ADD_KEY_VALUE("K", K);
ADD_KEY_VALUE("StrideA", stride_A);
ADD_KEY_VALUE("StrideB", stride_B);
ADD_KEY_VALUE("StrideC", stride_C);
ADD_KEY_VALUE("kbatch", kbatch);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
END_JSON_DUMP_FILE();
}
void dump_gemm_multi_d_fp16_json_results(const std::string& json_filename,
const std::string& op_name,
int M,
int N,
int K,
int StrideA,
int StrideB,
int StrideD0,
int StrideD1,
int StrideE,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "gemm_multi_d_fp16")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("op_name", op_name);
ADD_KEY_VALUE("M", M);
ADD_KEY_VALUE("N", N);
ADD_KEY_VALUE("K", K);
ADD_KEY_VALUE("StrideA", StrideA);
ADD_KEY_VALUE("StrideB", StrideB);
ADD_KEY_VALUE("StrideD0", StrideD0);
ADD_KEY_VALUE("StrideD1", StrideD1);
ADD_KEY_VALUE("StrideE", StrideE);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
END_JSON_DUMP_FILE();
}
void dump_elementwise_json_results(const std::string& json_filename,
const std::string& prec,
int grid_size,
int block_size,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "elementwise")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("prec", prec);
ADD_KEY_VALUE("grid_size", grid_size);
ADD_KEY_VALUE("block_size", block_size);
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
END_JSON_DUMP_FILE();
}
void dump_layernorm2d_fwd_json_results(const std::string& json_filename,
const std::string& prec_i,
const std::string& prec_o,
const std::string& prec_sm,
const std::string& prec_sy,
int m,
int n,
int x_stride,
int xr_stride,
int y_stride,
int yr_stride,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "layernorm2d_fwd")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("prec_i", prec_i);
ADD_KEY_VALUE("prec_o", prec_o);
ADD_KEY_VALUE("prec_sm", prec_sm);
ADD_KEY_VALUE("prec_sy", prec_sy);
ADD_KEY_VALUE("m", m);
ADD_KEY_VALUE("n", n);
ADD_KEY_VALUE("x_stride", x_stride);
ADD_KEY_VALUE("xr_stride", xr_stride);
ADD_KEY_VALUE("y_stride", y_stride);
ADD_KEY_VALUE("yr_stride", yr_stride);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
END_JSON_DUMP_FILE();
}
template <typename DataType, template <typename> typename DTypeTraits>
void dump_reduce_json_results(const std::string& json_filename,
int N,
int C,
int H,
int W,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "reduce")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
using Traits = DTypeTraits<DataType>;
ADD_KEY_VALUE("data_type", Traits::name);
ADD_KEY_VALUE("N", N);
ADD_KEY_VALUE("C", C);
ADD_KEY_VALUE("H", H);
ADD_KEY_VALUE("W", W);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
END_JSON_DUMP_FILE();
}
void dump_permute_json_results(const std::string& json_filename,
const std::string& data_type,
bool pass,
float ave_time,
float tflop,
float gb_per_sec,
const std::string& kernel_name = "permute")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("data_type", data_type);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflop, gb_per_sec)
END_JSON_DUMP_FILE();
}
void dump_topk_softmax_json(const std::string& json_filename,
const std::string& input_prec,
const std::string& weight_prec,
int tokens,
int experts,
int topk,
int stride_input,
int stride_output,
float ave_time,
float tflop,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "topk_softmax")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("input_prec", input_prec);
ADD_KEY_VALUE("weight_prec", weight_prec);
ADD_KEY_VALUE("tokens", tokens);
ADD_KEY_VALUE("experts", experts);
ADD_KEY_VALUE("topk", topk);
ADD_KEY_VALUE("stride_input", stride_input);
ADD_KEY_VALUE("stride_output", stride_output);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflop, gb_per_sec);
END_JSON_DUMP_FILE();
}
void dump_rmsnorm2d_fwd_json(const std::string& json_filename,
const std::string& prec_str,
int m,
int n,
int x_stride,
int xr_stride,
int y_stride,
int yr_stride,
int use_model_sensitive_rmsnorm,
float ave_time,
float tflops,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "rmsnorm2d_fwd")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("prec", prec_str);
ADD_KEY_VALUE("m", m);
ADD_KEY_VALUE("n", n);
ADD_KEY_VALUE("x_stride", x_stride);
ADD_KEY_VALUE("xr_stride", xr_stride);
ADD_KEY_VALUE("y_stride", y_stride);
ADD_KEY_VALUE("yr_stride", yr_stride);
ADD_KEY_VALUE("use_model_sensitive_rmsnorm", use_model_sensitive_rmsnorm);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec);
END_JSON_DUMP_FILE();
}
void dump_add_rmsnorm2d_rdquant_fwd_json(
const std::string& json_filename,
const std::string& input_data_type,
const std::string& quantized_data_type,
int m,
int n,
int stride,
float epsilon,
float ave_time,
float tflops,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "add_rmsnorm2d_rdquant_fwd")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("input_data_type", input_data_type);
ADD_KEY_VALUE("quantized_data_type", quantized_data_type);
ADD_KEY_VALUE("m", m);
ADD_KEY_VALUE("n", n);
ADD_KEY_VALUE("stride", stride);
ADD_KEY_VALUE("epsilon", epsilon);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec);
END_JSON_DUMP_FILE();
}
void dump_smoothquant_json(const std::string& json_filename,
const std::string& prec_str,
int m,
int n,
int x_stride,
int y_stride,
float ave_time,
float tflops,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "smoothquant")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("prec", prec_str);
ADD_KEY_VALUE("m", m);
ADD_KEY_VALUE("n", n);
ADD_KEY_VALUE("x_stride", x_stride);
ADD_KEY_VALUE("y_stride", y_stride);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec);
END_JSON_DUMP_FILE();
}
void dump_moe_sorting_json(const std::string& json_filename,
const std::string& index_prec,
const std::string& weight_prec,
const std::string& workspace_size,
int dispatch_policy,
int tokens,
int num_experts,
int topk,
float ave_time,
float tflops,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "moe_sorting")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("index_prec", index_prec);
ADD_KEY_VALUE("weight_prec", weight_prec);
ADD_KEY_VALUE("workspace_size", workspace_size);
ADD_KEY_VALUE("dispatch_policy", dispatch_policy);
ADD_KEY_VALUE("tokens", tokens);
ADD_KEY_VALUE("num_experts", num_experts);
ADD_KEY_VALUE("topk", topk);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
END_JSON_DUMP_FILE();
}
void dump_batched_transpose_json(const std::string& json_filename,
int N,
int C,
int H,
int W,
const std::string& layout_in,
const std::string& layout_out,
const std::string& prec,
float ave_time,
float tflops,
float gb_per_sec,
bool pass,
const std::string& kernel_name = "batched_transpose")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("N", N);
ADD_KEY_VALUE("C", C);
ADD_KEY_VALUE("H", H);
ADD_KEY_VALUE("W", W);
ADD_KEY_VALUE("LayoutIn", layout_in);
ADD_KEY_VALUE("LayoutOut", layout_out);
ADD_KEY_VALUE("Precision", prec);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
END_JSON_DUMP_FILE();
}
void dump_moe_smoothquant_json(const std::string& json_filename,
const std::string& prec_i,
const std::string& prec_o,
int tokens,
int hidden_size,
int stride,
int experts,
int topk,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "moe_smoothquant")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("prec_i", prec_i);
ADD_KEY_VALUE("prec_o", prec_o);
ADD_KEY_VALUE("tokens", tokens);
ADD_KEY_VALUE("hidden_size", hidden_size);
ADD_KEY_VALUE("stride", stride);
ADD_KEY_VALUE("experts", experts);
ADD_KEY_VALUE("topk", topk);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
END_JSON_DUMP_FILE();
}
void dump_fused_moe_json(const std::string& json_filename,
const std::string& api_str,
const std::string& prec_str,
int tokens,
bool is_local_token,
int local_tokens,
int experts,
int topk,
int hidden_size,
int intermediate_size,
int stride,
int block_m,
int activation,
bool gate_only,
bool fused_quant,
bool pass,
float ave_time,
float tflops,
float tb_per_sec,
const std::string& kernel_name = "fused_moe")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("api", api_str);
ADD_KEY_VALUE("prec", prec_str);
ADD_KEY_VALUE("tokens", tokens);
if(is_local_token)
{
ADD_KEY_VALUE("local_tokens", local_tokens);
}
ADD_KEY_VALUE("experts", experts);
ADD_KEY_VALUE("topk", topk);
ADD_KEY_VALUE("hidden_size", hidden_size);
ADD_KEY_VALUE("intermediate_size", intermediate_size);
ADD_KEY_VALUE("stride", stride);
ADD_KEY_VALUE("block_m", block_m);
ADD_KEY_VALUE("activation", activation);
ADD_KEY_VALUE("gate_only", gate_only);
ADD_KEY_VALUE("fused_quant", fused_quant);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, (tb_per_sec * 1024.0f))
END_JSON_DUMP_FILE();
}
void dump_fmha_fwd_json_results(const std::string& json_filename,
const std::string& prec,
const std::string& mode,
const std::string& io_layout,
int batch,
int nhead,
int nhead_k,
int seqlen_qs,
int seqlen_ks,
int seqlen_kpads,
int hdim_q,
int hdim_v,
float scale_s,
float p_drop,
bool lse,
bool squant,
const std::string& bias,
const std::string& vlayout,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "fmha_fwd")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("prec", prec);
ADD_KEY_VALUE("mode", mode);
ADD_KEY_VALUE("io_layout", io_layout);
ADD_KEY_VALUE("batch", batch);
ADD_KEY_VALUE("nhead", nhead);
ADD_KEY_VALUE("nhead_k", nhead_k);
ADD_KEY_VALUE("seqlen_q", seqlen_qs);
ADD_KEY_VALUE("seqlen_k", seqlen_ks);
ADD_KEY_VALUE("seqlen_kpads", seqlen_kpads);
ADD_KEY_VALUE("hdim_q", hdim_q);
ADD_KEY_VALUE("hdim_v", hdim_v);
ADD_KEY_VALUE("scale_s", scale_s);
ADD_KEY_VALUE("p_drop", p_drop);
ADD_KEY_VALUE("lse", lse);
ADD_KEY_VALUE("squant", squant);
ADD_KEY_VALUE("bias", bias);
ADD_KEY_VALUE("vlayout", vlayout);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
END_JSON_DUMP_FILE();
}
void dump_fmha_bwd_json_results(const std::string& json_filename,
const std::string& data_type,
const std::string& mode,
const std::string& i_perm,
const std::string& o_perm,
int batch,
int nhead,
int nhead_k,
int seqlen_q,
int seqlen_k,
int hdim_q,
int hdim_v,
float scale,
const std::string& bias,
bool use_dbias,
float p_drop,
bool s_randval,
bool deterministic,
const std::string& mask,
int mask_left,
int mask_right,
int workspace_size,
bool pass,
float ave_time,
float tflops,
float gb_per_sec,
const std::string& kernel_name = "fmha_bwd")
{
START_JSON_DUMP_FILE(json_filename);
ADD_KEY_VALUE("name", kernel_name);
ADD_KEY_VALUE("prec", data_type);
ADD_KEY_VALUE("mode", mode);
ADD_KEY_VALUE("i_perm", i_perm);
ADD_KEY_VALUE("o_perm", o_perm);
ADD_KEY_VALUE("batch", batch);
ADD_KEY_VALUE("nhead", nhead);
ADD_KEY_VALUE("nhead_k", nhead_k);
ADD_KEY_VALUE("seqlen_q", seqlen_q);
ADD_KEY_VALUE("seqlen_k", seqlen_k);
ADD_KEY_VALUE("hdim_q", hdim_q);
ADD_KEY_VALUE("hdim_v", hdim_v);
ADD_KEY_VALUE("scale", scale);
ADD_KEY_VALUE("bias", bias);
ADD_KEY_VALUE("use_dbias", use_dbias);
ADD_KEY_VALUE("p_drop", p_drop);
ADD_KEY_VALUE("s_randval", s_randval);
ADD_KEY_VALUE("deterministic", deterministic ? "true" : "false");
ADD_KEY_VALUE("mask", mask);
ADD_KEY_VALUE("mask_left", mask_left);
ADD_KEY_VALUE("mask_right", mask_right);
ADD_KEY_VALUE("workspace_size", workspace_size);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
END_JSON_DUMP_FILE();
}