mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 03:19:48 +00:00
gemm/Conv xdlops + dlops quantization (#625)
* Add conv perlayer quantization * Add gemm_dlops quantization * Support int8 for innerproduct * Refine gemm dlops int8 kernel parameter * Support gfx908(MI100) and gfx90a(MI200) * clang-format * Rename example number * Support different layout for d tensor * Add conv dlops perchannel quantization example * Move to example 40 * Extract the common code for different platform (dlops and xdlops) * Move ot subfolder. Prepare to add other op of quantization * Refine the quantization instance library * Add conv dl instances and client example * Remove unnecessary type * Add gemm quantization instance * Add external api and client example * Refine num_bytes * Separete different layout to different cpp * Add more xdl instances * Revert "Remove unnecessary type" This reverts commit820869182f. * Remove CShuffleDataType in dlops Let acc and CShuffleDataType be the same in xdlops --------- Co-authored-by: zjing14 <zhangjing14@gmail.com> [ROCm/composable_kernel commit:16dc18e0f9]
This commit is contained in:
@@ -134,7 +134,8 @@ __global__ void
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__) || \
|
||||
defined(__gfx90a__) || defined(__gfx908__))
|
||||
// offset base pointer for each work-group
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
@@ -314,9 +315,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
const auto in_gemmm_gemmk_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
|
||||
|
||||
const auto M = in_gemmm_gemmk_desc.GetLength(I0);
|
||||
const auto K = in_gemmm_gemmk_desc.GetLength(I1);
|
||||
|
||||
const auto M = in_gemmm_gemmk_desc.GetLength(I0);
|
||||
const auto K = in_gemmm_gemmk_desc.GetLength(I1);
|
||||
const auto AK0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
@@ -709,7 +709,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
|
||||
// check device
|
||||
if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030"))
|
||||
if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
|
||||
ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx908"))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -834,6 +835,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check Gridwise GEMM
|
||||
return GridwiseGemm::CheckValidity(
|
||||
arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.e_grid_desc_m_n_);
|
||||
|
||||
@@ -51,7 +51,7 @@ __global__ void
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx1030__))
|
||||
defined(__gfx90a__) || defined(__gfx1030__))
|
||||
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType);
|
||||
@@ -552,7 +552,7 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout,
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx908" ||
|
||||
ck::get_device_name() == "gfx1030")
|
||||
ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx1030")
|
||||
{
|
||||
return GridwiseGemm::CheckValidity(
|
||||
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_);
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
// #include "ck/utility/get_id.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -17,18 +18,27 @@ struct Activation_Mul_Clamp
|
||||
|
||||
__host__ __device__ constexpr void operator()(int8_t& y, const int32_t& x) const
|
||||
{
|
||||
float x_fp32 = ck::type_convert<float>(x);
|
||||
activationOp_(x_fp32, x_fp32);
|
||||
float y_fp32 = math::clamp(requantScale_ * x_fp32, -128.f, 127.f);
|
||||
y = ck::type_convert<int8_t>(y_fp32);
|
||||
float y_fp32 = ck::type_convert<float>(x);
|
||||
activationOp_(y_fp32, y_fp32);
|
||||
y_fp32 = math::clamp(requantScale_ * y_fp32, -128.f, 127.f);
|
||||
y = ck::type_convert<int8_t>(y_fp32);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void operator()(float& y, const int32_t& x) const
|
||||
__device__ constexpr void operator()(int32_t& y, const int32_t& x) const
|
||||
{
|
||||
// We might type_convert to int8 after lambda in someplace
|
||||
float x_fp32 = ck::type_convert<float>(x);
|
||||
activationOp_(x_fp32, x_fp32);
|
||||
y = math::clamp(requantScale_ * x_fp32, -128.f, 127.f);
|
||||
// CAUSION - We might type_convert to int8 in threadwise copy
|
||||
// eg. GridwiseGemmDlMultipleD_km_kn_mn
|
||||
float y_fp32 = ck::type_convert<float>(x);
|
||||
activationOp_(y_fp32, y_fp32);
|
||||
y_fp32 = math::clamp(requantScale_ * y_fp32, -128.f, 127.f);
|
||||
y = ck::type_convert<int32_t>(y_fp32);
|
||||
}
|
||||
|
||||
__host__ constexpr void operator()(float& y, const float& x) const
|
||||
{
|
||||
// CAUSION - We might float in & float out in reference code
|
||||
activationOp_(y, x);
|
||||
y = math::clamp(requantScale_ * y, -128.f, 127.f);
|
||||
}
|
||||
|
||||
float requantScale_;
|
||||
@@ -51,6 +61,17 @@ struct Activation_Mul2_Clamp
|
||||
y = ck::type_convert<int8_t>(y_fp32);
|
||||
}
|
||||
|
||||
__device__ constexpr void
|
||||
operator()(int32_t& y, const int32_t& x, const float& requantScale) const
|
||||
{
|
||||
// CAUSION - We might type_convert to int8 in threadwise copy
|
||||
// eg. GridwiseGemmDlMultipleD_km_kn_mn
|
||||
float y_fp32 = ck::type_convert<float>(x);
|
||||
activationOp_(y_fp32, y_fp32);
|
||||
y_fp32 = math::clamp(requantScale * y_fp32, -128.f, 127.f);
|
||||
y = ck::type_convert<int32_t>(y_fp32);
|
||||
}
|
||||
|
||||
Activation activationOp_;
|
||||
};
|
||||
|
||||
@@ -72,6 +93,17 @@ struct Add_Activation_Mul_Clamp
|
||||
y = ck::type_convert<int8_t>(y_fp32);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void
|
||||
operator()(int32_t& y, const int32_t& x, const int32_t& bias) const
|
||||
{
|
||||
// CAUSION - We might type_convert to int8 in threadwise copy
|
||||
// eg. GridwiseGemmDlMultipleD_km_kn_mn
|
||||
float y_fp32 = ck::type_convert<float>(x + bias);
|
||||
activationOp_(y_fp32, y_fp32);
|
||||
y_fp32 = math::clamp(requantScale_ * y_fp32, -128.f, 127.f);
|
||||
y = ck::type_convert<int32_t>(y_fp32);
|
||||
}
|
||||
|
||||
float requantScale_;
|
||||
Activation activationOp_;
|
||||
};
|
||||
@@ -92,6 +124,17 @@ struct Add_Activation_Mul2_Clamp
|
||||
y = ck::type_convert<int8_t>(y_fp32);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void
|
||||
operator()(int32_t& y, const int32_t& x, const int32_t& bias, const float& requantScale) const
|
||||
{
|
||||
// CAUSION - We might type_convert to int8 in threadwise copy
|
||||
// eg. GridwiseGemmDlMultipleD_km_kn_mn
|
||||
float y_fp32 = ck::type_convert<float>(x + bias);
|
||||
activationOp_(y_fp32, y_fp32);
|
||||
y_fp32 = math::clamp(requantScale * y_fp32, -128.f, 127.f);
|
||||
y = ck::type_convert<int32_t>(y_fp32);
|
||||
}
|
||||
|
||||
Activation activationOp_;
|
||||
};
|
||||
|
||||
|
||||
@@ -185,8 +185,10 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
|
||||
return b_grid_desc_k0_n0_n1_k1;
|
||||
}
|
||||
|
||||
// E desc for destination in blockwise copy
|
||||
template <typename CGridDesc_M_N_>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N_& c_grid_desc_m_n)
|
||||
{
|
||||
const auto M = c_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = c_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
@@ -135,6 +135,28 @@ __device__ void inner_product<half8_t, half8_t, float>(const half8_t& a, const h
|
||||
c);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void inner_product<int8_t, int8_t, int32_t>(const int8_t& a, const int8_t& b, int32_t& c)
|
||||
{
|
||||
c += type_convert<int32_t>(a) * type_convert<int32_t>(b);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void
|
||||
inner_product<int8x2_t, int8x2_t, int32_t>(const int8x2_t& a, const int8x2_t& b, int32_t& c)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
inner_product(vector_type<int8_t, 2>{a}.AsType<int8_t>()[I0],
|
||||
vector_type<int8_t, 2>{b}.AsType<int8_t>()[I0],
|
||||
c);
|
||||
|
||||
inner_product(vector_type<int8_t, 2>{a}.AsType<int8_t>()[I1],
|
||||
vector_type<int8_t, 2>{b}.AsType<int8_t>()[I1],
|
||||
c);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void
|
||||
inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b, int32_t& c)
|
||||
|
||||
Reference in New Issue
Block a user