mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
Gemm layernorm welford (#413)
* Add device op of gemm layernorm * [What] Rename F to H [Why] F and G prepare for welford tensor * Add gridwise gemm + welford * Extract template parameter * Rename kernel. Prepare to add second half kernel * Extract var * Add second kernel for gemm+layernorm * Move to the gemm_layernorm folder * Rename F and G to mean and var * Do not use snakeCurved, it makes determination of padding for welford difficult * Rewrite the device interface and rename some var * Add welford count * Update interface * Sync code, prepare to test on MI200 * Clean the code * Implement layernorm * Add comment to mension hipFree * Wrtie out the e for debug. This could be remove and use h for instead * 1. Allocate mean, var and count into by SetWorkSpacePointer. 2. Add GetWorkSpaceSize to calculate the space size * Add gemm layernorm host code * use reference layernorm * Fix bug of blockwise welford for first kernel * Fix bug of mean var padding for layernorm * Use sgpr for shuffleM_index * padding for GemmMeanVarCountGridDescriptor_M_NBlock * Add layout parameter * Check argument for gemm * calculate max count for tail block * Share E and H memory in device op * Hard code the vector dim * Refine the MakeDescriptor * 1. Remove E parameter, because E is inside of device op 2. Check vector size * [What] Rename MakeMeanVarDescriptor_M_N [Why] Prepare to add count version of make descriptor * Use 1D global memory for count * Prevent redundant IO * Update parameter * Add pipeline v1/v2 selector * Rename the example name * Add base class for gemm layernorm * Refine naming to distinguish naive and welford * Add comment to explan in detail * We don't need to pad in N dimension in gemm for mean/var/count. Set NPerTile 1 * Rewrite the 2st kernel, use multiple block along N dimension in layernorm kernel * Share the vector size * Refine var name * [What] Force LayernormThreadSliceSize_N = vector size. [Why] Memory coalesce * Add comment * Extract divisor out of the loop in reference layernorm * Pad different size for E and H in layernorm kernel according to different block tile * Refine naming * Refine naming * Prevent implicit cast * [What] use ck::math::sqrt instead of __builtin_amdgcn_sqrtf [Why] __builtin_amdgcn_sqrtf is only support float, double will cause casting * Cast only constant * Change of post shuffle thread descriptor * Add EMeanVarDataType parameter. * Merge the mean and var threadwise copy * Add missing index * Fix Typo * Sync the variable with previous if * 1. Declare e inside the host_gemm_layernorm() 2. Prevent implicit cast in reference code Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
This commit is contained in:
@@ -0,0 +1,67 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include "device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// GEMM:
|
||||
// input : A[M, K]
|
||||
// input : B[N, K]
|
||||
// input : D0[M, N], D1[M, N], ...
|
||||
// output : E[M, N]
|
||||
// output : H[M, N]
|
||||
// C = a_op(A) * b_op(B)
|
||||
// E = cde_op(C, D0, D1, ...)
|
||||
// H = layernorm(E)
|
||||
// Assume:
|
||||
// D0, D1, ... and E have the same layout
|
||||
// Calculate mean & variance along N dimension in layernorm(E)
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename HLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename HDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
typename HElementwiseOperation>
|
||||
struct DeviceGemmMultipleDLayernorm : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
const void* p_gamma,
|
||||
const void* p_beta,
|
||||
void* p_h,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
std::array<index_t, NumDTensor> StrideDs,
|
||||
index_t StrideH,
|
||||
double epsilon,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
HElementwiseOperation h_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
}; // namespace device
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user