mirror of
https://github.com/amd/blis.git
synced 2026-04-20 07:38:53 +00:00
Packed A matrix stride update to account for fringe cases.
-When A matrix is packed, it is packed in blocks of MRxKC, to form a whole packed MCxKC block. If the m value is not a multiple of MR, then the m % MR block is packed in a different manner as opposed to the MR blocks. Subsequently the strides of the packed MR block and m % MR blocks are different and the same needs to be updated when calling the GEMV kernels with packed A matrix. -Fixes to address compiler warnings. AMD-Internal: [SWLCSG-3359] Change-Id: I7f47afbc9cd92536cb375431d74d9b8bca7bab44
This commit is contained in:
committed by
Nallani Bhaskar
parent
66461b8df3
commit
39289858b7
@@ -652,7 +652,7 @@ AOCL_GEMM_REORDER(int8_t, bf16s4f32of32)
|
||||
b.cs = cs_b;
|
||||
b.width = n;
|
||||
b.length = k;
|
||||
b.mtag = input_mat_type;
|
||||
b.mat_type = input_mat_type;
|
||||
|
||||
reorderb_nr64_bf16s4f32of32(&b, &b_reorder, &rntm_g, lcntx_g);
|
||||
}
|
||||
|
||||
@@ -403,7 +403,7 @@ void reorderb_nr64_bf16s4f32of32
|
||||
dim_t cs_b = b->cs;
|
||||
dim_t n = b->width;
|
||||
dim_t k = b->length;
|
||||
AOCL_MATRIX_TYPE mtag = b->mtag;
|
||||
AOCL_MATRIX_TYPE mat_type = b->mat_type;
|
||||
|
||||
dim_t rs_b_reorder;
|
||||
dim_t cs_b_reorder;
|
||||
@@ -495,13 +495,16 @@ void reorderb_nr64_bf16s4f32of32
|
||||
// st = ( jc_cur_loop * k ) <traverse blocks 1,2,3,4>
|
||||
// + ( n_sub_updated * pc ) <traverse block 5>
|
||||
// + ( NC' * kc0_updated) <traverse block 6>
|
||||
((pack_s4)lcntx->packb_fun_ptr)(
|
||||
((int8_t *)b_reorder->storage.aligned_buffer) +
|
||||
( (jc_cur_loop * k_updated) + (n_sub_updated * pc) +
|
||||
(jc_cur_loop_rem * kc0_updated) ) / 2,
|
||||
(((int8_t *)b->storage.aligned_buffer) +
|
||||
( (rs_b * pc) + (jc * cs_b) ) / 2),
|
||||
rs_b, cs_b, nc0, kc0, &rs_b_reorder, &cs_b_reorder, NULL, mtag);
|
||||
((pack_s4)lcntx->packb_fun_ptr)
|
||||
(
|
||||
( ( int8_t* )b_reorder->storage.aligned_buffer ) +
|
||||
( ( jc_cur_loop * k_updated ) + ( n_sub_updated * pc ) +
|
||||
( jc_cur_loop_rem * kc0_updated ) ) / 2,
|
||||
( ( ( int8_t* )b->storage.aligned_buffer ) +
|
||||
( ( rs_b * pc ) + ( jc * cs_b ) ) / 2 ),
|
||||
rs_b, cs_b, nc0, kc0, &rs_b_reorder, &cs_b_reorder,
|
||||
NULL, mat_type
|
||||
);
|
||||
}
|
||||
|
||||
adjust_B_panel_reordered_jc(&jc, jc_cur_loop);
|
||||
@@ -637,4 +640,4 @@ void reorderb_mxp_nr64_f32obf16
|
||||
b_reorder->rs = rs_b_reorder;
|
||||
b_reorder->cs = cs_b_reorder;
|
||||
b_reorder->mtag = REORDERED;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -136,6 +136,7 @@ typedef struct
|
||||
dim_t cs;
|
||||
|
||||
AOCL_MEMORY_TAG mtag;
|
||||
AOCL_MATRIX_TYPE mat_type;
|
||||
|
||||
lpgemm_mem_t storage;
|
||||
} lpgemm_obj_t;
|
||||
|
||||
@@ -211,6 +211,8 @@ LPGEMV(int8_t,int8_t,int32_t,s8s8s32o32)
|
||||
}
|
||||
else
|
||||
{
|
||||
dim_t gemm_MR = lcntx->blksz.MR;
|
||||
|
||||
dim_t jc_start, jc_end;
|
||||
thread_jc.n_way = ( thread_jc.n_way == 1 ) ?
|
||||
( thread->n_threads ) : ( thread_jc.n_way );
|
||||
@@ -246,6 +248,10 @@ LPGEMV(int8_t,int8_t,int32_t,s8s8s32o32)
|
||||
1, k,
|
||||
&rs_a_use, &cs_a_use
|
||||
);
|
||||
get_packa_strides_mfringe_u8s8s32os32
|
||||
(
|
||||
&rs_a_use, &cs_a_use, gemm_MR, 1
|
||||
);
|
||||
|
||||
a_use = pack_a_buffer_s8s8s32os32;
|
||||
}
|
||||
|
||||
@@ -53,12 +53,18 @@ BLIS_INLINE void calculate_n_threads_per_gemm
|
||||
rntm_t* rntm_g
|
||||
)
|
||||
{
|
||||
*n_threads = bli_rntm_num_threads( rntm_g ); \
|
||||
*n_gemms_in_parallel = -1; \
|
||||
if( *n_threads == 1 ) *n_gemms_in_parallel = 1; \
|
||||
else if( *n_gemms_in_parallel < 1 ) *n_gemms_in_parallel = bli_min(*n_threads, batch_size); \
|
||||
/* ToDo: All the leftover thrads might go under-utilized. Could be optimized further. */ \
|
||||
*n_threads_per_gemm = ( *n_threads ) / *n_gemms_in_parallel;
|
||||
*n_threads = bli_rntm_num_threads( rntm_g );
|
||||
*n_gemms_in_parallel = -1;
|
||||
if( *n_threads == 1 )
|
||||
{
|
||||
( *n_gemms_in_parallel ) = 1;
|
||||
}
|
||||
else if( *n_gemms_in_parallel < 1 )
|
||||
{
|
||||
( *n_gemms_in_parallel ) = bli_min( ( *n_threads ), batch_size );
|
||||
}
|
||||
/* ToDo: All the leftover thrads might go under-utilized. Could be optimized further. */
|
||||
( *n_threads_per_gemm ) = ( *n_threads ) / ( *n_gemms_in_parallel );
|
||||
}
|
||||
|
||||
BLIS_INLINE dim_t next_factor
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
Copyright (C) 2022 - 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
@@ -194,6 +194,8 @@ LPGEMV(uint8_t,int8_t,int32_t,u8s8s32os32)
|
||||
}
|
||||
else
|
||||
{
|
||||
dim_t gemm_MR = lcntx->blksz.MR;
|
||||
|
||||
// Compute the JC loop thread range for the current thread.
|
||||
dim_t jc_start, jc_end;
|
||||
thread_jc.n_way = ( thread_jc.n_way == 1 ) ?
|
||||
@@ -227,6 +229,10 @@ LPGEMV(uint8_t,int8_t,int32_t,u8s8s32os32)
|
||||
1, k,
|
||||
&rs_a_use, &cs_a_use
|
||||
);
|
||||
get_packa_strides_mfringe_u8s8s32os32
|
||||
(
|
||||
&rs_a_use, &cs_a_use, gemm_MR, 1
|
||||
);
|
||||
|
||||
a_use = pack_a_buffer;
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
An object-based framework for developing high-performance BLAS-like
|
||||
libraries.
|
||||
|
||||
Copyright (C) 2022 - 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
Copyright (C) 2022 - 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
@@ -35,6 +35,20 @@
|
||||
#ifndef BLIS_GEMM_INT8_PACKA
|
||||
#define BLIS_GEMM_INT8_PACKA
|
||||
|
||||
// The strides needs to be updated based on the m_fringe value to account
|
||||
// for different schemas used to pack A fringe cases.
|
||||
BLIS_INLINE void get_packa_strides_mfringe_u8s8s32os32
|
||||
(
|
||||
dim_t* rs,
|
||||
dim_t* cs,
|
||||
dim_t MR,
|
||||
dim_t m_fringe
|
||||
)
|
||||
{
|
||||
( *rs ) = 4;
|
||||
( *cs ) = ( ( *cs ) / MR ) * m_fringe;
|
||||
}
|
||||
|
||||
typedef void (*packa_s32)
|
||||
(
|
||||
uint8_t*,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
r t n n r 1 128 64 1 128 128 *:none
|
||||
c n t n n 32 128 2 32 128 32 bf16bf16f32of32:bias=na,swish
|
||||
r n n n r 6 1 4 4 16 16 bf16s4f32of32:pre_op_scale=scalar,pre_op_scale_type=bf16,group_size=2
|
||||
r n n n r 6 1 4 4 16 16 bf16s4f32of32:pre_op_zp=vector,pre_op_scale=scalar,pre_op_scale_type=bf16,group_size=2
|
||||
|
||||
Reference in New Issue
Block a user