Packed A matrix stride update to account for fringe cases.

-When A matrix is packed, it is packed in blocks of MRxKC, to form a
whole packed MCxKC block. If the m value is not a multiple of MR, then
the m % MR block is packed in a different manner as opposed to the MR
blocks. Subsequently the strides of the packed MR block and m % MR
blocks are different and the same needs to be updated when calling the
GEMV kernels with packed A matrix.
-Fixes to address compiler warnings.

AMD-Internal: [SWLCSG-3359]
Change-Id: I7f47afbc9cd92536cb375431d74d9b8bca7bab44
This commit is contained in:
Mithun Mohan
2025-01-21 11:45:28 +00:00
committed by Nallani Bhaskar
parent 66461b8df3
commit 39289858b7
8 changed files with 55 additions and 18 deletions

View File

@@ -652,7 +652,7 @@ AOCL_GEMM_REORDER(int8_t, bf16s4f32of32)
b.cs = cs_b;
b.width = n;
b.length = k;
b.mtag = input_mat_type;
b.mat_type = input_mat_type;
reorderb_nr64_bf16s4f32of32(&b, &b_reorder, &rntm_g, lcntx_g);
}

View File

@@ -403,7 +403,7 @@ void reorderb_nr64_bf16s4f32of32
dim_t cs_b = b->cs;
dim_t n = b->width;
dim_t k = b->length;
AOCL_MATRIX_TYPE mtag = b->mtag;
AOCL_MATRIX_TYPE mat_type = b->mat_type;
dim_t rs_b_reorder;
dim_t cs_b_reorder;
@@ -495,13 +495,16 @@ void reorderb_nr64_bf16s4f32of32
// st = ( jc_cur_loop * k ) <traverse blocks 1,2,3,4>
// + ( n_sub_updated * pc ) <traverse block 5>
// + ( NC' * kc0_updated) <traverse block 6>
((pack_s4)lcntx->packb_fun_ptr)(
((int8_t *)b_reorder->storage.aligned_buffer) +
( (jc_cur_loop * k_updated) + (n_sub_updated * pc) +
(jc_cur_loop_rem * kc0_updated) ) / 2,
(((int8_t *)b->storage.aligned_buffer) +
( (rs_b * pc) + (jc * cs_b) ) / 2),
rs_b, cs_b, nc0, kc0, &rs_b_reorder, &cs_b_reorder, NULL, mtag);
((pack_s4)lcntx->packb_fun_ptr)
(
( ( int8_t* )b_reorder->storage.aligned_buffer ) +
( ( jc_cur_loop * k_updated ) + ( n_sub_updated * pc ) +
( jc_cur_loop_rem * kc0_updated ) ) / 2,
( ( ( int8_t* )b->storage.aligned_buffer ) +
( ( rs_b * pc ) + ( jc * cs_b ) ) / 2 ),
rs_b, cs_b, nc0, kc0, &rs_b_reorder, &cs_b_reorder,
NULL, mat_type
);
}
adjust_B_panel_reordered_jc(&jc, jc_cur_loop);
@@ -637,4 +640,4 @@ void reorderb_mxp_nr64_f32obf16
b_reorder->rs = rs_b_reorder;
b_reorder->cs = cs_b_reorder;
b_reorder->mtag = REORDERED;
}
}

View File

@@ -136,6 +136,7 @@ typedef struct
dim_t cs;
AOCL_MEMORY_TAG mtag;
AOCL_MATRIX_TYPE mat_type;
lpgemm_mem_t storage;
} lpgemm_obj_t;

View File

@@ -211,6 +211,8 @@ LPGEMV(int8_t,int8_t,int32_t,s8s8s32o32)
}
else
{
dim_t gemm_MR = lcntx->blksz.MR;
dim_t jc_start, jc_end;
thread_jc.n_way = ( thread_jc.n_way == 1 ) ?
( thread->n_threads ) : ( thread_jc.n_way );
@@ -246,6 +248,10 @@ LPGEMV(int8_t,int8_t,int32_t,s8s8s32o32)
1, k,
&rs_a_use, &cs_a_use
);
get_packa_strides_mfringe_u8s8s32os32
(
&rs_a_use, &cs_a_use, gemm_MR, 1
);
a_use = pack_a_buffer_s8s8s32os32;
}

View File

@@ -53,12 +53,18 @@ BLIS_INLINE void calculate_n_threads_per_gemm
rntm_t* rntm_g
)
{
*n_threads = bli_rntm_num_threads( rntm_g ); \
*n_gemms_in_parallel = -1; \
if( *n_threads == 1 ) *n_gemms_in_parallel = 1; \
else if( *n_gemms_in_parallel < 1 ) *n_gemms_in_parallel = bli_min(*n_threads, batch_size); \
/* ToDo: All the leftover thrads might go under-utilized. Could be optimized further. */ \
*n_threads_per_gemm = ( *n_threads ) / *n_gemms_in_parallel;
*n_threads = bli_rntm_num_threads( rntm_g );
*n_gemms_in_parallel = -1;
if( *n_threads == 1 )
{
( *n_gemms_in_parallel ) = 1;
}
else if( *n_gemms_in_parallel < 1 )
{
( *n_gemms_in_parallel ) = bli_min( ( *n_threads ), batch_size );
}
/* ToDo: All the leftover thrads might go under-utilized. Could be optimized further. */
( *n_threads_per_gemm ) = ( *n_threads ) / ( *n_gemms_in_parallel );
}
BLIS_INLINE dim_t next_factor

View File

@@ -4,7 +4,7 @@
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved.
Copyright (C) 2022 - 2025, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -194,6 +194,8 @@ LPGEMV(uint8_t,int8_t,int32_t,u8s8s32os32)
}
else
{
dim_t gemm_MR = lcntx->blksz.MR;
// Compute the JC loop thread range for the current thread.
dim_t jc_start, jc_end;
thread_jc.n_way = ( thread_jc.n_way == 1 ) ?
@@ -227,6 +229,10 @@ LPGEMV(uint8_t,int8_t,int32_t,u8s8s32os32)
1, k,
&rs_a_use, &cs_a_use
);
get_packa_strides_mfringe_u8s8s32os32
(
&rs_a_use, &cs_a_use, gemm_MR, 1
);
a_use = pack_a_buffer;
}

View File

@@ -4,7 +4,7 @@
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved.
Copyright (C) 2022 - 2025, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -35,6 +35,20 @@
#ifndef BLIS_GEMM_INT8_PACKA
#define BLIS_GEMM_INT8_PACKA
// The strides needs to be updated based on the m_fringe value to account
// for different schemas used to pack A fringe cases.
BLIS_INLINE void get_packa_strides_mfringe_u8s8s32os32
(
dim_t* rs,
dim_t* cs,
dim_t MR,
dim_t m_fringe
)
{
( *rs ) = 4;
( *cs ) = ( ( *cs ) / MR ) * m_fringe;
}
typedef void (*packa_s32)
(
uint8_t*,

View File

@@ -1,3 +1,4 @@
r t n n r 1 128 64 1 128 128 *:none
c n t n n 32 128 2 32 128 32 bf16bf16f32of32:bias=na,swish
r n n n r 6 1 4 4 16 16 bf16s4f32of32:pre_op_scale=scalar,pre_op_scale_type=bf16,group_size=2
r n n n r 6 1 4 4 16 16 bf16s4f32of32:pre_op_zp=vector,pre_op_scale=scalar,pre_op_scale_type=bf16,group_size=2