diff --git a/codegen/include/ck/host/device_batched_gemm_softmax_gemm/problem.hpp b/codegen/include/ck/host/device_batched_gemm_softmax_gemm/problem.hpp index 428034a3ba..8e68f6cc88 100644 --- a/codegen/include/ck/host/device_batched_gemm_softmax_gemm/problem.hpp +++ b/codegen/include/ck/host/device_batched_gemm_softmax_gemm/problem.hpp @@ -15,23 +15,24 @@ namespace device_batched_gemm_softmax_gemm { // defines the problem specification for a GEMM operation struct Problem { - std::size_t M = 0; - std::size_t N = 0; - std::size_t K = 0; - std::size_t O = 0; - bool TransA = false; - bool TransB = false; - bool TransB1 = false; - bool TransC = false; - DataType ADataType = DataType::Half; - DataType BDataType = DataType::Half; - DataType B1DataType = DataType::Half; - DataType CDataType = DataType::Half; - std::string AElementOp = PassThrough; - std::string BElementOp = PassThrough; - std::string B1ElementOp = PassThrough; - std::string CElementOp = PassThrough; - std::string AccElementOp = Scale; + std::size_t M = 0; + std::size_t N = 0; + std::size_t K = 0; + std::size_t O = 0; + bool TransA = false; + bool TransB = false; + bool TransB1 = false; + bool TransC = false; + DataType ADataType = DataType::Half; + DataType BDataType = DataType::Half; + DataType B1DataType = DataType::Half; + DataType CDataType = DataType::Half; + std::string AElementOp = PassThrough; + std::string BElementOp = PassThrough; + std::string B1ElementOp = PassThrough; + std::string CElementOp = PassThrough; + std::string AccElementOp = Scale; + bool MaskOutUpperTriangle = false; // returns the correct device op file for the operation std::string GetIncludeHeader() const; diff --git a/codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp b/codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp index b12c2e1a4a..6029ab0c7d 100644 --- a/codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp +++ b/codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp @@ -259,10 +259,7 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( x.tile_desc.gemm1_n_per_block); x.update_prologue(prologue); x.update_epilogue(epilogue); - x.mask_out_upper_triangle = true; - result.push_back(x); - - x.mask_out_upper_triangle = false; + x.mask_out_upper_triangle = prob.MaskOutUpperTriangle; result.push_back(x); } return result; @@ -273,13 +270,20 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( std::vector> Operation_Xdl_CShuffle::CreateOperations(const std::string& prologue, const std::string& epilogue) { + std::vector problems; + Problem prob; prob.TransA = false; prob.TransB = true; prob.TransB1 = false; prob.TransC = false; + problems.push_back(prob); - return {CreateOperations(prob, prologue, epilogue)}; + prob.MaskOutUpperTriangle = true; + problems.push_back(prob); + + return Transform(problems, + [&](const Problem& p) { return CreateOperations(p, prologue, epilogue); }); } static const char* const DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffleTemplate = diff --git a/codegen/test/batched_gemm_softmax_gemm.cpp b/codegen/test/batched_gemm_softmax_gemm.cpp index 3f0b8bfe6a..0de8dbdd51 100644 --- a/codegen/test/batched_gemm_softmax_gemm.cpp +++ b/codegen/test/batched_gemm_softmax_gemm.cpp @@ -42,7 +42,7 @@ TEST_CASE(test_problem_kernel) prob.K = 1024; prob.O = 1024; prob.TransB = true; - check_all check1, check2; + check_all check; auto a = to_gpu(generate_buffer(1024 * 1024, 0)); auto b = to_gpu(generate_buffer(1024 * 1024, 1)); auto b1 = to_gpu(generate_buffer(1024 * 1024, 2)); @@ -77,10 +77,8 @@ TEST_CASE(test_problem_kernel) k.launch(nullptr, grid_size * block_size, block_size)( a.data(), b.data(), b1.data(), c.data()); - if(solution.GetTemplateParameter("MaskOutUpperTriangle")) - CHECK(report(solution, check1(rtc::from_gpu(c)))); - else - CHECK(report(solution, check2(rtc::from_gpu(c)))); + // NOTE: Solutions where MaskOutUpperTriangle is True don't produce consistent results + CHECK(report(solution, check(rtc::from_gpu(c)))); } } diff --git a/codegen/test/rtc/src/compile_kernel.cpp b/codegen/test/rtc/src/compile_kernel.cpp index a8da88be09..262e6bae46 100644 --- a/codegen/test/rtc/src/compile_kernel.cpp +++ b/codegen/test/rtc/src/compile_kernel.cpp @@ -279,6 +279,7 @@ static kernel hiprtc_compile_kernel(const std::vector& srcs, compile_o { options.flags += " -I. -O3"; options.flags += " -std=c++17"; + options.flags += " -DCK_CODE_GEN_RTC"; options.flags += " --offload-arch=" + get_device_name(); auto cos = compile_hip_src_with_hiprtc(srcs, options); if(cos.size() != 1) diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp index 3c79b92ec8..ef0b5286ac 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp @@ -125,6 +125,7 @@ struct DeviceGemmMultipleDSplitKBPreShuffle : public BaseOperator { static constexpr index_t NumDTensor = DsDataType::Size(); +#ifndef CK_CODE_GEN_RTC virtual std::unique_ptr MakeArgumentPointer(const void* p_a, const void* p_b, @@ -145,6 +146,7 @@ struct DeviceGemmMultipleDSplitKBPreShuffle : public BaseOperator virtual std::unique_ptr MakeInvokerPointer() = 0; virtual int GetPreShuffleParameters() = 0; +#endif }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp index b4ab96d397..e846b0630b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp @@ -614,7 +614,6 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle return true; } -#ifndef __HIPCC_RTC__ static constexpr bool IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_, index_t Gemm1NRaw_) { @@ -705,6 +704,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle return true; } +#ifndef __HIPCC_RTC__ static bool IsSupportedArgument(const Argument& arg) { if(!ck::is_xdl_supported()) diff --git a/include/ck/utility/amd_ck_fp8.hpp b/include/ck/utility/amd_ck_fp8.hpp index 42b784d303..0593a24bd3 100644 --- a/include/ck/utility/amd_ck_fp8.hpp +++ b/include/ck/utility/amd_ck_fp8.hpp @@ -6,6 +6,7 @@ #include "ck/ck.hpp" #include "ck/utility/enable_if.hpp" #include "ck/utility/random_gen.hpp" +#include "ck/utility/functional.hpp" #include "ck/utility/type.hpp" #ifdef CK_USE_FNUZ_FP8 @@ -193,10 +194,10 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x) } } - typename std::conditional< + typename ck::conditional_t< sizeof(T) == 2, unsigned short int, - typename std::conditional::type>::type + typename ck::conditional_t> retval; if constexpr(we == 5 && is_half && !is_fnuz) @@ -539,10 +540,10 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn constexpr int mfmt = (sizeof(T) == 8) ? 52 : ((sizeof(T) == 4) ? 23 : 10); - using T_bitwise = typename std::conditional< + using T_bitwise = typename ck::conditional_t< sizeof(T) == 2, unsigned short int, - typename std::conditional::type>::type; + typename ck::conditional_t>; T_bitwise x_bitwise = bit_cast(_x); unsigned long long x{x_bitwise}; diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 2e3b09eae9..a0d29e5a0f 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -19,7 +19,7 @@ using float_t = float; #endif // __HIPCC_RTC__ namespace ck { -#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) +#if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC) using byte = unsigned char; #else using std::byte; @@ -1805,7 +1805,7 @@ struct non_native_vector_base< // implementation for f6x16 and f6x32 template -struct non_native_vector_base> +struct non_native_vector_base> { using data_t = typename nnvb_data_t_selector::type; // select data_t based on declared base type diff --git a/include/ck/utility/mxf4_utils.hpp b/include/ck/utility/mxf4_utils.hpp index 15e693bd0d..757d3914e3 100644 --- a/include/ck/utility/mxf4_utils.hpp +++ b/include/ck/utility/mxf4_utils.hpp @@ -1,6 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +#ifndef CK_CODE_GEN_RTC #pragma once #include "ck/utility/data_type.hpp" @@ -41,7 +42,7 @@ template <> __host__ __device__ inline float to_float(e8m0_bexp_t const scale, f4_t const data) { if(is_nan(scale, data)) - return std::numeric_limits::quiet_NaN(); + return NumericLimits::QuietNaN(); if(is_zero(scale, data)) return 0.0f; @@ -105,5 +106,5 @@ __host__ __device__ inline f4_t sat_convert_to_type_sr(float value, uint32 return res; } - } // namespace ck::utils +#endif diff --git a/include/ck/utility/mxf6_utils.hpp b/include/ck/utility/mxf6_utils.hpp index e3b37bedda..00b4f8e5d4 100644 --- a/include/ck/utility/mxf6_utils.hpp +++ b/include/ck/utility/mxf6_utils.hpp @@ -1,6 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#ifndef CK_CODE_GEN_RTC #pragma once #include "ck/utility/data_type.hpp" @@ -138,7 +139,7 @@ template <> __host__ __device__ inline float to_float(e8m0_bexp_t const scale, f6_t const data) { if(is_nan(scale, data)) - return std::numeric_limits::quiet_NaN(); + return NumericLimits::QuietNaN(); if(is_zero(scale, data)) return 0.0f; @@ -164,7 +165,7 @@ template <> __host__ __device__ inline float to_float(e8m0_bexp_t const scale, bf6_t const data) { if(is_nan(scale, data)) - return std::numeric_limits::quiet_NaN(); + return NumericLimits::QuietNaN(); if(is_zero(scale, data)) return 0.0f; @@ -307,7 +308,6 @@ __host__ __device__ inline bf6_t sat_convert_to_type_sr(float value, uint if(std::isnan(value)) return sign ? NumericUtils::data_max_negative_normal_mask : NumericUtils::data_max_positive_normal_mask; - if(std::abs(value) > NumericLimits::Max()) // covers inf case as well return sign ? NumericUtils::data_max_negative_normal_mask : NumericUtils::data_max_positive_normal_mask; @@ -321,5 +321,5 @@ __host__ __device__ inline bf6_t sat_convert_to_type_sr(float value, uint return res; } - } // namespace ck::utils +#endif diff --git a/include/ck/utility/mxfp_utils.hpp b/include/ck/utility/mxfp_utils.hpp index e23836c87f..947d64b705 100644 --- a/include/ck/utility/mxfp_utils.hpp +++ b/include/ck/utility/mxfp_utils.hpp @@ -3,6 +3,11 @@ #pragma once +#include "ck/utility/data_type.hpp" + +#ifdef CK_CODE_GEN_RTC +#define UINT_MAX 4294967295 +#endif namespace ck::utils { union cvt @@ -380,5 +385,4 @@ inline T convert_to_type_sr(float value, uint32_t seed) auto val = sign | biased_exp << NumericUtils::mant | mant; return val; } - } // namespace ck::utils diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index cf862ae640..69d1631ae3 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -706,7 +706,7 @@ inline __host__ __device__ half_t type_convert(bf8_fnuz_t x) return utils::cast_from_f8(x); #endif } - +#ifndef CK_CODE_GEN_RTC // convert fp32 to fp4 with rounding to nearest even inline __host__ __device__ f4_t f4_convert_rne(float x, float scale = 1.0f) { @@ -927,7 +927,11 @@ inline __host__ __device__ f4x32_t f4_convert_rne(float32_t x, float scale = 1.0 inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f) { constexpr int seed = 1254739; - uint32_t rng = prand_generator(reinterpret_cast(&x), x); +#ifndef CK_CODE_GEN_RTC + uint32_t rng = prand_generator(reinterpret_cast(&x), x); +#else + uint32_t rng = prand_generator(reinterpret_cast(&x), x); +#endif #if defined(__gfx950__) union { @@ -952,7 +956,11 @@ inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f) inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f) { constexpr int seed = 1254739; - uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); +#ifndef CK_CODE_GEN_RTC + uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); +#else + uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); +#endif #if defined(__gfx950__) union { @@ -978,7 +986,11 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f) inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f) { constexpr int seed = 1254739; - uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); +#ifndef CK_CODE_GEN_RTC + uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); +#else + uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); +#endif #if defined(__gfx950__) union { @@ -1544,7 +1556,11 @@ inline __host__ __device__ f6x32_t f6_convert_rne(float32_t x, float scale = 1.0 inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f) { constexpr int seed = 1254739; - uint32_t rng = prand_generator(reinterpret_cast(&x), x); +#ifndef CK_CODE_GEN_RTC + uint32_t rng = prand_generator(reinterpret_cast(&x), x); +#else + uint32_t rng = prand_generator(reinterpret_cast(&x), x); +#endif #if defined(__gfx950__) union { @@ -1584,8 +1600,13 @@ inline __host__ __device__ f6x32_t f6_convert_sr(float32_t x, float scale = 1.0f float32_t float_vector; float float_array[32]; } float_values{x}; +#ifndef CK_CODE_GEN_RTC uint32_t rng = prand_generator(reinterpret_cast(&x), float_values.float_array[0]); +#else + uint32_t rng = + prand_generator(reinterpret_cast(&x), float_values.float_array[0]); +#endif #if defined(__gfx950__) return __builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(x, rng, scale); #else @@ -1803,7 +1824,11 @@ inline __host__ __device__ bf6x32_t bf6_convert_rne(float32_t x, float scale = 1 inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f) { constexpr int seed = 1254739; - uint32_t rng = prand_generator(reinterpret_cast(&x), x); +#ifndef CK_CODE_GEN_RTC + uint32_t rng = prand_generator(reinterpret_cast(&x), x); +#else + uint32_t rng = prand_generator(reinterpret_cast(&x), x); +#endif #if defined(__gfx950__) union { @@ -1845,8 +1870,13 @@ inline __host__ __device__ bf6x32_t bf6_convert_sr(float32_t x, float scale = 1. float32_t float_vector; float float_array[32]; } float_values{x}; +#ifndef CK_CODE_GEN_RTC uint32_t rng = prand_generator(reinterpret_cast(&x), float_values.float_array[0]); +#else + uint32_t rng = + prand_generator(reinterpret_cast(&x), float_values.float_array[0]); +#endif #if defined(__gfx950__) return __builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(x, rng, scale); #else @@ -1978,7 +2008,7 @@ inline __host__ __device__ float32_t type_convert(bf6x32_t return out.float_vector; #endif } - +#endif #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) template inline __host__ __device__ void array_convert(std::array& y,