mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
MIGraphX hipRTC fix (#1923)
* fixed hiprtc compilation issues from new additions, removed clashing mixed precision functionality from codegen(ignore the whole file) * fixed device op error: misplaced header guard * restrict virtual function use in device_gemm_multiple_d file for codegen hiprtc compilation * add CK_CODE_GEN_RTC flag for compilation, since this flag has wider coverage for hiprtc compilation * fixed conditional error in amd_ck_fp8.hpp * Add MaskOutUpperTriangle as a problem parameter to BatchedGemmSoftmaxGemm and disable tests with MaskOutUpperTriangle==True. Signed-off-by: Mirza Halilcevic <mirza.halilcevic@amd.com> --------- Signed-off-by: Mirza Halilcevic <mirza.halilcevic@amd.com> Co-authored-by: Mirza Halilcevic <mirza.halilcevic@amd.com>
This commit is contained in:
@@ -15,23 +15,24 @@ namespace device_batched_gemm_softmax_gemm {
|
||||
// defines the problem specification for a GEMM operation
|
||||
struct Problem
|
||||
{
|
||||
std::size_t M = 0;
|
||||
std::size_t N = 0;
|
||||
std::size_t K = 0;
|
||||
std::size_t O = 0;
|
||||
bool TransA = false;
|
||||
bool TransB = false;
|
||||
bool TransB1 = false;
|
||||
bool TransC = false;
|
||||
DataType ADataType = DataType::Half;
|
||||
DataType BDataType = DataType::Half;
|
||||
DataType B1DataType = DataType::Half;
|
||||
DataType CDataType = DataType::Half;
|
||||
std::string AElementOp = PassThrough;
|
||||
std::string BElementOp = PassThrough;
|
||||
std::string B1ElementOp = PassThrough;
|
||||
std::string CElementOp = PassThrough;
|
||||
std::string AccElementOp = Scale;
|
||||
std::size_t M = 0;
|
||||
std::size_t N = 0;
|
||||
std::size_t K = 0;
|
||||
std::size_t O = 0;
|
||||
bool TransA = false;
|
||||
bool TransB = false;
|
||||
bool TransB1 = false;
|
||||
bool TransC = false;
|
||||
DataType ADataType = DataType::Half;
|
||||
DataType BDataType = DataType::Half;
|
||||
DataType B1DataType = DataType::Half;
|
||||
DataType CDataType = DataType::Half;
|
||||
std::string AElementOp = PassThrough;
|
||||
std::string BElementOp = PassThrough;
|
||||
std::string B1ElementOp = PassThrough;
|
||||
std::string CElementOp = PassThrough;
|
||||
std::string AccElementOp = Scale;
|
||||
bool MaskOutUpperTriangle = false;
|
||||
|
||||
// returns the correct device op file for the operation
|
||||
std::string GetIncludeHeader() const;
|
||||
|
||||
@@ -259,10 +259,7 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
|
||||
x.tile_desc.gemm1_n_per_block);
|
||||
x.update_prologue(prologue);
|
||||
x.update_epilogue(epilogue);
|
||||
x.mask_out_upper_triangle = true;
|
||||
result.push_back(x);
|
||||
|
||||
x.mask_out_upper_triangle = false;
|
||||
x.mask_out_upper_triangle = prob.MaskOutUpperTriangle;
|
||||
result.push_back(x);
|
||||
}
|
||||
return result;
|
||||
@@ -273,13 +270,20 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
|
||||
std::vector<std::vector<Operation_Xdl_CShuffle>>
|
||||
Operation_Xdl_CShuffle::CreateOperations(const std::string& prologue, const std::string& epilogue)
|
||||
{
|
||||
std::vector<Problem> problems;
|
||||
|
||||
Problem prob;
|
||||
prob.TransA = false;
|
||||
prob.TransB = true;
|
||||
prob.TransB1 = false;
|
||||
prob.TransC = false;
|
||||
problems.push_back(prob);
|
||||
|
||||
return {CreateOperations(prob, prologue, epilogue)};
|
||||
prob.MaskOutUpperTriangle = true;
|
||||
problems.push_back(prob);
|
||||
|
||||
return Transform(problems,
|
||||
[&](const Problem& p) { return CreateOperations(p, prologue, epilogue); });
|
||||
}
|
||||
|
||||
static const char* const DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffleTemplate =
|
||||
|
||||
@@ -42,7 +42,7 @@ TEST_CASE(test_problem_kernel)
|
||||
prob.K = 1024;
|
||||
prob.O = 1024;
|
||||
prob.TransB = true;
|
||||
check_all<half> check1, check2;
|
||||
check_all<half> check;
|
||||
auto a = to_gpu(generate_buffer<half>(1024 * 1024, 0));
|
||||
auto b = to_gpu(generate_buffer<half>(1024 * 1024, 1));
|
||||
auto b1 = to_gpu(generate_buffer<half>(1024 * 1024, 2));
|
||||
@@ -77,10 +77,8 @@ TEST_CASE(test_problem_kernel)
|
||||
k.launch(nullptr, grid_size * block_size, block_size)(
|
||||
a.data(), b.data(), b1.data(), c.data());
|
||||
|
||||
if(solution.GetTemplateParameter<bool>("MaskOutUpperTriangle"))
|
||||
CHECK(report(solution, check1(rtc::from_gpu(c))));
|
||||
else
|
||||
CHECK(report(solution, check2(rtc::from_gpu(c))));
|
||||
// NOTE: Solutions where MaskOutUpperTriangle is True don't produce consistent results
|
||||
CHECK(report(solution, check(rtc::from_gpu(c))));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -279,6 +279,7 @@ static kernel hiprtc_compile_kernel(const std::vector<src_file>& srcs, compile_o
|
||||
{
|
||||
options.flags += " -I. -O3";
|
||||
options.flags += " -std=c++17";
|
||||
options.flags += " -DCK_CODE_GEN_RTC";
|
||||
options.flags += " --offload-arch=" + get_device_name();
|
||||
auto cos = compile_hip_src_with_hiprtc(srcs, options);
|
||||
if(cos.size() != 1)
|
||||
|
||||
@@ -125,6 +125,7 @@ struct DeviceGemmMultipleDSplitKBPreShuffle : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
@@ -145,6 +146,7 @@ struct DeviceGemmMultipleDSplitKBPreShuffle : public BaseOperator
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
|
||||
virtual int GetPreShuffleParameters() = 0;
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -614,7 +614,6 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
return true;
|
||||
}
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
static constexpr bool
|
||||
IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_, index_t Gemm1NRaw_)
|
||||
{
|
||||
@@ -705,6 +704,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
return true;
|
||||
}
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_xdl_supported())
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/enable_if.hpp"
|
||||
#include "ck/utility/random_gen.hpp"
|
||||
#include "ck/utility/functional.hpp"
|
||||
#include "ck/utility/type.hpp"
|
||||
|
||||
#ifdef CK_USE_FNUZ_FP8
|
||||
@@ -193,10 +194,10 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
|
||||
}
|
||||
}
|
||||
|
||||
typename std::conditional<
|
||||
typename ck::conditional_t<
|
||||
sizeof(T) == 2,
|
||||
unsigned short int,
|
||||
typename std::conditional<sizeof(T) == 4, unsigned int, unsigned long long>::type>::type
|
||||
typename ck::conditional_t<sizeof(T) == 4, unsigned int, unsigned long long>>
|
||||
retval;
|
||||
|
||||
if constexpr(we == 5 && is_half && !is_fnuz)
|
||||
@@ -539,10 +540,10 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
|
||||
|
||||
constexpr int mfmt = (sizeof(T) == 8) ? 52 : ((sizeof(T) == 4) ? 23 : 10);
|
||||
|
||||
using T_bitwise = typename std::conditional<
|
||||
using T_bitwise = typename ck::conditional_t<
|
||||
sizeof(T) == 2,
|
||||
unsigned short int,
|
||||
typename std::conditional<sizeof(T) == 4, unsigned int, unsigned long long>::type>::type;
|
||||
typename ck::conditional_t<sizeof(T) == 4, unsigned int, unsigned long long>>;
|
||||
T_bitwise x_bitwise = bit_cast<T_bitwise>(_x);
|
||||
|
||||
unsigned long long x{x_bitwise};
|
||||
|
||||
@@ -19,7 +19,7 @@ using float_t = float;
|
||||
#endif // __HIPCC_RTC__
|
||||
|
||||
namespace ck {
|
||||
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
|
||||
#if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC)
|
||||
using byte = unsigned char;
|
||||
#else
|
||||
using std::byte;
|
||||
@@ -1805,7 +1805,7 @@ struct non_native_vector_base<
|
||||
|
||||
// implementation for f6x16 and f6x32
|
||||
template <typename T, index_t N>
|
||||
struct non_native_vector_base<T, N, std::enable_if_t<sizeof(T) == 12 || sizeof(T) == 24>>
|
||||
struct non_native_vector_base<T, N, ck::enable_if_t<sizeof(T) == 12 || sizeof(T) == 24>>
|
||||
{
|
||||
using data_t =
|
||||
typename nnvb_data_t_selector<T>::type; // select data_t based on declared base type
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
@@ -41,7 +42,7 @@ template <>
|
||||
__host__ __device__ inline float to_float<f4_t>(e8m0_bexp_t const scale, f4_t const data)
|
||||
{
|
||||
if(is_nan<f4_t>(scale, data))
|
||||
return std::numeric_limits<float>::quiet_NaN();
|
||||
return NumericLimits<float>::QuietNaN();
|
||||
|
||||
if(is_zero<f4_t>(scale, data))
|
||||
return 0.0f;
|
||||
@@ -105,5 +106,5 @@ __host__ __device__ inline f4_t sat_convert_to_type_sr<f4_t>(float value, uint32
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
} // namespace ck::utils
|
||||
#endif
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
@@ -138,7 +139,7 @@ template <>
|
||||
__host__ __device__ inline float to_float<f6_t>(e8m0_bexp_t const scale, f6_t const data)
|
||||
{
|
||||
if(is_nan<f6_t>(scale, data))
|
||||
return std::numeric_limits<float>::quiet_NaN();
|
||||
return NumericLimits<float>::QuietNaN();
|
||||
|
||||
if(is_zero<f6_t>(scale, data))
|
||||
return 0.0f;
|
||||
@@ -164,7 +165,7 @@ template <>
|
||||
__host__ __device__ inline float to_float<bf6_t>(e8m0_bexp_t const scale, bf6_t const data)
|
||||
{
|
||||
if(is_nan<bf6_t>(scale, data))
|
||||
return std::numeric_limits<float>::quiet_NaN();
|
||||
return NumericLimits<float>::QuietNaN();
|
||||
|
||||
if(is_zero<bf6_t>(scale, data))
|
||||
return 0.0f;
|
||||
@@ -307,7 +308,6 @@ __host__ __device__ inline bf6_t sat_convert_to_type_sr<bf6_t>(float value, uint
|
||||
if(std::isnan(value))
|
||||
return sign ? NumericUtils<bf6_t>::data_max_negative_normal_mask
|
||||
: NumericUtils<bf6_t>::data_max_positive_normal_mask;
|
||||
|
||||
if(std::abs(value) > NumericLimits<bf6_t>::Max()) // covers inf case as well
|
||||
return sign ? NumericUtils<bf6_t>::data_max_negative_normal_mask
|
||||
: NumericUtils<bf6_t>::data_max_positive_normal_mask;
|
||||
@@ -321,5 +321,5 @@ __host__ __device__ inline bf6_t sat_convert_to_type_sr<bf6_t>(float value, uint
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
} // namespace ck::utils
|
||||
#endif
|
||||
|
||||
@@ -3,6 +3,11 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#ifdef CK_CODE_GEN_RTC
|
||||
#define UINT_MAX 4294967295
|
||||
#endif
|
||||
namespace ck::utils {
|
||||
|
||||
union cvt
|
||||
@@ -380,5 +385,4 @@ inline T convert_to_type_sr(float value, uint32_t seed)
|
||||
auto val = sign | biased_exp << NumericUtils<T>::mant | mant;
|
||||
return val;
|
||||
}
|
||||
|
||||
} // namespace ck::utils
|
||||
|
||||
@@ -706,7 +706,7 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_fnuz_t>(bf8_fnuz_t x)
|
||||
return utils::cast_from_f8<bf8_fnuz_t, half_t, negative_zero_nan>(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
// convert fp32 to fp4 with rounding to nearest even
|
||||
inline __host__ __device__ f4_t f4_convert_rne(float x, float scale = 1.0f)
|
||||
{
|
||||
@@ -927,7 +927,11 @@ inline __host__ __device__ f4x32_t f4_convert_rne(float32_t x, float scale = 1.0
|
||||
inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f)
|
||||
{
|
||||
constexpr int seed = 1254739;
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#else
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
|
||||
#endif
|
||||
#if defined(__gfx950__)
|
||||
union
|
||||
{
|
||||
@@ -952,7 +956,11 @@ inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f)
|
||||
inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f)
|
||||
{
|
||||
constexpr int seed = 1254739;
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
|
||||
#else
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x[0]);
|
||||
#endif
|
||||
#if defined(__gfx950__)
|
||||
union
|
||||
{
|
||||
@@ -978,7 +986,11 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f)
|
||||
inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f)
|
||||
{
|
||||
constexpr int seed = 1254739;
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
|
||||
#else
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x[0]);
|
||||
#endif
|
||||
#if defined(__gfx950__)
|
||||
union
|
||||
{
|
||||
@@ -1544,7 +1556,11 @@ inline __host__ __device__ f6x32_t f6_convert_rne(float32_t x, float scale = 1.0
|
||||
inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f)
|
||||
{
|
||||
constexpr int seed = 1254739;
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#else
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
|
||||
#endif
|
||||
#if defined(__gfx950__)
|
||||
union
|
||||
{
|
||||
@@ -1584,8 +1600,13 @@ inline __host__ __device__ f6x32_t f6_convert_sr(float32_t x, float scale = 1.0f
|
||||
float32_t float_vector;
|
||||
float float_array[32];
|
||||
} float_values{x};
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
uint32_t rng =
|
||||
prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), float_values.float_array[0]);
|
||||
#else
|
||||
uint32_t rng =
|
||||
prand_generator<float, seed>(reinterpret_cast<size_t>(&x), float_values.float_array[0]);
|
||||
#endif
|
||||
#if defined(__gfx950__)
|
||||
return __builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(x, rng, scale);
|
||||
#else
|
||||
@@ -1803,7 +1824,11 @@ inline __host__ __device__ bf6x32_t bf6_convert_rne(float32_t x, float scale = 1
|
||||
inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f)
|
||||
{
|
||||
constexpr int seed = 1254739;
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#else
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
|
||||
#endif
|
||||
#if defined(__gfx950__)
|
||||
union
|
||||
{
|
||||
@@ -1845,8 +1870,13 @@ inline __host__ __device__ bf6x32_t bf6_convert_sr(float32_t x, float scale = 1.
|
||||
float32_t float_vector;
|
||||
float float_array[32];
|
||||
} float_values{x};
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
uint32_t rng =
|
||||
prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), float_values.float_array[0]);
|
||||
#else
|
||||
uint32_t rng =
|
||||
prand_generator<float, seed>(reinterpret_cast<size_t>(&x), float_values.float_array[0]);
|
||||
#endif
|
||||
#if defined(__gfx950__)
|
||||
return __builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(x, rng, scale);
|
||||
#else
|
||||
@@ -1978,7 +2008,7 @@ inline __host__ __device__ float32_t type_convert<float32_t, bf6x32_t>(bf6x32_t
|
||||
return out.float_vector;
|
||||
#endif
|
||||
}
|
||||
|
||||
#endif
|
||||
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
|
||||
template <typename Y, typename X, size_t NumElems>
|
||||
inline __host__ __device__ void array_convert(std::array<Y, NumElems>& y,
|
||||
|
||||
Reference in New Issue
Block a user