mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
[CK_TILE] fused-moe first version (#1634)
* moe pipeline * update code * compile OK * update * update cpu reference * update pipeline_gemm0 * compiler ok * update pipeline * rename to ex pipeline * block-asm * update * update * update first gemm ok * compute correct * update file structure * update README * update * update * update code * update API * return unsupport case * add comment * update readme * update * uncomment * update * fix build err --------- Co-authored-by: valarLip <340077269@qq.com>
This commit is contained in:
@@ -572,6 +572,105 @@ struct FastGelu
|
||||
}
|
||||
};
|
||||
|
||||
struct FastGeluAsm
|
||||
{
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_DEVICE void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST void operator()<float, float>(float& y, const float& x) const
|
||||
{
|
||||
// const float u = -2.f * x * (0.035677f * x * x + 0.797885f);
|
||||
const float c1 = -2.0 * 0.035677f;
|
||||
const float c2 = -2.0 * 0.797885f;
|
||||
const float u = x * (c1 * x * x + c2);
|
||||
const float emu = exp(u);
|
||||
y = x / (1.f + emu);
|
||||
}
|
||||
|
||||
// device code, use lower precision "__ocml_exp_f32" and "rcp"
|
||||
template <>
|
||||
CK_TILE_DEVICE void operator()<float, float>(float& y, const float& x) const
|
||||
{
|
||||
const uint32_t c1 = 0xbd92220c; // -2.0 * 0.035677f;
|
||||
const float c2 = -2.0 * 0.797885f;
|
||||
const uint32_t log2e_ = 0x3fb8aa3b; // log2e_v<float>;
|
||||
float tmp;
|
||||
|
||||
asm volatile("v_mul_f32 %[v_tmp], %[v_x], %[v_x] ; x*x\n"
|
||||
"v_fma_f32 %[v_tmp], %[v_tmp], %[s_c1], %[v_c2] ; c1*x*x+c2\n"
|
||||
"v_mul_f32 %[v_tmp], %[v_tmp], %[v_x] ; x*(c1*x*x+c2)\n"
|
||||
"v_mul_f32 %[v_tmp], %[v_tmp], %[s_log2e] ; log2e*x*(c1*x*x+c2)\n"
|
||||
"v_exp_f32 %[v_tmp], %[v_tmp] ; emu = exp2(log2e*x*(c1*x*x+c2))\n"
|
||||
"s_nop 0 ; hazard for exp\n"
|
||||
"v_add_f32 %[v_tmp], %[v_tmp], 1.0 ; emu+1.0f\n"
|
||||
"v_rcp_f32 %[v_tmp], %[v_tmp] ; 1/(emu+1.0f)\n"
|
||||
"s_nop 0 ; hazard for rcp \n"
|
||||
"v_mul_f32 %[v_y], %[v_tmp], %[v_x] ; x * 1/(emu+1f)\n"
|
||||
: [v_y] "=v"(y), [v_tmp] "+v"(tmp)
|
||||
: [v_x] "v"(x), [s_c1] "s"(c1), [v_c2] "v"(c2), [s_log2e] "s"(log2e_)
|
||||
:);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST void operator()<fp32x2_t, fp32x2_t>(fp32x2_t& y, const fp32x2_t& x) const
|
||||
{
|
||||
const float c1 = -2.0 * 0.035677f;
|
||||
const float c2 = -2.0 * 0.797885f;
|
||||
const float u0 = x.x * (c1 * x.x * x.x + c2);
|
||||
const float emu0 = exp(u0);
|
||||
y.x = x.x / (1.f + emu0);
|
||||
const float u1 = x.y * (c1 * x.y * x.y + c2);
|
||||
const float emu1 = exp(u1);
|
||||
y.y = x.y / (1.f + emu1);
|
||||
}
|
||||
|
||||
// this is packed verion to remove data hazard for trans
|
||||
template <>
|
||||
CK_TILE_DEVICE void operator()<fp32x2_t, fp32x2_t>(fp32x2_t& y, const fp32x2_t& x) const
|
||||
{
|
||||
const uint32_t c1 = 0xbd92220c; // -2.0 * 0.035677f;
|
||||
float c2 = -2.0 * 0.797885f;
|
||||
const uint32_t log2e_ = 0x3fb8aa3b; // log2e_v<float>;
|
||||
float tmp0, tmp1;
|
||||
float y0 = x.x, y1 = x.y;
|
||||
|
||||
asm volatile(
|
||||
"v_mul_f32 %[v_tmp0], %[v_y0], %[v_y0] ; x*x\n"
|
||||
"v_mul_f32 %[v_tmp1], %[v_y1], %[v_y1] ; x*x\n"
|
||||
"v_fma_f32 %[v_tmp0], %[v_tmp0], %[s_c1], %[v_c2] ; c1*x*x+c2\n"
|
||||
"v_fma_f32 %[v_tmp1], %[v_tmp1], %[s_c1], %[v_c2] ; c1*x*x+c2\n"
|
||||
"v_mul_f32 %[v_tmp0], %[v_tmp0], %[v_y0] ; x*(c1*x*x+c2)\n"
|
||||
"v_mul_f32 %[v_tmp1], %[v_tmp1], %[v_y1] ; x*(c1*x*x+c2)\n"
|
||||
"v_mul_f32 %[v_tmp0], %[v_tmp0], %[s_log2e] ; log2e*x*(c1*x*x+c2)\n"
|
||||
"v_mul_f32 %[v_tmp1], %[v_tmp1], %[s_log2e] ; log2e*x*(c1*x*x+c2)\n"
|
||||
"v_exp_f32 %[v_tmp0], %[v_tmp0] ; emu = exp2(log2e*x*(c1*x*x+c2))\n"
|
||||
"v_exp_f32 %[v_tmp1], %[v_tmp1] ; emu = exp2(log2e*x*(c1*x*x+c2))\n"
|
||||
"v_add_f32 %[v_tmp0], %[v_tmp0], 1.0 ; emu+1.0f\n"
|
||||
"v_add_f32 %[v_tmp1], %[v_tmp1], 1.0 ; emu+1.0f\n"
|
||||
"v_rcp_f32 %[v_tmp0], %[v_tmp0] ; 1/(emu+1.0f)\n"
|
||||
"v_rcp_f32 %[v_tmp1], %[v_tmp1] ; 1/(emu+1.0f)\n"
|
||||
"v_mul_f32 %[v_y0], %[v_tmp0], %[v_y0] ; x * 1/(emu+1f)\n"
|
||||
"v_mul_f32 %[v_y1], %[v_tmp1], %[v_y1] ; x * 1/(emu+1f)\n"
|
||||
: [v_y0] "+v"(y0),
|
||||
[v_y1] "+v"(y1),
|
||||
[v_c2] "+v"(c2),
|
||||
// NOTE! it is totally possible that c2/y0/y1 share same register, they are all local
|
||||
// tmp variables we need to expicitly hint compiler they may read+write, to allow
|
||||
// allocate different register , the side effect is c2=** may issue for every such
|
||||
// inline asm block
|
||||
[v_tmp0] "+v"(tmp0),
|
||||
[v_tmp1] "+v"(tmp1)
|
||||
: [s_c1] "s"(c1), [s_log2e] "s"(log2e_)
|
||||
:);
|
||||
y.x = y0;
|
||||
y.y = y1;
|
||||
}
|
||||
};
|
||||
|
||||
// https://paperswithcode.com/method/gelu
|
||||
// y = 0.5*x*(1+erf(x/sqrt(2)))
|
||||
struct Gelu
|
||||
|
||||
Reference in New Issue
Block a user