mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 06:44:36 +00:00
Merge commit 'ec006bb8e008caaf5cb95c1ca5037dc6ac026bde' into develop
This commit is contained in:
@@ -399,9 +399,9 @@ CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
|
||||
}
|
||||
mantissa += (1u << SrcT_mant); // Add the implicit 1 into mantissa
|
||||
}
|
||||
// The value is smaller than min f8 denormal and results in zero (the early exit also prevents
|
||||
// The value is <= than min f8 denormal/2 and results in zero (the early exit also prevents
|
||||
// an undefined behavior of bit shifts >= type width).
|
||||
if(exponent_diff > DstT_mant)
|
||||
if(exponent_diff > DstT_mant + 1)
|
||||
{
|
||||
return is_fnuz ? 0 : (sign << (DstT_exp + DstT_mant));
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/host/ranges.hpp"
|
||||
#include "ck_tile/host/reference/reference_batched_dropout.hpp"
|
||||
#include "ck_tile/host/reference/reference_batched_dropout_randval.hpp"
|
||||
#include "ck_tile/host/reference/reference_batched_elementwise.hpp"
|
||||
#include "ck_tile/host/reference/reference_batched_gemm.hpp"
|
||||
#include "ck_tile/host/reference/reference_batched_masking.hpp"
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include <thread>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename RandValOutputDataType>
|
||||
CK_TILE_HOST void
|
||||
reference_batched_dropout_randval(HostTensor<RandValOutputDataType>& randval_b_m_n,
|
||||
index_t batch,
|
||||
uint64_t drop_seed,
|
||||
uint64_t drop_offset)
|
||||
{
|
||||
const index_t nhead = randval_b_m_n.mDesc.get_lengths()[0];
|
||||
const index_t real_seqlen_q = randval_b_m_n.mDesc.get_lengths()[1];
|
||||
const index_t real_seqlen_k = randval_b_m_n.mDesc.get_lengths()[2];
|
||||
|
||||
static_assert(std::is_same_v<RandValOutputDataType, uint8_t>);
|
||||
|
||||
// BlockDropout generates random numbers by 32x32 tiles. Even when warp gemm 16x16 is used, the
|
||||
// order of values in the bigger 32x32 tile must be the same because fwd and bwd may use
|
||||
// different warp gemms (16x16 or 32x32).
|
||||
// To compute 32x32 tiles, WarpGemmMfmaF16F16F32M32N32K16SwizzleA is used. It is
|
||||
// WarpGemmAttributeMfmaImplF16F16F32M32N32K8 with SFactor = 2 (swizzling factor).
|
||||
// Matrix element to register mapping for WarpGemmAttributeMfmaImplF16F16F32M32N32K8:
|
||||
// C i: (8 * floor(GPR_num / 4) % 32) + 4 * floor(lane / 32) + (GPR_num % 4)
|
||||
// C j: (lane % 32)
|
||||
// With SFactor = 2 it becomes:
|
||||
// C i: (16 * floor(GPR_num / 8) % 32) + 8 * floor(lane / 32) + (GPR_num % 8)
|
||||
// C j: (lane % 32)
|
||||
|
||||
constexpr index_t max_warp_size = 64;
|
||||
constexpr index_t warp_gemm_mn = 32;
|
||||
|
||||
const index_t rows = integer_divide_ceil(real_seqlen_q, warp_gemm_mn);
|
||||
const index_t cols = integer_divide_ceil(real_seqlen_k, warp_gemm_mn);
|
||||
|
||||
auto f = [&](index_t i_h, index_t row, index_t col) {
|
||||
uint2 rowcol = make_uint2(row, col);
|
||||
for(index_t lane = 0; lane < max_warp_size; lane++)
|
||||
{
|
||||
philox ph(drop_seed, drop_offset + (batch * nhead + i_h) * max_warp_size + lane);
|
||||
|
||||
uint8_t random_uint8_t[16];
|
||||
ph.get_random_16x8(random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol));
|
||||
|
||||
for(auto r = 0; r < 16; r++)
|
||||
{
|
||||
index_t i = (16 * (r / 8) % 32) + 8 * (lane / 32) + (r % 8);
|
||||
index_t j = (lane % 32);
|
||||
index_t m = row * warp_gemm_mn + i;
|
||||
index_t n = col * warp_gemm_mn + j;
|
||||
|
||||
if(m < real_seqlen_q && n < real_seqlen_k)
|
||||
{
|
||||
randval_b_m_n(i_h, m, n) = random_uint8_t[r];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f, nhead, rows, cols)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -611,7 +611,7 @@ void dump_fmha_fwd_json_results(const std::string& json_filename,
|
||||
float p_drop,
|
||||
bool lse,
|
||||
bool squant,
|
||||
const std::string& bais,
|
||||
const std::string& bias,
|
||||
const std::string& vlayout,
|
||||
bool pass,
|
||||
float ave_time,
|
||||
@@ -636,7 +636,7 @@ void dump_fmha_fwd_json_results(const std::string& json_filename,
|
||||
ADD_KEY_VALUE("p_drop", p_drop);
|
||||
ADD_KEY_VALUE("lse", lse);
|
||||
ADD_KEY_VALUE("squant", squant);
|
||||
ADD_KEY_VALUE("bias", bais);
|
||||
ADD_KEY_VALUE("bias", bias);
|
||||
ADD_KEY_VALUE("vlayout", vlayout);
|
||||
ADD_KEY_VALUE("verification", pass ? "pass" : "fail");
|
||||
ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec)
|
||||
|
||||
Reference in New Issue
Block a user