Upgrade to ROCm7.0.1 compiler. (#2909)

* upgrade default docker to rocm7.0.1

* turn on build and test on gfx950 by default

* use rocm-dev instead of rocm

* link libhiprtc for codegen targets

* resolving codegen compilation errors: removed calls to other std functions, resolved issues with int32_t: needed the correct header, put use of e8m0 into header guards

---------

Co-authored-by: Astha Rai <astha.rai713@gmail.com>
This commit is contained in:
Illia Silin
2025-09-24 10:00:53 -07:00
committed by GitHub
parent fe0a47a011
commit 8fe3838c65
15 changed files with 50 additions and 47 deletions

View File

@@ -2,7 +2,7 @@
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <stdint.h>
#include "ck/utility/amd_ck_fp8.hpp"
#include "ck/utility/e8m0.hpp"
#include "ck/utility/statically_indexed_array.hpp"
@@ -325,12 +325,14 @@ struct scalar_type<bf8_ocp_t>
static constexpr index_t vector_size = 1;
};
#ifndef CK_CODE_GEN_RTC
template <>
struct scalar_type<e8m0_bexp_t>
{
using type = e8m0_bexp_t::type;
static constexpr index_t vector_size = 1;
};
#endif
template <>
struct scalar_type<f4x2_pk_t>
@@ -483,8 +485,10 @@ inline const char* get_type_name()
return "f8";
else if constexpr(is_same_v<T, bf8_t>)
return "bf8";
#ifndef CK_CODE_GEN_RTC
else if constexpr(is_same_v<T, e8m0_bexp_t>)
return "e8m0";
#endif
else if constexpr(is_same_v<T, float>)
return "fp32";
#if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC)

View File

@@ -13,7 +13,7 @@ template <typename T, typename Enable = void>
struct PrintAsType;
template <typename T>
struct PrintAsType<T, typename std::enable_if<std::is_floating_point<T>::value>::type>
struct PrintAsType<T, typename enable_if<is_floating_point<T>::value>::type>
{
using type = float;
__host__ __device__ static void Print(const T& p) { printf("%.3f ", static_cast<type>(p)); }
@@ -30,7 +30,7 @@ struct PrintAsType<ck::half_t, void>
};
template <typename T>
struct PrintAsType<T, typename std::enable_if<std::is_integral<T>::value>::type>
struct PrintAsType<T, typename enable_if<is_integral<T>::value>::type>
{
using type = int;
__host__ __device__ static void Print(const T& p) { printf("%d ", static_cast<type>(p)); }

View File

@@ -1294,6 +1294,7 @@ struct nnvb_data_t_selector<bf8_ocp_t>
using type = bf8_ocp_t::data_type;
};
#ifndef CK_CODE_GEN_RTC
template <>
struct nnvb_data_t_selector<f8_fnuz_t>
{
@@ -1311,6 +1312,7 @@ struct nnvb_data_t_selector<e8m0_bexp_t>
{
using type = e8m0_bexp_t::type;
};
#endif
template <>
struct nnvb_data_t_selector<f6x16_pk_t>
@@ -2270,8 +2272,10 @@ using bf6x16_t = typename vector_type<bf6x16_pk_t, 1>::type;
using bf6x16x2_t = typename vector_type<bf6x16_pk_t, 2>::type;
using bf6x32_t = typename vector_type<bf6x32_pk_t, 1>::type;
#ifndef CK_CODE_GEN_RTC
// e8m0
using e8m0x4_bexp_t = typename vector_type<e8m0_bexp_t, 4>::type;
#endif
// pack int4
using pk_i4x2_t = typename vector_type<pk_i4_t, 2>::type;

View File

@@ -3,6 +3,7 @@
#pragma once
#ifndef CK_CODE_GEN_RTC
#include "ck/utility/type.hpp"
namespace ck {
@@ -78,3 +79,4 @@ __host__ __device__ inline constexpr int32_t get_exponent_value<e8m0_bexp_t>(e8m
} // namespace utils
} // namespace ck
#endif

View File

@@ -273,8 +273,8 @@ template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ Y cast_to_f8(X x, uint32_t rng)
{
// check datatypes
constexpr bool is_half = std::is_same<X, half_t>::value;
constexpr bool is_float = std::is_same<X, float>::value;
constexpr bool is_half = is_same<X, half_t>::value;
constexpr bool is_float = is_same<X, float>::value;
static_assert(is_half || is_float, "Only half and float can be casted.");
return run_cast_to_f8<X, Y, negative_zero_nan, clip, stoch>(x, rng);
@@ -284,8 +284,8 @@ template <typename X, typename Y, bool negative_zero_nan>
__host__ __device__ Y cast_from_f8(X x)
{
// check datatype
constexpr bool is_half = std::is_same<Y, half_t>::value;
constexpr bool is_float = std::is_same<Y, float>::value;
constexpr bool is_half = is_same<Y, half_t>::value;
constexpr bool is_float = is_same<Y, float>::value;
static_assert(is_half || is_float, "only half and float are supported.");
return run_cast_from_f8<X, Y, negative_zero_nan>(x);

View File

@@ -10,10 +10,6 @@
#include "type.hpp"
#include "tuple.hpp"
#ifdef CK_CODE_GEN_RTC
#define INT32_MAX 2147483647
#endif
namespace ck {
// magic number division

View File

@@ -522,8 +522,6 @@ struct NumericLimits<bf6_t>
}
};
#endif
template <>
struct NumericLimits<e8m0_bexp_t>
{
@@ -551,5 +549,6 @@ struct NumericLimits<e8m0_bexp_t>
return e8m0_bexp_t(binary_142);
}
};
#endif
} // namespace ck

View File

@@ -10,6 +10,7 @@ struct NumericUtils
{
};
#ifndef CK_CODE_GEN_RTC
template <>
struct NumericUtils<e8m0_bexp_t>
{
@@ -24,6 +25,7 @@ struct NumericUtils<e8m0_bexp_t>
using bitwise_type = uint8_t;
};
#endif
template <>
struct NumericUtils<float>

View File

@@ -15,7 +15,7 @@ namespace ck {
// Pseudo random number generator
// version for fp32
template <typename T, uint32_t seed_t, ck::enable_if_t<std::is_same<float, T>{}, bool> = false>
template <typename T, uint32_t seed_t, ck::enable_if_t<is_same<float, T>{}, bool> = false>
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
{
uint32_t x = bit_cast<uint32_t>(val);
@@ -31,7 +31,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
}
// version for fp16
template <typename T, uint32_t seed_t, ck::enable_if_t<std::is_same<_Float16, T>{}, bool> = false>
template <typename T, uint32_t seed_t, ck::enable_if_t<is_same<_Float16, T>{}, bool> = false>
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
{
uint16_t x = bit_cast<uint16_t>(val);
@@ -48,7 +48,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
// return 0 if data is not fp16 or fp32
template <typename T,
uint32_t seed_t,
ck::enable_if_t<!(std::is_same<float, T>{} || std::is_same<_Float16, T>{}), bool> = false>
ck::enable_if_t<!(is_same<float, T>{} || is_same<_Float16, T>{}), bool> = false>
__host__ __device__ uint32_t prand_generator(int id, T val, uint32_t seed = seed_t)
{
ck::ignore = id;