mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
gemm/Conv xdlops + dlops quantization (#625)
* Add conv perlayer quantization
* Add gemm_dlops quantization
* Support int8 for innerproduct
* Refine gemm dlops int8 kernel parameter
* Support gfx908(MI100) and gfx90a(MI200)
* clang-format
* Rename example number
* Support different layout for d tensor
* Add conv dlops perchannel quantization example
* Move to example 40
* Extract the common code for different platform (dlops and xdlops)
* Move ot subfolder. Prepare to add other op of quantization
* Refine the quantization instance library
* Add conv dl instances and client example
* Remove unnecessary type
* Add gemm quantization instance
* Add external api and client example
* Refine num_bytes
* Separete different layout to different cpp
* Add more xdl instances
* Revert "Remove unnecessary type"
This reverts commit 820869182f.
* Remove CShuffleDataType in dlops
Let acc and CShuffleDataType be the same in xdlops
---------
Co-authored-by: zjing14 <zhangjing14@gmail.com>
This commit is contained in:
@@ -134,7 +134,8 @@ __global__ void
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__) || \
|
||||
defined(__gfx90a__) || defined(__gfx908__))
|
||||
// offset base pointer for each work-group
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
@@ -314,9 +315,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
const auto in_gemmm_gemmk_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
|
||||
|
||||
const auto M = in_gemmm_gemmk_desc.GetLength(I0);
|
||||
const auto K = in_gemmm_gemmk_desc.GetLength(I1);
|
||||
|
||||
const auto M = in_gemmm_gemmk_desc.GetLength(I0);
|
||||
const auto K = in_gemmm_gemmk_desc.GetLength(I1);
|
||||
const auto AK0 = K / K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
@@ -709,7 +709,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
namespace ctc = tensor_layout::convolution;
|
||||
|
||||
// check device
|
||||
if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030"))
|
||||
if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
|
||||
ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx908"))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -834,6 +835,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check Gridwise GEMM
|
||||
return GridwiseGemm::CheckValidity(
|
||||
arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.e_grid_desc_m_n_);
|
||||
|
||||
@@ -51,7 +51,7 @@ __global__ void
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx1030__))
|
||||
defined(__gfx90a__) || defined(__gfx1030__))
|
||||
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType);
|
||||
@@ -552,7 +552,7 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout,
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx908" ||
|
||||
ck::get_device_name() == "gfx1030")
|
||||
ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx1030")
|
||||
{
|
||||
return GridwiseGemm::CheckValidity(
|
||||
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_);
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
// #include "ck/utility/get_id.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -17,18 +18,27 @@ struct Activation_Mul_Clamp
|
||||
|
||||
__host__ __device__ constexpr void operator()(int8_t& y, const int32_t& x) const
|
||||
{
|
||||
float x_fp32 = ck::type_convert<float>(x);
|
||||
activationOp_(x_fp32, x_fp32);
|
||||
float y_fp32 = math::clamp(requantScale_ * x_fp32, -128.f, 127.f);
|
||||
y = ck::type_convert<int8_t>(y_fp32);
|
||||
float y_fp32 = ck::type_convert<float>(x);
|
||||
activationOp_(y_fp32, y_fp32);
|
||||
y_fp32 = math::clamp(requantScale_ * y_fp32, -128.f, 127.f);
|
||||
y = ck::type_convert<int8_t>(y_fp32);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void operator()(float& y, const int32_t& x) const
|
||||
__device__ constexpr void operator()(int32_t& y, const int32_t& x) const
|
||||
{
|
||||
// We might type_convert to int8 after lambda in someplace
|
||||
float x_fp32 = ck::type_convert<float>(x);
|
||||
activationOp_(x_fp32, x_fp32);
|
||||
y = math::clamp(requantScale_ * x_fp32, -128.f, 127.f);
|
||||
// CAUSION - We might type_convert to int8 in threadwise copy
|
||||
// eg. GridwiseGemmDlMultipleD_km_kn_mn
|
||||
float y_fp32 = ck::type_convert<float>(x);
|
||||
activationOp_(y_fp32, y_fp32);
|
||||
y_fp32 = math::clamp(requantScale_ * y_fp32, -128.f, 127.f);
|
||||
y = ck::type_convert<int32_t>(y_fp32);
|
||||
}
|
||||
|
||||
__host__ constexpr void operator()(float& y, const float& x) const
|
||||
{
|
||||
// CAUSION - We might float in & float out in reference code
|
||||
activationOp_(y, x);
|
||||
y = math::clamp(requantScale_ * y, -128.f, 127.f);
|
||||
}
|
||||
|
||||
float requantScale_;
|
||||
@@ -51,6 +61,17 @@ struct Activation_Mul2_Clamp
|
||||
y = ck::type_convert<int8_t>(y_fp32);
|
||||
}
|
||||
|
||||
__device__ constexpr void
|
||||
operator()(int32_t& y, const int32_t& x, const float& requantScale) const
|
||||
{
|
||||
// CAUSION - We might type_convert to int8 in threadwise copy
|
||||
// eg. GridwiseGemmDlMultipleD_km_kn_mn
|
||||
float y_fp32 = ck::type_convert<float>(x);
|
||||
activationOp_(y_fp32, y_fp32);
|
||||
y_fp32 = math::clamp(requantScale * y_fp32, -128.f, 127.f);
|
||||
y = ck::type_convert<int32_t>(y_fp32);
|
||||
}
|
||||
|
||||
Activation activationOp_;
|
||||
};
|
||||
|
||||
@@ -72,6 +93,17 @@ struct Add_Activation_Mul_Clamp
|
||||
y = ck::type_convert<int8_t>(y_fp32);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void
|
||||
operator()(int32_t& y, const int32_t& x, const int32_t& bias) const
|
||||
{
|
||||
// CAUSION - We might type_convert to int8 in threadwise copy
|
||||
// eg. GridwiseGemmDlMultipleD_km_kn_mn
|
||||
float y_fp32 = ck::type_convert<float>(x + bias);
|
||||
activationOp_(y_fp32, y_fp32);
|
||||
y_fp32 = math::clamp(requantScale_ * y_fp32, -128.f, 127.f);
|
||||
y = ck::type_convert<int32_t>(y_fp32);
|
||||
}
|
||||
|
||||
float requantScale_;
|
||||
Activation activationOp_;
|
||||
};
|
||||
@@ -92,6 +124,17 @@ struct Add_Activation_Mul2_Clamp
|
||||
y = ck::type_convert<int8_t>(y_fp32);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void
|
||||
operator()(int32_t& y, const int32_t& x, const int32_t& bias, const float& requantScale) const
|
||||
{
|
||||
// CAUSION - We might type_convert to int8 in threadwise copy
|
||||
// eg. GridwiseGemmDlMultipleD_km_kn_mn
|
||||
float y_fp32 = ck::type_convert<float>(x + bias);
|
||||
activationOp_(y_fp32, y_fp32);
|
||||
y_fp32 = math::clamp(requantScale * y_fp32, -128.f, 127.f);
|
||||
y = ck::type_convert<int32_t>(y_fp32);
|
||||
}
|
||||
|
||||
Activation activationOp_;
|
||||
};
|
||||
|
||||
|
||||
@@ -185,8 +185,10 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
|
||||
return b_grid_desc_k0_n0_n1_k1;
|
||||
}
|
||||
|
||||
// E desc for destination in blockwise copy
|
||||
template <typename CGridDesc_M_N_>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N_& c_grid_desc_m_n)
|
||||
{
|
||||
const auto M = c_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = c_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
Reference in New Issue
Block a user