Files
blis/bench/UnitTests/test_thread_control.c
Varaganti, Kiran bb6545a46b Added new thread control API with global and thread-local variants
CPUPL-7578: New thread control API with global and thread-local variants

Summary: Add new BLIS thread control APIs that provide fine-grained control over threading with proper global and thread-local (TLS) semantics. Fix several correctness issues where set_num_threads() and set_ways() did not properly override each other's state.

New/Modified APIs:

bli_thread_set_num_threads() — Sets thread count globally (updates both global_rntm and tl_rntm)
bli_thread_set_num_threads_local() — Sets thread count for calling thread only (tl_rntm)
bli_thread_get_num_threads() — Returns effective thread count, deriving from ways if set
bli_thread_reset() — Resyncs tl_rntm from global_rntm
bli_thread_set_ways() — Sets loop factorization (jc, pc, ic, jr, ir)
bli_thread_get_is_parallel() — Returns whether parallelism is enabled
bli_thread_get_jc_nt/ic_nt/pc_nt/jr_nt/ir_nt() — Returns individual way values
b77_thread_set_num_threads_local_() — Fortran-compatible wrapper
Bug fixes:

bli_thread_set_num_threads() now clears ways (-1) and sets auto_factor=TRUE on both global_rntm and tl_rntm, so it properly overrides prior BLIS_JC_NT/BLIS_IC_NT environment settings
bli_thread_set_ways() now propagates to global_rntm (inside mutex) and clears stale num_threads on both global_rntm and tl_rntm, so get_num_threads() returns the product of ways instead of a stale value
Fix data race in bli_thread_init_rntm_from_global_rntm() — copy global_rntm under mutex before debug printing
Fix data race in set_num_threads_local() debug print
Test suite (43 tests, 106 assertions):

test_thread_control.c (OpenMP, 23 tests): environment inheritance, global propagation, thread-local isolation, local precedence, per-thread local, reset, nested parallel, edge cases, set_ways, is_parallel, concurrent updates, DGEMM with threads, interleaved settings, persistence, parallel DGEMM, thread pool, reset-to-sync, env ways vs set_num_threads, ways→set_nt→reset, ways→local→reset, round-trip, set_nt→set_ways override, set_ways propagation to new threads
test_thread_control_pthread.c (pthread, 20 tests): equivalent coverage plus concurrent set/reset race condition test, set_nt→set_ways override, set_ways propagation via pthread_create
Files changed (9 files, +2630/-29 lines):

bli_thread.c — Core API implementations and fixes
bli_thread.h — New function declarations
b77_thread.c — Fortran wrapper
test_thread_control.c — OpenMP test suite (23 tests)
test_thread_control_pthread.c — pthread test suite (20 tests)
TEST_THREAD_CONTROL_README.md — Documentation
AMD-Internal: CPUPL-7578
2026-03-06 12:16:17 +05:30

854 lines
34 KiB
C

/*
* Comprehensive test suite for BLIS thread control API
* Tests the new global and thread-local thread control variants:
* - bli_thread_set_num_threads() : Sets both global and thread-local
* - bli_thread_set_num_threads_local() : Sets thread-local only
* - bli_thread_get_num_threads() : Gets effective thread count
* - bli_thread_reset() : Resets thread-local to global value
* - bli_thread_set_ways() : Sets loop factorization
* - bli_thread_get_is_parallel() : Checks if parallelism is enabled
*
* Compile: gcc test_thread_control.c -fopenmp -lblis-mt -o test_thread_control
* Run: OMP_MAX_ACTIVE_LEVELS=2 ./test_thread_control [test_number]
*/
#include <omp.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
extern void bli_thread_set_num_threads(int num_threads);
extern void bli_thread_set_num_threads_local(int num_threads);
extern int bli_thread_get_num_threads(void);
extern void bli_thread_reset(void);
extern void bli_thread_set_ways(int jc, int pc, int ic, int jr, int ir);
extern int bli_thread_get_is_parallel(void);
extern int bli_thread_get_jc_nt(void);
extern int bli_thread_get_ic_nt(void);
extern int bli_thread_get_pc_nt(void);
extern int bli_thread_get_jr_nt(void);
extern int bli_thread_get_ir_nt(void);
extern void dgemm_(char* transa, char* transb, int* m, int* n, int* k,
double* alpha, double* a, int* lda, double* b, int* ldb,
double* beta, double* c, int* ldc);
#define MAX_THREADS 512
#define PASS "\033[32mPASS\033[0m"
#define FAIL "\033[31mFAIL\033[0m"
static int tests_passed = 0;
static int tests_failed = 0;
void print_separator(const char* title) {
printf("\n");
printf("========================================\n");
printf(" %s\n", title);
printf("========================================\n");
}
void check_result(const char* test_name, int condition) {
if (condition) { printf("[%s] %s\n", PASS, test_name); tests_passed++; }
else { printf("[%s] %s\n", FAIL, test_name); tests_failed++; }
}
void print_info(const char* msg) { printf("[INFO] %s\n", msg); }
// =============================================================================
// TEST 1: Environment variable inheritance
// =============================================================================
void test_1_env_inheritance(void) {
print_separator("TEST 1: Environment Variable Inheritance");
int initial_nt = bli_thread_get_num_threads();
printf("OMP_MAX_ACTIVE_LEVELS=%d, omp_get_max_threads()=%d\n",
omp_get_max_active_levels(), omp_get_max_threads());
printf("Initial bli_thread_get_num_threads() = %d\n", initial_nt);
check_result("Initial thread count > 0", initial_nt > 0);
}
// =============================================================================
// TEST 2: Global thread setting propagation
// =============================================================================
void test_2_global_propagation(void) {
print_separator("TEST 2: Global Setting Propagates to NEW Threads");
if (omp_get_max_active_levels() < 2) {
print_info("OMP_MAX_ACTIVE_LEVELS < 2: Thread spawning may be limited");
}
int num_threads[MAX_THREADS] = {0};
int num_launched = 0;
const int EXPECTED_NT = 16;
bli_thread_set_num_threads(EXPECTED_NT);
printf("Set global to %d\n", EXPECTED_NT);
#pragma omp parallel num_threads(4)
{
int tid = omp_get_thread_num();
#pragma omp single
num_launched = omp_get_num_threads();
num_threads[tid] = bli_thread_get_num_threads();
}
for (int i = 0; i < num_launched; i++)
printf(" Thread %d sees: %d\n", i, num_threads[i]);
int all_correct = 1;
for (int i = 0; i < num_launched; i++)
if (num_threads[i] != EXPECTED_NT) all_correct = 0;
check_result("Main thread sees correct value", bli_thread_get_num_threads() == EXPECTED_NT);
if (omp_get_max_active_levels() >= 2) {
check_result("All threads see global setting", all_correct);
} else {
// With limited active levels, threads may not spawn or may reuse main thread TLS
print_info("Skipping thread propagation check (OMP_MAX_ACTIVE_LEVELS < 2)");
}
bli_thread_reset();
}
// =============================================================================
// =============================================================================
// TEST 3: Local only affects calling thread
// =============================================================================
void test_3_local_only_affects_caller(void) {
print_separator("TEST 3: Local Setting Only Affects Calling Thread");
const int GLOBAL_NT = 8, LOCAL_NT = 24;
// Set global first, then reset threads to sync with new global
bli_thread_set_num_threads(GLOBAL_NT);
#pragma omp parallel num_threads(4)
{
bli_thread_reset();
}
bli_thread_set_num_threads_local(LOCAL_NT);
check_result("Main thread sees local override", bli_thread_get_num_threads() == LOCAL_NT);
int num_threads[MAX_THREADS] = {0};
int num_launched = 0;
#pragma omp parallel num_threads(4)
{
int tid = omp_get_thread_num();
#pragma omp single
num_launched = omp_get_num_threads();
num_threads[tid] = bli_thread_get_num_threads();
}
printf("Note: Thread 0 may reuse main thread's TLS\n");
for (int i = 0; i < num_launched; i++) {
const char* note = (num_threads[i] == LOCAL_NT) ? " (reused)" : "";
printf(" Thread %d sees: %d%s\n", i, num_threads[i], note);
}
if (omp_get_max_active_levels() >= 2) {
int some_see_global = 0;
for (int i = 0; i < num_launched; i++)
if (num_threads[i] == GLOBAL_NT) some_see_global = 1;
check_result("Some threads see global value", some_see_global);
} else {
print_info("Skipping thread isolation check (OMP_MAX_ACTIVE_LEVELS < 2)");
}
bli_thread_reset();
}
// =============================================================================
// TEST 4: Local override precedence
// =============================================================================
void test_4_local_precedence(void) {
print_separator("TEST 4: Local Override Precedence and Reset");
const int GLOBAL_NT = 16, LOCAL_NT = 32;
bli_thread_set_num_threads(GLOBAL_NT);
check_result("After global set", bli_thread_get_num_threads() == GLOBAL_NT);
bli_thread_set_num_threads_local(LOCAL_NT);
check_result("Local overrides global", bli_thread_get_num_threads() == LOCAL_NT);
bli_thread_reset();
check_result("Reset restores global", bli_thread_get_num_threads() == GLOBAL_NT);
}
// =============================================================================
// =============================================================================
// TEST 5: Per-thread local settings
// =============================================================================
void test_5_per_thread_local(void) {
print_separator("TEST 5: Per-Thread Local Settings");
if (omp_get_max_active_levels() < 2) {
print_info("OMP_MAX_ACTIVE_LEVELS < 2: Skipping (thread-local isolation requires nested parallelism)");
return;
}
bli_thread_set_num_threads(1);
bli_thread_reset(); // Ensure clean state
int local_values[3] = {4, 12, 20};
int seen_values[3] = {0};
#pragma omp parallel num_threads(3)
{
int tid = omp_get_thread_num();
bli_thread_set_num_threads_local(local_values[tid]);
seen_values[tid] = bli_thread_get_num_threads();
printf("Thread %d: set %d, gets %d\n", tid, local_values[tid], seen_values[tid]);
}
int all_correct = 1;
for (int i = 0; i < 3; i++)
if (seen_values[i] != local_values[i]) all_correct = 0;
check_result("Each thread sees its own local setting", all_correct);
}
// =============================================================================
// =============================================================================
// TEST 6: Reset in child threads
// =============================================================================
void test_6_reset_in_children(void) {
print_separator("TEST 6: Reset in Child Threads");
if (omp_get_max_active_levels() < 2) {
print_info("OMP_MAX_ACTIVE_LEVELS < 2: Skipping (requires nested parallelism)");
return;
}
const int GLOBAL_NT = 8;
bli_thread_set_num_threads(GLOBAL_NT);
int after_reset[MAX_THREADS] = {0};
int num_launched = 0;
#pragma omp parallel num_threads(3)
{
int tid = omp_get_thread_num();
#pragma omp single
num_launched = omp_get_num_threads();
bli_thread_set_num_threads_local(100 + tid);
bli_thread_reset();
after_reset[tid] = bli_thread_get_num_threads();
}
int reset_works = 1;
for (int i = 0; i < num_launched; i++) {
printf(" Thread %d after reset: %d (expected %d)\n", i, after_reset[i], GLOBAL_NT);
if (after_reset[i] != GLOBAL_NT) reset_works = 0;
}
check_result("Reset restores global in all threads", reset_works);
}
// =============================================================================
// TEST 7: Nested parallel regions
// =============================================================================
void test_7_nested_parallel(void) {
print_separator("TEST 7: Nested Parallel Regions");
if (omp_get_max_active_levels() < 2) {
print_info("Skipping: need OMP_MAX_ACTIVE_LEVELS>=2");
return;
}
const int GLOBAL_NT = 8;
bli_thread_set_num_threads(GLOBAL_NT);
int outer_values[3] = {2, 3, 4};
int outer_sees[3] = {0};
int inner_sees[3][2] = {{0}};
#pragma omp parallel num_threads(3)
{
int ptid = omp_get_thread_num();
bli_thread_set_num_threads_local(outer_values[ptid]);
outer_sees[ptid] = bli_thread_get_num_threads();
printf("Outer[%d]: local=%d\n", ptid, outer_sees[ptid]);
#pragma omp parallel num_threads(2)
{
int ctid = omp_get_thread_num();
inner_sees[ptid][ctid] = bli_thread_get_num_threads();
printf(" Inner[%d.%d]: sees=%d\n", ptid, ctid, inner_sees[ptid][ctid]);
}
}
// Verify outer threads see their local value
int outer_correct = 1;
for (int p = 0; p < 3; p++) {
if (outer_sees[p] != outer_values[p]) outer_correct = 0;
}
check_result("Outer threads see their local values", outer_correct);
// Document: Inner threads do NOT inherit parent's local - they see global or default
// This is expected TLS behavior - each new thread starts fresh
print_info("Note: Inner threads don't inherit parent's local (expected TLS behavior)");
int inner_valid = 1;
for (int p = 0; p < 3; p++) {
for (int c = 0; c < 2; c++) {
// Inner threads should see global or default, not parent's local
if (inner_sees[p][c] == outer_values[p]) {
printf(" Unexpected: Inner[%d.%d] inherited parent local\n", p, c);
inner_valid = 0;
}
}
}
check_result("Inner threads have independent TLS (don't inherit parent)", inner_valid);
}
// =============================================================================
// TEST 8: Edge cases
// =============================================================================
void test_8_edge_cases(void) {
print_separator("TEST 8: Edge Cases");
bli_thread_set_num_threads(0);
check_result("Zero becomes 1", bli_thread_get_num_threads() == 1);
bli_thread_set_num_threads_local(0);
check_result("Local zero becomes 1", bli_thread_get_num_threads() == 1);
bli_thread_set_num_threads(1000);
check_result("Large value accepted", bli_thread_get_num_threads() == 1000);
bli_thread_reset();
}
// =============================================================================
// TEST 9: bli_thread_set_ways() API
// =============================================================================
void test_9_set_ways(void) {
print_separator("TEST 9: bli_thread_set_ways() API");
bli_thread_set_ways(2, 1, 2, 2, 1);
check_result("Ways (2*1*2*2*1=8)", bli_thread_get_num_threads() == 8);
bli_thread_set_ways(4, 1, 4, 1, 1);
check_result("Ways (4*1*4*1*1=16)", bli_thread_get_num_threads() == 16);
}
// =============================================================================
// TEST 10: bli_thread_get_is_parallel() API
// =============================================================================
void test_10_is_parallel(void) {
print_separator("TEST 10: bli_thread_get_is_parallel() API");
bli_thread_set_num_threads(1);
bli_thread_reset(); // Ensure clean state
check_result("1 thread = not parallel", bli_thread_get_is_parallel() == 0);
bli_thread_set_num_threads(4);
check_result("4 threads = parallel", bli_thread_get_is_parallel() == 1);
}
// =============================================================================
// TEST 11: Concurrent global updates (stress test)
// =============================================================================
void test_11_concurrent_global_updates(void) {
print_separator("TEST 11: Concurrent Global Updates (Stress Test)");
bli_thread_set_num_threads(1);
bli_thread_reset(); // Ensure clean state
int bad_values = 0;
#pragma omp parallel num_threads(4) reduction(+:bad_values)
{
for (int i = 0; i < 100; i++) {
bli_thread_set_num_threads(omp_get_thread_num() + 1);
int val = bli_thread_get_num_threads();
// Value should be one of the expected values (1, 2, 3, or 4)
if (val < 1 || val > 4) {
bad_values++;
}
}
}
int final = bli_thread_get_num_threads();
check_result("All values in expected range (1-4)", bad_values == 0);
check_result("Final value valid", final >= 1 && final <= 4);
print_info("Note: Run with -fsanitize=thread to detect actual data races");
}
// =============================================================================
// TEST 12: DGEMM with different thread settings
// =============================================================================
void test_12_dgemm_with_threads(void) {
print_separator("TEST 12: DGEMM with Different Thread Settings");
int n = 100;
double alpha = 1.0, beta = 0.0;
double *A = calloc(n * n, sizeof(double));
double *B = calloc(n * n, sizeof(double));
double *C = calloc(n * n, sizeof(double));
for (int i = 0; i < n * n; i++) { A[i] = 1.0; B[i] = 1.0; }
int thread_counts[] = {1, 2, 4, 8};
int all_correct = 1;
for (int t = 0; t < 4; t++) {
bli_thread_set_num_threads(thread_counts[t]);
memset(C, 0, n * n * sizeof(double));
dgemm_("N", "N", &n, &n, &n, &alpha, A, &n, B, &n, &beta, C, &n);
int correct = (C[0] == (double)n);
printf("DGEMM with %d threads: %s\n", thread_counts[t], correct ? "PASS" : "FAIL");
if (!correct) all_correct = 0;
}
free(A); free(B); free(C);
check_result("DGEMM correct with various threads", all_correct);
}
// =============================================================================
// TEST 13: Interleaved global and local settings
// =============================================================================
void test_13_interleaved_settings(void) {
print_separator("TEST 13: Interleaved Global and Local Settings");
// Reset to ensure clean state
bli_thread_reset();
int seq[5] = {0};
bli_thread_set_num_threads(4); seq[0] = bli_thread_get_num_threads();
bli_thread_set_num_threads(8); seq[1] = bli_thread_get_num_threads();
bli_thread_set_num_threads_local(12); seq[2] = bli_thread_get_num_threads();
bli_thread_set_num_threads(16); seq[3] = bli_thread_get_num_threads();
bli_thread_reset(); seq[4] = bli_thread_get_num_threads();
printf("Sequence: %d->%d->%d->%d->%d (expected 4->8->12->16->16)\n",
seq[0], seq[1], seq[2], seq[3], seq[4]);
int correct = (seq[0]==4 && seq[1]==8 && seq[2]==12 && seq[3]==16 && seq[4]==16);
check_result("Sequence correct", correct);
}
// =============================================================================
// TEST 14: Thread count persists across OMP regions
// =============================================================================
void test_14_persistence_across_regions(void) {
print_separator("TEST 14: Thread Count Persists Across Regions");
bli_thread_set_num_threads_local(42);
#pragma omp parallel num_threads(2)
{ /* dummy region */ }
check_result("tl_rntm persists", bli_thread_get_num_threads() == 42);
bli_thread_reset();
}
// =============================================================================
// TEST 15: Parallel DGEMM with per-thread settings
// =============================================================================
void test_15_parallel_dgemm_different_threads(void) {
print_separator("TEST 15: Parallel DGEMM with Per-Thread Settings");
int n = 100;
double alpha = 1.0, beta = 0.0;
double *A = calloc(n * n, sizeof(double));
double *B = calloc(n * n, sizeof(double));
double *C1 = calloc(n * n, sizeof(double));
double *C2 = calloc(n * n, sizeof(double));
for (int i = 0; i < n * n; i++) { A[i] = 1.0; B[i] = 1.0; }
bli_thread_set_num_threads(1);
bli_thread_reset(); // Ensure clean state
int results[2] = {0};
#pragma omp parallel num_threads(2)
{
int tid = omp_get_thread_num();
bli_thread_set_num_threads_local(tid == 0 ? 2 : 4);
double *C = (tid == 0) ? C1 : C2;
memset(C, 0, n * n * sizeof(double));
dgemm_("N", "N", &n, &n, &n, &alpha, A, &n, B, &n, &beta, C, &n);
results[tid] = (C[0] == (double)n);
printf("Thread %d: BLIS=%d, C[0]=%f\n", tid, bli_thread_get_num_threads(), C[0]);
}
free(A); free(B); free(C1); free(C2);
check_result("Parallel DGEMM correct", results[0] && results[1]);
}
// =============================================================================
// TEST 16: Thread pool reuse behavior (informational)
// =============================================================================
void test_16_thread_pool_behavior(void) {
print_separator("TEST 16: Thread Pool Reuse Behavior");
bli_thread_set_num_threads(4);
// Create thread pool with initial setting
int first_values[MAX_THREADS] = {0};
#pragma omp parallel num_threads(4)
{
bli_thread_set_num_threads_local(omp_get_thread_num() + 1);
first_values[omp_get_thread_num()] = bli_thread_get_num_threads();
}
print_info("OMP may reuse threads - they keep existing tl_rntm");
bli_thread_set_num_threads(32);
printf("Set global to 32 AFTER thread pool created\n");
int second_values[MAX_THREADS] = {0};
#pragma omp parallel num_threads(4)
{
second_values[omp_get_thread_num()] = bli_thread_get_num_threads();
}
printf("Thread values: %d, %d, %d, %d\n",
second_values[0], second_values[1], second_values[2], second_values[3]);
print_info("Reused threads may NOT see 32 - this is expected");
print_info("Solution: call bli_thread_reset() in threads to sync");
check_result("Test documents pool behavior", 1);
}
// =============================================================================
// TEST 17: Use reset() to sync threads with global
// =============================================================================
void test_17_reset_to_sync_global(void) {
print_separator("TEST 17: Use reset() to Sync Threads with Global");
if (omp_get_max_active_levels() < 2) {
print_info("OMP_MAX_ACTIVE_LEVELS < 2: Skipping (requires nested parallelism)");
return;
}
bli_thread_set_num_threads(64);
printf("Set global to 64\n");
int values[MAX_THREADS] = {0};
#pragma omp parallel num_threads(4)
{
bli_thread_reset();
values[omp_get_thread_num()] = bli_thread_get_num_threads();
}
printf("After reset(): %d, %d, %d, %d\n", values[0], values[1], values[2], values[3]);
int all_64 = (values[0]==64 && values[1]==64 && values[2]==64 && values[3]==64);
check_result("All threads sync to 64 after reset()", all_64);
}
// =============================================================================
// Main
// =============================================================================
// TEST 18: Interaction between BLIS_*_NT env vars and bli_thread_set_num_threads()
// =============================================================================
void test_18_env_ways_vs_set_num_threads(void) {
print_separator("TEST 18: BLIS_*_NT env vars vs bli_thread_set_num_threads()");
// This test verifies that bli_thread_set_num_threads() correctly overrides
// any prior ways configuration (BLIS_JC_NT, BLIS_IC_NT, etc.).
//
// Expected behavior (after fix):
// - bli_thread_set_num_threads() clears ways and sets num_threads
// - bli_thread_get_num_threads() returns the new value
// - Ways are reset to -1 (unset), enabling auto-factorization
// First, set ways explicitly to simulate env var initialization
bli_thread_set_ways(2, 1, 4, 2, 1); // Total = 2*1*4*2*1 = 16 threads
int initial_nt = bli_thread_get_num_threads();
int initial_jc = bli_thread_get_jc_nt();
int initial_ic = bli_thread_get_ic_nt();
int initial_jr = bli_thread_get_jr_nt();
printf("After bli_thread_set_ways(2,1,4,2,1):\n");
printf(" num_threads=%d (derived from ways: 2*1*4*2*1=16)\n", initial_nt);
printf(" jc_nt=%d, ic_nt=%d, jr_nt=%d\n", initial_jc, initial_ic, initial_jr);
check_result("Initial ways set correctly (jc=2)", initial_jc == 2);
check_result("Initial ways set correctly (ic=4)", initial_ic == 4);
check_result("Initial num_threads = 16", initial_nt == 16);
// Now call bli_thread_set_num_threads() with a different value
bli_thread_set_num_threads(8);
int after_nt = bli_thread_get_num_threads();
int after_jc = bli_thread_get_jc_nt();
int after_ic = bli_thread_get_ic_nt();
int after_jr = bli_thread_get_jr_nt();
printf("\nAfter bli_thread_set_num_threads(8):\n");
printf(" bli_thread_get_num_threads() = %d\n", after_nt);
printf(" jc_nt=%d, ic_nt=%d, jr_nt=%d\n", after_jc, after_ic, after_jr);
// After fix: num_threads should be 8, ways should be cleared (-1)
check_result("num_threads changed to 8", after_nt == 8);
check_result("jc_nt cleared to -1", after_jc == -1);
check_result("ic_nt cleared to -1", after_ic == -1);
check_result("jr_nt cleared to -1", after_jr == -1);
printf("\nVerified: bli_thread_set_num_threads() correctly overrides ways\n");
// Cleanup
bli_thread_reset();
bli_thread_set_num_threads(1);
}
// =============================================================================
// TEST 19: set_ways then set_num_threads then reset
// =============================================================================
void test_19_ways_then_set_nt_then_reset(void) {
print_separator("TEST 19: set_ways -> set_num_threads -> reset");
// Step 1: Set ways
bli_thread_set_ways(2, 1, 4, 2, 1); // 16 threads
int nt1 = bli_thread_get_num_threads();
int jc1 = bli_thread_get_jc_nt();
printf("After set_ways(2,1,4,2,1): nt=%d, jc=%d\n", nt1, jc1);
check_result("Ways give 16 threads", nt1 == 16);
check_result("jc=2", jc1 == 2);
// Step 2: Override with set_num_threads
bli_thread_set_num_threads(8);
int nt2 = bli_thread_get_num_threads();
int jc2 = bli_thread_get_jc_nt();
printf("After set_num_threads(8): nt=%d, jc=%d\n", nt2, jc2);
check_result("num_threads = 8", nt2 == 8);
check_result("jc cleared to -1", jc2 == -1);
// Step 3: Reset - should restore to global, which was updated by set_num_threads
bli_thread_reset();
int nt3 = bli_thread_get_num_threads();
int jc3 = bli_thread_get_jc_nt();
printf("After reset(): nt=%d, jc=%d\n", nt3, jc3);
check_result("After reset, num_threads = 8 (from global)", nt3 == 8);
check_result("After reset, jc still -1 (global was cleared)", jc3 == -1);
// Cleanup
bli_thread_set_num_threads(1);
}
// =============================================================================
// TEST 20: set_ways then set_num_threads_local then reset
// =============================================================================
void test_20_ways_then_local_then_reset(void) {
print_separator("TEST 20: set_ways -> set_num_threads_local -> reset");
// Step 1: Set ways (updates both tl_rntm AND global_rntm)
bli_thread_set_ways(2, 1, 4, 2, 1); // 16 threads
int nt1 = bli_thread_get_num_threads();
printf("After set_ways(2,1,4,2,1): nt=%d\n", nt1);
check_result("Ways give 16 threads", nt1 == 16);
// Step 2: Override locally with set_num_threads_local
bli_thread_set_num_threads_local(8);
int nt2 = bli_thread_get_num_threads();
int jc2 = bli_thread_get_jc_nt();
printf("After set_num_threads_local(8): nt=%d, jc=%d\n", nt2, jc2);
check_result("Local num_threads = 8", nt2 == 8);
check_result("Local jc cleared to -1", jc2 == -1);
// Step 3: Reset - restores tl_rntm from global_rntm
// global_rntm WAS modified by set_ways (ways + blis_mt + num_threads=-1).
// set_num_threads_local() does NOT update global_rntm.
// Therefore reset restores the ways set in Step 1.
bli_thread_reset();
int nt3 = bli_thread_get_num_threads();
int jc3 = bli_thread_get_jc_nt();
printf("After reset(): nt=%d, jc=%d\n", nt3, jc3);
// The local override (8) is gone; the ways from set_ways() are restored
// from global_rntm.
check_result("After reset, nt is NOT 8 (local cleared)", nt3 != 8);
check_result("After reset, nt=16 (ways restored from global)", nt3 == 16);
check_result("After reset, jc=2 (ways restored from global)", jc3 == 2);
// Cleanup
bli_thread_set_num_threads(1);
}
// =============================================================================
// TEST 21: set_num_threads then set_ways then set_num_threads
// =============================================================================
void test_21_nt_ways_nt_roundtrip(void) {
print_separator("TEST 21: set_num_threads -> set_ways -> set_num_threads");
// Step 1: Set num_threads
bli_thread_set_num_threads(8);
check_result("Step 1: nt=8", bli_thread_get_num_threads() == 8);
check_result("Step 1: jc=-1 (auto factor)", bli_thread_get_jc_nt() == -1);
// Step 2: Override with explicit ways
bli_thread_set_ways(2, 1, 2, 2, 1); // 8 threads via ways
int nt2 = bli_thread_get_num_threads();
int jc2 = bli_thread_get_jc_nt();
printf("After set_ways(2,1,2,2,1): nt=%d, jc=%d\n", nt2, jc2);
check_result("Step 2: nt=8 (from ways)", nt2 == 8);
check_result("Step 2: jc=2", jc2 == 2);
// Step 3: Override again with set_num_threads - should clear ways
bli_thread_set_num_threads(4);
int nt3 = bli_thread_get_num_threads();
int jc3 = bli_thread_get_jc_nt();
printf("After set_num_threads(4): nt=%d, jc=%d\n", nt3, jc3);
check_result("Step 3: nt=4", nt3 == 4);
check_result("Step 3: jc cleared", jc3 == -1);
// Cleanup
bli_thread_set_num_threads(1);
}
void test_22_set_nt_then_set_ways(void) {
print_separator("TEST 22: set_num_threads then set_ways overrides nt");
// Step 1: Set num_threads to 8
bli_thread_set_num_threads(8);
int nt1 = bli_thread_get_num_threads();
printf("After set_num_threads(8): nt=%d\n", nt1);
check_result("Step 1: nt=8", nt1 == 8);
// Step 2: Set ways to 4x1x2x1x1 = 8 (same total, but via ways)
bli_thread_set_ways(4, 1, 2, 1, 1);
int nt2 = bli_thread_get_num_threads();
int jc2 = bli_thread_get_jc_nt();
int ic2 = bli_thread_get_ic_nt();
printf("After set_ways(4,1,2,1,1): nt=%d, jc=%d, ic=%d\n",
nt2, jc2, ic2);
check_result("Step 2: nt=8 (from 4*1*2*1*1)", nt2 == 8);
check_result("Step 2: jc=4", jc2 == 4);
check_result("Step 2: ic=2", ic2 == 2);
// Step 3: Set ways to a DIFFERENT total: 2x1x3x1x1 = 6
// This is the key test — nt must NOT be stale 8
bli_thread_set_ways(2, 1, 3, 1, 1);
int nt3 = bli_thread_get_num_threads();
int jc3 = bli_thread_get_jc_nt();
int ic3 = bli_thread_get_ic_nt();
printf("After set_ways(2,1,3,1,1): nt=%d, jc=%d, ic=%d\n",
nt3, jc3, ic3);
check_result("Step 3: nt=6 (from 2*1*3*1*1, not stale 8)", nt3 == 6);
check_result("Step 3: jc=2", jc3 == 2);
check_result("Step 3: ic=3", ic3 == 3);
// Cleanup
bli_thread_set_num_threads(1);
}
void test_23_set_ways_propagates_to_new_threads(void) {
print_separator("TEST 23: set_ways propagates to new OMP threads via global_rntm");
// Step 1: set_num_threads(12), then set_ways(3,1,2,1,1) = 6
bli_thread_set_num_threads(12);
int nt1 = bli_thread_get_num_threads();
printf("After set_num_threads(12): nt=%d\n", nt1);
check_result("Step 1: nt=12", nt1 == 12);
bli_thread_set_ways(3, 1, 2, 1, 1);
int nt2 = bli_thread_get_num_threads();
printf("After set_ways(3,1,2,1,1): nt=%d\n", nt2);
check_result("Step 2: nt=6 (from 3*1*2*1*1)", nt2 == 6);
// Step 3: Spawn a new OMP thread — it should inherit ways from global_rntm
// (not the stale num_threads=12)
int child_nt = -1;
int child_jc = -1;
int child_ic = -1;
#pragma omp parallel num_threads(1)
{
// New thread — tl_rntm initialized from global_rntm
child_nt = bli_thread_get_num_threads();
child_jc = bli_thread_get_jc_nt();
child_ic = bli_thread_get_ic_nt();
}
printf("Child thread: nt=%d, jc=%d, ic=%d\n",
child_nt, child_jc, child_ic);
check_result("Step 3: child nt=6 (from ways, not stale 12)", child_nt == 6);
check_result("Step 3: child jc=3", child_jc == 3);
check_result("Step 3: child ic=2", child_ic == 2);
// Cleanup
bli_thread_set_num_threads(1);
}
// =============================================================================
int main(int argc, char** argv) {
printf("BLIS Thread Control API Test Suite\n");
printf("===================================\n");
printf("OMP_MAX_ACTIVE_LEVELS=%d, omp_get_max_threads()=%d\n",
omp_get_max_active_levels(), omp_get_max_threads());
if (argc == 1) {
test_1_env_inheritance(); test_2_global_propagation();
test_3_local_only_affects_caller(); test_4_local_precedence();
test_5_per_thread_local(); test_6_reset_in_children();
test_7_nested_parallel(); test_8_edge_cases();
test_9_set_ways(); test_10_is_parallel();
test_11_concurrent_global_updates(); test_12_dgemm_with_threads();
test_13_interleaved_settings(); test_14_persistence_across_regions();
test_15_parallel_dgemm_different_threads(); test_16_thread_pool_behavior();
test_17_reset_to_sync_global();
test_18_env_ways_vs_set_num_threads();
test_19_ways_then_set_nt_then_reset();
test_20_ways_then_local_then_reset();
test_21_nt_ways_nt_roundtrip();
test_22_set_nt_then_set_ways();
test_23_set_ways_propagates_to_new_threads();
} else {
int test_num = atoi(argv[1]);
switch (test_num) {
case 0: /* fall through to run all */
test_1_env_inheritance(); test_2_global_propagation();
test_3_local_only_affects_caller(); test_4_local_precedence();
test_5_per_thread_local(); test_6_reset_in_children();
test_7_nested_parallel(); test_8_edge_cases();
test_9_set_ways(); test_10_is_parallel();
test_11_concurrent_global_updates(); test_12_dgemm_with_threads();
test_13_interleaved_settings(); test_14_persistence_across_regions();
test_15_parallel_dgemm_different_threads(); test_16_thread_pool_behavior();
test_17_reset_to_sync_global();
test_18_env_ways_vs_set_num_threads();
test_19_ways_then_set_nt_then_reset();
test_20_ways_then_local_then_reset();
test_21_nt_ways_nt_roundtrip();
test_22_set_nt_then_set_ways();
test_23_set_ways_propagates_to_new_threads();
break;
case 1: test_1_env_inheritance(); break;
case 2: test_2_global_propagation(); break;
case 3: test_3_local_only_affects_caller(); break;
case 4: test_4_local_precedence(); break;
case 5: test_5_per_thread_local(); break;
case 6: test_6_reset_in_children(); break;
case 7: test_7_nested_parallel(); break;
case 8: test_8_edge_cases(); break;
case 9: test_9_set_ways(); break;
case 10: test_10_is_parallel(); break;
case 11: test_11_concurrent_global_updates(); break;
case 12: test_12_dgemm_with_threads(); break;
case 13: test_13_interleaved_settings(); break;
case 14: test_14_persistence_across_regions(); break;
case 15: test_15_parallel_dgemm_different_threads(); break;
case 16: test_16_thread_pool_behavior(); break;
case 17: test_17_reset_to_sync_global(); break;
case 18: test_18_env_ways_vs_set_num_threads(); break;
case 19: test_19_ways_then_set_nt_then_reset(); break;
case 20: test_20_ways_then_local_then_reset(); break;
case 21: test_21_nt_ways_nt_roundtrip(); break;
case 22: test_22_set_nt_then_set_ways(); break;
case 23: test_23_set_ways_propagates_to_new_threads(); break;
default: printf("Invalid test number: %d\n", test_num); return 1;
}
}
printf("\n========================================\n");
printf(" SUMMARY\n");
printf("========================================\n");
printf("Passed: %d\n", tests_passed);
printf("Failed: %d\n", tests_failed);
printf("Total: %d\n", tests_passed + tests_failed);
return tests_failed > 0 ? 1 : 0;
}