From 222e00e840ea85dff34a66b392fa952c72ff1baf Mon Sep 17 00:00:00 2001 From: Harsh Dave Date: Mon, 6 Feb 2023 11:49:45 -0600 Subject: [PATCH] Obliterated usage of rbp register in SUP gemm kernel - Mx4 edge kernels were overwriting rbp registers for prefetches. - Since rbp along with rsp defines stack frame, it resulted in stack overflow issue. - Replaced rbp with rdx register for prefetches. AMD-Internal: [CPUPL-2987] Change-Id: I4e52cf691b70be5ab63f562d7630d640b29e1cfd --- .../d6x8/bli_gemmsup_rv_haswell_asm_dMx4.c | 228 +++++++++--------- 1 file changed, 114 insertions(+), 114 deletions(-) diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx4.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx4.c index bdcf833e3..607ed2b43 100644 --- a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx4.c +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rv_haswell_asm_dMx4.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019-23, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -115,9 +115,9 @@ void bli_dgemmsup_rv_haswell_asm_6x4 // ------------------------------------------------------------------------- begin_asm() - + vzeroall() // zero all xmm/ymm registers. - + mov(var(a), rax) // load address of a. mov(var(rs_a), r8) // load rs_a mov(var(cs_a), r9) // load cs_a @@ -132,7 +132,7 @@ void bli_dgemmsup_rv_haswell_asm_6x4 //mov(var(cs_b), r11) // load cs_b lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) - + // NOTE: We cannot pre-load elements of a or b // because it could eventually, in the last // unrolled iter or the cleanup loop, result @@ -163,38 +163,38 @@ void bli_dgemmsup_rv_haswell_asm_6x4 mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) - lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; prefetch(0, mem(rcx, 5*8)) // prefetch c + 0*cs_c prefetch(0, mem(rcx, rsi, 1, 5*8)) // prefetch c + 1*cs_c prefetch(0, mem(rcx, rsi, 2, 5*8)) // prefetch c + 2*cs_c - prefetch(0, mem(rcx, rbp, 1, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rdx, 1, 5*8)) // prefetch c + 3*cs_c label(.DPOSTPFETCH) // done prefetching c - - + + #if 1 lea(mem(rax, r9, 8), rdx) // lea(mem(rdx, r9, 8), rdx) // rdx = a + 16*cs_a; #endif - - + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - - + + // ---------------------------------- iteration 0 #if 1 prefetch(0, mem(rdx, 5*8)) #endif - + vmovupd(mem(rbx, 0*32), ymm0) add(r10, rbx) // b += rs_b; @@ -202,19 +202,19 @@ void bli_dgemmsup_rv_haswell_asm_6x4 vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm0, ymm3, ymm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm0, ymm3, ymm10) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm0, ymm3, ymm14) - + // ---------------------------------- iteration 1 #if 0 @@ -228,25 +228,25 @@ void bli_dgemmsup_rv_haswell_asm_6x4 vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm0, ymm3, ymm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm0, ymm3, ymm10) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm0, ymm3, ymm14) - + // ---------------------------------- iteration 2 #if 1 prefetch(0, mem(rdx, r9, 2, 5*8)) #endif - + vmovupd(mem(rbx, 0*32), ymm0) add(r10, rbx) // b += rs_b; @@ -254,18 +254,18 @@ void bli_dgemmsup_rv_haswell_asm_6x4 vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm0, ymm3, ymm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm0, ymm3, ymm10) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm0, ymm3, ymm14) - + // ---------------------------------- iteration 3 @@ -280,43 +280,43 @@ void bli_dgemmsup_rv_haswell_asm_6x4 vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm0, ymm3, ymm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm0, ymm3, ymm10) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm0, ymm3, ymm14) - - - + + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP #if 0 prefetch(0, mem(rdx, 5*8)) add(r9, rdx) #endif - + vmovupd(mem(rbx, 0*32), ymm0) add(r10, rbx) // b += rs_b; @@ -324,57 +324,57 @@ void bli_dgemmsup_rv_haswell_asm_6x4 vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm0, ymm3, ymm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm0, ymm3, ymm10) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm0, ymm3, ymm14) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(ymm0, ymm4, ymm4) // scale by alpha vmulpd(ymm0, ymm6, ymm6) vmulpd(ymm0, ymm8, ymm8) vmulpd(ymm0, ymm10, ymm10) vmulpd(ymm0, ymm12, ymm12) vmulpd(ymm0, ymm14, ymm14) - - - - - - + + + + + + mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case @@ -383,42 +383,42 @@ void bli_dgemmsup_rv_haswell_asm_6x4 cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORED) // jump to column storage case - - + + label(.DROWSTORED) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) vmovupd(ymm4, mem(rcx, 0*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) vmovupd(ymm6, mem(rcx, 0*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) vmovupd(ymm8, mem(rcx, 0*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) vmovupd(ymm10, mem(rcx, 0*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) vmovupd(ymm12, mem(rcx, 0*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) vmovupd(ymm14, mem(rcx, 0*32)) //add(rdi, rcx) - - + + jmp(.DDONE) // jump to end. @@ -466,45 +466,45 @@ void bli_dgemmsup_rv_haswell_asm_6x4 jmp(.DDONE) // jump to end. - - - - + + + + label(.DBETAZERO) cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORBZ) // jump to column storage case - - + + label(.DROWSTORBZ) - - + + vmovupd(ymm4, mem(rcx, 0*32)) add(rdi, rcx) - + vmovupd(ymm6, mem(rcx, 0*32)) add(rdi, rcx) - - + + vmovupd(ymm8, mem(rcx, 0*32)) add(rdi, rcx) - + vmovupd(ymm10, mem(rcx, 0*32)) add(rdi, rcx) - - + + vmovupd(ymm12, mem(rcx, 0*32)) add(rdi, rcx) - + vmovupd(ymm14, mem(rcx, 0*32)) //add(rdi, rcx) - + jmp(.DDONE) // jump to end. @@ -539,13 +539,13 @@ void bli_dgemmsup_rv_haswell_asm_6x4 vmovupd(xmm4, mem(rdx, rax, 1)) //lea(mem(rdx, rsi, 4), rdx) - - - - + + + + label(.DDONE) - - + + end_asm( : // output operands (none) @@ -566,7 +566,7 @@ void bli_dgemmsup_rv_haswell_asm_6x4 [a_next] "m" (a_next), [b_next] "m" (b_next)*/ : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -657,11 +657,11 @@ void bli_dgemmsup_rv_haswell_asm_5x4 mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) - lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; prefetch(0, mem(rcx, 4*8)) // prefetch c + 0*cs_c prefetch(0, mem(rcx, rsi, 1, 4*8)) // prefetch c + 1*cs_c prefetch(0, mem(rcx, rsi, 2, 4*8)) // prefetch c + 2*cs_c - prefetch(0, mem(rcx, rbp, 1, 4*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rdx, 1, 4*8)) // prefetch c + 3*cs_c label(.DPOSTPFETCH) // done prefetching c @@ -1037,7 +1037,7 @@ void bli_dgemmsup_rv_haswell_asm_5x4 [a_next] "m" (a_next), [b_next] "m" (b_next)*/ : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -1127,11 +1127,11 @@ void bli_dgemmsup_rv_haswell_asm_4x4 mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) - lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*cs_c prefetch(0, mem(rcx, rsi, 1, 3*8)) // prefetch c + 1*cs_c prefetch(0, mem(rcx, rsi, 2, 3*8)) // prefetch c + 2*cs_c - prefetch(0, mem(rcx, rbp, 1, 3*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rdx, 1, 3*8)) // prefetch c + 3*cs_c label(.DPOSTPFETCH) // done prefetching c @@ -1457,7 +1457,7 @@ void bli_dgemmsup_rv_haswell_asm_4x4 [a_next] "m" (a_next), [b_next] "m" (b_next)*/ : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -1546,11 +1546,11 @@ void bli_dgemmsup_rv_haswell_asm_3x4 mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) - lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; prefetch(0, mem(rcx, 2*8)) // prefetch c + 0*cs_c prefetch(0, mem(rcx, rsi, 1, 2*8)) // prefetch c + 1*cs_c prefetch(0, mem(rcx, rsi, 2, 2*8)) // prefetch c + 2*cs_c - prefetch(0, mem(rcx, rbp, 1, 2*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rdx, 1, 2*8)) // prefetch c + 3*cs_c label(.DPOSTPFETCH) // done prefetching c @@ -1884,7 +1884,7 @@ void bli_dgemmsup_rv_haswell_asm_3x4 [a_next] "m" (a_next), [b_next] "m" (b_next)*/ : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -1972,11 +1972,11 @@ void bli_dgemmsup_rv_haswell_asm_2x4 mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) - lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*cs_c prefetch(0, mem(rcx, rsi, 1, 1*8)) // prefetch c + 1*cs_c prefetch(0, mem(rcx, rsi, 2, 1*8)) // prefetch c + 2*cs_c - prefetch(0, mem(rcx, rbp, 1, 1*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rdx, 1, 1*8)) // prefetch c + 3*cs_c label(.DPOSTPFETCH) // done prefetching c @@ -2247,7 +2247,7 @@ void bli_dgemmsup_rv_haswell_asm_2x4 [a_next] "m" (a_next), [b_next] "m" (b_next)*/ : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -2334,11 +2334,11 @@ void bli_dgemmsup_rv_haswell_asm_1x4 mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) - lea(mem(rsi, rsi, 2), rbp) // rbp = 3*cs_c; + lea(mem(rsi, rsi, 2), rdx) // rdx = 3*cs_c; prefetch(0, mem(rcx, 0*8)) // prefetch c + 0*cs_c prefetch(0, mem(rcx, rsi, 1, 0*8)) // prefetch c + 1*cs_c prefetch(0, mem(rcx, rsi, 2, 0*8)) // prefetch c + 2*cs_c - prefetch(0, mem(rcx, rbp, 1, 0*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rcx, rdx, 1, 0*8)) // prefetch c + 3*cs_c label(.DPOSTPFETCH) // done prefetching c @@ -2588,7 +2588,7 @@ void bli_dgemmsup_rv_haswell_asm_1x4 [a_next] "m" (a_next), [b_next] "m" (b_next)*/ : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7",