Merge commit 'ec006bb8e008caaf5cb95c1ca5037dc6ac026bde' into develop

This commit is contained in:
assistant-librarian[bot]
2025-09-10 03:20:14 +00:00
parent b2b0389f76
commit 535686e647
27 changed files with 2429 additions and 865 deletions

View File

@@ -399,9 +399,9 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
}
mantissa += (1u << SrcT_mant); // Add the implicit 1 into mantissa
}
// The value is smaller than min f8 denormal and results in zero (the early exit also prevents
// The value is <= than min f8 denormal/2 and results in zero (the early exit also prevents
// an undefined behavior of bit shifts >= type width).
if(exponent_diff > DstT_mant)
if(exponent_diff > DstT_mant + 1)
{
return is_fnuz ? 0 : (sign << (DstT_exp + DstT_mant));
}

View File

@@ -18,6 +18,7 @@
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/ranges.hpp"
#include "ck_tile/host/reference/reference_batched_dropout.hpp"
#include "ck_tile/host/reference/reference_batched_dropout_randval.hpp"
#include "ck_tile/host/reference/reference_batched_elementwise.hpp"
#include "ck_tile/host/reference/reference_batched_gemm.hpp"
#include "ck_tile/host/reference/reference_batched_masking.hpp"

View File

@@ -0,0 +1,70 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace ck_tile {
template <typename RandValOutputDataType>
CK_TILE_HOST void
reference_batched_dropout_randval(HostTensor<RandValOutputDataType>& randval_b_m_n,
index_t batch,
uint64_t drop_seed,
uint64_t drop_offset)
{
const index_t nhead = randval_b_m_n.mDesc.get_lengths()[0];
const index_t real_seqlen_q = randval_b_m_n.mDesc.get_lengths()[1];
const index_t real_seqlen_k = randval_b_m_n.mDesc.get_lengths()[2];
static_assert(std::is_same_v<RandValOutputDataType, uint8_t>);
// BlockDropout generates random numbers by 32x32 tiles. Even when warp gemm 16x16 is used, the
// order of values in the bigger 32x32 tile must be the same because fwd and bwd may use
// different warp gemms (16x16 or 32x32).
// To compute 32x32 tiles, WarpGemmMfmaF16F16F32M32N32K16SwizzleA is used. It is
// WarpGemmAttributeMfmaImplF16F16F32M32N32K8 with SFactor = 2 (swizzling factor).
// Matrix element to register mapping for WarpGemmAttributeMfmaImplF16F16F32M32N32K8:
// C i: (8 * floor(GPR_num / 4) % 32) + 4 * floor(lane / 32) + (GPR_num % 4)
// C j: (lane % 32)
// With SFactor = 2 it becomes:
// C i: (16 * floor(GPR_num / 8) % 32) + 8 * floor(lane / 32) + (GPR_num % 8)
// C j: (lane % 32)
constexpr index_t max_warp_size = 64;
constexpr index_t warp_gemm_mn = 32;
const index_t rows = integer_divide_ceil(real_seqlen_q, warp_gemm_mn);
const index_t cols = integer_divide_ceil(real_seqlen_k, warp_gemm_mn);
auto f = [&](index_t i_h, index_t row, index_t col) {
uint2 rowcol = make_uint2(row, col);
for(index_t lane = 0; lane < max_warp_size; lane++)
{
philox ph(drop_seed, drop_offset + (batch * nhead + i_h) * max_warp_size + lane);
uint8_t random_uint8_t[16];
ph.get_random_16x8(random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol));
for(auto r = 0; r < 16; r++)
{
index_t i = (16 * (r / 8) % 32) + 8 * (lane / 32) + (r % 8);
index_t j = (lane % 32);
index_t m = row * warp_gemm_mn + i;
index_t n = col * warp_gemm_mn + j;
if(m < real_seqlen_q && n < real_seqlen_k)
{
randval_b_m_n(i_h, m, n) = random_uint8_t[r];
}
}
}
};
make_ParallelTensorFunctor(f, nhead, rows, cols)(std::thread::hardware_concurrency());
}
} // namespace ck_tile

View File

@@ -611,7 +611,7 @@ void dump_fmha_fwd_json_results(const std::string& json_filename,
float p_drop,
bool lse,
bool squant,
const std::string& bais,
const std::string& bias,
const std::string& vlayout,
bool pass,
float ave_time,
@@ -636,7 +636,7 @@ void dump_fmha_fwd_json_results(const std::string& json_filename,
ADD_KEY_VALUE("p_drop", p_drop);
ADD_KEY_VALUE("lse", lse);
ADD_KEY_VALUE("squant", squant);
ADD_KEY_VALUE("bias", bais);
ADD_KEY_VALUE("bias", bias);
ADD_KEY_VALUE("vlayout", vlayout);
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)