mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Add json dump support to output details from CK/CKTile Examples. (#2551)
* Adding RapidJson Library * Adding Json Dumps in all CK_Tile Examples Not verified yet * Adding json to cktile Batched Transpose * adding json dumps to layernorm2d_fwd * Adding json dump to flatmm_basic * Adding RapidJson Library * Adding Json Dumps in all CK_Tile Examples Not verified yet * Adding json to cktile Batched Transpose * adding json dumps to layernorm2d_fwd * Adding json dump to flatmm_basic * Adding json in 03_gemm * Add json dump to 16_batched_gemm * Add json dump to gemm_multi_d_fp16 * Add json dump to grouped_gemm * fix fmha_bwd/fwd * Fix clang-format errors exclude include/rapidjson in jenkins as its a third-party library * Saparating function and defination. * Update Documentation of 03_gemm * Refactoring as per code review * Disable fp8 instances on unsupported targets (#2592) * Restrict building of gemm_universal_preshuffle_f8 instances to specific targets in CMakeLists.txt * Add condition to skip gemm_xdl_universal_preshuffle_f8 instances for unsupported targets in CMakeLists.txt * Add conditions to skip unsupported targets for gemm_universal_preshuffle_f8 and gemm_xdl_universal_preshuffle_f8 instances in CMakeLists.txt * Refine conditions to exclude gemm_universal_preshuffle_f8 instances for unsupported targets in CMakeLists.txt --------- Co-authored-by: AviralGoelAMD <aviralgoel@amd.com> * fix clang format * remove duplicate lines of code from library/src/tensor_operation_instance/gpu/CMakeLists.txt * Fixing Readme and unifying jsondumps * adding moe_smoothquant * adding fused_moe * Fixing Readme for batched_gemm * Fixing Readme for grouped_gemm * adding flatmm * adding gemm_multi_d_fp16 * adding elementwise * adding File name when json is dumped * Fixing Reduce after merge * adding batched_transpose * Adding Warptile in Gemm * Fixing Clang Format --------- Co-authored-by: Aviral Goel <aviral.goel@amd.com> Co-authored-by: AviralGoelAMD <aviralgoel@amd.com> Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
include_directories(BEFORE
|
||||
${PROJECT_SOURCE_DIR}/include
|
||||
${PROJECT_SOURCE_DIR}/library/include
|
||||
${PROJECT_SOURCE_DIR}/example/include
|
||||
)
|
||||
|
||||
add_custom_target(examples)
|
||||
|
||||
@@ -74,6 +74,8 @@ args:
|
||||
-num_splits number of splits for key/value. 0 to determine actual number by heuristic (default:1)
|
||||
-warmup number of iterations before benchmark the kernel (default:5)
|
||||
-repeat number of iterations to benchmark the kernel (default:20)
|
||||
-json 0: No Json, 1: Dump Results in Json format (default:0)
|
||||
-jsonfile json file name to dump results (default:fmha_fwd.json)
|
||||
```
|
||||
Example 1: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case.
|
||||
Example 2: `./bin/tile_example_fmha_fwd -b=1 -h=8 -s=16384 -d=64 -drop_prefs=1 -drop_seed=10 -drop_offset=1234` will run a fmha case with
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "mask.hpp"
|
||||
#include "utils.hpp"
|
||||
#include "json_dump.hpp"
|
||||
|
||||
#include <array>
|
||||
#include <cstring>
|
||||
@@ -94,7 +95,9 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("deterministic",
|
||||
"0",
|
||||
"if set to 1 will use multi-buffer reduction strategy for dq, atomic opeartion "
|
||||
"will not be used");
|
||||
"will not be used")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "fmha_bwd.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -584,53 +587,54 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
<< std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec
|
||||
<< " GB/s" << std::flush;
|
||||
|
||||
bool pass = true;
|
||||
if(!do_validation)
|
||||
{
|
||||
std::cout << std::flush << std::endl;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
|
||||
std::vector<ck_tile::HostTensor<QDataType>> q_host_refs;
|
||||
std::vector<ck_tile::HostTensor<KDataType>> k_host_refs;
|
||||
std::vector<ck_tile::HostTensor<VDataType>> v_host_refs;
|
||||
std::vector<ck_tile::HostTensor<ODataType>> o_host_refs;
|
||||
std::vector<ck_tile::HostTensor<RandValOutputDataType>> randval_host_refs;
|
||||
std::vector<ck_tile::HostTensor<AccDataType>> p_hp_host_refs;
|
||||
std::vector<ck_tile::HostTensor<GemmDataType>> p_lp_host_refs;
|
||||
|
||||
randval_buf.FromDevice(randval_host.data());
|
||||
|
||||
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
|
||||
else
|
||||
{
|
||||
const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
|
||||
const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
|
||||
std::vector<ck_tile::HostTensor<QDataType>> q_host_refs;
|
||||
std::vector<ck_tile::HostTensor<KDataType>> k_host_refs;
|
||||
std::vector<ck_tile::HostTensor<VDataType>> v_host_refs;
|
||||
std::vector<ck_tile::HostTensor<ODataType>> o_host_refs;
|
||||
std::vector<ck_tile::HostTensor<RandValOutputDataType>> randval_host_refs;
|
||||
std::vector<ck_tile::HostTensor<AccDataType>> p_hp_host_refs;
|
||||
std::vector<ck_tile::HostTensor<GemmDataType>> p_lp_host_refs;
|
||||
|
||||
// adjust matrix index according to the mode
|
||||
const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0);
|
||||
const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
|
||||
const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]);
|
||||
randval_buf.FromDevice(randval_host.data());
|
||||
|
||||
ck_tile::HostTensor<QDataType> q_host_ref({nhead, real_seqlen_q, hdim_q}); // q_g_m_k
|
||||
ck_tile::HostTensor<KDataType> k_host_ref({nhead, real_seqlen_k, hdim_q}); // k_g_n_k
|
||||
ck_tile::HostTensor<VDataType> v_host_ref({nhead, hdim_v, real_seqlen_k}); // v_g_o_n
|
||||
ck_tile::HostTensor<ODataType> o_host_ref({nhead, real_seqlen_q, hdim_v}); // o_g_m_o
|
||||
ck_tile::HostTensor<LSEDataType> lse_host_ref({nhead, real_seqlen_q}); // lse_g_m
|
||||
ck_tile::HostTensor<RandValOutputDataType> randval_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // randval_g_m_n
|
||||
ck_tile::HostTensor<AccDataType> s_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // s_g_m_n
|
||||
ck_tile::HostTensor<AccDataType> p_hp_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // p_hp_g_m_n high precision
|
||||
ck_tile::HostTensor<AccDataType> p_dropped_hp_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // p_dropped_hp_g_m_n high precision
|
||||
ck_tile::HostTensor<GemmDataType> p_lp_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // p_lp_g_m_n low precision
|
||||
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
|
||||
{
|
||||
const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
|
||||
const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
|
||||
|
||||
ck_tile::index_t nr = nhead / nhead_k;
|
||||
// adjust matrix index according to the mode
|
||||
const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0);
|
||||
const ck_tile::index_t query_offset =
|
||||
(mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
|
||||
const ck_tile::index_t key_offset =
|
||||
(mode == mode_enum::batch ? 0 : seqstart_k_host[wb]);
|
||||
|
||||
// clang-format off
|
||||
ck_tile::HostTensor<QDataType> q_host_ref({nhead, real_seqlen_q, hdim_q}); // q_g_m_k
|
||||
ck_tile::HostTensor<KDataType> k_host_ref({nhead, real_seqlen_k, hdim_q}); // k_g_n_k
|
||||
ck_tile::HostTensor<VDataType> v_host_ref({nhead, hdim_v, real_seqlen_k}); // v_g_o_n
|
||||
ck_tile::HostTensor<ODataType> o_host_ref({nhead, real_seqlen_q, hdim_v}); // o_g_m_o
|
||||
ck_tile::HostTensor<LSEDataType> lse_host_ref({nhead, real_seqlen_q}); // lse_g_m
|
||||
ck_tile::HostTensor<RandValOutputDataType> randval_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // randval_g_m_n
|
||||
ck_tile::HostTensor<AccDataType> s_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // s_g_m_n
|
||||
ck_tile::HostTensor<AccDataType> p_hp_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // p_hp_g_m_n high precision
|
||||
ck_tile::HostTensor<AccDataType> p_dropped_hp_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // p_dropped_hp_g_m_n high precision
|
||||
ck_tile::HostTensor<GemmDataType> p_lp_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // p_lp_g_m_n low precision
|
||||
|
||||
ck_tile::index_t nr = nhead / nhead_k;
|
||||
|
||||
// clang-format off
|
||||
// permute
|
||||
if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[0], i[1] + query_offset, i[2]); });
|
||||
else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[1] + query_offset, i[0], i[2]); });
|
||||
@@ -642,281 +646,294 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[2] + key_offset, i[1]); });
|
||||
// v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d]
|
||||
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[2] + key_offset, i[0] / nr, i[1]); });
|
||||
// clang-format on
|
||||
|
||||
// reference
|
||||
// S = scale * Q * K^T
|
||||
ck_tile::reference_batched_gemm<QDataType, KDataType, AccDataType, AccDataType>(
|
||||
q_host_ref,
|
||||
k_host_ref,
|
||||
s_host_ref,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales(scale)); // s_g_m_n = scale * q_g_m_k@k_g_n_k
|
||||
|
||||
if(bias.type == bias_enum::elementwise_bias)
|
||||
{
|
||||
// elementwise bias
|
||||
ck_tile::HostTensor<BiasDataType> bias_host_ref({1, real_seqlen_q, real_seqlen_k});
|
||||
// clang-format off
|
||||
if(i_perm)
|
||||
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2]); });
|
||||
else
|
||||
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2]); });
|
||||
// clang-format on
|
||||
|
||||
// broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q,
|
||||
// real_seqlen_k]
|
||||
ck_tile::
|
||||
reference_batched_elementwise<AccDataType, BiasDataType, AccDataType, AccDataType>(
|
||||
s_host_ref, bias_host_ref, s_host_ref);
|
||||
}
|
||||
else if(bias.type == bias_enum::alibi)
|
||||
{
|
||||
// alibi construct elementwise bias to verify
|
||||
auto alibi_host = [&]() {
|
||||
if(mask.type != mask_enum::no_mask)
|
||||
{
|
||||
return ck_tile::make_alibi_from_lr_mask<AccDataType, false>(
|
||||
0,
|
||||
mask.left,
|
||||
mask.right,
|
||||
real_seqlen_q,
|
||||
real_seqlen_k,
|
||||
static_cast<ck_tile::GenericAttentionMaskEnum>(mask.type));
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::Alibi<AccDataType, false>{
|
||||
0, real_seqlen_q, real_seqlen_k, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT};
|
||||
}
|
||||
}();
|
||||
// reference
|
||||
// S = scale * Q * K^T
|
||||
ck_tile::reference_batched_gemm<QDataType, KDataType, AccDataType, AccDataType>(
|
||||
q_host_ref,
|
||||
k_host_ref,
|
||||
s_host_ref,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales(scale)); // s_g_m_n = scale * q_g_m_k@k_g_n_k
|
||||
|
||||
ck_tile::HostTensor<AccDataType> alibi_bias_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k});
|
||||
auto i_b_slope = bias.rank_info == 0 ? 0 : wb;
|
||||
for(auto i_h = 0; i_h < nhead; i_h++)
|
||||
if(bias.type == bias_enum::elementwise_bias)
|
||||
{
|
||||
AccDataType current_slope = alibi_slope_host(i_b_slope, i_h);
|
||||
alibi_host.slope = alibi_host.mode == ck_tile::AlibiMode::VERTICAL ? current_slope
|
||||
: -current_slope;
|
||||
for(auto i_r = 0; i_r < real_seqlen_q; i_r++)
|
||||
{
|
||||
for(auto i_c = 0; i_c < real_seqlen_k; i_c++)
|
||||
// elementwise bias
|
||||
ck_tile::HostTensor<BiasDataType> bias_host_ref({1, real_seqlen_q, real_seqlen_k});
|
||||
// clang-format off
|
||||
if(i_perm)
|
||||
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2]); });
|
||||
else
|
||||
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2]); });
|
||||
// clang-format on
|
||||
|
||||
// broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q,
|
||||
// real_seqlen_k]
|
||||
ck_tile::reference_batched_elementwise<AccDataType,
|
||||
BiasDataType,
|
||||
AccDataType,
|
||||
AccDataType>(
|
||||
s_host_ref, bias_host_ref, s_host_ref);
|
||||
}
|
||||
else if(bias.type == bias_enum::alibi)
|
||||
{
|
||||
// alibi construct elementwise bias to verify
|
||||
auto alibi_host = [&]() {
|
||||
if(mask.type != mask_enum::no_mask)
|
||||
{
|
||||
AccDataType pixel = 0;
|
||||
alibi_host.update(pixel, i_r, i_c);
|
||||
alibi_bias_host_ref(i_h, i_r, i_c) = pixel;
|
||||
return ck_tile::make_alibi_from_lr_mask<AccDataType, false>(
|
||||
0,
|
||||
mask.left,
|
||||
mask.right,
|
||||
real_seqlen_q,
|
||||
real_seqlen_k,
|
||||
static_cast<ck_tile::GenericAttentionMaskEnum>(mask.type));
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::Alibi<AccDataType, false>{
|
||||
0, real_seqlen_q, real_seqlen_k, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT};
|
||||
}
|
||||
}();
|
||||
|
||||
ck_tile::HostTensor<AccDataType> alibi_bias_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k});
|
||||
auto i_b_slope = bias.rank_info == 0 ? 0 : wb;
|
||||
for(auto i_h = 0; i_h < nhead; i_h++)
|
||||
{
|
||||
AccDataType current_slope = alibi_slope_host(i_b_slope, i_h);
|
||||
alibi_host.slope = alibi_host.mode == ck_tile::AlibiMode::VERTICAL
|
||||
? current_slope
|
||||
: -current_slope;
|
||||
for(auto i_r = 0; i_r < real_seqlen_q; i_r++)
|
||||
{
|
||||
for(auto i_c = 0; i_c < real_seqlen_k; i_c++)
|
||||
{
|
||||
AccDataType pixel = 0;
|
||||
alibi_host.update(pixel, i_r, i_c);
|
||||
alibi_bias_host_ref(i_h, i_r, i_c) = pixel;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// [nhead, real_seqlen_q, real_seqlen_k]
|
||||
ck_tile::
|
||||
reference_batched_elementwise<AccDataType, AccDataType, AccDataType, AccDataType>(
|
||||
// [nhead, real_seqlen_q, real_seqlen_k]
|
||||
ck_tile::reference_batched_elementwise<AccDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AccDataType>(
|
||||
s_host_ref, alibi_bias_host_ref, s_host_ref);
|
||||
}
|
||||
}
|
||||
|
||||
if(mask.type == mask_enum::no_mask)
|
||||
{
|
||||
ck_tile::reference_batched_masking<AccDataType>(
|
||||
s_host_ref, FmhaMasks::NoMask{real_seqlen_q, real_seqlen_k});
|
||||
}
|
||||
else if(mask.type == mask_enum::window_generic)
|
||||
{
|
||||
ck_tile::reference_batched_masking<AccDataType>(
|
||||
s_host_ref,
|
||||
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
|
||||
mask.left, mask.right, real_seqlen_q, real_seqlen_k));
|
||||
}
|
||||
else
|
||||
{
|
||||
// if left window size is negative, means causal
|
||||
// else means generic (for current batch)
|
||||
if(mask.left < 0)
|
||||
if(mask.type == mask_enum::no_mask)
|
||||
{
|
||||
ck_tile::reference_batched_masking<AccDataType>(
|
||||
s_host_ref,
|
||||
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
|
||||
mask.left,
|
||||
mask.right,
|
||||
real_seqlen_q,
|
||||
real_seqlen_k,
|
||||
mask.type == mask_enum::mask_top_left));
|
||||
else
|
||||
s_host_ref, FmhaMasks::NoMask{real_seqlen_q, real_seqlen_k});
|
||||
}
|
||||
else if(mask.type == mask_enum::window_generic)
|
||||
{
|
||||
ck_tile::reference_batched_masking<AccDataType>(
|
||||
s_host_ref,
|
||||
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
|
||||
mask.left,
|
||||
mask.right,
|
||||
real_seqlen_q,
|
||||
real_seqlen_k,
|
||||
mask.type == mask_enum::mask_top_left));
|
||||
}
|
||||
ck_tile::reference_batched_softmax<AccDataType, LSEDataType, AccDataType>(
|
||||
s_host_ref, p_hp_host_ref, ck_tile::identity{}, lse_host_ref);
|
||||
mask.left, mask.right, real_seqlen_q, real_seqlen_k));
|
||||
}
|
||||
else
|
||||
{
|
||||
// if left window size is negative, means causal
|
||||
// else means generic (for current batch)
|
||||
if(mask.left < 0)
|
||||
ck_tile::reference_batched_masking<AccDataType>(
|
||||
s_host_ref,
|
||||
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
|
||||
mask.left,
|
||||
mask.right,
|
||||
real_seqlen_q,
|
||||
real_seqlen_k,
|
||||
mask.type == mask_enum::mask_top_left));
|
||||
else
|
||||
ck_tile::reference_batched_masking<AccDataType>(
|
||||
s_host_ref,
|
||||
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
|
||||
mask.left,
|
||||
mask.right,
|
||||
real_seqlen_q,
|
||||
real_seqlen_k,
|
||||
mask.type == mask_enum::mask_top_left));
|
||||
}
|
||||
ck_tile::reference_batched_softmax<AccDataType, LSEDataType, AccDataType>(
|
||||
s_host_ref, p_hp_host_ref, ck_tile::identity{}, lse_host_ref);
|
||||
|
||||
if(p_drop > 0)
|
||||
{
|
||||
p_dropped_hp_host_ref = p_hp_host_ref;
|
||||
randval_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]);
|
||||
});
|
||||
ck_tile::reference_batched_dropout(
|
||||
p_dropped_hp_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop);
|
||||
p_lp_host_ref = p_dropped_hp_host_ref.template CopyAsType<GemmDataType>();
|
||||
}
|
||||
else
|
||||
{
|
||||
p_lp_host_ref = p_hp_host_ref.template CopyAsType<GemmDataType>();
|
||||
}
|
||||
if(p_drop > 0)
|
||||
{
|
||||
p_dropped_hp_host_ref = p_hp_host_ref;
|
||||
randval_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]);
|
||||
});
|
||||
ck_tile::reference_batched_dropout(
|
||||
p_dropped_hp_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop);
|
||||
p_lp_host_ref = p_dropped_hp_host_ref.template CopyAsType<GemmDataType>();
|
||||
}
|
||||
else
|
||||
{
|
||||
p_lp_host_ref = p_hp_host_ref.template CopyAsType<GemmDataType>();
|
||||
}
|
||||
|
||||
// O = P * V
|
||||
ck_tile::reference_batched_gemm<GemmDataType, VDataType, AccDataType, ODataType>(
|
||||
p_lp_host_ref, v_host_ref, o_host_ref); // o_g_m_o = p_lp_g_m_n@v_g_o_n
|
||||
// O = P * V
|
||||
ck_tile::reference_batched_gemm<GemmDataType, VDataType, AccDataType, ODataType>(
|
||||
p_lp_host_ref, v_host_ref, o_host_ref); // o_g_m_o = p_lp_g_m_n@v_g_o_n
|
||||
|
||||
// clang-format off
|
||||
// clang-format off
|
||||
// permute
|
||||
if(o_perm) o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[0], idx[1] + query_offset, idx[2]) = self(idx); });
|
||||
else o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[1] + query_offset, idx[0], idx[2]) = self(idx); });
|
||||
|
||||
lse_host_ref.ForEach([&](auto& self, auto idx) { lse_host(b, idx[0], idx[1] + query_offset) = self(idx); });
|
||||
// clang-format on
|
||||
// clang-format on
|
||||
|
||||
q_host_refs.push_back(q_host_ref);
|
||||
k_host_refs.push_back(k_host_ref);
|
||||
v_host_refs.push_back(v_host_ref);
|
||||
o_host_refs.push_back(o_host_ref);
|
||||
p_hp_host_refs.push_back(p_hp_host_ref);
|
||||
p_lp_host_refs.push_back(p_lp_host_ref);
|
||||
if(p_drop > 0)
|
||||
{
|
||||
randval_host_refs.push_back(randval_host_ref);
|
||||
q_host_refs.push_back(q_host_ref);
|
||||
k_host_refs.push_back(k_host_ref);
|
||||
v_host_refs.push_back(v_host_ref);
|
||||
o_host_refs.push_back(o_host_ref);
|
||||
p_hp_host_refs.push_back(p_hp_host_ref);
|
||||
p_lp_host_refs.push_back(p_lp_host_ref);
|
||||
if(p_drop > 0)
|
||||
{
|
||||
randval_host_refs.push_back(randval_host_ref);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// set to bad values to check if the kernel writes to these buffers
|
||||
ck_tile::FillConstant<QGradDataType>{ck_tile::numeric<QGradDataType>::infinity()}(dq_host);
|
||||
ck_tile::FillConstant<KGradDataType>{ck_tile::numeric<KGradDataType>::infinity()}(dk_host);
|
||||
ck_tile::FillConstant<VGradDataType>{ck_tile::numeric<VGradDataType>::infinity()}(dv_host);
|
||||
dq_buf.ToDevice(dq_host.data());
|
||||
dk_buf.ToDevice(dk_host.data());
|
||||
dv_buf.ToDevice(dv_host.data());
|
||||
// set to bad values to check if the kernel writes to these buffers
|
||||
ck_tile::FillConstant<QGradDataType>{ck_tile::numeric<QGradDataType>::infinity()}(dq_host);
|
||||
ck_tile::FillConstant<KGradDataType>{ck_tile::numeric<KGradDataType>::infinity()}(dk_host);
|
||||
ck_tile::FillConstant<VGradDataType>{ck_tile::numeric<VGradDataType>::infinity()}(dv_host);
|
||||
dq_buf.ToDevice(dq_host.data());
|
||||
dk_buf.ToDevice(dk_host.data());
|
||||
dv_buf.ToDevice(dv_host.data());
|
||||
|
||||
o_buf.ToDevice(o_host.data());
|
||||
lse_buf.ToDevice(lse_host.data());
|
||||
dq_buf.SetZero();
|
||||
dbias_buf.SetZero();
|
||||
dq_acc_buf.SetZero();
|
||||
o_buf.ToDevice(o_host.data());
|
||||
lse_buf.ToDevice(lse_host.data());
|
||||
dq_buf.SetZero();
|
||||
dbias_buf.SetZero();
|
||||
dq_acc_buf.SetZero();
|
||||
|
||||
ck_tile::stream_config stream_config_v{
|
||||
nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")};
|
||||
fmha_bwd(fmha_traits, fmha_args, stream_config_v);
|
||||
ck_tile::stream_config stream_config_v{
|
||||
nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")};
|
||||
fmha_bwd(fmha_traits, fmha_args, stream_config_v);
|
||||
|
||||
dq_buf.FromDevice(dq_host.data());
|
||||
dk_buf.FromDevice(dk_host.data());
|
||||
dv_buf.FromDevice(dv_host.data());
|
||||
dbias_buf.FromDevice(dbias_host.data());
|
||||
dq_buf.FromDevice(dq_host.data());
|
||||
dk_buf.FromDevice(dk_host.data());
|
||||
dv_buf.FromDevice(dv_host.data());
|
||||
dbias_buf.FromDevice(dbias_host.data());
|
||||
|
||||
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
|
||||
{
|
||||
const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
|
||||
const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
|
||||
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
|
||||
{
|
||||
const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
|
||||
const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
|
||||
|
||||
// adjust matrix index according to the mode
|
||||
const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0);
|
||||
const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
|
||||
const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]);
|
||||
// adjust matrix index according to the mode
|
||||
const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0);
|
||||
const ck_tile::index_t query_offset =
|
||||
(mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
|
||||
const ck_tile::index_t key_offset =
|
||||
(mode == mode_enum::batch ? 0 : seqstart_k_host[wb]);
|
||||
|
||||
ck_tile::HostTensor<OGradDataType> do_host_ref({nhead, real_seqlen_q, hdim_v}); // do_g_m_o
|
||||
ck_tile::HostTensor<AccDataType> ds_hp_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // ds_g_m_n high precision
|
||||
ck_tile::HostTensor<GemmDataType> ds_lp_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // ds_g_m_n low precision
|
||||
ck_tile::HostTensor<AccDataType> dp_hp_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // dp_g_m_n high precision
|
||||
ck_tile::HostTensor<BiasGradDataType> dbias_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // dbias_g_m_n
|
||||
ck_tile::HostTensor<QGradDataType> dq_host_ref({nhead, real_seqlen_q, hdim_q}); // dq_g_m_k
|
||||
ck_tile::HostTensor<KGradDataType> dk_host_ref({nhead, real_seqlen_k, hdim_q}); // dk_g_n_k
|
||||
ck_tile::HostTensor<VGradDataType> dv_host_ref({nhead, real_seqlen_k, hdim_v}); // dv_g_n_o
|
||||
ck_tile::HostTensor<OGradDataType> do_host_ref(
|
||||
{nhead, real_seqlen_q, hdim_v}); // do_g_m_o
|
||||
ck_tile::HostTensor<AccDataType> ds_hp_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // ds_g_m_n high precision
|
||||
ck_tile::HostTensor<GemmDataType> ds_lp_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // ds_g_m_n low precision
|
||||
ck_tile::HostTensor<AccDataType> dp_hp_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // dp_g_m_n high precision
|
||||
ck_tile::HostTensor<BiasGradDataType> dbias_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // dbias_g_m_n
|
||||
ck_tile::HostTensor<QGradDataType> dq_host_ref(
|
||||
{nhead, real_seqlen_q, hdim_q}); // dq_g_m_k
|
||||
ck_tile::HostTensor<KGradDataType> dk_host_ref(
|
||||
{nhead, real_seqlen_k, hdim_q}); // dk_g_n_k
|
||||
ck_tile::HostTensor<VGradDataType> dv_host_ref(
|
||||
{nhead, real_seqlen_k, hdim_v}); // dv_g_n_o
|
||||
|
||||
// clang-format off
|
||||
// clang-format off
|
||||
if(o_perm) do_host_ref.ForEach([&](auto& self, auto i) { self(i) = do_host(b, i[0], i[1] + query_offset, i[2]); });
|
||||
else do_host_ref.ForEach([&](auto& self, auto i) { self(i) = do_host(b, i[1] + query_offset, i[0], i[2]); });
|
||||
// clang-format on
|
||||
// clang-format on
|
||||
|
||||
// dP = dO@V x Z w/ dropout
|
||||
// dP = dO@V w/o dropout
|
||||
auto v_t_host_ref = v_host_refs[wb].transpose({0, 2, 1}); // v_g_o_n -> v_g_n_o
|
||||
ck_tile::reference_batched_gemm<OGradDataType, VDataType, AccDataType, AccDataType>(
|
||||
do_host_ref, v_t_host_ref, dp_hp_host_ref); // dp_g_m_n = do_g_m_o@v_g_n_o
|
||||
// dP = dO@V x Z w/ dropout
|
||||
// dP = dO@V w/o dropout
|
||||
auto v_t_host_ref = v_host_refs[wb].transpose({0, 2, 1}); // v_g_o_n -> v_g_n_o
|
||||
ck_tile::reference_batched_gemm<OGradDataType, VDataType, AccDataType, AccDataType>(
|
||||
do_host_ref, v_t_host_ref, dp_hp_host_ref); // dp_g_m_n = do_g_m_o@v_g_n_o
|
||||
|
||||
if(p_drop > 0)
|
||||
{
|
||||
ck_tile::reference_batched_dropout(
|
||||
dp_hp_host_ref, randval_host_refs[wb], p_undrop_in_uint8_t, rp_undrop);
|
||||
}
|
||||
if(p_drop > 0)
|
||||
{
|
||||
ck_tile::reference_batched_dropout(
|
||||
dp_hp_host_ref, randval_host_refs[wb], p_undrop_in_uint8_t, rp_undrop);
|
||||
}
|
||||
|
||||
// dS_i_j = P_i_j .* (dP_i_j - dO_i dot O_i)
|
||||
ck_tile::make_ParallelTensorFunctor(
|
||||
[&](auto i0, auto i1, auto i2) {
|
||||
AccDataType do_dot_o = 0;
|
||||
for(int o = 0; o < hdim_v; o++)
|
||||
{
|
||||
do_dot_o += ck_tile::type_convert<AccDataType>(do_host_ref(i0, i1, o)) *
|
||||
ck_tile::type_convert<AccDataType>(o_host_refs[wb](i0, i1, o));
|
||||
}
|
||||
ds_hp_host_ref(i0, i1, i2) = ck_tile::type_convert<AccDataType>(
|
||||
p_hp_host_refs[wb](i0, i1, i2) * (dp_hp_host_ref(i0, i1, i2) - do_dot_o));
|
||||
},
|
||||
ds_hp_host_ref.mDesc.get_lengths()[0],
|
||||
ds_hp_host_ref.mDesc.get_lengths()[1],
|
||||
ds_hp_host_ref.mDesc.get_lengths()[2])(std::thread::hardware_concurrency());
|
||||
// dS_i_j = P_i_j .* (dP_i_j - dO_i dot O_i)
|
||||
ck_tile::make_ParallelTensorFunctor(
|
||||
[&](auto i0, auto i1, auto i2) {
|
||||
AccDataType do_dot_o = 0;
|
||||
for(int o = 0; o < hdim_v; o++)
|
||||
{
|
||||
do_dot_o += ck_tile::type_convert<AccDataType>(do_host_ref(i0, i1, o)) *
|
||||
ck_tile::type_convert<AccDataType>(o_host_refs[wb](i0, i1, o));
|
||||
}
|
||||
ds_hp_host_ref(i0, i1, i2) = ck_tile::type_convert<AccDataType>(
|
||||
p_hp_host_refs[wb](i0, i1, i2) * (dp_hp_host_ref(i0, i1, i2) - do_dot_o));
|
||||
},
|
||||
ds_hp_host_ref.mDesc.get_lengths()[0],
|
||||
ds_hp_host_ref.mDesc.get_lengths()[1],
|
||||
ds_hp_host_ref.mDesc.get_lengths()[2])(std::thread::hardware_concurrency());
|
||||
|
||||
if(use_dbias)
|
||||
{
|
||||
dbias_host_ref = ds_hp_host_ref.template CopyAsType<BiasGradDataType>();
|
||||
}
|
||||
if(use_dbias)
|
||||
{
|
||||
dbias_host_ref = ds_hp_host_ref.template CopyAsType<BiasGradDataType>();
|
||||
}
|
||||
|
||||
ds_lp_host_ref = ds_hp_host_ref.template CopyAsType<GemmDataType>();
|
||||
ds_lp_host_ref = ds_hp_host_ref.template CopyAsType<GemmDataType>();
|
||||
|
||||
// dV = P_drop^T@dO^T
|
||||
// dV = P^T@dO^T w/o dropout
|
||||
auto p_t_lp_host_ref = p_lp_host_refs[wb].transpose({0, 2, 1}); // p_lp_g_m_n -> p_lp_g_n_m
|
||||
auto do_t_host_ref = do_host_ref.transpose({0, 2, 1}); // do_g_m_o -> do_g_o_m
|
||||
ck_tile::reference_batched_gemm<GemmDataType, OGradDataType, AccDataType, VGradDataType>(
|
||||
p_t_lp_host_ref, do_t_host_ref, dv_host_ref); // dv_g_n_o = p_lp_g_n_m@do_g_o_m
|
||||
// dV = P_drop^T@dO^T
|
||||
// dV = P^T@dO^T w/o dropout
|
||||
auto p_t_lp_host_ref =
|
||||
p_lp_host_refs[wb].transpose({0, 2, 1}); // p_lp_g_m_n -> p_lp_g_n_m
|
||||
auto do_t_host_ref = do_host_ref.transpose({0, 2, 1}); // do_g_m_o -> do_g_o_m
|
||||
ck_tile::
|
||||
reference_batched_gemm<GemmDataType, OGradDataType, AccDataType, VGradDataType>(
|
||||
p_t_lp_host_ref, do_t_host_ref, dv_host_ref); // dv_g_n_o = p_lp_g_n_m@do_g_o_m
|
||||
|
||||
// dQ = scale * dS@K^T
|
||||
auto k_t_host_ref = k_host_refs[wb].transpose({0, 2, 1}); // k_g_n_k -> k_g_k_n
|
||||
ck_tile::reference_batched_gemm<GemmDataType, KDataType, AccDataType, QGradDataType>(
|
||||
ds_lp_host_ref,
|
||||
k_t_host_ref,
|
||||
dq_host_ref,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales(scale)); // dq_g_m_k = ds_g_m_n@k_g_k_n
|
||||
// dQ = scale * dS@K^T
|
||||
auto k_t_host_ref = k_host_refs[wb].transpose({0, 2, 1}); // k_g_n_k -> k_g_k_n
|
||||
ck_tile::reference_batched_gemm<GemmDataType, KDataType, AccDataType, QGradDataType>(
|
||||
ds_lp_host_ref,
|
||||
k_t_host_ref,
|
||||
dq_host_ref,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales(scale)); // dq_g_m_k = ds_g_m_n@k_g_k_n
|
||||
|
||||
// dK = scale * dS^T@Q^T
|
||||
auto ds_t_lp_host_ref = ds_lp_host_ref.transpose({0, 2, 1}); // ds_g_m_n -> ds_g_n_m
|
||||
auto q_t_host_ref = q_host_refs[wb].transpose({0, 2, 1}); // q_g_m_k -> q_g_k_m
|
||||
ck_tile::reference_batched_gemm<GemmDataType, QDataType, AccDataType, KGradDataType>(
|
||||
ds_t_lp_host_ref,
|
||||
q_t_host_ref,
|
||||
dk_host_ref,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales(scale)); // dk_g_n_k = ds_g_n_m@q_g_k_m
|
||||
// dK = scale * dS^T@Q^T
|
||||
auto ds_t_lp_host_ref = ds_lp_host_ref.transpose({0, 2, 1}); // ds_g_m_n -> ds_g_n_m
|
||||
auto q_t_host_ref = q_host_refs[wb].transpose({0, 2, 1}); // q_g_m_k -> q_g_k_m
|
||||
ck_tile::reference_batched_gemm<GemmDataType, QDataType, AccDataType, KGradDataType>(
|
||||
ds_t_lp_host_ref,
|
||||
q_t_host_ref,
|
||||
dk_host_ref,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales(scale)); // dk_g_n_k = ds_g_n_m@q_g_k_m
|
||||
|
||||
ck_tile::HostTensor<QGradDataType> dq_host_result(
|
||||
{nhead, real_seqlen_q, hdim_q}); // dq_g_m_k
|
||||
ck_tile::HostTensor<KGradDataType> dk_host_result(
|
||||
{nhead, real_seqlen_k, hdim_q}); // dk_g_n_k
|
||||
ck_tile::HostTensor<VGradDataType> dv_host_result(
|
||||
{nhead, real_seqlen_k, hdim_v}); // dv_g_n_o
|
||||
ck_tile::HostTensor<BiasGradDataType> dbias_host_result(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // dbias_g_m_n
|
||||
ck_tile::HostTensor<QGradDataType> dq_host_result(
|
||||
{nhead, real_seqlen_q, hdim_q}); // dq_g_m_k
|
||||
ck_tile::HostTensor<KGradDataType> dk_host_result(
|
||||
{nhead, real_seqlen_k, hdim_q}); // dk_g_n_k
|
||||
ck_tile::HostTensor<VGradDataType> dv_host_result(
|
||||
{nhead, real_seqlen_k, hdim_v}); // dv_g_n_o
|
||||
ck_tile::HostTensor<BiasGradDataType> dbias_host_result(
|
||||
{nhead, real_seqlen_q, real_seqlen_k}); // dbias_g_m_n
|
||||
|
||||
// clang-format off
|
||||
// clang-format off
|
||||
// permute
|
||||
if(i_perm) dq_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dq_host(b, idx[0], idx[1] + query_offset, idx[2]); });
|
||||
else dq_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dq_host(b, idx[1] + query_offset, idx[0], idx[2]); });
|
||||
@@ -932,49 +949,90 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(i_perm) dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[0], idx[1] + query_offset, idx[2]); });
|
||||
else dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[1] + query_offset, idx[0], idx[2]); });
|
||||
}
|
||||
// clang-format on
|
||||
// clang-format on
|
||||
|
||||
auto [rtol, atol] = get_elimit<DataTypeConfig>(hdim_q, hdim_v);
|
||||
bool dq_cur_pass = ck_tile::check_err(dq_host_result,
|
||||
dq_host_ref,
|
||||
std::string("Error: QGrad Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
bool dk_cur_pass = ck_tile::check_err(dk_host_result,
|
||||
dk_host_ref,
|
||||
std::string("Error: KGrad Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
bool dv_cur_pass = ck_tile::check_err(dv_host_result,
|
||||
dv_host_ref,
|
||||
std::string("Error: VGrad Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
auto [rtol, atol] = get_elimit<DataTypeConfig>(hdim_q, hdim_v);
|
||||
bool dq_cur_pass = ck_tile::check_err(dq_host_result,
|
||||
dq_host_ref,
|
||||
std::string("Error: QGrad Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
bool dk_cur_pass = ck_tile::check_err(dk_host_result,
|
||||
dk_host_ref,
|
||||
std::string("Error: KGrad Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
bool dv_cur_pass = ck_tile::check_err(dv_host_result,
|
||||
dv_host_ref,
|
||||
std::string("Error: VGrad Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
|
||||
bool dbias_cur_pass = true;
|
||||
if(use_dbias)
|
||||
{
|
||||
dbias_cur_pass = ck_tile::check_err(dbias_host_result,
|
||||
dbias_host_ref,
|
||||
std::string("Error: BiasGrad Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
bool dbias_cur_pass = true;
|
||||
if(use_dbias)
|
||||
{
|
||||
dbias_cur_pass =
|
||||
ck_tile::check_err(dbias_host_result,
|
||||
dbias_host_ref,
|
||||
std::string("Error: BiasGrad Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
}
|
||||
pass &= (dq_cur_pass & dk_cur_pass & dv_cur_pass & dbias_cur_pass);
|
||||
if(!(dq_cur_pass & dk_cur_pass & dv_cur_pass & dbias_cur_pass))
|
||||
{
|
||||
std::cerr << "mismatch found at batch: " << wb << std::endl
|
||||
<< "\tseqlen_q: " << real_seqlen_q << std::endl
|
||||
<< "\tseqlen_k: " << real_seqlen_k << std::endl
|
||||
<< "\tseqstart_q: " << seqstart_q_host << std::endl
|
||||
<< "\tseqstart_k: " << seqstart_k_host << std::endl;
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
pass &= (dq_cur_pass & dk_cur_pass & dv_cur_pass & dbias_cur_pass);
|
||||
if(!(dq_cur_pass & dk_cur_pass & dv_cur_pass & dbias_cur_pass))
|
||||
{
|
||||
std::cerr << "mismatch found at batch: " << wb << std::endl
|
||||
<< "\tseqlen_q: " << real_seqlen_q << std::endl
|
||||
<< "\tseqlen_k: " << real_seqlen_k << std::endl
|
||||
<< "\tseqstart_q: " << seqstart_q_host << std::endl
|
||||
<< "\tseqstart_k: " << seqstart_k_host << std::endl;
|
||||
|
||||
break;
|
||||
}
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
|
||||
if(arg_parser.get_int("json") == 1)
|
||||
{
|
||||
dump_fmha_bwd_json_results(
|
||||
arg_parser.get_str("jsonfile"),
|
||||
data_type,
|
||||
mode == mode_enum::batch ? "batch" : "group",
|
||||
i_perm ? "true" : "false",
|
||||
o_perm ? "true" : "false",
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
scale,
|
||||
bias.type == bias_enum::elementwise_bias
|
||||
? "elementwise_bias"
|
||||
: (bias.type == bias_enum::alibi ? "alibi" : "no_bias"),
|
||||
use_dbias ? "true" : "false",
|
||||
p_drop,
|
||||
s_randval,
|
||||
deterministic,
|
||||
mask.type == mask_enum::no_mask
|
||||
? "no_mask"
|
||||
: (mask.type == mask_enum::window_generic
|
||||
? "window_generic"
|
||||
: (mask.type == mask_enum::mask_top_left
|
||||
? "mask_top_left"
|
||||
: (mask.type == mask_enum::mask_bottom_right ? "mask_bottom_right"
|
||||
: "mask_generic"))),
|
||||
mask.left,
|
||||
mask.right,
|
||||
workspace_size,
|
||||
pass,
|
||||
ave_time,
|
||||
tflops,
|
||||
gb_per_sec);
|
||||
}
|
||||
return pass;
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "mask.hpp"
|
||||
#include "rotary.hpp"
|
||||
#include "utils.hpp"
|
||||
#include "json_dump.hpp"
|
||||
|
||||
#include <array>
|
||||
#include <cstring>
|
||||
@@ -138,7 +139,9 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("page_block_size", "0", "paged-kvcache block size. 0 means not use paged-kvcahe")
|
||||
.insert("cache_batch_idx", "0", "whether to use index map to the kvcache")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel");
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "fmha_fwd.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -1137,12 +1140,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
<< std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec
|
||||
<< " GB/s" << std::flush << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
if(do_validation == 0)
|
||||
{
|
||||
std::cout << std::flush << std::endl;
|
||||
return true;
|
||||
}
|
||||
if(do_validation == 2)
|
||||
else if(do_validation == 2)
|
||||
{
|
||||
// NOTE: use gpu to do validation
|
||||
ck_tile::naive_attention_fwd_traits naive_t;
|
||||
@@ -1188,64 +1191,67 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
o_buf.FromDevice(o_host.data()); // TODO: ugly
|
||||
|
||||
auto [rtol_, atol_] = get_elimit<DataTypeConfig>(init_method);
|
||||
bool pass_ = ck_tile::check_err(
|
||||
pass = ck_tile::check_err(
|
||||
o_host, o_naive_ref, std::string("OUT Error: Incorrect results!"), rtol_, atol_);
|
||||
std::cout << ", valid:" << (pass_ ? "y" : "n") << std::flush << std::endl;
|
||||
return pass_;
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
o_buf.FromDevice(o_host.data());
|
||||
lse_buf.FromDevice(lse_host.data());
|
||||
randval_buf.FromDevice(randval_host.data());
|
||||
|
||||
auto p_compute_element_func = [&]() {
|
||||
if constexpr(std::is_same_v<DataTypeConfig, ck_tile::fp8_t>)
|
||||
return ck_tile::scales{scale_p};
|
||||
else
|
||||
return ck_tile::identity{};
|
||||
}();
|
||||
|
||||
auto oacc_element_func = [&]() {
|
||||
if constexpr(std::is_same_v<DataTypeConfig, ck_tile::fp8_t>)
|
||||
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
|
||||
ck_tile::scales{scale_o});
|
||||
else
|
||||
return ck_tile::identity{};
|
||||
}();
|
||||
|
||||
float p_undrop = 1.0 - p_drop;
|
||||
uint8_t p_undrop_in_uint8_t =
|
||||
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
|
||||
float rp_undrop = 1.0 / p_undrop;
|
||||
|
||||
bool pass = true;
|
||||
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
|
||||
else
|
||||
{
|
||||
const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
|
||||
const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
|
||||
|
||||
// adjust matrix index according to the mode
|
||||
const ck_tile::index_t b_idx = (mode == mode_enum::batch ? wb : 0);
|
||||
const ck_tile::index_t cache_b_idx =
|
||||
(use_cache_batch_idx ? cache_batch_idx_host(b_idx) : b_idx);
|
||||
const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
|
||||
const ck_tile::index_t key_offset =
|
||||
(mode == mode_enum::batch
|
||||
? 0
|
||||
: (seqlen_kpads[0] < 0 ? seqstart_k_host[wb] : seqstart_k_with_padding_host[wb]));
|
||||
o_buf.FromDevice(o_host.data());
|
||||
lse_buf.FromDevice(lse_host.data());
|
||||
randval_buf.FromDevice(randval_host.data());
|
||||
|
||||
ck_tile::HostTensor<QDataType> q_host_ref({nhead, real_seqlen_q, hdim_q});
|
||||
ck_tile::HostTensor<KDataType> k_host_ref({nhead, real_seqlen_k, hdim_q});
|
||||
ck_tile::HostTensor<VDataType> v_host_ref({nhead, hdim_v, real_seqlen_k});
|
||||
ck_tile::HostTensor<ODataType> o_host_ref({nhead, real_seqlen_q, hdim_v});
|
||||
auto p_compute_element_func = [&]() {
|
||||
if constexpr(std::is_same_v<DataTypeConfig, ck_tile::fp8_t>)
|
||||
return ck_tile::scales{scale_p};
|
||||
else
|
||||
return ck_tile::identity{};
|
||||
}();
|
||||
|
||||
ck_tile::HostTensor<SMPLComputeDataType> s_host_ref({nhead, real_seqlen_q, real_seqlen_k});
|
||||
ck_tile::HostTensor<PDataType> p_host_ref({nhead, real_seqlen_q, real_seqlen_k});
|
||||
ck_tile::HostTensor<SMPLComputeDataType> lse_host_ref({nhead, real_seqlen_q});
|
||||
auto oacc_element_func = [&]() {
|
||||
if constexpr(std::is_same_v<DataTypeConfig, ck_tile::fp8_t>)
|
||||
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
|
||||
ck_tile::scales{scale_o});
|
||||
else
|
||||
return ck_tile::identity{};
|
||||
}();
|
||||
|
||||
ck_tile::index_t nr = nhead / nhead_k;
|
||||
float p_undrop = 1.0 - p_drop;
|
||||
uint8_t p_undrop_in_uint8_t =
|
||||
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
|
||||
float rp_undrop = 1.0 / p_undrop;
|
||||
|
||||
// clang-format off
|
||||
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
|
||||
{
|
||||
const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
|
||||
const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
|
||||
|
||||
// adjust matrix index according to the mode
|
||||
const ck_tile::index_t b_idx = (mode == mode_enum::batch ? wb : 0);
|
||||
const ck_tile::index_t cache_b_idx =
|
||||
(use_cache_batch_idx ? cache_batch_idx_host(b_idx) : b_idx);
|
||||
const ck_tile::index_t query_offset =
|
||||
(mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
|
||||
const ck_tile::index_t key_offset =
|
||||
(mode == mode_enum::batch
|
||||
? 0
|
||||
: (seqlen_kpads[0] < 0 ? seqstart_k_host[wb]
|
||||
: seqstart_k_with_padding_host[wb]));
|
||||
|
||||
ck_tile::HostTensor<QDataType> q_host_ref({nhead, real_seqlen_q, hdim_q});
|
||||
ck_tile::HostTensor<KDataType> k_host_ref({nhead, real_seqlen_k, hdim_q});
|
||||
ck_tile::HostTensor<VDataType> v_host_ref({nhead, hdim_v, real_seqlen_k});
|
||||
ck_tile::HostTensor<ODataType> o_host_ref({nhead, real_seqlen_q, hdim_v});
|
||||
|
||||
ck_tile::HostTensor<SMPLComputeDataType> s_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k});
|
||||
ck_tile::HostTensor<PDataType> p_host_ref({nhead, real_seqlen_q, real_seqlen_k});
|
||||
ck_tile::HostTensor<SMPLComputeDataType> lse_host_ref({nhead, real_seqlen_q});
|
||||
|
||||
ck_tile::index_t nr = nhead / nhead_k;
|
||||
|
||||
// clang-format off
|
||||
// permute
|
||||
if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b_idx, i[0], i[1] + query_offset, i[2]); });
|
||||
else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b_idx, i[1] + query_offset, i[0], i[2]); });
|
||||
@@ -1379,198 +1385,179 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
});
|
||||
}
|
||||
#endif
|
||||
// clang-format on
|
||||
|
||||
// reference
|
||||
ck_tile::reference_batched_gemm<QDataType, KDataType, SaccDataType, SMPLComputeDataType>(
|
||||
q_host_ref,
|
||||
k_host_ref,
|
||||
s_host_ref,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales(scale_s));
|
||||
|
||||
if(0.f < logits_soft_cap)
|
||||
{
|
||||
ck_tile::reference_unary_elementwise<SaccDataType, SaccDataType, SaccDataType>(
|
||||
s_host_ref, s_host_ref, [logits_soft_cap](SaccDataType logits) {
|
||||
return ck_tile::type_convert<SaccDataType>(
|
||||
logits_soft_cap *
|
||||
std::tanhf(ck_tile::type_convert<float>(logits / logits_soft_cap)));
|
||||
});
|
||||
}
|
||||
|
||||
if(bias.type == bias_enum::elementwise_bias)
|
||||
{
|
||||
// elementwise bias
|
||||
ck_tile::HostTensor<BiasDataType> bias_host_ref({1, real_seqlen_q, real_seqlen_k});
|
||||
// clang-format off
|
||||
if(i_perm)
|
||||
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2]); });
|
||||
else
|
||||
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2]); });
|
||||
// clang-format on
|
||||
|
||||
// broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q,
|
||||
// real_seqlen_k]
|
||||
ck_tile::reference_batched_elementwise<SMPLComputeDataType,
|
||||
BiasDataType,
|
||||
SMPLComputeDataType,
|
||||
SMPLComputeDataType>(
|
||||
s_host_ref, bias_host_ref, s_host_ref);
|
||||
}
|
||||
else if(bias.type == bias_enum::alibi)
|
||||
{
|
||||
// alibi construct elementwise bias to verify
|
||||
auto alibi_host = [&]() {
|
||||
if(mask.type != mask_enum::no_mask)
|
||||
{
|
||||
return ck_tile::make_alibi_from_lr_mask<SaccDataType, true>(
|
||||
0,
|
||||
mask.left,
|
||||
mask.right,
|
||||
real_seqlen_q,
|
||||
real_seqlen_k,
|
||||
static_cast<ck_tile::GenericAttentionMaskEnum>(mask.type));
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::Alibi<SaccDataType, true>{
|
||||
0, real_seqlen_q, real_seqlen_k, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT};
|
||||
}
|
||||
}();
|
||||
// reference
|
||||
ck_tile::
|
||||
reference_batched_gemm<QDataType, KDataType, SaccDataType, SMPLComputeDataType>(
|
||||
q_host_ref,
|
||||
k_host_ref,
|
||||
s_host_ref,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales(scale_s));
|
||||
|
||||
ck_tile::HostTensor<SaccDataType> alibi_bias_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k});
|
||||
auto i_b_slope = bias.rank_info == 0 ? 0 : wb;
|
||||
for(auto i_h = 0; i_h < nhead; i_h++)
|
||||
if(0.f < logits_soft_cap)
|
||||
{
|
||||
SaccDataType current_slope = alibi_slope_host(i_b_slope, i_h);
|
||||
alibi_host.slope = alibi_host.mode == ck_tile::AlibiMode::VERTICAL ? current_slope
|
||||
: -current_slope;
|
||||
for(auto i_r = 0; i_r < real_seqlen_q; i_r++)
|
||||
{
|
||||
for(auto i_c = 0; i_c < real_seqlen_k; i_c++)
|
||||
ck_tile::reference_unary_elementwise<SaccDataType, SaccDataType, SaccDataType>(
|
||||
s_host_ref, s_host_ref, [logits_soft_cap](SaccDataType logits) {
|
||||
return ck_tile::type_convert<SaccDataType>(
|
||||
logits_soft_cap *
|
||||
std::tanhf(ck_tile::type_convert<float>(logits / logits_soft_cap)));
|
||||
});
|
||||
}
|
||||
|
||||
if(bias.type == bias_enum::elementwise_bias)
|
||||
{
|
||||
// elementwise bias
|
||||
ck_tile::HostTensor<BiasDataType> bias_host_ref({1, real_seqlen_q, real_seqlen_k});
|
||||
// clang-format off
|
||||
if(i_perm)
|
||||
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2]); });
|
||||
else
|
||||
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2]); });
|
||||
// clang-format on
|
||||
|
||||
// broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q,
|
||||
// real_seqlen_k]
|
||||
ck_tile::reference_batched_elementwise<SMPLComputeDataType,
|
||||
BiasDataType,
|
||||
SMPLComputeDataType,
|
||||
SMPLComputeDataType>(
|
||||
s_host_ref, bias_host_ref, s_host_ref);
|
||||
}
|
||||
else if(bias.type == bias_enum::alibi)
|
||||
{
|
||||
// alibi construct elementwise bias to verify
|
||||
auto alibi_host = [&]() {
|
||||
if(mask.type != mask_enum::no_mask)
|
||||
{
|
||||
SaccDataType pixel = 0;
|
||||
alibi_host.update(pixel, i_r, i_c);
|
||||
alibi_bias_host_ref(i_h, i_r, i_c) = pixel;
|
||||
return ck_tile::make_alibi_from_lr_mask<SaccDataType, true>(
|
||||
0,
|
||||
mask.left,
|
||||
mask.right,
|
||||
real_seqlen_q,
|
||||
real_seqlen_k,
|
||||
static_cast<ck_tile::GenericAttentionMaskEnum>(mask.type));
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::Alibi<SaccDataType, true>{
|
||||
0, real_seqlen_q, real_seqlen_k, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT};
|
||||
}
|
||||
}();
|
||||
|
||||
ck_tile::HostTensor<SaccDataType> alibi_bias_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k});
|
||||
auto i_b_slope = bias.rank_info == 0 ? 0 : wb;
|
||||
for(auto i_h = 0; i_h < nhead; i_h++)
|
||||
{
|
||||
SaccDataType current_slope = alibi_slope_host(i_b_slope, i_h);
|
||||
alibi_host.slope = alibi_host.mode == ck_tile::AlibiMode::VERTICAL
|
||||
? current_slope
|
||||
: -current_slope;
|
||||
for(auto i_r = 0; i_r < real_seqlen_q; i_r++)
|
||||
{
|
||||
for(auto i_c = 0; i_c < real_seqlen_k; i_c++)
|
||||
{
|
||||
SaccDataType pixel = 0;
|
||||
alibi_host.update(pixel, i_r, i_c);
|
||||
alibi_bias_host_ref(i_h, i_r, i_c) = pixel;
|
||||
}
|
||||
}
|
||||
}
|
||||
// [nhead, real_seqlen_q, real_seqlen_k]
|
||||
ck_tile::reference_batched_elementwise<SMPLComputeDataType,
|
||||
SaccDataType,
|
||||
SMPLComputeDataType,
|
||||
SMPLComputeDataType>(
|
||||
s_host_ref, alibi_bias_host_ref, s_host_ref);
|
||||
}
|
||||
// [nhead, real_seqlen_q, real_seqlen_k]
|
||||
ck_tile::reference_batched_elementwise<SMPLComputeDataType,
|
||||
SaccDataType,
|
||||
SMPLComputeDataType,
|
||||
SMPLComputeDataType>(
|
||||
s_host_ref, alibi_bias_host_ref, s_host_ref);
|
||||
}
|
||||
|
||||
if(mask.type == mask_enum::no_mask)
|
||||
{
|
||||
ck_tile::reference_batched_masking<SaccDataType>(
|
||||
s_host_ref, FmhaMasks::NoMask{real_seqlen_q, real_seqlen_k});
|
||||
}
|
||||
else if(mask.type == mask_enum::window_generic)
|
||||
{
|
||||
ck_tile::reference_batched_masking<SaccDataType>(
|
||||
s_host_ref,
|
||||
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
|
||||
mask.left, mask.right, real_seqlen_q, real_seqlen_k));
|
||||
}
|
||||
else
|
||||
{
|
||||
// if left window size is negative, means causal
|
||||
// else means generic (for current batch)
|
||||
if(mask.left < 0)
|
||||
if(mask.type == mask_enum::no_mask)
|
||||
{
|
||||
ck_tile::reference_batched_masking<SaccDataType>(
|
||||
s_host_ref,
|
||||
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
|
||||
mask.left,
|
||||
mask.right,
|
||||
real_seqlen_q,
|
||||
real_seqlen_k,
|
||||
mask.type == mask_enum::mask_top_left));
|
||||
else
|
||||
s_host_ref, FmhaMasks::NoMask{real_seqlen_q, real_seqlen_k});
|
||||
}
|
||||
else if(mask.type == mask_enum::window_generic)
|
||||
{
|
||||
ck_tile::reference_batched_masking<SaccDataType>(
|
||||
s_host_ref,
|
||||
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
|
||||
mask.left,
|
||||
mask.right,
|
||||
real_seqlen_q,
|
||||
real_seqlen_k,
|
||||
mask.type == mask_enum::mask_top_left));
|
||||
}
|
||||
if(lse)
|
||||
{
|
||||
ck_tile::reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
|
||||
s_host_ref, p_host_ref, p_compute_element_func, lse_host_ref);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
|
||||
s_host_ref, p_host_ref, p_compute_element_func);
|
||||
}
|
||||
mask.left, mask.right, real_seqlen_q, real_seqlen_k));
|
||||
}
|
||||
else
|
||||
{
|
||||
// if left window size is negative, means causal
|
||||
// else means generic (for current batch)
|
||||
if(mask.left < 0)
|
||||
ck_tile::reference_batched_masking<SaccDataType>(
|
||||
s_host_ref,
|
||||
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
|
||||
mask.left,
|
||||
mask.right,
|
||||
real_seqlen_q,
|
||||
real_seqlen_k,
|
||||
mask.type == mask_enum::mask_top_left));
|
||||
else
|
||||
ck_tile::reference_batched_masking<SaccDataType>(
|
||||
s_host_ref,
|
||||
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
|
||||
mask.left,
|
||||
mask.right,
|
||||
real_seqlen_q,
|
||||
real_seqlen_k,
|
||||
mask.type == mask_enum::mask_top_left));
|
||||
}
|
||||
if(lse)
|
||||
{
|
||||
ck_tile::
|
||||
reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
|
||||
s_host_ref, p_host_ref, p_compute_element_func, lse_host_ref);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::
|
||||
reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
|
||||
s_host_ref, p_host_ref, p_compute_element_func);
|
||||
}
|
||||
|
||||
if(p_drop > 0)
|
||||
{
|
||||
ck_tile::HostTensor<RandValOutputDataType> randval_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k});
|
||||
randval_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
self(idx) = randval_host(b_idx, idx[0], idx[1] + query_offset, idx[2]);
|
||||
});
|
||||
ck_tile::reference_batched_dropout(
|
||||
p_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop);
|
||||
}
|
||||
if(p_drop > 0)
|
||||
{
|
||||
ck_tile::HostTensor<RandValOutputDataType> randval_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k});
|
||||
randval_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
self(idx) = randval_host(b_idx, idx[0], idx[1] + query_offset, idx[2]);
|
||||
});
|
||||
ck_tile::reference_batched_dropout(
|
||||
p_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop);
|
||||
}
|
||||
|
||||
ck_tile::reference_batched_gemm<PDataType, VDataType, OaccDataType, ODataType>(
|
||||
p_host_ref,
|
||||
v_host_ref,
|
||||
o_host_ref,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
oacc_element_func);
|
||||
ck_tile::reference_batched_gemm<PDataType, VDataType, OaccDataType, ODataType>(
|
||||
p_host_ref,
|
||||
v_host_ref,
|
||||
o_host_ref,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
oacc_element_func);
|
||||
|
||||
ck_tile::HostTensor<ODataType> o_host_result({nhead, real_seqlen_q, hdim_v});
|
||||
// clang-format off
|
||||
ck_tile::HostTensor<ODataType> o_host_result({nhead, real_seqlen_q, hdim_v});
|
||||
// clang-format off
|
||||
// permute
|
||||
if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[0], idx[1] + query_offset, idx[2]); });
|
||||
else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); });
|
||||
// clang-format on
|
||||
|
||||
auto [rtol, atol] = get_elimit<DataTypeConfig>(init_method);
|
||||
bool cur_pass = ck_tile::check_err(
|
||||
o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
|
||||
pass &= cur_pass;
|
||||
if(!cur_pass)
|
||||
{
|
||||
std::cerr << "OUT mismatch found at batch: " << wb << std::endl
|
||||
<< "\tseqlen_q: " << real_seqlen_q << std::endl
|
||||
<< "\tseqlen_k: " << real_seqlen_k << std::endl
|
||||
<< "\tseqstart_q: " << seqstart_q_host << std::endl
|
||||
<< "\tseqstart_k: " << seqstart_k_host << std::endl;
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
if(lse)
|
||||
{
|
||||
ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
|
||||
lse_host_result.ForEach([&](auto& self, auto idx) {
|
||||
self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset);
|
||||
});
|
||||
|
||||
cur_pass = ck_tile::check_err(lse_host_result,
|
||||
lse_host_ref,
|
||||
"LSE Error: Incorrect results!",
|
||||
rtol,
|
||||
atol,
|
||||
/* allow_infinity_ref = */ true);
|
||||
// clang-format on
|
||||
|
||||
auto [rtol, atol] = get_elimit<DataTypeConfig>(init_method);
|
||||
bool cur_pass = ck_tile::check_err(o_host_result,
|
||||
o_host_ref,
|
||||
std::string("OUT Error: Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
pass &= cur_pass;
|
||||
if(!cur_pass)
|
||||
{
|
||||
std::cerr << "LSE mismatch found at batch: " << wb << std::endl
|
||||
std::cerr << "OUT mismatch found at batch: " << wb << std::endl
|
||||
<< "\tseqlen_q: " << real_seqlen_q << std::endl
|
||||
<< "\tseqlen_k: " << real_seqlen_k << std::endl
|
||||
<< "\tseqstart_q: " << seqstart_q_host << std::endl
|
||||
@@ -1578,10 +1565,65 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
if(lse)
|
||||
{
|
||||
ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
|
||||
lse_host_result.ForEach([&](auto& self, auto idx) {
|
||||
self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset);
|
||||
});
|
||||
|
||||
cur_pass = ck_tile::check_err(lse_host_result,
|
||||
lse_host_ref,
|
||||
"LSE Error: Incorrect results!",
|
||||
rtol,
|
||||
atol,
|
||||
/* allow_infinity_ref = */ true);
|
||||
|
||||
pass &= cur_pass;
|
||||
if(!cur_pass)
|
||||
{
|
||||
std::cerr << "LSE mismatch found at batch: " << wb << std::endl
|
||||
<< "\tseqlen_q: " << real_seqlen_q << std::endl
|
||||
<< "\tseqlen_k: " << real_seqlen_k << std::endl
|
||||
<< "\tseqstart_q: " << seqstart_q_host << std::endl
|
||||
<< "\tseqstart_k: " << seqstart_k_host << std::endl;
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
if(arg_parser.get_int("json") == 1)
|
||||
{
|
||||
dump_fmha_fwd_json_results(arg_parser.get_str("jsonfile"),
|
||||
prec,
|
||||
mode == mode_enum::batch ? "batch" : "group",
|
||||
io_layout(i_perm, o_perm),
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_qs[0],
|
||||
seqlen_ks[0],
|
||||
seqlen_kpads[0],
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
scale_s,
|
||||
p_drop,
|
||||
lse,
|
||||
squant,
|
||||
bias.type == bias_enum::elementwise_bias
|
||||
? "elementwise_bias"
|
||||
: (bias.type == bias_enum::alibi ? "alibi" : "no_bias"),
|
||||
vlayout,
|
||||
pass,
|
||||
ave_time,
|
||||
tflops,
|
||||
gb_per_sec);
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
@@ -65,6 +65,8 @@ args:
|
||||
-fquant fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant (default:0)
|
||||
-warmup cold iter (default:5)
|
||||
-repeat hot iter (default:20)
|
||||
-json 0: No Json, 1: Dump Results in Json format (default:0)
|
||||
-jsonfile json file name to dump results (default:layernorm2d_fwd.json)
|
||||
|
||||
```
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "layernorm2d_fwd.hpp"
|
||||
#include "json_dump.hpp"
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
|
||||
@@ -53,7 +54,9 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("fadd", "0", "fused-add, 0:no fused add, 1:preadd+store, 2:preadd only")
|
||||
.insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant")
|
||||
.insert("warmup", "5", "cold iter")
|
||||
.insert("repeat", "20", "hot iter");
|
||||
.insert("repeat", "20", "hot iter")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "layernorm2d_fwd.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -405,6 +408,24 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
if(arg_parser.get_int("json") == 1)
|
||||
{
|
||||
dump_layernorm2d_fwd_json_results(arg_parser.get_str("jsonfile"),
|
||||
prec_i,
|
||||
prec_o,
|
||||
prec_sm,
|
||||
prec_sy,
|
||||
m,
|
||||
n,
|
||||
x_stride,
|
||||
xr_stride,
|
||||
y_stride,
|
||||
yr_stride,
|
||||
pass,
|
||||
ave_time,
|
||||
0,
|
||||
gb_per_sec);
|
||||
}
|
||||
return pass;
|
||||
}
|
||||
|
||||
|
||||
@@ -9,11 +9,11 @@ mkdir build && cd build
|
||||
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
|
||||
../script/cmake-ck-dev.sh ../ <arch>
|
||||
# The basic pipeline method on the gemm calculation
|
||||
make tile_example_gemm_basic -j
|
||||
make tile_example_gemm_basic -j`nproc`
|
||||
# The memory bound pipeline on the gemm calculation
|
||||
make tile_example_gemm_universal -j
|
||||
make tile_example_gemm_universal -j`nproc`
|
||||
# The weight preshuffle pipeline on the gemm calculation
|
||||
make tile_example_gemm_weight_preshuffle -j
|
||||
make tile_example_gemm_weight_preshuffle -j`nproc`
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_gemm_basic` & `build/bin/tile_example_gemm_universal`
|
||||
|
||||
@@ -30,11 +30,13 @@ args:
|
||||
-stride_b Tensor B stride (default:0)
|
||||
-stride_c Tensor C stride (default:0)
|
||||
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2)
|
||||
-prec data type. fp16/bf16/fp8/bf8/int8 (default:fp16)
|
||||
-warmup number of iterations before benchmark the kernel (default:10)
|
||||
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
|
||||
-warmup number of iterations before benchmark the kernel (default:50)
|
||||
-repeat number of iterations to benchmark the kernel (default:100)
|
||||
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
|
||||
-split_k splitK value (default:1)
|
||||
-init 0:random, 1:linear, 2:constant (default:1)
|
||||
-init 0:random, 1:linear, 2:constant(1) (default:0)
|
||||
-persistent 0:non-persistent, 1:persistent (default:0)
|
||||
-json 0: No Json, 1: Dump Results in Json format (default:0)
|
||||
-jsonfile json file name to dump results (default:gemm.json)
|
||||
```
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "json_dump.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
@@ -493,6 +494,8 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("persistent", "0", "0:non-persistent, 1:persistent")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "gemm.json", "json file name to dump results")
|
||||
.insert("flush_cache", "true", "flush cache before running the kernel, defaults to true")
|
||||
.insert("rotating_count", "1000", "rotating count, defaults to 1000");
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#pragma once
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
@@ -236,23 +235,6 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
nullptr, true, 1, n_warmup, n_repeat, true, flush_cache, rotating_count});
|
||||
}
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run Gemm kernel with \n M=" << M << " N=" << N << " K=" << K
|
||||
<< " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C
|
||||
<< " A_Layout=" << ALayout::name << " B_Layout =" << BLayout::name
|
||||
<< " C_Layout=" << CLayout::name << " A_Type=" << DataTypeTraits<ADataType>::name
|
||||
<< " B_Type=" << DataTypeTraits<BDataType>::name
|
||||
<< " C_Type=" << DataTypeTraits<CDataType>::name
|
||||
<< " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off")
|
||||
<< " Persistent=" << (persistent ? "on" : "off") << " : \n"
|
||||
<< ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
@@ -416,32 +398,49 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
invoke_gemm<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout>(a_m_k_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
c_m_n_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat,
|
||||
persistent,
|
||||
flush_cache,
|
||||
rotating_count);
|
||||
float ave_time = invoke_gemm<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout>(a_m_k_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
c_m_n_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat,
|
||||
persistent,
|
||||
flush_cache,
|
||||
rotating_count);
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run Gemm kernel with M=" << M << " N=" << N << " K=" << K
|
||||
<< " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C
|
||||
<< " A_Layout=" << ALayout::name << " B_Layout =" << BLayout::name
|
||||
<< " C_Layout=" << CLayout::name << " A_Type=" << DataTypeTraits<ADataType>::name
|
||||
<< " B_Type=" << DataTypeTraits<BDataType>::name
|
||||
<< " C_Type=" << DataTypeTraits<CDataType>::name
|
||||
<< " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off")
|
||||
<< " Persistent=" << (persistent ? "on" : "off") << " : " << ave_time << " ms, "
|
||||
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
// memory on host to store gpu reference result
|
||||
@@ -496,5 +495,28 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
|
||||
pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "GPU");
|
||||
}
|
||||
|
||||
if(arg_parser.get_int("json") == 1)
|
||||
{
|
||||
dump_gemm_json_results<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
GemmConfig,
|
||||
DataTypeTraits>(arg_parser.get_str("jsonfile"),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
persistent,
|
||||
pass,
|
||||
ave_time,
|
||||
tflops,
|
||||
gb_per_sec);
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
@@ -3,8 +3,24 @@
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/reduce.hpp"
|
||||
#include "json_dump.hpp"
|
||||
#include <cstring>
|
||||
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::half_t>
|
||||
{
|
||||
static constexpr const char* name = "fp16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf16_t>
|
||||
{
|
||||
static constexpr const char* name = "bf16";
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
@@ -14,8 +30,10 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("c", "512", "c dimension")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("prec", "fp16", "precision")
|
||||
.insert("warmup", "0", "cold iter")
|
||||
.insert("repeat", "1", "hot iter");
|
||||
.insert("warmup", "5", "cold iter")
|
||||
.insert("repeat", "20", "hot iter")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "reduce.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -126,6 +144,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
if(arg_parser.get_int("json") == 1)
|
||||
{
|
||||
dump_reduce_json_results<DataType, DataTypeTraits>(
|
||||
arg_parser.get_str("jsonfile"), N, C, H, W, pass, ave_time, 0, gb_per_sec);
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "permute.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
@@ -127,7 +127,8 @@ auto create_args(int argc, char* argv[])
|
||||
"random seed used for initializing input tensors. 0 for "
|
||||
"non-deterministic seed")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel");
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel")
|
||||
.insert("jsonfile", "permute.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -382,6 +383,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush;
|
||||
}
|
||||
|
||||
if(arg_parser.get_int("json") == 1)
|
||||
{
|
||||
dump_permute_json_results(arg_parser.get_str("jsonfile"), data_type, pass, ave_time, 0, 0);
|
||||
}
|
||||
|
||||
std::cout << std::endl;
|
||||
|
||||
return pass;
|
||||
|
||||
@@ -24,5 +24,7 @@ args:
|
||||
-st_o row stride of output/indices, -1 means same as topk (default:-1)
|
||||
-seed seed to be used, -1 means random every time (default:-1)
|
||||
-kname when set to 1 it will print kernel name (default:0)
|
||||
-json 0: No Json, 1: Dump Results in Json format (default:0)
|
||||
-jsonfile json file name to dump results (default:topk_softmax.json)
|
||||
|
||||
```
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
@@ -13,6 +13,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/reduce.hpp"
|
||||
#include "topk_softmax_api.hpp"
|
||||
#include "json_dump.hpp"
|
||||
|
||||
#if 0
|
||||
template <typename T>
|
||||
@@ -130,7 +131,9 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("seed", "-1", "seed to be used, -1 means random every time")
|
||||
.insert("kname", "0", "when set to 1 it will print kernel name")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel");
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "topk_softmax.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -273,6 +276,23 @@ bool test_topk_softmax(ck_tile::ArgParser args)
|
||||
}
|
||||
|
||||
printf("valid:%s\n", rtn ? "y" : "n");
|
||||
|
||||
if(args.get_int("json") == 1)
|
||||
{
|
||||
dump_topk_softmax_json(args.get_str("jsonfile"),
|
||||
input_prec,
|
||||
weight_prec,
|
||||
tokens,
|
||||
experts,
|
||||
topk,
|
||||
stride_input,
|
||||
stride_output,
|
||||
ms,
|
||||
0,
|
||||
0,
|
||||
rtn);
|
||||
}
|
||||
|
||||
fflush(stdout);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "topk_softmax_api.hpp"
|
||||
|
||||
|
||||
@@ -6,17 +6,34 @@ This folder contains example for Rmsnorm2D forward using ck_tile tile-programmin
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
make tile_rmsnorm2d_fwd -j
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
make tile_rmsnorm2d_fwd -j`nproc`
|
||||
```
|
||||
This will result in an executable `build/bin/tile_rmsnorm2d_fwd`
|
||||
|
||||
## cmdline
|
||||
```
|
||||
args:
|
||||
-m m dimension (default:3328)
|
||||
-n m dimension (default:4096)
|
||||
-e epsilon (default:1e-5)
|
||||
-v cpu validation or not (default:1)
|
||||
-prec precision (default:fp16)
|
||||
-m m dimension (default:3328)
|
||||
-n n dimension (default:4096)
|
||||
-x_stride x row_stride, if -1 then equal to n (default:-1)
|
||||
-xr_stride x residule row_stride, if -1 then equal to n (default:-1)
|
||||
-y_stride y row_stride, if -1 then equal to n (default:-1)
|
||||
-yr_stride y residule row_stride, if -1 then equal to n (default:-1)
|
||||
-e epsilon (default:1e-5)
|
||||
-save_rms save rms(invrms) or not. set to 1 in training case (default:0)
|
||||
-save_unquant save result before quant (default:0)
|
||||
-v cpu validation or not (default:1)
|
||||
-kname print kernel name or not (default:1)
|
||||
-prec_i input precision (default:fp16)
|
||||
-prec_o output precision, set auto will be the same as input (default:auto)
|
||||
-prec_sm output quant scale type, set auto will use fp32. used when fquant=1 (default:auto)
|
||||
-prec_sy output quant scale type, set auto will use fp32. used when fquant=1 or 2 (default:auto)
|
||||
-fadd fused-add, 0:no fused add, 1:preadd+store, 2:preadd only (default:0)
|
||||
-fquant fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant (default:0)
|
||||
-warmup cold iter (default:5)
|
||||
-repeat hot iter (default:20)
|
||||
-s sensitive model mode, 0: for no specific model, 1: for T5-like model (default:0)
|
||||
-json 0: No Json, 1: Dump Results in Json format (default:0)
|
||||
-jsonfile json file name to dump results (default:rmsnorm2d_fwd.json)
|
||||
```
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "rmsnorm2d_fwd.hpp"
|
||||
#include <cstring>
|
||||
#include "json_dump.hpp"
|
||||
|
||||
// different threshold for different dtype
|
||||
template <typename DataType>
|
||||
@@ -53,7 +54,9 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant")
|
||||
.insert("warmup", "5", "cold iter")
|
||||
.insert("repeat", "20", "hot iter")
|
||||
.insert("s", "0", "sensitive model mode, 0: for no specific model, 1: for T5-like model");
|
||||
.insert("s", "0", "sensitive model mode, 0: for no specific model, 1: for T5-like model")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "rmsnorm2d_fwd.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -437,6 +440,23 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
if(arg_parser.get_int("json") == 1)
|
||||
{
|
||||
dump_rmsnorm2d_fwd_json(arg_parser.get_str("jsonfile"),
|
||||
prec_str,
|
||||
m,
|
||||
n,
|
||||
x_stride,
|
||||
xr_stride,
|
||||
y_stride,
|
||||
yr_stride,
|
||||
use_model_sensitive_rmsnorm,
|
||||
ave_time,
|
||||
0,
|
||||
gb_per_sec,
|
||||
pass);
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
|
||||
@@ -6,8 +6,8 @@ This folder contains example for add + Rmsnorm2D + rowwise dynamic quantization
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
make tile_add_rmsnorm2d_rdquant_fwd -j
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
make tile_add_rmsnorm2d_rdquant_fwd -j`nproc`
|
||||
```
|
||||
This will result in an executable `build/bin/tile_add_rmsnorm2d_rdquant_fwd`
|
||||
|
||||
@@ -15,8 +15,16 @@ This will result in an executable `build/bin/tile_add_rmsnorm2d_rdquant_fwd`
|
||||
```
|
||||
args:
|
||||
-m m dimension (default:3328)
|
||||
-n m dimension (default:4096)
|
||||
-n n dimension (default:4096)
|
||||
-stride stride per row, if -1 then equal to n (default:-1)
|
||||
-e epsilon (default:1e-5)
|
||||
-save_x save rms(invrms) or not. set to 1 in training case (default:1)
|
||||
-v cpu validation or not (default:1)
|
||||
-kname print kernel name or not (default:1)
|
||||
-prec precision (default:fp16)
|
||||
-quant precision (default:int8)
|
||||
-warmup cold iter (default:5)
|
||||
-repeat hot iter (default:20)
|
||||
-json 0: No Json, 1: Dump Results in Json format (default:0)
|
||||
-jsonfile json file name to dump results (default:add_rmsnorm2d_rdquant_fwd.json)
|
||||
```
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "add_rmsnorm2d_rdquant_fwd.hpp"
|
||||
#include <cstring>
|
||||
#include "json_dump.hpp"
|
||||
|
||||
// different threshold for different dtype
|
||||
template <typename InputDataType>
|
||||
@@ -41,7 +42,9 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("prec", "fp16", "precision")
|
||||
.insert("quant", "int8", "precision")
|
||||
.insert("warmup", "5", "cold iter")
|
||||
.insert("repeat", "20", "hot iter");
|
||||
.insert("repeat", "20", "hot iter")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "add_rmsnorm2d_rdquant_fwd.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -260,6 +263,21 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
if(arg_parser.get_int("json") == 1)
|
||||
{
|
||||
dump_add_rmsnorm2d_rdquant_fwd_json(arg_parser.get_str("jsonfile"),
|
||||
input_data_type,
|
||||
quantized_data_type,
|
||||
m,
|
||||
n,
|
||||
stride,
|
||||
epsilon,
|
||||
ave_time,
|
||||
0,
|
||||
gb_per_sec,
|
||||
pass);
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
|
||||
@@ -6,8 +6,8 @@ This folder contains example for smoothquant using ck_tile tile-programming impl
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
make tile_smoothquant -j
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
make tile_smoothquant -j`nproc`
|
||||
```
|
||||
This will result in an executable `build/bin/tile_smoothquant`
|
||||
|
||||
@@ -15,7 +15,14 @@ This will result in an executable `build/bin/tile_smoothquant`
|
||||
```
|
||||
args:
|
||||
-m m dimension (default:3328)
|
||||
-n m dimension (default:4096)
|
||||
-n n dimension (default:4096)
|
||||
-x_stride input stride per row, if -1 then equal to n (default:-1)
|
||||
-y_stride output stride per row, if -1 then equal to n (default:-1)
|
||||
-v cpu validation or not (default:1)
|
||||
-kname print kernel name or not (default:1)
|
||||
-prec precision (default:fp16)
|
||||
-warmup cold iter (default:5)
|
||||
-repeat hot iter (default:20)
|
||||
-json 0: No Json, 1: Dump Results in Json format (default:0)
|
||||
-jsonfile json file name to dump results (default:smoothquant.json)
|
||||
```
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "smoothquant.hpp"
|
||||
#include "json_dump.hpp"
|
||||
#include <cstring>
|
||||
|
||||
// different threshold for different dtype
|
||||
@@ -39,7 +40,9 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("kname", "1", "print kernel name or not")
|
||||
.insert("prec", "fp16", "precision")
|
||||
.insert("warmup", "5", "cold iter")
|
||||
.insert("repeat", "20", "hot iter");
|
||||
.insert("repeat", "20", "hot iter")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "smoothquant.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -202,6 +205,19 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
if(arg_parser.get_int("json") == 1)
|
||||
{
|
||||
dump_smoothquant_json(arg_parser.get_str("jsonfile"),
|
||||
data_type,
|
||||
m,
|
||||
n,
|
||||
x_stride,
|
||||
y_stride,
|
||||
ave_time,
|
||||
0,
|
||||
gb_per_sec,
|
||||
pass);
|
||||
}
|
||||
return pass;
|
||||
}
|
||||
|
||||
|
||||
@@ -6,32 +6,36 @@ This folder contains example for moe-sorting kernel using ck_tile tile-programmi
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
make tile_example_moe_sorting -j
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
make tile_example_moe_sorting -j`nproc`
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_moe_sorting`
|
||||
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-v turn CPU validation on (1) or off (0). (default:1)
|
||||
-pr_i index data type. Only int32 is currently supported. (default:int32)
|
||||
-pr_w output weight data type. Only fp32 is currently supported. (default:fp32)
|
||||
-t number of input tokens. (default:128)
|
||||
If "local_t" presents, this value indicates global concurrency of all ranks.
|
||||
-local_t Number of local input tokens for curent rank. (default:-1)
|
||||
This value must be within range "[0, t)", or "-1"(no such feature)
|
||||
This feature is to simulate EP case where where each rank has different tokens.
|
||||
Besides, this value will be stored in a GPU buffer, which is friendly for CUDA graph.
|
||||
-e number of num_experts (default:8)
|
||||
-k topk (default:4)
|
||||
-unit unit_size (default:32)
|
||||
-moe_buf_size moe_buf_size (default:0)
|
||||
-local_eid a list of experts enabled as local expert. e.g. "0,1,4,5" (default:-1)
|
||||
please make sure eid is in ascending order!
|
||||
-seed seed to be used. When set to -1, a random seed will be generated each time invoking this example (default:-1)
|
||||
-kname prints the kernel name when set to 1 (default:0)
|
||||
-warmup number of iterations before benchmark the kernel (default:5)
|
||||
-repeat number of iterations to benchmark the kernel (default:20)
|
||||
|
||||
-v turn CPU validation on (1) or off (0). (default:1)
|
||||
-pr_i index data type. Only int32 is currently supported. (default:int32)
|
||||
-pr_w output weight data type. Only fp32 is currently supported. (default:fp32)
|
||||
-t number of input tokens. (default:128)
|
||||
If "local_t" presents, this value indicates global concurrency of all ranks.
|
||||
-local_t Number of local input tokens for curent rank. (default:-1)
|
||||
This value must be within range "[0, t)", or "-1"(no such feature)
|
||||
This feature is to simulate EP case where where each rank has different tokens.
|
||||
Besides, this value will be stored in a GPU buffer, which is friendly for CUDA graph.
|
||||
-e number of num_experts (default:8)
|
||||
-k topk (default:4)
|
||||
-unit unit_size (default:32)
|
||||
-moe_buf_interm_dim interm_dim(col) of the following fmoe buf (default:0)
|
||||
-moe_buf_elem_bytes fmoe buf element byte size, 1:8bit, 2:16bit, 4:32bit... (default:2)
|
||||
-ci clear workspace inside API or not(if "0", require manually clear outside) (default:1)
|
||||
-dispatch dispatch policy. 0:automatically pick up kernel, 1:use single kernel, 2:use mp kernel (default:0)
|
||||
-local_eid a list of experts enabled as local expert. e.g. "0,1,4,5" (default:-1)
|
||||
please make sure eid is in ascending order!
|
||||
-seed seed to be used. When set to -1, a random seed will be generated each time invoking this example (default:-1)
|
||||
-kname prints the kernel name when set to 1 (default:0)
|
||||
-warmup number of iterations before benchmark the kernel (default:5)
|
||||
-repeat number of iterations to benchmark the kernel (default:20)
|
||||
-json 0: No Json, 1: Dump Results in Json format (default:0)
|
||||
-jsonfile json file name to dump results (default:moe_sorting.json)
|
||||
```
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <set>
|
||||
#include <vector>
|
||||
@@ -14,6 +14,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/reduce.hpp"
|
||||
#include "moe_sorting_api.hpp"
|
||||
#include "json_dump.hpp"
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
@@ -59,7 +60,9 @@ auto create_args(int argc, char* argv[])
|
||||
"invoking this example")
|
||||
.insert("kname", "0", "prints the kernel name when set to 1")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel");
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "moe_sorting.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -437,6 +440,23 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
printf(", (%d)", seed);
|
||||
printf("\n");
|
||||
fflush(stdout);
|
||||
|
||||
if(args.get_int("json") == 1)
|
||||
{
|
||||
dump_moe_sorting_json(args.get_str("jsonfile"),
|
||||
index_prec,
|
||||
weight_prec,
|
||||
workspace_size == 0 ? "cx" : (clear_inside ? "ci" : "co"),
|
||||
dispatch_policy,
|
||||
tokens,
|
||||
num_experts,
|
||||
topk,
|
||||
ms,
|
||||
0,
|
||||
0,
|
||||
rtn);
|
||||
}
|
||||
|
||||
return rtn;
|
||||
}
|
||||
|
||||
|
||||
@@ -9,7 +9,25 @@ Unlike standard smoothquant op, the input scale is from different expert `[exper
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
make tile_example_moe_smoothquant -j
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
make tile_example_moe_smoothquant -j`nproc`
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_moe_smoothquant`
|
||||
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-t tokens dimension (default:3328)
|
||||
-h hidden_size dimension (default:4096)
|
||||
-e experts (default:32)
|
||||
-k topk (default:5)
|
||||
-stride stride per row, if -1 then equal to hidden_size (default:-1)
|
||||
-v cpu validation or not (default:1)
|
||||
-kname print kernel name or not (default:1)
|
||||
-prec_i input precision, fp16/bf16 (default:fp16)
|
||||
-prec_o precision, int8/fp8 (default:int8)
|
||||
-warmup cold iter (default:5)
|
||||
-repeat hot iter (default:20)
|
||||
-json 0: No Json, 1: Dump Results in Json format (default:0)
|
||||
-jsonfile json file name to dump results (default:moe_smoothquant.json)
|
||||
```
|
||||
@@ -1,5 +1,6 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "moe_smoothquant.hpp"
|
||||
#include "json_dump.hpp"
|
||||
#include <cstring>
|
||||
#include <set>
|
||||
|
||||
@@ -66,7 +67,9 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("prec_i", "fp16", "input precision, fp16/bf16")
|
||||
.insert("prec_o", "int8", "precision, int8/fp8")
|
||||
.insert("warmup", "5", "cold iter")
|
||||
.insert("repeat", "20", "hot iter");
|
||||
.insert("repeat", "20", "hot iter")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "moe_smoothquant.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -244,6 +247,21 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
if(arg_parser.get_int("json"))
|
||||
{
|
||||
dump_moe_smoothquant_json(arg_parser.get_str("jsonfile"),
|
||||
prec_i,
|
||||
prec_o,
|
||||
tokens,
|
||||
hidden_size,
|
||||
stride,
|
||||
experts,
|
||||
topk,
|
||||
pass,
|
||||
ave_time,
|
||||
0,
|
||||
gb_per_sec);
|
||||
}
|
||||
return pass;
|
||||
}
|
||||
|
||||
|
||||
@@ -69,4 +69,42 @@ summary of the key design of this fused-moe operator:
|
||||
// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
|
||||
//
|
||||
// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)
|
||||
```
|
||||
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-t number of input tokens. (default:128)
|
||||
If "local_t" presents, this value indicates global concurrency of all ranks.
|
||||
-local_t Number of local input tokens for curent rank. (default:-1)
|
||||
This value must be within range "[0, t)", or "-1"(no such feature)
|
||||
This feature is to simulate EP case where where each rank has different tokens.
|
||||
Besides, this value will be stored in a GPU buffer, which is friendly for CUDA graph.
|
||||
-e num of experts (default:32)
|
||||
-k topk (default:5)
|
||||
-h hidden_size of this model (default:8192)
|
||||
-i intermediate_size between 2 gemms of FFN (default:8192)
|
||||
-stride stride per row, if -1 then equal to hidden_size (default:-1)
|
||||
-bm blocking factor for sorted tokens (default:32)
|
||||
-tp tensor parallel size (default:8)
|
||||
-v cpu validation or not (default:1)
|
||||
-kname print kernel name or not (default:1)
|
||||
-prec_i input precision (default:bf16)
|
||||
-prec_w weight precision (default:bf16)
|
||||
-prec_o output precision (default:bf16)
|
||||
-prec_st token scale data type. auto will set to fp32 (default:auto)
|
||||
-prec_sw weight scale data type. auto will set to fp32 (default:auto)
|
||||
-prec_sq (dynamic) smooth quant data type. auto will set to fp32 (default:auto)
|
||||
-prec_kw topk-weight data type. auto will set to fp32 (default:auto)
|
||||
-fquant fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant (default:0)
|
||||
-gate_only w0(gate/up) style, 0:gate+up will double interm size, 1:only gate (default:1)
|
||||
-api benchmark api set: 0:fused-moe(moe-gemm+moe-sorting), 1:moe-gemm (default:0)
|
||||
-act activation after first gemm. 0:gelu, 1:silu (default:0)
|
||||
-balance if set to 1, will try balance the expert in topk-ids(convenient for testing) (default:0)
|
||||
-init init method. 0:random stepped float(fast). 1: random uniform[-0.5, 0.5], 2:rand normalized[0, 1]normalized(slow) (default:1)
|
||||
-seed seed used to do random (default:11939)
|
||||
-warmup cold iter (default:5)
|
||||
-repeat hot iter (default:20)
|
||||
-json 0: No Json, 1: Dump Results in Json format (default:0)
|
||||
-jsonfile json file name to dump results (default:fused_moe.json)
|
||||
```
|
||||
@@ -5,6 +5,7 @@
|
||||
#include <set>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "json_dump.hpp"
|
||||
#include "fused_moe.hpp"
|
||||
|
||||
// different threshold for different dtype
|
||||
@@ -130,7 +131,9 @@ auto create_args(int argc, char* argv[])
|
||||
"normalized(slow)")
|
||||
.insert("seed", "11939", "seed used to do random")
|
||||
.insert("warmup", "5", "cold iter")
|
||||
.insert("repeat", "20", "hot iter");
|
||||
.insert("repeat", "20", "hot iter")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "fused_moe.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -513,6 +516,29 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush;
|
||||
}
|
||||
std::cout << std::flush << std::endl;
|
||||
|
||||
if(arg_parser.get_int("json") == 1)
|
||||
{
|
||||
dump_fused_moe_json(arg_parser.get_str("jsonfile"),
|
||||
api_str,
|
||||
prec_str,
|
||||
tokens,
|
||||
is_local_token,
|
||||
local_tokens,
|
||||
experts,
|
||||
topk,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
stride,
|
||||
block_m,
|
||||
activation,
|
||||
gate_only,
|
||||
fused_quant,
|
||||
pass,
|
||||
ave_time,
|
||||
cal_tflops(ave_time),
|
||||
cal_tbps(ave_time));
|
||||
}
|
||||
return pass;
|
||||
}
|
||||
else if(api == 1)
|
||||
@@ -619,6 +645,29 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
std::cout << std::flush << std::endl;
|
||||
|
||||
if(arg_parser.get_int("json") == 1)
|
||||
{
|
||||
dump_fused_moe_json(arg_parser.get_str("jsonfile"),
|
||||
api_str,
|
||||
prec_str,
|
||||
tokens,
|
||||
is_local_token,
|
||||
local_tokens,
|
||||
experts,
|
||||
topk,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
stride,
|
||||
block_m,
|
||||
activation,
|
||||
gate_only,
|
||||
fused_quant,
|
||||
pass,
|
||||
ave_time,
|
||||
cal_tflops(ave_time),
|
||||
cal_tbps(ave_time));
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
return false;
|
||||
|
||||
@@ -15,23 +15,25 @@ This will result in an executable `build/bin/tile_example_batched_gemm`
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-m m dimension (default:256)
|
||||
-n n dimension (default:128)
|
||||
-k k dimension (default:128)
|
||||
-a_layout A tensor data layout (default:R) (R for Row, C for Col)
|
||||
-b_layout B tensor data layout (default:R) (R for Row, C for Col)
|
||||
-c_layout C tensor data layout (default:R) (R for Row, C for Col)
|
||||
-stride_a Tensor A stride (default:128)
|
||||
-stride_b Tensor B stride (default:128)
|
||||
-stride_c Tensor C stride (default:128)
|
||||
-batch_stride_a Batch A stride (default:32768)
|
||||
-batch_stride_b Batch B stride (default:16384)
|
||||
-batch_stride_c Batch C stride (default:32768)
|
||||
-batch_count Batch count (default:16)
|
||||
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2)
|
||||
-e Absolute error tolerance (default:1e-5)
|
||||
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
|
||||
-warmup number of iterations before benchmark the kernel (default:10)
|
||||
-repeat number of iterations to benchmark the kernel (default:100)
|
||||
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
|
||||
-m m dimension (default:512)
|
||||
-n n dimension (default:1024)
|
||||
-k k dimension (default:2048)
|
||||
-stride_a Tensor A stride (default:0)
|
||||
-stride_b Tensor B stride (default:0)
|
||||
-stride_c Tensor C stride (default:0)
|
||||
-a_layout A tensor data layout - Row by default (default:R)
|
||||
-b_layout B tensor data layout - Row by default (default:C)
|
||||
-c_layout C tensor data layout - Row by default (default:R)
|
||||
-batch_stride_a Batch A stride (default:1048576)
|
||||
-batch_stride_b Batch B stride (default:2097152)
|
||||
-batch_stride_c Batch C stride (default:524288)
|
||||
-batch_count Batch count (default:8)
|
||||
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2)
|
||||
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
|
||||
-warmup number of iterations before benchmark the kernel (default:50)
|
||||
-repeat number of iterations to benchmark the kernel (default:100)
|
||||
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
|
||||
-split_k splitK value (default:1)
|
||||
-json 0: No Json, 1: Dump Results in Json format (default:0)
|
||||
-jsonfile json file name to dump results (default:cktile_batched_gemm.json)
|
||||
```
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
#include <json_dump.hpp>
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
@@ -75,7 +76,9 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value");
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "cktile_batched_gemm.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
@@ -77,21 +76,6 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
CDEElementWise>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::string op_name{"Batched Gemm"};
|
||||
std::size_t flop = std::size_t(2) * batch_count * M * N * K;
|
||||
std::size_t num_byte = sizeof(ADataType) * batch_count * M * K +
|
||||
sizeof(BDataType) * batch_count * N * K +
|
||||
sizeof(CDataType) * batch_count * M * N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run " << op_name << "kernel with M =" << M << " N =" << N << " K =" << K
|
||||
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
|
||||
<< " batch_stride_A =" << batch_stride_A << " batch_stride_B =" << batch_stride_B
|
||||
<< " batch_stride_C =" << batch_stride_C << " batch_count =" << batch_count << " : "
|
||||
<< ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
@@ -186,31 +170,47 @@ int run_batched_gemm_example_with_layouts(int argc,
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
invoke_batched_gemm<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout>(a_m_k_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
c_m_n_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
batch_stride_A,
|
||||
batch_stride_B,
|
||||
batch_stride_C,
|
||||
batch_count,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
float ave_time = invoke_batched_gemm<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout>(a_m_k_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
c_m_n_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
batch_stride_A,
|
||||
batch_stride_B,
|
||||
batch_stride_C,
|
||||
batch_count,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
|
||||
std::string op_name{"Batched Gemm"};
|
||||
std::size_t flop = std::size_t(2) * batch_count * M * N * K;
|
||||
std::size_t num_byte = sizeof(ADataType) * batch_count * M * K +
|
||||
sizeof(BDataType) * batch_count * N * K +
|
||||
sizeof(CDataType) * batch_count * M * N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run " << op_name << "kernel with M =" << M << " N =" << N << " K =" << K
|
||||
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
|
||||
<< " batch_stride_A =" << batch_stride_A << " batch_stride_B =" << batch_stride_B
|
||||
<< " batch_stride_C =" << batch_stride_C << " batch_count =" << batch_count << " : "
|
||||
<< ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< std::endl;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
@@ -310,6 +310,27 @@ int run_batched_gemm_example_with_layouts(int argc,
|
||||
std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
if(arg_parser.get_int("json") == 1)
|
||||
{
|
||||
dump_batched_gemm_json_results(arg_parser.get_str("jsonfile"),
|
||||
op_name,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
batch_stride_A,
|
||||
batch_stride_B,
|
||||
batch_stride_C,
|
||||
batch_count,
|
||||
pass,
|
||||
ave_time,
|
||||
tflops,
|
||||
gb_per_sec,
|
||||
"batched_gemm");
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
|
||||
@@ -157,17 +157,20 @@ This will result in an executable `build/bin/tile_example_grouped_gemm`
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-Ms M dimensions - (Default: empty).
|
||||
-Ns N dimensions - (Default: empty).
|
||||
-Ks K dimensions - (Default: empty).
|
||||
-stride_As Tensor A strides - (Default: empty).
|
||||
-stride_Bs Tensor B strides - (Default: empty).
|
||||
-stride_Cs Tensor C strides - (Default: empty).
|
||||
-a_layout A tensor data layout - (Default: Row).
|
||||
-b_layout B tensor data layout - (Default: Col).
|
||||
-c_layout C tensor data layout - (Default: Row).
|
||||
-validate 0. No validation, 1. Validation on CPU. (Default: 1).
|
||||
-warmup Number of iterations before benchmark the kernel. (Default: 10).
|
||||
-repeat Number of iterations to benchmark the kernel. (Default: 100).
|
||||
-group_count Group count. (Default: 16).
|
||||
-Ms M dimensions - empty by default. (default:)
|
||||
-Ns N dimensions - empty by default. (default:)
|
||||
-Ks K dimensions - empty by default. (default:)
|
||||
-stride_As Tensor A strides - it is empty by default. (default:)
|
||||
-stride_Bs Tensor B strides - it is empty by default. (default:)
|
||||
-stride_Cs Tensor C strides - it is empty by default. (default:)
|
||||
-a_layout A tensor data layout - Row by default. (default:R)
|
||||
-b_layout B tensor data layout - Row by default. (default:C)
|
||||
-c_layout C tensor data layout - Row by default. (default:R)
|
||||
-validate 0. No validation, 1. Validation on CPU. (default:1)
|
||||
-warmup number of iterations before benchmark the kernel. (default:10)
|
||||
-repeat number of iterations to benchmark the kernel. (default:100)
|
||||
-group_count group count. (default:8)
|
||||
-kbatch kbatch for SplitK (default:1)
|
||||
-json 0: No Json, 1: Dump Results in Json format (default:0)
|
||||
-jsonfile json file name to dump results (default:grouped_gemm.json)
|
||||
```
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
#include "json_dump.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
@@ -171,7 +172,9 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("warmup", "10", "number of iterations before benchmark the kernel.")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
|
||||
.insert("group_count", "8", "group count.")
|
||||
.insert("kbatch", "1", "kbatch for SplitK");
|
||||
.insert("kbatch", "1", "kbatch for SplitK")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "grouped_gemm.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
@@ -114,24 +113,6 @@ float invoke_gemm(int n_warmup,
|
||||
CDataType>(stream, group_count, kargs_ptr, splitk);
|
||||
}
|
||||
|
||||
std::string op_name{"Grouped Gemm"};
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
for(int j = 0; j < group_count; ++j)
|
||||
{
|
||||
flop += std::size_t(2) * args[j].M * args[j].N * args[j].K;
|
||||
|
||||
num_btype += sizeof(ADataType) * args[j].M * args[j].K +
|
||||
sizeof(BDataType) * args[j].K * args[j].N +
|
||||
sizeof(CDataType) * args[j].M * args[j].N;
|
||||
}
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
@@ -259,17 +240,34 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
{p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
|
||||
}
|
||||
|
||||
invoke_gemm<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
Persistent>(warmup, repeat, group_count, gemm_descs);
|
||||
float ave_time = invoke_gemm<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
Persistent>(warmup, repeat, group_count, gemm_descs);
|
||||
|
||||
std::string op_name{"Grouped Gemm"};
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
for(int j = 0; j < group_count; ++j)
|
||||
{
|
||||
flop += std::size_t(2) * gemm_descs[j].M * gemm_descs[j].N * gemm_descs[j].K;
|
||||
|
||||
num_btype += sizeof(ADataType) * gemm_descs[j].M * gemm_descs[j].K +
|
||||
sizeof(BDataType) * gemm_descs[j].K * gemm_descs[j].N +
|
||||
sizeof(CDataType) * gemm_descs[j].M * gemm_descs[j].N;
|
||||
}
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
@@ -304,6 +302,17 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
if(arg_parser.get_int("json") == 1)
|
||||
{
|
||||
dump_grouped_gemm_json_results<ALayout, BLayout, CLayout>(arg_parser.get_str("jsonfile"),
|
||||
op_name,
|
||||
group_count,
|
||||
pass,
|
||||
ave_time,
|
||||
tflops,
|
||||
gb_per_sec);
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
|
||||
@@ -16,20 +16,23 @@ This will result in an executable `build/bin/tile_example_flatmm_basic`
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-b batch size (default:1)
|
||||
-m m dimension (default:1024)
|
||||
-n n dimension (default:2048)
|
||||
-k k dimension (default:64)
|
||||
-a_layout Tensor A data layout (default: R)
|
||||
-b_layout Tensor B data layout (default: R)
|
||||
-c_layout Tensor C data layout (default: R)
|
||||
-m m dimension (default:256)
|
||||
-n n dimension (default:256)
|
||||
-k k dimension (default:128)
|
||||
-a_layout A tensor data layout - Row by default (default:R)
|
||||
-b_layout B tensor data layout - Row by default (default:C)
|
||||
-c_layout C tensor data layout - Row by default (default:R)
|
||||
-stride_a Tensor A stride (default:0)
|
||||
-stride_b Tensor B stride (default:0)
|
||||
-stride_c Tensor C stride (default:0)
|
||||
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2)
|
||||
-e Absolute error tolerance (default:1e-5)
|
||||
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:1)
|
||||
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
|
||||
-warmup number of iterations before benchmark the kernel (default:10)
|
||||
-warmup number of iterations before benchmark the kernel (default:50)
|
||||
-repeat number of iterations to benchmark the kernel (default:100)
|
||||
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
|
||||
-split_k splitK value (default:1)
|
||||
-init 0:random, 1:linear, 2:constant(1) (default:0)
|
||||
-warp_tile 0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only) (default:0)
|
||||
-json 0: No Json, 1: Dump Results in Json format (default:0)
|
||||
-jsonfile json file name to dump results (default:flatmm_basic.json)
|
||||
```
|
||||
|
||||
@@ -183,9 +183,10 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("warp_tile",
|
||||
"0",
|
||||
"0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)");
|
||||
.insert(
|
||||
"warp_tile", "0", "0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "flatmm_basic.json", "json file name to dump results");
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#pragma once
|
||||
#include <type_traits>
|
||||
|
||||
#include "json_dump.hpp"
|
||||
template <typename T>
|
||||
constexpr const char* DataTypeToString()
|
||||
{
|
||||
@@ -140,17 +140,6 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
CDEElementWise>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run Flatmm kernel with DataType = " << DataTypeToString<ADataType>()
|
||||
<< " M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A
|
||||
<< " StrideB =" << stride_B << " StrideC =" << stride_C << " : " << ave_time
|
||||
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
@@ -242,27 +231,38 @@ int run_flatmm_example_with_layouts(int argc,
|
||||
ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes());
|
||||
b_shuffle_dev_buf.ToDevice(b_shuffle_host.data());
|
||||
|
||||
invoke_flatmm<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout>(a_dev_buf,
|
||||
b_shuffle_dev_buf,
|
||||
c_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
float ave_time = invoke_flatmm<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout>(a_dev_buf,
|
||||
b_shuffle_dev_buf,
|
||||
c_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run Flatmm kernel with DataType = " << DataTypeToString<ADataType>()
|
||||
<< " M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A
|
||||
<< " StrideB =" << stride_B << " StrideC =" << stride_C << " : " << ave_time
|
||||
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
|
||||
|
||||
c_dev_buf.FromDevice(c_rslt_host.data());
|
||||
bool pass = true;
|
||||
@@ -350,5 +350,22 @@ int run_flatmm_example_with_layouts(int argc,
|
||||
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
if(arg_parser.get_int("json") == 1)
|
||||
{
|
||||
dump_flatmm_json_results(arg_parser.get_str("jsonfile"),
|
||||
DataTypeToString<ADataType>(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
pass,
|
||||
ave_time,
|
||||
tflops,
|
||||
gb_per_sec);
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
@@ -17,19 +17,21 @@ This will result in an executable `build/bin/tile_example_gemm_multi_d_fp16`
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-m M dimensions - (Default: 3840)
|
||||
-n N dimensions - (Default: 4096)
|
||||
-k K dimensions - (Default: 4096)
|
||||
-a_layout Tensor A layout (default:R)
|
||||
-b_layout Tensor B layout (default:C)
|
||||
-ds_layout Tensor D layout (default:R)
|
||||
-e_layout Tensor E layout (default:R)
|
||||
-stride_a Tensor A strides - (Default: 0)
|
||||
-stride_b Tensor B strides - (Default: 0)
|
||||
-stride_e Tensor C strides - (Default: 0)
|
||||
-stride_ds Tensor D strides - (Default: 0)
|
||||
-validate 0. No validation, 1. Validation on GPU. (Default: 1)
|
||||
-warmup Number of iterations before benchmark the kernel. (Default: 10)
|
||||
-repeat Number of iterations to benchmark the kernel. (Default: 100)
|
||||
-kbatch kbatch for SplitK. (Default 1)
|
||||
-m m dimension (default:3840)
|
||||
-n n dimension (default:4096)
|
||||
-k k dimension (default:4096)
|
||||
-a_layout A tensor data layout - Row by default (default:R)
|
||||
-b_layout B tensor data layout - Col by default (default:C)
|
||||
-ds_layout Ds tensor data layout - Row by default (default:R)
|
||||
-e_layout E tensor data layout - Row by default (default:R)
|
||||
-stride_a Tensor A stride (default:0)
|
||||
-stride_b Tensor B stride (default:0)
|
||||
-stride_ds Tensor Ds stride (default:0)
|
||||
-stride_e Tensor E stride (default:0)
|
||||
-v 0. No validation, 1. Validation on GPU (default:1)
|
||||
-warmup number of iterations before benchmark the kernel (default:50)
|
||||
-repeat number of iterations to benchmark the kernel (default:100)
|
||||
-kbatch kbatch for SplitK (default:1)
|
||||
-json 0: No Json, 1: Dump Results in Json format (default:0)
|
||||
-jsonfile json file name to dump results (default:cktile_gemm_multi_d_fp16.json)
|
||||
```
|
||||
|
||||
@@ -58,7 +58,9 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("v", "1", "0. No validation, 1. Validation on GPU")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("kbatch", "1", "kbatch for SplitK");
|
||||
.insert("kbatch", "1", "kbatch for SplitK")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "cktile_gemm_multi_d_fp16.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#pragma once
|
||||
#include <cstddef>
|
||||
#include "json_dump.hpp"
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
@@ -54,30 +55,6 @@ float invoke_gemm_multi_d(const void* a_m_k_dev_buf,
|
||||
CDEElementWise>(
|
||||
gemm_descs, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::string op_name{"Gemm Multiple-D"};
|
||||
static constexpr ck_tile::index_t NumDTensor = DsDataType::size();
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
|
||||
flop += std::size_t(2) * M * N * K;
|
||||
|
||||
ck_tile::static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
num_btype += sizeof(ck_tile::remove_cvref_t<std::tuple_element_t<i, DsDataType>>) * M * N;
|
||||
flop += sizeof(ck_tile::remove_cvref_t<std::tuple_element_t<i, DsDataType>>) * M * N;
|
||||
});
|
||||
|
||||
num_btype += sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run Gemm Multiple-D kernel with:\n";
|
||||
std::cout << "M =" << M << " N =" << N << " K =" << K << "\n";
|
||||
std::cout << "StrideA = " << StrideA << " StrideB = " << StrideB << " StrideE = " << StrideE
|
||||
<< "\n";
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< "\n";
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
@@ -159,29 +136,53 @@ int run_multiple_d_gemm_example_with_layouts(int argc,
|
||||
|
||||
std::array<ck_tile::index_t, DsDataType::size()> stridesDs = {StrideD0, StrideD1};
|
||||
|
||||
invoke_gemm_multi_d<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDElementWiseFn>(a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
ds_ptr_buf,
|
||||
e_m_n_dev_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
stridesDs,
|
||||
StrideE,
|
||||
n_warmup,
|
||||
n_repeat,
|
||||
k_batch);
|
||||
float ave_time = invoke_gemm_multi_d<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDElementWiseFn>(a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
ds_ptr_buf,
|
||||
e_m_n_dev_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
stridesDs,
|
||||
StrideE,
|
||||
n_warmup,
|
||||
n_repeat,
|
||||
k_batch);
|
||||
|
||||
std::string op_name{"Gemm Multiple-D"};
|
||||
static constexpr ck_tile::index_t NumDTensor = DsDataType::size();
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
|
||||
flop += std::size_t(2) * M * N * K;
|
||||
|
||||
ck_tile::static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
num_btype += sizeof(ck_tile::remove_cvref_t<std::tuple_element_t<i, DsDataType>>) * M * N;
|
||||
flop += sizeof(ck_tile::remove_cvref_t<std::tuple_element_t<i, DsDataType>>) * M * N;
|
||||
});
|
||||
|
||||
num_btype += sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run Gemm Multiple-D kernel with:\n";
|
||||
std::cout << "M =" << M << " N =" << N << " K =" << K << "\n";
|
||||
std::cout << "StrideA = " << StrideA << " StrideB = " << StrideB << " StrideE = " << StrideE
|
||||
<< "\n";
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< "\n";
|
||||
|
||||
e_m_n_dev_buf.FromDevice(e_m_n_device_result.data());
|
||||
|
||||
@@ -217,6 +218,24 @@ int run_multiple_d_gemm_example_with_layouts(int argc,
|
||||
<< std::endl;
|
||||
std::cout << "The CPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
if(arg_parser.get_int("json") == 1)
|
||||
{
|
||||
dump_gemm_multi_d_fp16_json_results(arg_parser.get_str("jsonfile"),
|
||||
op_name,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideD0,
|
||||
StrideD1,
|
||||
StrideE,
|
||||
pass,
|
||||
ave_time,
|
||||
tflops,
|
||||
gb_per_sec);
|
||||
}
|
||||
return pass;
|
||||
}
|
||||
|
||||
|
||||
@@ -120,7 +120,8 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)");
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
#include "ck_tile/host/reference/reference_elementwise.hpp"
|
||||
#include "json_dump.hpp"
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
@@ -15,7 +16,9 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("prec", "fp16", "precision")
|
||||
.insert("warmup", "10", "cold iter")
|
||||
.insert("repeat", "50", "hot iter");
|
||||
.insert("repeat", "50", "hot iter")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "elementwise.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -195,6 +198,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
y_validation, y_host, "Elementwise Add Error: Incorrect results!", 0.01, 0.01);
|
||||
}
|
||||
|
||||
if(arg_parser.get_int("json") == 1)
|
||||
{
|
||||
dump_elementwise_json_results(arg_parser.get_str("jsonfile"),
|
||||
arg_parser.get_str("prec"),
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
ave_time,
|
||||
0,
|
||||
0,
|
||||
"elementwise_add");
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
#include "ck_tile/host/reference/reference_elementwise.hpp"
|
||||
#include "json_dump.hpp"
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
@@ -16,7 +17,9 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("prec", "fp16", "precision")
|
||||
.insert("warmup", "10", "cold iter")
|
||||
.insert("repeat", "50", "hot iter");
|
||||
.insert("repeat", "50", "hot iter")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "elementwise_add_4d.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -140,6 +143,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
y_validation, y_host, "Elementwise Add Error: Incorrect results!", 0.01, 0.01);
|
||||
}
|
||||
|
||||
if(arg_parser.get_int("json") == 1)
|
||||
{
|
||||
dump_elementwise_json_results(arg_parser.get_str("jsonfile"),
|
||||
arg_parser.get_str("prec"),
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
ave_time,
|
||||
0,
|
||||
0,
|
||||
"elementwise_add_4d");
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
#include "ck_tile/host/reference/reference_transpose.hpp"
|
||||
#include "json_dump.hpp"
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
@@ -14,7 +15,9 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("prec", "fp16", "precision")
|
||||
.insert("warmup", "10", "cold iter")
|
||||
.insert("repeat", "50", "hot iter");
|
||||
.insert("repeat", "50", "hot iter")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "elementwise_transpose.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -137,6 +140,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
y_validation, y_host, "Transpose Error: Incorrect results!", 0.01, 0.01);
|
||||
}
|
||||
|
||||
if(arg_parser.get_int("json") == 1)
|
||||
{
|
||||
dump_elementwise_json_results(arg_parser.get_str("jsonfile"),
|
||||
arg_parser.get_str("prec"),
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
ave_time,
|
||||
0,
|
||||
0,
|
||||
"elementwise_transpose");
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
#include "ck_tile/host/reference/reference_elementwise.hpp"
|
||||
#include "json_dump.hpp"
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
@@ -15,7 +16,9 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("prec", "fp16", "precision")
|
||||
.insert("warmup", "10", "cold iter")
|
||||
.insert("repeat", "50", "hot iter");
|
||||
.insert("repeat", "50", "hot iter")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "elementwise_unary.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -127,6 +130,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
y_validation, y_host, "Elementwise Add Error: Incorrect results!", 0.01, 0.01);
|
||||
}
|
||||
|
||||
if(arg_parser.get_int("json") == 1)
|
||||
{
|
||||
dump_elementwise_json_results(arg_parser.get_str("jsonfile"),
|
||||
arg_parser.get_str("prec"),
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
ave_time,
|
||||
0,
|
||||
0,
|
||||
"elementwise_unary");
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
|
||||
#include "batched_transpose_example.hpp"
|
||||
|
||||
#include "json_dump.hpp"
|
||||
#if 0
|
||||
template <typename T>
|
||||
void dump_host_tensor_4d(const ck_tile::HostTensor<T>& x)
|
||||
@@ -103,6 +104,8 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("seed", "-1", "seed to be used, -1 means random every time")
|
||||
.insert("kname", "0", "t to 1 will print kernel name")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "batched_transpose.json", "json file name to dump results")
|
||||
.insert("pipeline", "0", "0: no LDS usage, 1: LDS-accelerated (gfx950)");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
@@ -236,6 +239,23 @@ bool run_batched_transpose(ck_tile::ArgParser args)
|
||||
"--------------------------------------------------------------------\n",
|
||||
rtn ? "y" : "n");
|
||||
fflush(stdout);
|
||||
|
||||
if(args.get_int("json") == 1)
|
||||
{
|
||||
dump_batched_transpose_json(args.get_str("jsonfile"),
|
||||
N,
|
||||
C,
|
||||
H,
|
||||
W,
|
||||
layout_in,
|
||||
layout_out,
|
||||
prec,
|
||||
ms,
|
||||
0,
|
||||
gb_per_sec,
|
||||
rtn);
|
||||
}
|
||||
|
||||
return rtn;
|
||||
}
|
||||
|
||||
|
||||
700
example/include/json_dump.hpp
Normal file
700
example/include/json_dump.hpp
Normal file
@@ -0,0 +1,700 @@
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wzero-as-null-pointer-constant"
|
||||
#include "rapidjson/writer.h"
|
||||
#include "rapidjson/stringbuffer.h"
|
||||
#include "rapidjson/document.h"
|
||||
#include "rapidjson/rapidjson.h"
|
||||
// #include <fstream>
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
#define START_JSON_DUMP_FILE(file_name) \
|
||||
std::string file_str(file_name); \
|
||||
std::ofstream file(file_str); \
|
||||
if(!file.is_open()) \
|
||||
{ \
|
||||
throw std::runtime_error("Could not open file: " + std::string(file_name)); \
|
||||
} \
|
||||
rapidjson::StringBuffer s; \
|
||||
rapidjson::Writer<rapidjson::StringBuffer> writer(s); \
|
||||
writer.StartObject();
|
||||
|
||||
#define END_JSON_DUMP_FILE() \
|
||||
writer.EndObject(); \
|
||||
file << s.GetString(); \
|
||||
file.close(); \
|
||||
std::cout << "Results written to " << file_str << " successfully" << std::endl;
|
||||
|
||||
#define ADD_KEY_VALUE(key, value) add_key_value_pair(writer, key, value);
|
||||
#define ADD_PERF_TO_JSON(_time, tflops, gbytes) add_perf_to_json(writer, _time, tflops, gbytes);
|
||||
|
||||
template <typename T>
|
||||
void add_key_value_pair(rapidjson::Writer<rapidjson::StringBuffer>& writer,
|
||||
const char* key,
|
||||
T value)
|
||||
{
|
||||
writer.Key(key);
|
||||
if constexpr(std::is_same<T, const char*>::value)
|
||||
{
|
||||
writer.String(value, static_cast<rapidjson::SizeType>(std::strlen(value)));
|
||||
}
|
||||
else if constexpr(std::is_same<T, std::string>::value)
|
||||
{
|
||||
writer.String(value.c_str(), static_cast<rapidjson::SizeType>(value.length()));
|
||||
}
|
||||
else if constexpr(std::is_floating_point<T>::value)
|
||||
{
|
||||
writer.Double(static_cast<double>(value));
|
||||
}
|
||||
else if constexpr(std::is_integral<T>::value)
|
||||
{
|
||||
writer.Int64(static_cast<int64_t>(value));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(std::is_same<T, const char*>::value || std::is_floating_point<T>::value ||
|
||||
std::is_integral<T>::value,
|
||||
"Unsupported type for JSON serialization");
|
||||
}
|
||||
}
|
||||
|
||||
static void add_perf_to_json(rapidjson::Writer<rapidjson::StringBuffer>& writer,
|
||||
float time,
|
||||
float tflops,
|
||||
float gbytes)
|
||||
{
|
||||
std::string roster("perf");
|
||||
writer.String(roster.c_str(), static_cast<rapidjson::SizeType>(roster.length()));
|
||||
|
||||
writer.StartArray();
|
||||
writer.StartObject();
|
||||
|
||||
add_key_value_pair(writer, "time", time);
|
||||
add_key_value_pair(writer, "tflops", tflops);
|
||||
add_key_value_pair(writer, "gbytes", gbytes);
|
||||
|
||||
writer.EndObject();
|
||||
writer.EndArray();
|
||||
}
|
||||
|
||||
// Helper traits to check for static member existence
|
||||
template <typename T, typename = void>
|
||||
struct has_warp_tile_members : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct has_warp_tile_members<
|
||||
T,
|
||||
std::void_t<decltype(T::M_Warp_Tile), decltype(T::N_Warp_Tile), decltype(T::K_Warp_Tile)>>
|
||||
: std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename GemmConfig,
|
||||
template <typename>
|
||||
typename DTypeTraits>
|
||||
void dump_gemm_json_results(const std::string& json_filename,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int stride_A,
|
||||
int stride_B,
|
||||
int stride_C,
|
||||
bool persistent,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "gemm_basic")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("M", M);
|
||||
ADD_KEY_VALUE("N", N);
|
||||
ADD_KEY_VALUE("K", K);
|
||||
ADD_KEY_VALUE("stride_A", stride_A);
|
||||
ADD_KEY_VALUE("stride_B", stride_B);
|
||||
ADD_KEY_VALUE("stride_C", stride_C);
|
||||
ADD_KEY_VALUE("A_layout", ALayout::name);
|
||||
ADD_KEY_VALUE("B_layout", BLayout::name);
|
||||
ADD_KEY_VALUE("C_layout", CLayout::name);
|
||||
using TraitsADataType = DTypeTraits<ADataType>;
|
||||
using TraitsBDataType = DTypeTraits<BDataType>;
|
||||
using TraitsCDataType = DTypeTraits<CDataType>;
|
||||
ADD_KEY_VALUE("A_type", TraitsADataType::name);
|
||||
ADD_KEY_VALUE("B_type", TraitsBDataType::name);
|
||||
ADD_KEY_VALUE("C_type", TraitsCDataType::name);
|
||||
ADD_KEY_VALUE("structured_sparsity", GemmConfig::UseStructuredSparsity ? "on" : "off");
|
||||
|
||||
if constexpr(has_warp_tile_members<GemmConfig>::value)
|
||||
{
|
||||
ADD_KEY_VALUE("warp_tile",
|
||||
std::to_string(GemmConfig::M_Warp_Tile) + "x" +
|
||||
std::to_string(GemmConfig::N_Warp_Tile) + "x" +
|
||||
std::to_string(GemmConfig::K_Warp_Tile));
|
||||
}
|
||||
ADD_KEY_VALUE("persistent", persistent ? "on" : "off");
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec);
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_batched_gemm_json_results(const std::string& json_filename,
|
||||
const std::string& op_name,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int stride_A,
|
||||
int stride_B,
|
||||
int stride_C,
|
||||
int batch_stride_A,
|
||||
int batch_stride_B,
|
||||
int batch_stride_C,
|
||||
int batch_count,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "batched_gemm_basic")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("op_name", op_name);
|
||||
ADD_KEY_VALUE("M", M);
|
||||
ADD_KEY_VALUE("N", N);
|
||||
ADD_KEY_VALUE("K", K);
|
||||
ADD_KEY_VALUE("stride_A", stride_A);
|
||||
ADD_KEY_VALUE("stride_B", stride_B);
|
||||
ADD_KEY_VALUE("stride_C", stride_C);
|
||||
ADD_KEY_VALUE("batch_stride_A", batch_stride_A);
|
||||
ADD_KEY_VALUE("batch_stride_B", batch_stride_B);
|
||||
ADD_KEY_VALUE("batch_stride_C", batch_stride_C);
|
||||
ADD_KEY_VALUE("batch_count", batch_count);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
void dump_grouped_gemm_json_results(const std::string& json_filename,
|
||||
const std::string& op_name,
|
||||
int group_count,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "grouped_gemm")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("op_name", op_name);
|
||||
ADD_KEY_VALUE("group_count", group_count);
|
||||
ADD_KEY_VALUE("A_layout", ALayout::name);
|
||||
ADD_KEY_VALUE("B_layout", BLayout::name);
|
||||
ADD_KEY_VALUE("C_layout", CLayout::name);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_flatmm_json_results(const std::string& json_filename,
|
||||
const std::string& datatype,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int stride_A,
|
||||
int stride_B,
|
||||
int stride_C,
|
||||
int kbatch,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "flatmm_basic")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("DataType", datatype);
|
||||
ADD_KEY_VALUE("M", M);
|
||||
ADD_KEY_VALUE("N", N);
|
||||
ADD_KEY_VALUE("K", K);
|
||||
ADD_KEY_VALUE("StrideA", stride_A);
|
||||
ADD_KEY_VALUE("StrideB", stride_B);
|
||||
ADD_KEY_VALUE("StrideC", stride_C);
|
||||
ADD_KEY_VALUE("kbatch", kbatch);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_gemm_multi_d_fp16_json_results(const std::string& json_filename,
|
||||
const std::string& op_name,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int StrideA,
|
||||
int StrideB,
|
||||
int StrideD0,
|
||||
int StrideD1,
|
||||
int StrideE,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "gemm_multi_d_fp16")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("op_name", op_name);
|
||||
ADD_KEY_VALUE("M", M);
|
||||
ADD_KEY_VALUE("N", N);
|
||||
ADD_KEY_VALUE("K", K);
|
||||
ADD_KEY_VALUE("StrideA", StrideA);
|
||||
ADD_KEY_VALUE("StrideB", StrideB);
|
||||
ADD_KEY_VALUE("StrideD0", StrideD0);
|
||||
ADD_KEY_VALUE("StrideD1", StrideD1);
|
||||
ADD_KEY_VALUE("StrideE", StrideE);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_elementwise_json_results(const std::string& json_filename,
|
||||
const std::string& prec,
|
||||
int grid_size,
|
||||
int block_size,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "elementwise")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("prec", prec);
|
||||
ADD_KEY_VALUE("grid_size", grid_size);
|
||||
ADD_KEY_VALUE("block_size", block_size);
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_layernorm2d_fwd_json_results(const std::string& json_filename,
|
||||
const std::string& prec_i,
|
||||
const std::string& prec_o,
|
||||
const std::string& prec_sm,
|
||||
const std::string& prec_sy,
|
||||
int m,
|
||||
int n,
|
||||
int x_stride,
|
||||
int xr_stride,
|
||||
int y_stride,
|
||||
int yr_stride,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "layernorm2d_fwd")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("prec_i", prec_i);
|
||||
ADD_KEY_VALUE("prec_o", prec_o);
|
||||
ADD_KEY_VALUE("prec_sm", prec_sm);
|
||||
ADD_KEY_VALUE("prec_sy", prec_sy);
|
||||
ADD_KEY_VALUE("m", m);
|
||||
ADD_KEY_VALUE("n", n);
|
||||
ADD_KEY_VALUE("x_stride", x_stride);
|
||||
ADD_KEY_VALUE("xr_stride", xr_stride);
|
||||
ADD_KEY_VALUE("y_stride", y_stride);
|
||||
ADD_KEY_VALUE("yr_stride", yr_stride);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
template <typename DataType, template <typename> typename DTypeTraits>
|
||||
void dump_reduce_json_results(const std::string& json_filename,
|
||||
int N,
|
||||
int C,
|
||||
int H,
|
||||
int W,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "reduce")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
using Traits = DTypeTraits<DataType>;
|
||||
ADD_KEY_VALUE("data_type", Traits::name);
|
||||
ADD_KEY_VALUE("N", N);
|
||||
ADD_KEY_VALUE("C", C);
|
||||
ADD_KEY_VALUE("H", H);
|
||||
ADD_KEY_VALUE("W", W);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_permute_json_results(const std::string& json_filename,
|
||||
const std::string& data_type,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflop,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "permute")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("data_type", data_type);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflop, gb_per_sec)
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_topk_softmax_json(const std::string& json_filename,
|
||||
const std::string& input_prec,
|
||||
const std::string& weight_prec,
|
||||
int tokens,
|
||||
int experts,
|
||||
int topk,
|
||||
int stride_input,
|
||||
int stride_output,
|
||||
float ave_time,
|
||||
float tflop,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "topk_softmax")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("input_prec", input_prec);
|
||||
ADD_KEY_VALUE("weight_prec", weight_prec);
|
||||
ADD_KEY_VALUE("tokens", tokens);
|
||||
ADD_KEY_VALUE("experts", experts);
|
||||
ADD_KEY_VALUE("topk", topk);
|
||||
ADD_KEY_VALUE("stride_input", stride_input);
|
||||
ADD_KEY_VALUE("stride_output", stride_output);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflop, gb_per_sec);
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_rmsnorm2d_fwd_json(const std::string& json_filename,
|
||||
const std::string& prec_str,
|
||||
int m,
|
||||
int n,
|
||||
int x_stride,
|
||||
int xr_stride,
|
||||
int y_stride,
|
||||
int yr_stride,
|
||||
int use_model_sensitive_rmsnorm,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "rmsnorm2d_fwd")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("prec", prec_str);
|
||||
ADD_KEY_VALUE("m", m);
|
||||
ADD_KEY_VALUE("n", n);
|
||||
ADD_KEY_VALUE("x_stride", x_stride);
|
||||
ADD_KEY_VALUE("xr_stride", xr_stride);
|
||||
ADD_KEY_VALUE("y_stride", y_stride);
|
||||
ADD_KEY_VALUE("yr_stride", yr_stride);
|
||||
ADD_KEY_VALUE("use_model_sensitive_rmsnorm", use_model_sensitive_rmsnorm);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec);
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_add_rmsnorm2d_rdquant_fwd_json(
|
||||
const std::string& json_filename,
|
||||
const std::string& input_data_type,
|
||||
const std::string& quantized_data_type,
|
||||
int m,
|
||||
int n,
|
||||
int stride,
|
||||
float epsilon,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "add_rmsnorm2d_rdquant_fwd")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("input_data_type", input_data_type);
|
||||
ADD_KEY_VALUE("quantized_data_type", quantized_data_type);
|
||||
ADD_KEY_VALUE("m", m);
|
||||
ADD_KEY_VALUE("n", n);
|
||||
ADD_KEY_VALUE("stride", stride);
|
||||
ADD_KEY_VALUE("epsilon", epsilon);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec);
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_smoothquant_json(const std::string& json_filename,
|
||||
const std::string& prec_str,
|
||||
int m,
|
||||
int n,
|
||||
int x_stride,
|
||||
int y_stride,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "smoothquant")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("prec", prec_str);
|
||||
ADD_KEY_VALUE("m", m);
|
||||
ADD_KEY_VALUE("n", n);
|
||||
ADD_KEY_VALUE("x_stride", x_stride);
|
||||
ADD_KEY_VALUE("y_stride", y_stride);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec);
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_moe_sorting_json(const std::string& json_filename,
|
||||
const std::string& index_prec,
|
||||
const std::string& weight_prec,
|
||||
const std::string& workspace_size,
|
||||
int dispatch_policy,
|
||||
int tokens,
|
||||
int num_experts,
|
||||
int topk,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "moe_sorting")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("index_prec", index_prec);
|
||||
ADD_KEY_VALUE("weight_prec", weight_prec);
|
||||
ADD_KEY_VALUE("workspace_size", workspace_size);
|
||||
ADD_KEY_VALUE("dispatch_policy", dispatch_policy);
|
||||
ADD_KEY_VALUE("tokens", tokens);
|
||||
ADD_KEY_VALUE("num_experts", num_experts);
|
||||
ADD_KEY_VALUE("topk", topk);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_batched_transpose_json(const std::string& json_filename,
|
||||
int N,
|
||||
int C,
|
||||
int H,
|
||||
int W,
|
||||
const std::string& layout_in,
|
||||
const std::string& layout_out,
|
||||
const std::string& prec,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
bool pass,
|
||||
const std::string& kernel_name = "batched_transpose")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("N", N);
|
||||
ADD_KEY_VALUE("C", C);
|
||||
ADD_KEY_VALUE("H", H);
|
||||
ADD_KEY_VALUE("W", W);
|
||||
ADD_KEY_VALUE("LayoutIn", layout_in);
|
||||
ADD_KEY_VALUE("LayoutOut", layout_out);
|
||||
ADD_KEY_VALUE("Precision", prec);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_moe_smoothquant_json(const std::string& json_filename,
|
||||
const std::string& prec_i,
|
||||
const std::string& prec_o,
|
||||
int tokens,
|
||||
int hidden_size,
|
||||
int stride,
|
||||
int experts,
|
||||
int topk,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "moe_smoothquant")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("prec_i", prec_i);
|
||||
ADD_KEY_VALUE("prec_o", prec_o);
|
||||
ADD_KEY_VALUE("tokens", tokens);
|
||||
ADD_KEY_VALUE("hidden_size", hidden_size);
|
||||
ADD_KEY_VALUE("stride", stride);
|
||||
ADD_KEY_VALUE("experts", experts);
|
||||
ADD_KEY_VALUE("topk", topk);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_fused_moe_json(const std::string& json_filename,
|
||||
const std::string& api_str,
|
||||
const std::string& prec_str,
|
||||
int tokens,
|
||||
bool is_local_token,
|
||||
int local_tokens,
|
||||
int experts,
|
||||
int topk,
|
||||
int hidden_size,
|
||||
int intermediate_size,
|
||||
int stride,
|
||||
int block_m,
|
||||
int activation,
|
||||
bool gate_only,
|
||||
bool fused_quant,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float tb_per_sec,
|
||||
const std::string& kernel_name = "fused_moe")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("api", api_str);
|
||||
ADD_KEY_VALUE("prec", prec_str);
|
||||
ADD_KEY_VALUE("tokens", tokens);
|
||||
if(is_local_token)
|
||||
{
|
||||
ADD_KEY_VALUE("local_tokens", local_tokens);
|
||||
}
|
||||
ADD_KEY_VALUE("experts", experts);
|
||||
ADD_KEY_VALUE("topk", topk);
|
||||
ADD_KEY_VALUE("hidden_size", hidden_size);
|
||||
ADD_KEY_VALUE("intermediate_size", intermediate_size);
|
||||
ADD_KEY_VALUE("stride", stride);
|
||||
ADD_KEY_VALUE("block_m", block_m);
|
||||
ADD_KEY_VALUE("activation", activation);
|
||||
ADD_KEY_VALUE("gate_only", gate_only);
|
||||
ADD_KEY_VALUE("fused_quant", fused_quant);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, (tb_per_sec * 1024.0f))
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_fmha_fwd_json_results(const std::string& json_filename,
|
||||
const std::string& prec,
|
||||
const std::string& mode,
|
||||
const std::string& io_layout,
|
||||
int batch,
|
||||
int nhead,
|
||||
int nhead_k,
|
||||
int seqlen_qs,
|
||||
int seqlen_ks,
|
||||
int seqlen_kpads,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
float scale_s,
|
||||
float p_drop,
|
||||
bool lse,
|
||||
bool squant,
|
||||
const std::string& bais,
|
||||
const std::string& vlayout,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "fmha_fwd")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("prec", prec);
|
||||
ADD_KEY_VALUE("mode", mode);
|
||||
ADD_KEY_VALUE("io_layout", io_layout);
|
||||
ADD_KEY_VALUE("batch", batch);
|
||||
ADD_KEY_VALUE("nhead", nhead);
|
||||
ADD_KEY_VALUE("nhead_k", nhead_k);
|
||||
ADD_KEY_VALUE("seqlen_q", seqlen_qs);
|
||||
ADD_KEY_VALUE("seqlen_k", seqlen_ks);
|
||||
ADD_KEY_VALUE("seqlen_kpads", seqlen_kpads);
|
||||
ADD_KEY_VALUE("hdim_q", hdim_q);
|
||||
ADD_KEY_VALUE("hdim_v", hdim_v);
|
||||
ADD_KEY_VALUE("scale_s", scale_s);
|
||||
ADD_KEY_VALUE("p_drop", p_drop);
|
||||
ADD_KEY_VALUE("lse", lse);
|
||||
ADD_KEY_VALUE("squant", squant);
|
||||
ADD_KEY_VALUE("bias", bais);
|
||||
ADD_KEY_VALUE("vlayout", vlayout);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
|
||||
void dump_fmha_bwd_json_results(const std::string& json_filename,
|
||||
const std::string& data_type,
|
||||
const std::string& mode,
|
||||
const std::string& i_perm,
|
||||
const std::string& o_perm,
|
||||
int batch,
|
||||
int nhead,
|
||||
int nhead_k,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
float scale,
|
||||
const std::string& bias,
|
||||
bool use_dbias,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
bool deterministic,
|
||||
const std::string& mask,
|
||||
int mask_left,
|
||||
int mask_right,
|
||||
int workspace_size,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
float tflops,
|
||||
float gb_per_sec,
|
||||
const std::string& kernel_name = "fmha_bwd")
|
||||
{
|
||||
START_JSON_DUMP_FILE(json_filename);
|
||||
ADD_KEY_VALUE("name", kernel_name);
|
||||
ADD_KEY_VALUE("prec", data_type);
|
||||
ADD_KEY_VALUE("mode", mode);
|
||||
ADD_KEY_VALUE("i_perm", i_perm);
|
||||
ADD_KEY_VALUE("o_perm", o_perm);
|
||||
ADD_KEY_VALUE("batch", batch);
|
||||
ADD_KEY_VALUE("nhead", nhead);
|
||||
ADD_KEY_VALUE("nhead_k", nhead_k);
|
||||
ADD_KEY_VALUE("seqlen_q", seqlen_q);
|
||||
ADD_KEY_VALUE("seqlen_k", seqlen_k);
|
||||
ADD_KEY_VALUE("hdim_q", hdim_q);
|
||||
ADD_KEY_VALUE("hdim_v", hdim_v);
|
||||
ADD_KEY_VALUE("scale", scale);
|
||||
ADD_KEY_VALUE("bias", bias);
|
||||
ADD_KEY_VALUE("use_dbias", use_dbias);
|
||||
ADD_KEY_VALUE("p_drop", p_drop);
|
||||
ADD_KEY_VALUE("s_randval", s_randval);
|
||||
ADD_KEY_VALUE("deterministic", deterministic ? "true" : "false");
|
||||
ADD_KEY_VALUE("mask", mask);
|
||||
ADD_KEY_VALUE("mask_left", mask_left);
|
||||
ADD_KEY_VALUE("mask_right", mask_right);
|
||||
ADD_KEY_VALUE("workspace_size", workspace_size);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
|
||||
END_JSON_DUMP_FILE();
|
||||
}
|
||||
Reference in New Issue
Block a user