add int8 gemm multiply multiply a8w8 (#1591)

* add int8 gemm multiply multiply a8w8

* uncomment

* clang-format-12

* Add example_gemm_multiply_multiply_xdl_int8

* Remove shell scripts

* update preprocess number for mi308; bring back printout in ckprofiler

* format

---------

Co-authored-by: chenjun <junchen2@amd.com>
Co-authored-by: Haocong WANG <haocwang@amd.com>
Co-authored-by: carlushuang <carlus.huang@amd.com>

[ROCm/composable_kernel commit: 37f7afed1e]
This commit is contained in:
valarLip
2024-10-26 16:39:34 +08:00
committed by GitHub
parent d99a3611fc
commit 59e7fe3ac8
16 changed files with 794 additions and 28 deletions

View File

@@ -237,7 +237,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
Args... args)
{
#if CK_TIME_KERNEL
#define MEDIAN 1
#define MEDIAN 0
if(stream_config.time_kernel_)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
@@ -275,6 +275,14 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
#else
float total_time = 0;
#endif
hipEvent_t start, stop;
hip_check_error(hipEventCreate(&start));
hip_check_error(hipEventCreate(&stop));
hip_check_error(hipDeviceSynchronize());
hip_check_error(hipEventRecord(start, stream_config.stream_id_));
for(int i = 0; i < nrepeat; ++i)
{
if constexpr(!TimePreprocess)
@@ -282,13 +290,13 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
preprocess();
}
hipEvent_t start, stop;
// hipEvent_t start, stop;
hip_check_error(hipEventCreate(&start));
hip_check_error(hipEventCreate(&stop));
// hip_check_error(hipEventCreate(&start));
// hip_check_error(hipEventCreate(&stop));
hip_check_error(hipDeviceSynchronize());
hip_check_error(hipEventRecord(start, stream_config.stream_id_));
// hip_check_error(hipDeviceSynchronize());
// hip_check_error(hipEventRecord(start, stream_config.stream_id_));
// calculate preprocess time
if constexpr(TimePreprocess)
{
@@ -299,25 +307,34 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
hip_check_error(hipGetLastError());
// end real kernel
hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
hip_check_error(hipEventSynchronize(stop));
float cur_time = 0;
hip_check_error(hipEventElapsedTime(&cur_time, start, stop));
#if MEDIAN
times.insert(cur_time);
#else
total_time += cur_time;
#endif
// hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
// hip_check_error(hipEventSynchronize(stop));
// float cur_time = 0;
// hip_check_error(hipEventElapsedTime(&cur_time, start, stop));
// #if MEDIAN
// times.insert(cur_time);
// #else
// total_time += cur_time;
// #endif
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "i: " << i << " cur_time: " << cur_time << std::endl;
// std::cout << "i: " << i << " cur_time: " << cur_time << std::endl;
printf("gemm_args.p_a_grid: %p, gemm_args.p_b_grid:%p\n",
static_cast<const void*>(gemm_args.p_a_grid),
static_cast<const void*>(gemm_args.p_b_grid));
}
}
hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
hip_check_error(hipEventSynchronize(stop));
float cur_time = 0;
hip_check_error(hipEventElapsedTime(&cur_time, start, stop));
#if MEDIAN
times.insert(cur_time);
#else
total_time += cur_time;
#endif
#if MEDIAN
auto mid = times.begin();
@@ -333,7 +350,11 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
return (*mid + *mid_next) / 2;
}
#else
return total_time / nrepeat;
// return total_time / nrepeat;
hipDeviceProp_t deviceProps;
hip_check_error(hipGetDeviceProperties(&deviceProps, 0));
float preprocess_offset = deviceProps.multiProcessorCount == 80 ? 0.005 : 0.01;
return (total_time - preprocess_offset * nrepeat) / nrepeat;
#endif
}
else

View File

@@ -272,6 +272,26 @@ struct MultiplyMultiply
e = ck::type_convert<ck::bhalf_t>(x0_f);
}
template <>
__host__ __device__ constexpr void operator()<ck::half_t, int, ck::half_t, ck::half_t>(
ck::half_t& e, const int& c, const ck::half_t& d0, const ck::half_t& d1) const
{
const float x0_f =
ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
e = ck::type_convert<ck::half_t>(x0_f);
}
template <>
__host__ __device__ constexpr void operator()<ck::bhalf_t, int, float, float>(
ck::bhalf_t& e, const int& c, const float& d0, const float& d1) const
{
const float x0_f =
ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
e = ck::type_convert<ck::bhalf_t>(x0_f);
}
};
struct MultiplyAddFastGelu

View File

@@ -327,12 +327,12 @@ struct intrin_mfma_i32_16x16x32i8<16, 16>
__device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c)
{
reg_c.template AsType<int32x4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_i32_16x16x32i8(bit_cast<int64_t>(reg_a),
bit_cast<int64_t>(reg_b),
reg_c.template AsType<int32x4_t>()[Number<0>{}],
0,
0,
0);
__builtin_amdgcn_mfma_i32_16x16x32_i8(bit_cast<int64_t>(reg_a),
bit_cast<int64_t>(reg_b),
reg_c.template AsType<int32x4_t>()[Number<0>{}],
0,
0,
0);
}
};