mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] Fixing Type Conversions in PassThroughPack8 (#2769)
* Change the return type of run_gemm_combinations in the basic tests * Change the return type of run_gemm_combinations in the universal tests * Add universal GEMM tests for bf16 x pk_i4 and fp16 x pk_i4 * Add universal GEMM test for fp8 x pk_i4 * Add basic GEMM tests for bf16 x pk_i4, fp16 x pk_i4 and fp8 x pk_i4. * Add missing GemmTypeConfig<ck_tile::fp8_t, ck_tile::pk_int4_t, ck_tile::half_t> * Add missing GemmTypeConfig<ck_tile::bf16_t, ck_tile::pk_int4_t, ck_tile::bf16_t> * No need for utility in test_ck_tile_elementwise_1d * Fix conversion from pk_int4x4_t to bf16x8_t in PassThroughPack8 * Avoid union-based type punning in float_to_bf16_truc_raw to make it constexpr compliant * For consistency also make float_to_bf16_truc_nan_raw constexpr compliant by removing the union * Use a static_cast to bfloat16_t only when CK_TILE_USE_LLVM_BUILTIN_BF16 is enforced * Convert from float to bf16 during compilation rather than using magic values * Fix conversion from pk_int4x4_t to fp8x8_t in PassThroughPack8 * Comment out the basic test for fp16 x pk_i4 as it does not pass * Add missing GemmTypeConfig<ck_tile::bf8_t, ck_tile::pk_int4_t, ck_tile::half_t> * Fix conversion from pk_int4x4_t to bf8x8_t in PassThroughPack8 * Add basic and universal GEMM tests for bf8 x pk_i4 * Switch back to amd_assembly_i4_to_fp8x8 in PassThroughPack8 as it works now * Switch back to amd_assembly_i4_to_bf8x8 in PassThroughPack8 as it works now * Remove the inefficient fallbacks for fp8 and bf8 in elementwise/unary_element_wise_operation.hpp * Use explicit macros for enabling and disabling the the constexpr lookup based converters * Fix two failing tests * Avoid union-based type punning in float_to_bf16_rtn_raw to make it constexpr compliant * Use float_to_bf16_rtn_raw instead of float_to_bf16 to create the bf16 lookup table for use in conversions from pk_int4 to bf16 * On ROCm 7.0.1 we need an explicit cast to from uint16_t to bf16_t
This commit is contained in:
@@ -1,6 +1,3 @@
|
||||
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
add_gtest_executable(test_ck_tile_elementwise_1d test_elementwise_1d.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_ck_tile_elementwise_1d PRIVATE utility)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
Reference in New Issue
Block a user