Disabled topology detection in LPGEMM

- Disabled topology detection as libgomp is not honoring
  the standard function omp_get_place_proc_ids

- Added B prefetch in bf16 B packing kernels

AMD-Internal: SWLCSG-3761
This commit is contained in:
Bhaskar, Nallani
2025-08-26 19:20:01 +05:30
committed by GitHub
parent 15d2e5c628
commit b052775644
3 changed files with 18 additions and 1 deletions

View File

@@ -235,7 +235,14 @@ void lpgemm_load_thread_attrs()
lpgemm_thread_attrs.tid_distr_nearly_seq = FALSE;
lpgemm_thread_attrs.tid_core_grp_load_high = FALSE;
lpgemm_detect_thread_topo();
/*
TODO: Disabling lpgemm_detect_thread_topo detection for now until for
further investigation.
Reason: libgomp is not honoring standard function omp_get_place_proc_ids
on virtual machines.
*/
// lpgemm_detect_thread_topo();
}
void lpgemm_init_thread_attrs()

View File

@@ -180,6 +180,10 @@ void packb_nr64_bf16bf16f32of32_row_major
{
for ( dim_t kr = 0; kr < k_full_pieces; kr += 2 )
{
//Prefetch B data required for next iteration
_mm_prefetch( ( b + ( ldb * ( kr + 2 ) ) + jc ), _MM_HINT_T0);
_mm_prefetch( ( b + ( ldb * ( kr + 3 ) ) + jc ), _MM_HINT_T0);
// Rearrange for dpbf16_ps, read 2 rows from B with 64 elements in each row.
a0 = _mm512_loadu_si512( b + ( ldb * ( kr + 0 ) ) + jc );
b0 = _mm512_loadu_si512( b + ( ldb * ( kr + 0 ) ) + jc + 32 );
@@ -203,6 +207,7 @@ void packb_nr64_bf16bf16f32of32_row_major
_mm512_storeu_si512( pack_b_buffer_bf16bf16f32of32 + ( jc * KC_updated ) + ( ( kr + 1 ) * NR ), d0 );
_mm512_storeu_si512( pack_b_buffer_bf16bf16f32of32 + ( jc * KC_updated ) + ( ( kr + 1 ) * NR ) + 32, c0 );
}
// Handle k remainder.
if( k_partial_pieces > 0)
{

View File

@@ -155,6 +155,11 @@ LPGEMV_M_EQ1_KERN(bfloat16, bfloat16, float, bf16bf16f32of32)
for (dim_t k = 0; k < k_iter; k++)
{
//Prefetch B data required for next iteration
_mm_prefetch(b_use + 4*rs_b, _MM_HINT_T0);
_mm_prefetch(b_use + 5*rs_b, _MM_HINT_T0);
_mm_prefetch(b_use + 6*rs_b, _MM_HINT_T0);
_mm_prefetch(b_use + 7*rs_b, _MM_HINT_T0);
// load first 4x32 tile from row 0-3
zmm0 = (__m512bh)_mm512_maskz_loadu_epi16( k5, b_use );