mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
Add V4: remove gemm pipeline, combine gemm/normalization
This commit is contained in:
@@ -15,3 +15,6 @@ add_executable(${TARGET_NAME} mhc_v3_two_block_test.cpp)
|
||||
|
||||
set(TARGET_NAME example_mhc_v3_bf16_benchmark)
|
||||
add_executable(${TARGET_NAME} mhc_v3_bf16_benchmark.cpp)
|
||||
|
||||
set(TARGET_NAME example_mhc_v4_bf16_benchmark)
|
||||
add_executable(${TARGET_NAME} mhc_v4_bf16_benchmark.cpp)
|
||||
|
||||
216
example/ck_tile/42_mhc/mhc_v4_bf16_benchmark.cpp
Normal file
216
example/ck_tile/42_mhc/mhc_v4_bf16_benchmark.cpp
Normal file
@@ -0,0 +1,216 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include <tuple>
|
||||
#include <iostream>
|
||||
#include <cstring>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/mhc.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/host/reference/reference_mhc.hpp"
|
||||
#include "ck_tile/host/check_err.hpp"
|
||||
|
||||
// Parse command-line arguments for MHC benchmark
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("B", "1024", "Batch size")
|
||||
.insert("n", "4", "Expansion factor (number of streams)")
|
||||
.insert("C", "4096", "Channels per stream")
|
||||
.insert("v", "1", "CPU validation (0=no, 1=yes)")
|
||||
.insert("warmup", "5", "Number of warmup iterations")
|
||||
.insert("repeat", "20", "Number of benchmark iterations")
|
||||
.insert("r", "2.0", "Norm scaling factor")
|
||||
.insert("alpha_pre", "1.5", "Alpha for pre-activation")
|
||||
.insert("alpha_post", "2.5", "Alpha for post-activation")
|
||||
.insert("alpha_res", "3.5", "Alpha for residual")
|
||||
.insert("bias", "1.5", "Bias value");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename XDataType,
|
||||
typename PhiDataType,
|
||||
typename YDataType,
|
||||
typename ComputeDataType,
|
||||
typename ActivationFunc = ck_tile::element_wise::Sigmoid>
|
||||
bool run_mhc_benchmark(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
const int B = arg_parser.get_int("B");
|
||||
const int n = arg_parser.get_int("n");
|
||||
const int C = arg_parser.get_int("C");
|
||||
|
||||
const int nC = n * C;
|
||||
const int output_dim = 2 * n + n * n;
|
||||
|
||||
const int do_validation = arg_parser.get_int("v");
|
||||
const int warmup = arg_parser.get_int("warmup");
|
||||
const int repeat = arg_parser.get_int("repeat");
|
||||
|
||||
const float r = arg_parser.get_float("r");
|
||||
const float alpha_pre = arg_parser.get_float("alpha_pre");
|
||||
const float alpha_post = arg_parser.get_float("alpha_post");
|
||||
const float alpha_res = arg_parser.get_float("alpha_res");
|
||||
const float bias = arg_parser.get_float("bias");
|
||||
|
||||
std::cout << "\n========================================" << std::endl;
|
||||
std::cout << "MHC Kernel V4 Benchmark (BF16)" << std::endl;
|
||||
std::cout << "========================================" << std::endl;
|
||||
std::cout << "Configuration:" << std::endl;
|
||||
std::cout << " Batch size (B): " << B << std::endl;
|
||||
std::cout << " Expansion factor (n): " << n << std::endl;
|
||||
std::cout << " Channels per stream (C): " << C << std::endl;
|
||||
std::cout << " Input dimension (nC): " << nC << std::endl;
|
||||
std::cout << " Output dimension (2n+n^2): " << output_dim << std::endl;
|
||||
std::cout << " Data types: X=" << typeid(XDataType).name()
|
||||
<< ", Phi=" << typeid(PhiDataType).name() << ", Y=" << typeid(YDataType).name()
|
||||
<< ", Compute=" << typeid(ComputeDataType).name() << std::endl;
|
||||
std::cout << " Warmup iterations: " << warmup << std::endl;
|
||||
std::cout << " Benchmark iterations: " << repeat << std::endl;
|
||||
std::cout << "========================================" << std::endl;
|
||||
|
||||
// Allocate host tensors
|
||||
ck_tile::HostTensor<XDataType> h_x({B, nC});
|
||||
ck_tile::HostTensor<PhiDataType> h_phi({nC, output_dim});
|
||||
ck_tile::HostTensor<YDataType> h_output({B, output_dim});
|
||||
|
||||
// Initialize with random data
|
||||
ck_tile::FillUniformDistribution<XDataType>{-1.0f, 1.0f}(h_x);
|
||||
ck_tile::FillUniformDistribution<PhiDataType>{-0.5f, 0.5f}(h_phi);
|
||||
h_output.SetZero();
|
||||
|
||||
// Allocate device memory
|
||||
ck_tile::DeviceMem d_x_mem(h_x.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d_phi_mem(h_phi.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d_output_mem(h_output.get_element_space_size_in_bytes());
|
||||
|
||||
// Copy data to device
|
||||
d_x_mem.ToDevice(h_x.data());
|
||||
d_phi_mem.ToDevice(h_phi.data());
|
||||
d_output_mem.ToDevice(h_output.data());
|
||||
|
||||
// Define block shape - 64 threads (1 warp) to match BlockGemmShape configuration
|
||||
// This matches a 16x16 block tile with 1 warp (1x1 warp layout)
|
||||
using BlockShape = ck_tile::Generic2dBlockShape<ck_tile::sequence<1, 64>,
|
||||
ck_tile::sequence<1, 64>,
|
||||
ck_tile::sequence<1, 1>>;
|
||||
|
||||
using Problem = ck_tile::MHCProblem<XDataType, ComputeDataType, YDataType, BlockShape>;
|
||||
|
||||
// V4 kernel - optimized with single-pass data loading
|
||||
using KernelV4 = ck_tile::MHCKernelV4<Problem, ck_tile::MHCDefaultPolicy, ActivationFunc>;
|
||||
|
||||
const ck_tile::index_t kBlockSize = KernelV4::BlockSize();
|
||||
|
||||
// 2D grid: (batch / kMTile) × (output_dim / kNTile)
|
||||
auto grid_size = KernelV4::GetGridSize(B, output_dim);
|
||||
const ck_tile::index_t kGridSize =
|
||||
grid_size.at(ck_tile::number<0>{}) * grid_size.at(ck_tile::number<1>{});
|
||||
|
||||
std::cout << "\nKernel Configuration:" << std::endl;
|
||||
std::cout << " Grid: " << grid_size.at(ck_tile::number<0>{}) << " × "
|
||||
<< grid_size.at(ck_tile::number<1>{}) << " = " << kGridSize << " blocks" << std::endl;
|
||||
std::cout << " Block size: " << kBlockSize << " threads" << std::endl;
|
||||
std::cout << " Shared memory: " << KernelV4::GetSmemSize() << " bytes" << std::endl;
|
||||
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
// Launch kernel with timing
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
|
||||
ck_tile::make_kernel<kBlockPerCu>(KernelV4{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
KernelV4::GetSmemSize(),
|
||||
static_cast<XDataType*>(d_x_mem.GetDeviceBuffer()),
|
||||
static_cast<PhiDataType*>(d_phi_mem.GetDeviceBuffer()),
|
||||
static_cast<YDataType*>(d_output_mem.GetDeviceBuffer()),
|
||||
B,
|
||||
nC,
|
||||
output_dim,
|
||||
n,
|
||||
r,
|
||||
alpha_pre,
|
||||
alpha_post,
|
||||
alpha_res,
|
||||
bias));
|
||||
|
||||
// Calculate performance metrics
|
||||
std::size_t num_bytes = sizeof(XDataType) * B * nC + // Input x
|
||||
sizeof(PhiDataType) * nC * output_dim + // Weights phi
|
||||
sizeof(YDataType) * B * output_dim; // Output
|
||||
|
||||
float gb_per_sec = num_bytes / 1.E6 / ave_time;
|
||||
|
||||
// Calculate FLOPs: B * output_dim * (2*nC - 1) for GEMM + additional ops
|
||||
std::size_t num_flops = static_cast<std::size_t>(B) * output_dim * (2 * nC);
|
||||
float tflops = num_flops / 1.E9 / ave_time;
|
||||
|
||||
std::cout << "\n========================================" << std::endl;
|
||||
std::cout << "Performance Results:" << std::endl;
|
||||
std::cout << " Average time: " << ave_time << " ms" << std::endl;
|
||||
std::cout << " Bandwidth: " << gb_per_sec << " GB/s" << std::endl;
|
||||
std::cout << " Throughput: " << tflops << " TFLOPS" << std::endl;
|
||||
std::cout << "========================================" << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
std::cout << "\nRunning validation..." << std::endl;
|
||||
|
||||
d_output_mem.FromDevice(h_output.data());
|
||||
|
||||
// Compute reference
|
||||
ck_tile::HostTensor<YDataType> h_output_ref({B, output_dim});
|
||||
h_output_ref.SetZero();
|
||||
|
||||
ck_tile::reference_mhc<XDataType, PhiDataType, YDataType, ComputeDataType, ActivationFunc>(
|
||||
h_x,
|
||||
h_phi,
|
||||
h_output_ref,
|
||||
n,
|
||||
C,
|
||||
r,
|
||||
alpha_pre,
|
||||
alpha_post,
|
||||
alpha_res,
|
||||
bias,
|
||||
ActivationFunc{});
|
||||
|
||||
// Validate with appropriate tolerance for bf16
|
||||
float rtol = std::is_same_v<XDataType, ck_tile::bf16_t> ? 1e-2f : 1e-3f;
|
||||
float atol = std::is_same_v<XDataType, ck_tile::bf16_t> ? 1e-2f : 1e-3f;
|
||||
|
||||
pass = ck_tile::check_err(
|
||||
h_output, h_output_ref, "Error: MHC V4 kernel output mismatch!", rtol, atol);
|
||||
|
||||
std::cout << "Validation: " << (pass ? "PASS" : "FAIL") << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
std::cout << "Failed to parse arguments!" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Run with BF16 inputs, float output and compute
|
||||
bool pass = run_mhc_benchmark<ck_tile::bf16_t, // XDataType
|
||||
ck_tile::bf16_t, // PhiDataType
|
||||
float, // YDataType
|
||||
float, // ComputeDataType
|
||||
ck_tile::element_wise::Sigmoid>(arg_parser);
|
||||
|
||||
return pass ? 0 : -2;
|
||||
}
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/ops/mhc/kernel/mhc_kernel_tile.hpp"
|
||||
#include "ck_tile/ops/mhc/kernel/mhc_kernel_tile_v2.hpp"
|
||||
#include "ck_tile/ops/mhc/kernel/mhc_kernel_tile_v3.hpp"
|
||||
#include "ck_tile/ops/mhc/kernel/mhc_kernel_tile_v4.hpp"
|
||||
#include "ck_tile/ops/mhc/pipeline/mhc_default_policy.hpp"
|
||||
#include "ck_tile/ops/mhc/pipeline/mhc_gemm_shape.hpp"
|
||||
#include "ck_tile/ops/mhc/pipeline/mhc_problem.hpp"
|
||||
|
||||
@@ -181,11 +181,11 @@ struct MHCKernelV3
|
||||
auto phi_identity_func = [](auto& e, const PhiDataType& phi_val) { e = phi_val; };
|
||||
|
||||
auto result_tile = gemm_pipeline(make_tuple(x_dram_window),
|
||||
phi_identity_func,
|
||||
make_tuple(phi_dram_window),
|
||||
phi_identity_func,
|
||||
num_k_loops,
|
||||
smem);
|
||||
phi_identity_func,
|
||||
make_tuple(phi_dram_window),
|
||||
phi_identity_func,
|
||||
num_k_loops,
|
||||
smem);
|
||||
|
||||
// Apply normalization and activation in post-processing
|
||||
// Now we divide by norm AFTER the GEMM, which means:
|
||||
@@ -222,7 +222,8 @@ struct MHCKernelV3
|
||||
{
|
||||
ComputeDataType activated_value;
|
||||
Activation{}(activated_value, value);
|
||||
result_tile(i_j_idx) = alpha_post * inv_norm * 2.0f * activated_value + bias;
|
||||
result_tile(i_j_idx) =
|
||||
alpha_post * inv_norm * 2.0f * activated_value + bias;
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
255
include/ck_tile/ops/mhc/kernel/mhc_kernel_tile_v4.hpp
Normal file
255
include/ck_tile/ops/mhc/kernel/mhc_kernel_tile_v4.hpp
Normal file
@@ -0,0 +1,255 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/mhc/pipeline/mhc_problem.hpp"
|
||||
#include "ck_tile/ops/mhc/pipeline/mhc_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
// Manifold Constrained Hyper Connection Kernel V4:
|
||||
// =====================================================================
|
||||
// Optimizations implemented:
|
||||
// - Remove GEMM pipeline to avoid redundant global memory reads
|
||||
// - Use BlockGemm directly with manual LDS management (like v2)
|
||||
// - Compute normalization incrementally during GEMM loop
|
||||
// - Single pass through input data: load once, compute norm and GEMM together
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_,
|
||||
typename Policy_ = MHCDefaultPolicy,
|
||||
typename Activation_ = element_wise::Sigmoid>
|
||||
struct MHCKernelV4
|
||||
{
|
||||
using Activation = ck_tile::remove_cvref_t<Activation_>;
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
|
||||
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
|
||||
using PhiDataType = ck_tile::remove_cvref_t<typename Problem::PhiDataType>;
|
||||
|
||||
// Automatically derive tile sizes from BlockGemmShape (single source of truth!)
|
||||
static constexpr index_t kMTile = Problem::BlockGemmShape::kM; // Batch tile
|
||||
static constexpr index_t kNTile = Problem::BlockGemmShape::kN; // Output tile
|
||||
static constexpr index_t kKTile = Problem::BlockGemmShape::kK; // K tile for C dimension
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return kBlockSize; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
// LDS for BlockGemm: A[kMTile, kKTile] + B[kKTile, kNTile]
|
||||
constexpr index_t a_lds_size = kMTile * kKTile * sizeof(XDataType);
|
||||
constexpr index_t b_lds_size = kKTile * kNTile * sizeof(PhiDataType);
|
||||
return a_lds_size + b_lds_size;
|
||||
}
|
||||
|
||||
// Grid configuration: 2D grid over (batch, output_dim)
|
||||
CK_TILE_HOST static constexpr auto GetGridSize(index_t batch, index_t output_dim)
|
||||
{
|
||||
const index_t grid_m = (batch + kMTile - 1) / kMTile;
|
||||
const index_t grid_n = (output_dim + kNTile - 1) / kNTile;
|
||||
return make_tuple(grid_m, grid_n);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(const XDataType* p_x,
|
||||
const PhiDataType* p_phi,
|
||||
YDataType* p_output,
|
||||
index_t batch,
|
||||
index_t nC,
|
||||
index_t output_dim,
|
||||
[[maybe_unused]] index_t n,
|
||||
[[maybe_unused]] float r = 1.0f,
|
||||
[[maybe_unused]] float alpha_pre = 1.0f,
|
||||
[[maybe_unused]] float alpha_post = 1.0f,
|
||||
[[maybe_unused]] float alpha_res = 1.0f,
|
||||
[[maybe_unused]] float bias = 0.0f) const
|
||||
{
|
||||
// 2D block indexing
|
||||
const index_t grid_n_size = (output_dim + kNTile - 1) / kNTile;
|
||||
const index_t block_id = get_block_id();
|
||||
const index_t block_m = block_id / grid_n_size;
|
||||
const index_t block_n = block_id % grid_n_size;
|
||||
|
||||
const index_t batch_start = block_m * kMTile;
|
||||
const index_t out_start = block_n * kNTile;
|
||||
|
||||
if(batch_start >= batch || out_start >= output_dim)
|
||||
return;
|
||||
|
||||
const index_t tid = get_thread_id();
|
||||
|
||||
// Allocate shared memory for A and B tiles + norm accumulators + GEMM results
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
XDataType* x_lds = reinterpret_cast<XDataType*>(smem_ptr);
|
||||
PhiDataType* phi_lds =
|
||||
reinterpret_cast<PhiDataType*>(smem_ptr + kMTile * kKTile * sizeof(XDataType));
|
||||
|
||||
// Shared memory for norm accumulation (one per batch element in tile)
|
||||
__shared__ ComputeDataType sum_squares_shared[kMTile];
|
||||
|
||||
// Shared memory for GEMM result accumulation
|
||||
__shared__ ComputeDataType result_shared[kMTile * kNTile];
|
||||
|
||||
// Initialize shared norm accumulators and result
|
||||
for(index_t i = tid; i < kMTile; i += get_block_size())
|
||||
{
|
||||
sum_squares_shared[i] = 0.0f;
|
||||
}
|
||||
for(index_t i = tid; i < kMTile * kNTile; i += get_block_size())
|
||||
{
|
||||
result_shared[i] = 0.0f;
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
// Number of K-tile iterations
|
||||
const index_t num_k_tiles = (nC + kKTile - 1) / kKTile;
|
||||
|
||||
// Main loop: load tiles, compute norms incrementally, and accumulate GEMM
|
||||
for(index_t k_tile_idx = 0; k_tile_idx < num_k_tiles; ++k_tile_idx)
|
||||
{
|
||||
const index_t k_start = k_tile_idx * kKTile;
|
||||
const index_t k_end = min(k_start + kKTile, nC);
|
||||
const index_t k_len = k_end - k_start;
|
||||
|
||||
// Load X tile from global to LDS and accumulate norm
|
||||
for(index_t i = tid; i < kMTile * kKTile; i += get_block_size())
|
||||
{
|
||||
const index_t local_m = i / kKTile;
|
||||
const index_t local_k = i % kKTile;
|
||||
const index_t global_m = batch_start + local_m;
|
||||
const index_t global_k = k_start + local_k;
|
||||
|
||||
XDataType x_val = 0;
|
||||
if(global_m < batch && local_k < k_len)
|
||||
{
|
||||
x_val = p_x[global_m * nC + global_k];
|
||||
|
||||
// Accumulate norm for this batch element using atomics
|
||||
ComputeDataType x_compute = type_convert<ComputeDataType>(x_val);
|
||||
ComputeDataType sq = x_compute * x_compute;
|
||||
atomicAdd(&sum_squares_shared[local_m], sq);
|
||||
}
|
||||
x_lds[i] = x_val;
|
||||
}
|
||||
|
||||
// Load Phi tile from global to LDS in column-major format (K x N)
|
||||
// phi is stored in global memory as [nC, output_dim] row-major
|
||||
// We need to transpose it to [K, N] column-major for BlockGemm
|
||||
for(index_t i = tid; i < kKTile * kNTile; i += get_block_size())
|
||||
{
|
||||
const index_t local_k = i / kNTile;
|
||||
const index_t local_n = i % kNTile;
|
||||
const index_t global_k = k_start + local_k;
|
||||
const index_t global_n = out_start + local_n;
|
||||
|
||||
PhiDataType phi_val = 0;
|
||||
if(local_k < k_len && global_n < output_dim)
|
||||
{
|
||||
phi_val = p_phi[global_k * output_dim + global_n];
|
||||
}
|
||||
// Store in column-major: phi_lds[n * kKTile + k]
|
||||
phi_lds[local_n * kKTile + local_k] = phi_val;
|
||||
}
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// Perform manual GEMM: result_acc += x_lds * phi_lds^T
|
||||
// Distribute work: each thread computes a subset of output elements
|
||||
// With 64 threads and 16x16 output, each thread handles 4 elements
|
||||
const index_t total_elements = kMTile * kNTile;
|
||||
const index_t elements_per_thread =
|
||||
(total_elements + get_block_size() - 1) / get_block_size();
|
||||
|
||||
for(index_t elem_idx = 0; elem_idx < elements_per_thread; ++elem_idx)
|
||||
{
|
||||
const index_t global_elem = tid * elements_per_thread + elem_idx;
|
||||
if(global_elem < total_elements)
|
||||
{
|
||||
const index_t m_idx = global_elem / kNTile;
|
||||
const index_t n_idx = global_elem % kNTile;
|
||||
|
||||
ComputeDataType acc = 0.0f;
|
||||
for(index_t k_idx = 0; k_idx < kKTile; ++k_idx)
|
||||
{
|
||||
ComputeDataType x_val =
|
||||
type_convert<ComputeDataType>(x_lds[m_idx * kKTile + k_idx]);
|
||||
ComputeDataType phi_val =
|
||||
type_convert<ComputeDataType>(phi_lds[n_idx * kKTile + k_idx]);
|
||||
acc += x_val * phi_val;
|
||||
}
|
||||
// Accumulate to shared memory using atomics
|
||||
atomicAdd(&result_shared[m_idx * kNTile + n_idx], acc);
|
||||
}
|
||||
}
|
||||
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
// Ensure all norm accumulations are complete
|
||||
block_sync_lds();
|
||||
|
||||
// Compute inverse norms after all K-tiles processed
|
||||
ComputeDataType inv_norms[kMTile];
|
||||
for(index_t local_m = 0; local_m < kMTile; ++local_m)
|
||||
{
|
||||
const index_t global_m = batch_start + local_m;
|
||||
if(global_m < batch)
|
||||
{
|
||||
const ComputeDataType norm = ck_tile::sqrt(sum_squares_shared[local_m]) /
|
||||
ck_tile::sqrt(static_cast<ComputeDataType>(nC));
|
||||
inv_norms[local_m] = 1.0f / norm;
|
||||
}
|
||||
else
|
||||
{
|
||||
inv_norms[local_m] = 1.0f;
|
||||
}
|
||||
}
|
||||
|
||||
// Apply normalization, activation, and write output
|
||||
for(index_t i = tid; i < kMTile * kNTile; i += get_block_size())
|
||||
{
|
||||
const index_t local_m = i / kNTile;
|
||||
const index_t local_n = i % kNTile;
|
||||
|
||||
const index_t global_m = batch_start + local_m;
|
||||
const index_t global_n = out_start + local_n;
|
||||
|
||||
if(global_m >= batch || global_n >= output_dim)
|
||||
continue;
|
||||
|
||||
const ComputeDataType inv_norm = inv_norms[local_m];
|
||||
ComputeDataType value = result_shared[i];
|
||||
|
||||
// Apply normalization and activation based on output section
|
||||
if(global_n < n)
|
||||
{
|
||||
ComputeDataType activated_value;
|
||||
Activation{}(activated_value, value);
|
||||
value = alpha_pre * inv_norm * activated_value + bias;
|
||||
}
|
||||
else if(global_n < 2 * n)
|
||||
{
|
||||
ComputeDataType activated_value;
|
||||
Activation{}(activated_value, value);
|
||||
value = alpha_post * inv_norm * 2.0f * activated_value + bias;
|
||||
}
|
||||
else
|
||||
{
|
||||
value = alpha_res * inv_norm * value + bias;
|
||||
}
|
||||
|
||||
// Write to global memory
|
||||
p_output[global_m * output_dim + global_n] = type_convert<YDataType>(value);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user