[CK_TILE] Fixing Type Conversions in PassThroughPack8 (#2769)

* Change the return type of run_gemm_combinations in the basic tests

* Change the return type of run_gemm_combinations in the universal tests

* Add universal GEMM tests for bf16 x pk_i4 and fp16 x pk_i4

* Add universal GEMM test for fp8 x pk_i4

* Add basic GEMM tests for bf16 x pk_i4, fp16 x pk_i4 and fp8 x pk_i4.

* Add missing GemmTypeConfig<ck_tile::fp8_t, ck_tile::pk_int4_t, ck_tile::half_t>

* Add missing GemmTypeConfig<ck_tile::bf16_t, ck_tile::pk_int4_t, ck_tile::bf16_t>

* No need for utility in test_ck_tile_elementwise_1d

* Fix conversion from pk_int4x4_t to bf16x8_t in PassThroughPack8

* Avoid union-based type punning in float_to_bf16_truc_raw to make it constexpr compliant

* For consistency also make float_to_bf16_truc_nan_raw constexpr compliant by removing the union

* Use a static_cast to bfloat16_t only when CK_TILE_USE_LLVM_BUILTIN_BF16 is enforced

* Convert from float to bf16 during compilation rather than using magic values

* Fix conversion from pk_int4x4_t to fp8x8_t in PassThroughPack8

* Comment out the basic test for fp16 x pk_i4 as it does not pass

* Add missing GemmTypeConfig<ck_tile::bf8_t, ck_tile::pk_int4_t, ck_tile::half_t>

* Fix conversion from pk_int4x4_t to bf8x8_t in PassThroughPack8

* Add basic and universal GEMM tests for bf8 x pk_i4

* Switch back to amd_assembly_i4_to_fp8x8 in PassThroughPack8 as it works now

* Switch back to amd_assembly_i4_to_bf8x8 in PassThroughPack8 as it works now

* Remove the inefficient fallbacks for fp8 and bf8 in elementwise/unary_element_wise_operation.hpp

* Use explicit macros for enabling and disabling the the constexpr lookup based converters

* Fix two failing tests

* Avoid union-based type punning in float_to_bf16_rtn_raw to make it constexpr compliant

* Use float_to_bf16_rtn_raw instead of float_to_bf16 to create the bf16 lookup table for use in conversions from pk_int4 to bf16

* On ROCm 7.0.1 we need an explicit cast to from uint16_t to bf16_t
This commit is contained in:
SamiAario-AMD
2025-09-29 13:34:47 +03:00
committed by GitHub
parent e8842e3c1f
commit 0f10e6d921
16 changed files with 198 additions and 55 deletions

View File

@@ -1,6 +1,3 @@
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
add_gtest_executable(test_ck_tile_elementwise_1d test_elementwise_1d.cpp)
if(result EQUAL 0)
target_link_libraries(test_ck_tile_elementwise_1d PRIVATE utility)
endif()
endif()
endif()

View File

@@ -2,4 +2,11 @@
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_gemm_pipeline_basic_run_test.inc"
int main() { return run_gemm_combinations<ck_tile::bf16_t>(); }
int main()
{
bool is_success = true;
is_success = run_gemm_combinations<ck_tile::bf16_t>() && is_success;
is_success =
run_gemm_combinations<ck_tile::bf16_t, ck_tile::pk_int4_t, ck_tile::bf16_t>() && is_success;
return is_success ? EXIT_SUCCESS : EXIT_FAILURE;
}

View File

@@ -2,4 +2,12 @@
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_gemm_pipeline_basic_run_test.inc"
int main() { return run_gemm_combinations<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>(); }
int main()
{
bool is_success = true;
is_success =
run_gemm_combinations<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>() && is_success;
is_success =
run_gemm_combinations<ck_tile::bf8_t, ck_tile::pk_int4_t, ck_tile::half_t>() && is_success;
return is_success ? EXIT_SUCCESS : EXIT_FAILURE;
}

View File

@@ -2,4 +2,13 @@
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_gemm_pipeline_basic_run_test.inc"
int main() { return run_gemm_combinations<ck_tile::half_t>(); }
int main()
{
bool is_success = true;
is_success = run_gemm_combinations<ck_tile::half_t>() && is_success;
#if 0
is_success =
run_gemm_combinations<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>() && is_success;
#endif
return is_success ? EXIT_SUCCESS : EXIT_FAILURE;
}

View File

@@ -2,4 +2,12 @@
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_gemm_pipeline_basic_run_test.inc"
int main() { return run_gemm_combinations<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>(); }
int main()
{
bool is_success = true;
is_success =
run_gemm_combinations<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>() && is_success;
is_success =
run_gemm_combinations<ck_tile::fp8_t, ck_tile::pk_int4_t, ck_tile::half_t>() && is_success;
return is_success ? EXIT_SUCCESS : EXIT_FAILURE;
}

View File

@@ -225,7 +225,7 @@ bool run_gemm_test(int argc, char* argv[])
}
template <typename APrecType, typename BPrecType = APrecType, typename CPrecType = APrecType>
int run_gemm_combinations()
bool run_gemm_combinations()
{
// Define possible values for each parameter
std::vector<std::string> m_values = {"128", "1024"};
@@ -304,5 +304,5 @@ int run_gemm_combinations()
}
}
}
return is_success ? EXIT_SUCCESS : EXIT_FAILURE;
return is_success;
}

View File

@@ -263,6 +263,15 @@ struct GemmTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t>
using CDataType = ck_tile::bf16_t;
};
template <>
struct GemmTypeConfig<ck_tile::bf16_t, ck_tile::pk_int4_t, ck_tile::bf16_t>
{
using ADataType = ck_tile::bf16_t;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = ck_tile::bf16_t;
};
template <>
struct GemmTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>
{
@@ -281,6 +290,15 @@ struct GemmTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>
using CDataType = ck_tile::half_t;
};
template <>
struct GemmTypeConfig<ck_tile::bf8_t, ck_tile::pk_int4_t, ck_tile::half_t>
{
using ADataType = ck_tile::bf8_t;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmTypeConfig<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>
{
@@ -290,6 +308,15 @@ struct GemmTypeConfig<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>
using CDataType = ck_tile::half_t;
};
template <>
struct GemmTypeConfig<ck_tile::fp8_t, ck_tile::pk_int4_t, ck_tile::half_t>
{
using ADataType = ck_tile::fp8_t;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, int32_t>
{

View File

@@ -6,4 +6,11 @@
#include "test_gemm_pipeline_smoke_run_test.inc"
#include "test_gemm_pipeline_universal_run_test.inc"
int main() { return run_gemm_combinations<ck_tile::bf16_t>(); }
int main()
{
bool is_success = true;
is_success = run_gemm_combinations<ck_tile::bf16_t>() && is_success;
is_success =
run_gemm_combinations<ck_tile::bf16_t, ck_tile::pk_int4_t, ck_tile::bf16_t>() && is_success;
return is_success ? EXIT_SUCCESS : EXIT_FAILURE;
}

View File

@@ -6,4 +6,12 @@
#include "test_gemm_pipeline_smoke_run_test.inc"
#include "test_gemm_pipeline_universal_run_test.inc"
int main() { return run_gemm_combinations<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>(); }
int main()
{
bool is_success = true;
is_success =
run_gemm_combinations<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>() && is_success;
is_success =
run_gemm_combinations<ck_tile::bf8_t, ck_tile::pk_int4_t, ck_tile::half_t>() && is_success;
return is_success ? EXIT_SUCCESS : EXIT_FAILURE;
}

View File

@@ -6,4 +6,11 @@
#include "test_gemm_pipeline_smoke_run_test.inc"
#include "test_gemm_pipeline_universal_run_test.inc"
int main() { return run_gemm_combinations<ck_tile::half_t>(); }
int main()
{
bool is_success = true;
is_success = run_gemm_combinations<ck_tile::half_t>() && is_success;
is_success =
run_gemm_combinations<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>() && is_success;
return is_success ? EXIT_SUCCESS : EXIT_FAILURE;
}

View File

@@ -6,4 +6,12 @@
#include "test_gemm_pipeline_smoke_run_test.inc"
#include "test_gemm_pipeline_universal_run_test.inc"
int main() { return run_gemm_combinations<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>(); }
int main()
{
bool is_success = true;
is_success =
run_gemm_combinations<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>() && is_success;
is_success =
run_gemm_combinations<ck_tile::fp8_t, ck_tile::pk_int4_t, ck_tile::half_t>() && is_success;
return is_success ? EXIT_SUCCESS : EXIT_FAILURE;
}

View File

@@ -1,16 +1,15 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstddef>
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <string>
#include "ck_tile/host.hpp"
#include "test_gemm_pipeline_smoke_util.hpp"
#include "test_gemm_pipeline_smoke_run_test.inc"
#include "test_gemm_pipeline_universal_run_test.inc"
int main() { return run_gemm_combinations<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t>(); }
int main()
{
bool is_success = true;
is_success =
run_gemm_combinations<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t>() && is_success;
return is_success ? EXIT_SUCCESS : EXIT_FAILURE;
}

View File

@@ -1,16 +1,15 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstddef>
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <string>
#include "ck_tile/host.hpp"
#include "test_gemm_pipeline_smoke_util.hpp"
#include "test_gemm_pipeline_smoke_run_test.inc"
#include "test_gemm_pipeline_universal_run_test.inc"
int main() { return run_gemm_combinations<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>(); }
int main()
{
bool is_success = true;
is_success =
run_gemm_combinations<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>() && is_success;
return is_success ? EXIT_SUCCESS : EXIT_FAILURE;
}

View File

@@ -357,5 +357,5 @@ int run_gemm_combinations()
}
}
}
return is_success ? EXIT_SUCCESS : EXIT_FAILURE;
return is_success;
}