mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 12:00:07 +00:00
add int8 gemm multiply multiply a8w8 (#1591)
* add int8 gemm multiply multiply a8w8
* uncomment
* clang-format-12
* Add example_gemm_multiply_multiply_xdl_int8
* Remove shell scripts
* update preprocess number for mi308; bring back printout in ckprofiler
* format
---------
Co-authored-by: chenjun <junchen2@amd.com>
Co-authored-by: Haocong WANG <haocwang@amd.com>
Co-authored-by: carlushuang <carlus.huang@amd.com>
[ROCm/composable_kernel commit: 37f7afed1e]
This commit is contained in:
@@ -237,7 +237,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
|
||||
Args... args)
|
||||
{
|
||||
#if CK_TIME_KERNEL
|
||||
#define MEDIAN 1
|
||||
#define MEDIAN 0
|
||||
if(stream_config.time_kernel_)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
@@ -275,6 +275,14 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
|
||||
#else
|
||||
float total_time = 0;
|
||||
#endif
|
||||
hipEvent_t start, stop;
|
||||
|
||||
hip_check_error(hipEventCreate(&start));
|
||||
hip_check_error(hipEventCreate(&stop));
|
||||
|
||||
hip_check_error(hipDeviceSynchronize());
|
||||
hip_check_error(hipEventRecord(start, stream_config.stream_id_));
|
||||
|
||||
for(int i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
if constexpr(!TimePreprocess)
|
||||
@@ -282,13 +290,13 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
|
||||
preprocess();
|
||||
}
|
||||
|
||||
hipEvent_t start, stop;
|
||||
// hipEvent_t start, stop;
|
||||
|
||||
hip_check_error(hipEventCreate(&start));
|
||||
hip_check_error(hipEventCreate(&stop));
|
||||
// hip_check_error(hipEventCreate(&start));
|
||||
// hip_check_error(hipEventCreate(&stop));
|
||||
|
||||
hip_check_error(hipDeviceSynchronize());
|
||||
hip_check_error(hipEventRecord(start, stream_config.stream_id_));
|
||||
// hip_check_error(hipDeviceSynchronize());
|
||||
// hip_check_error(hipEventRecord(start, stream_config.stream_id_));
|
||||
// calculate preprocess time
|
||||
if constexpr(TimePreprocess)
|
||||
{
|
||||
@@ -299,25 +307,34 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
|
||||
hip_check_error(hipGetLastError());
|
||||
// end real kernel
|
||||
|
||||
hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
|
||||
hip_check_error(hipEventSynchronize(stop));
|
||||
float cur_time = 0;
|
||||
hip_check_error(hipEventElapsedTime(&cur_time, start, stop));
|
||||
#if MEDIAN
|
||||
times.insert(cur_time);
|
||||
#else
|
||||
total_time += cur_time;
|
||||
#endif
|
||||
// hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
|
||||
// hip_check_error(hipEventSynchronize(stop));
|
||||
// float cur_time = 0;
|
||||
// hip_check_error(hipEventElapsedTime(&cur_time, start, stop));
|
||||
// #if MEDIAN
|
||||
// times.insert(cur_time);
|
||||
// #else
|
||||
// total_time += cur_time;
|
||||
// #endif
|
||||
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "i: " << i << " cur_time: " << cur_time << std::endl;
|
||||
// std::cout << "i: " << i << " cur_time: " << cur_time << std::endl;
|
||||
|
||||
printf("gemm_args.p_a_grid: %p, gemm_args.p_b_grid:%p\n",
|
||||
static_cast<const void*>(gemm_args.p_a_grid),
|
||||
static_cast<const void*>(gemm_args.p_b_grid));
|
||||
}
|
||||
}
|
||||
hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
|
||||
hip_check_error(hipEventSynchronize(stop));
|
||||
float cur_time = 0;
|
||||
hip_check_error(hipEventElapsedTime(&cur_time, start, stop));
|
||||
#if MEDIAN
|
||||
times.insert(cur_time);
|
||||
#else
|
||||
total_time += cur_time;
|
||||
#endif
|
||||
|
||||
#if MEDIAN
|
||||
auto mid = times.begin();
|
||||
@@ -333,7 +350,11 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
|
||||
return (*mid + *mid_next) / 2;
|
||||
}
|
||||
#else
|
||||
return total_time / nrepeat;
|
||||
// return total_time / nrepeat;
|
||||
hipDeviceProp_t deviceProps;
|
||||
hip_check_error(hipGetDeviceProperties(&deviceProps, 0));
|
||||
float preprocess_offset = deviceProps.multiProcessorCount == 80 ? 0.005 : 0.01;
|
||||
return (total_time - preprocess_offset * nrepeat) / nrepeat;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
|
||||
@@ -272,6 +272,26 @@ struct MultiplyMultiply
|
||||
|
||||
e = ck::type_convert<ck::bhalf_t>(x0_f);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<ck::half_t, int, ck::half_t, ck::half_t>(
|
||||
ck::half_t& e, const int& c, const ck::half_t& d0, const ck::half_t& d1) const
|
||||
{
|
||||
const float x0_f =
|
||||
ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
|
||||
|
||||
e = ck::type_convert<ck::half_t>(x0_f);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<ck::bhalf_t, int, float, float>(
|
||||
ck::bhalf_t& e, const int& c, const float& d0, const float& d1) const
|
||||
{
|
||||
const float x0_f =
|
||||
ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
|
||||
|
||||
e = ck::type_convert<ck::bhalf_t>(x0_f);
|
||||
}
|
||||
};
|
||||
|
||||
struct MultiplyAddFastGelu
|
||||
|
||||
@@ -327,12 +327,12 @@ struct intrin_mfma_i32_16x16x32i8<16, 16>
|
||||
__device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<int32x4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_i32_16x16x32i8(bit_cast<int64_t>(reg_a),
|
||||
bit_cast<int64_t>(reg_b),
|
||||
reg_c.template AsType<int32x4_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
__builtin_amdgcn_mfma_i32_16x16x32_i8(bit_cast<int64_t>(reg_a),
|
||||
bit_cast<int64_t>(reg_b),
|
||||
reg_c.template AsType<int32x4_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user