mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 01:36:06 +00:00
Use DynamicBuffer instead of raw pointer (#32)
* Use DynamicBuffer to hold raw pointer (to global and LDS memory) * add workaround for compiler issue (inefficient ISA) of ds_write for int8x4, int8x8, int8x16
This commit is contained in:
@@ -146,16 +146,16 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = run_gridwise_operation<gridwise_gemm,
|
||||
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
const FloatAB*,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
const FloatAB*,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
FloatC*,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
integral_constant<bool, true>,
|
||||
integral_constant<bool, true>>;
|
||||
true,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
@@ -163,28 +163,26 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
a_k_m_global_desc,
|
||||
p_a_global,
|
||||
b_k_n_global_desc,
|
||||
p_b_global,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
p_c_global,
|
||||
c_block_cluster_desc,
|
||||
integral_constant<bool, true>{},
|
||||
integral_constant<bool, true>{});
|
||||
a_k_m_global_desc,
|
||||
b_k_n_global_desc,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
c_block_cluster_desc);
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = run_gridwise_operation<gridwise_gemm,
|
||||
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
const FloatAB*,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
const FloatAB*,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
FloatC*,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
integral_constant<bool, true>,
|
||||
integral_constant<bool, false>>;
|
||||
true,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
@@ -192,28 +190,26 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
a_k_m_global_desc,
|
||||
p_a_global,
|
||||
b_k_n_global_desc,
|
||||
p_b_global,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
p_c_global,
|
||||
c_block_cluster_desc,
|
||||
integral_constant<bool, true>{},
|
||||
integral_constant<bool, false>{});
|
||||
a_k_m_global_desc,
|
||||
b_k_n_global_desc,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
c_block_cluster_desc);
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = run_gridwise_operation<gridwise_gemm,
|
||||
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
const FloatAB*,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
const FloatAB*,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
FloatC*,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
integral_constant<bool, false>,
|
||||
integral_constant<bool, true>>;
|
||||
false,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
@@ -221,28 +217,26 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
a_k_m_global_desc,
|
||||
p_a_global,
|
||||
b_k_n_global_desc,
|
||||
p_b_global,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
p_c_global,
|
||||
c_block_cluster_desc,
|
||||
integral_constant<bool, false>{},
|
||||
integral_constant<bool, true>{});
|
||||
a_k_m_global_desc,
|
||||
b_k_n_global_desc,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
c_block_cluster_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = run_gridwise_operation<gridwise_gemm,
|
||||
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
const FloatAB*,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
const FloatAB*,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
FloatC*,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
integral_constant<bool, false>,
|
||||
integral_constant<bool, false>>;
|
||||
false,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
@@ -250,15 +244,13 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
a_k_m_global_desc,
|
||||
p_a_global,
|
||||
b_k_n_global_desc,
|
||||
p_b_global,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
p_c_global,
|
||||
c_block_cluster_desc,
|
||||
integral_constant<bool, false>{},
|
||||
integral_constant<bool, false>{});
|
||||
a_k_m_global_desc,
|
||||
b_k_n_global_desc,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
c_block_cluster_desc);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
@@ -277,13 +269,13 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = run_gridwise_operation<gridwise_gemm,
|
||||
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
const FloatAB*,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
const FloatAB*,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
FloatC*,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
true,
|
||||
true>;
|
||||
@@ -295,23 +287,23 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_a_global,
|
||||
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_b_global,
|
||||
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_c_global,
|
||||
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
FloatAB,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
FloatAB,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
true,
|
||||
false>;
|
||||
@@ -323,23 +315,23 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_a_global,
|
||||
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_b_global,
|
||||
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_c_global,
|
||||
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
FloatAB,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
FloatAB,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
false,
|
||||
true>;
|
||||
@@ -351,23 +343,23 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_a_global,
|
||||
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_b_global,
|
||||
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_c_global,
|
||||
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
FloatAB,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
FloatAB,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
false,
|
||||
false>;
|
||||
@@ -379,12 +371,12 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_a_global,
|
||||
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_b_global,
|
||||
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
|
||||
p_c_global,
|
||||
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
|
||||
}
|
||||
|
||||
|
||||
@@ -29,8 +29,6 @@ template <index_t BlockSize,
|
||||
index_t DstVectorDim,
|
||||
index_t SrcScalarPerVector,
|
||||
index_t DstScalarPerVector,
|
||||
AddressSpace SrcAddressSpace,
|
||||
AddressSpace DstAddressSpace,
|
||||
index_t SrcScalarStrideInVector,
|
||||
index_t DstScalarStrideInVector,
|
||||
index_t ThreadTransferSrcResetCoordinateAfterRun,
|
||||
@@ -79,24 +77,25 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcIteratorHacks>
|
||||
template <typename SrcBuffer, typename SrcIteratorHacks>
|
||||
__device__ void RunRead(const SrcDesc& src_desc,
|
||||
const SrcData* p_src,
|
||||
const SrcBuffer& src_buf,
|
||||
const SrcIteratorHacks& src_iterator_hacks)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunRead(src_desc, p_src, src_iterator_hacks);
|
||||
threadwise_transfer_.RunRead(src_desc, src_buf, src_iterator_hacks);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void RunWrite(const DstDesc& dst_desc, DstData* p_dst)
|
||||
template <typename DstBuffer>
|
||||
__device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunWrite(dst_desc, p_dst);
|
||||
threadwise_transfer_.RunWrite(dst_desc, dst_buf);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -152,8 +151,6 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
|
||||
DstScalarPerVector,
|
||||
SrcScalarStrideInVector,
|
||||
DstScalarStrideInVector,
|
||||
SrcAddressSpace,
|
||||
DstAddressSpace,
|
||||
ThreadTransferSrcResetCoordinateAfterRun,
|
||||
ThreadTransferDstResetCoordinateAfterRun>;
|
||||
|
||||
|
||||
@@ -115,8 +115,10 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
auto a_thread_buf = make_static_buffer<FloatA>(a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<FloatB>(b_thread_desc_.GetElementSpaceSize());
|
||||
auto a_thread_buf =
|
||||
make_static_buffer<AddressSpace::Vgpr, FloatA>(a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf =
|
||||
make_static_buffer<AddressSpace::Vgpr, FloatB>(b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
constexpr auto threadwise_gemm =
|
||||
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1<FloatA,
|
||||
@@ -176,8 +178,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
|
||||
Sequence<0, 1, 2>,
|
||||
2,
|
||||
AThreadCopyScalarPerVector_M1,
|
||||
AddressSpace::Generic,
|
||||
AddressSpace::Vgpr,
|
||||
1>;
|
||||
|
||||
using BThreadCopy =
|
||||
@@ -189,8 +189,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
|
||||
Sequence<0, 1, 2>,
|
||||
2,
|
||||
BThreadCopyScalarPerVector_N1,
|
||||
AddressSpace::Generic,
|
||||
AddressSpace::Vgpr,
|
||||
1>;
|
||||
|
||||
CIndex c_thread_origin_data_idx_;
|
||||
@@ -211,6 +209,8 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
|
||||
// 3. C:
|
||||
// 1. CThreadDesc is known at compile-time
|
||||
// 2. CThreadBuffer is StaticBuffer
|
||||
// Also assume:
|
||||
// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
|
||||
template <index_t BlockSize,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
@@ -312,8 +312,10 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
auto a_thread_buf = make_static_buffer<FloatA>(a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<FloatB>(b_thread_desc_.GetElementSpaceSize());
|
||||
auto a_thread_buf =
|
||||
make_static_buffer<AddressSpace::Vgpr, FloatA>(a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf =
|
||||
make_static_buffer<AddressSpace::Vgpr, FloatB>(b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
constexpr auto threadwise_gemm =
|
||||
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1<FloatA,
|
||||
@@ -481,8 +483,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
|
||||
Sequence<0, 1, 2>,
|
||||
2,
|
||||
AThreadCopyScalarPerVector_M1,
|
||||
AddressSpace::Generic,
|
||||
AddressSpace::Vgpr,
|
||||
1>;
|
||||
|
||||
using BThreadCopy =
|
||||
@@ -494,8 +494,6 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
|
||||
Sequence<0, 1, 2>,
|
||||
2,
|
||||
BThreadCopyScalarPerVector_N1,
|
||||
AddressSpace::Generic,
|
||||
AddressSpace::Vgpr,
|
||||
1>;
|
||||
|
||||
CIndex c_thread_origin_data_idx_;
|
||||
|
||||
@@ -49,8 +49,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
ThreadGemmADataPerRead_K,
|
||||
AddressSpace::Generic,
|
||||
AddressSpace::Vgpr,
|
||||
1>;
|
||||
|
||||
__device__ BlockwiseGemm_km_kn_m0m1n0n1_v3()
|
||||
@@ -140,7 +138,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
|
||||
static_assert(WPerThread % WoPerThreadSubC == 0, "");
|
||||
|
||||
// thread A buffer for GEMM
|
||||
StaticBuffer<FloatA, a_thread_mtx_.GetElementSpaceSize()> a_thread_buf;
|
||||
StaticBuffer<AddressSpace::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize()> a_thread_buf;
|
||||
|
||||
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v3<FloatA,
|
||||
FloatB,
|
||||
|
||||
@@ -14,54 +14,62 @@ namespace ck {
|
||||
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
template <typename GridwiseGemm,
|
||||
typename AGlobalDesc,
|
||||
typename FloatA,
|
||||
typename BGlobalDesc,
|
||||
typename FloatB,
|
||||
typename CGlobalDesc,
|
||||
typename FloatC,
|
||||
typename AGlobalDesc,
|
||||
typename BGlobalDesc,
|
||||
typename CGlobalDesc,
|
||||
typename CBlockClusterDesc,
|
||||
bool HasMainKBlockLoop,
|
||||
bool HasDoubleTailKBlockLoop>
|
||||
__global__ void kernel_dynamic_gemm_v1(const AGlobalDesc a_k_m_global_desc,
|
||||
const FloatA* __restrict__ p_a_global,
|
||||
const BGlobalDesc b_k_n_global_desc,
|
||||
const FloatB* __restrict__ p_b_global,
|
||||
const CGlobalDesc c_m0_m1_n0_n1_global_desc,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
const CBlockClusterDesc c_block_cluster_desc)
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_dynamic_gemm_v1(const FloatA* __restrict__ p_a_global,
|
||||
const FloatB* __restrict__ p_b_global,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
const AGlobalDesc a_k_m_global_desc,
|
||||
const BGlobalDesc b_k_n_global_desc,
|
||||
const CGlobalDesc c_m0_m1_n0_n1_global_desc,
|
||||
const CBlockClusterDesc c_block_cluster_desc)
|
||||
{
|
||||
GridwiseGemm{}.Run(a_k_m_global_desc,
|
||||
p_a_global,
|
||||
b_k_n_global_desc,
|
||||
p_b_global,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
p_c_global,
|
||||
c_block_cluster_desc,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
GridwiseGemm::Run(p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
a_k_m_global_desc,
|
||||
b_k_n_global_desc,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
c_block_cluster_desc,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
// pass tensor descriptor by __CONSTANT__ void pointer
|
||||
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to
|
||||
// non-modifiable parameter address space, so compiler can enable corresponding optimization
|
||||
template <typename GridwiseGemm,
|
||||
typename AGlobalDesc,
|
||||
typename FloatA,
|
||||
typename BGlobalDesc,
|
||||
typename FloatB,
|
||||
typename CGlobalDesc,
|
||||
typename FloatC,
|
||||
typename AGlobalDesc,
|
||||
typename BGlobalDesc,
|
||||
typename CGlobalDesc,
|
||||
typename CBlockClusterDesc,
|
||||
bool HasMainKBlockLoop,
|
||||
bool HasDoubleTailKBlockLoop>
|
||||
__global__ void kernel_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_global_desc,
|
||||
const FloatA* __restrict__ p_a_global,
|
||||
const void __CONSTANT__* p_b_k_n_global_desc,
|
||||
const FloatB* __restrict__ p_b_global,
|
||||
const void __CONSTANT__* p_c_m0_m1_n0_n1_global_desc,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
const void __CONSTANT__* p_c_block_cluster_desc)
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_dynamic_gemm_v1(const FloatA* __restrict__ p_a_global,
|
||||
const FloatB* __restrict__ p_b_global,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
const void __CONSTANT__* p_a_k_m_global_desc,
|
||||
const void __CONSTANT__* p_b_k_n_global_desc,
|
||||
const void __CONSTANT__* p_c_m0_m1_n0_n1_global_desc,
|
||||
const void __CONSTANT__* p_c_block_cluster_desc)
|
||||
{
|
||||
// first cast void __CONSTANT__ void* to void*
|
||||
// second cast void* to Desc*
|
||||
@@ -76,15 +84,15 @@ __global__ void kernel_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_global_d
|
||||
const auto c_block_cluster_desc =
|
||||
*reinterpret_cast<const CBlockClusterDesc*>((const void*)p_c_block_cluster_desc);
|
||||
|
||||
GridwiseGemm{}.Run(a_k_m_global_desc,
|
||||
p_a_global,
|
||||
b_k_n_global_desc,
|
||||
p_b_global,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
p_c_global,
|
||||
c_block_cluster_desc,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
GridwiseGemm::Run(p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
a_k_m_global_desc,
|
||||
b_k_n_global_desc,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
c_block_cluster_desc,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -161,22 +169,29 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
|
||||
__device__ void Run(const AGlobalDesc& a_k_m_global_desc,
|
||||
const FloatAB* __restrict__ p_a_global,
|
||||
const BGlobalDesc& b_k_n_global_desc,
|
||||
const FloatAB* __restrict__ p_b_global,
|
||||
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
const CBlockClusterDesc& c_block_cluster_desc,
|
||||
FloatAB* __restrict__ p_shared_block,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>) const
|
||||
__device__ static void Run(const FloatAB* __restrict__ p_a_global,
|
||||
const FloatAB* __restrict__ p_b_global,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
const AGlobalDesc& a_k_m_global_desc,
|
||||
const BGlobalDesc& b_k_n_global_desc,
|
||||
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
|
||||
const CBlockClusterDesc& c_block_cluster_desc,
|
||||
FloatAB* __restrict__ p_shared_block,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
p_a_global, a_k_m_global_desc.GetElementSpaceSize());
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
p_b_global, b_k_n_global_desc.GetElementSpaceSize());
|
||||
auto c_global_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
p_c_global, c_m0_m1_n0_n1_global_desc.GetElementSpaceSize());
|
||||
|
||||
const auto K = a_k_m_global_desc.GetLength(I0);
|
||||
const auto M = a_k_m_global_desc.GetLength(I1);
|
||||
const auto N = b_k_n_global_desc.GetLength(I1);
|
||||
@@ -226,8 +241,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
|
||||
1,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_M,
|
||||
AddressSpace::Global,
|
||||
AddressSpace::Lds,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
@@ -255,8 +268,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
|
||||
1,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_N,
|
||||
AddressSpace::Global,
|
||||
AddressSpace::Lds,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
@@ -331,8 +342,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
|
||||
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size;
|
||||
|
||||
// register allocation for output
|
||||
auto c_thread_buf =
|
||||
make_static_buffer<FloatAcc>(c_m0_m1_n0_n1_thread_desc.GetElementSpaceSize());
|
||||
auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>(
|
||||
c_m0_m1_n0_n1_thread_desc.GetElementSpaceSize());
|
||||
|
||||
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
|
||||
decltype(c_m0_m1_n0_n1_thread_desc),
|
||||
@@ -353,25 +364,23 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
|
||||
constexpr auto b_k_n_global_move_slice_window_iterator_hack =
|
||||
BGlobalMoveSliceWindowIteratorHacks{};
|
||||
|
||||
FloatAB* p_a_block_even = p_a_block_double;
|
||||
FloatAB* p_b_block_even = p_b_block_double;
|
||||
auto a_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
p_a_block_double, a_k_m_block_desc.GetElementSpaceSize());
|
||||
auto b_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
p_b_block_double, b_k_n_block_desc.GetElementSpaceSize());
|
||||
|
||||
FloatAB* p_a_block_odd = p_a_block_double + a_block_space_size;
|
||||
FloatAB* p_b_block_odd = p_b_block_double + b_block_space_size;
|
||||
|
||||
auto a_block_even_buf = make_dynamic_buffer(p_a_block_even);
|
||||
auto b_block_even_buf = make_dynamic_buffer(p_b_block_even);
|
||||
|
||||
auto a_block_odd_buf = make_dynamic_buffer(p_a_block_odd);
|
||||
auto b_block_odd_buf = make_dynamic_buffer(p_b_block_odd);
|
||||
auto a_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
p_a_block_double + a_block_space_size, a_k_m_block_desc.GetElementSpaceSize());
|
||||
auto b_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
p_b_block_double + b_block_space_size, b_k_n_block_desc.GetElementSpaceSize());
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks);
|
||||
b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
|
||||
a_blockwise_copy.RunRead(a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
|
||||
b_blockwise_copy.RunRead(b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double);
|
||||
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_double);
|
||||
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_even_buf);
|
||||
}
|
||||
|
||||
if constexpr(HasMainKBlockLoop)
|
||||
@@ -394,16 +403,16 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks);
|
||||
a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
|
||||
b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_odd);
|
||||
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_odd);
|
||||
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_odd_buf);
|
||||
|
||||
// odd iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc,
|
||||
@@ -417,16 +426,16 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks);
|
||||
a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
|
||||
b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_even);
|
||||
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_even);
|
||||
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_even_buf);
|
||||
|
||||
k_block_data_begin += 2 * KPerBlock;
|
||||
} while(k_block_data_begin < K - 2 * KPerBlock);
|
||||
@@ -445,15 +454,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: load last data from device mem
|
||||
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks);
|
||||
b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
|
||||
a_blockwise_copy.RunRead(a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
|
||||
b_blockwise_copy.RunRead(b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store last data to LDS
|
||||
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double + a_block_space_size);
|
||||
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_double + b_block_space_size);
|
||||
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_odd_buf);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
@@ -488,8 +497,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AddressSpace::Vgpr,
|
||||
AddressSpace::Global,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
true>{
|
||||
@@ -502,32 +509,32 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
p_c_global,
|
||||
c_global_buf,
|
||||
c_m0_m1_n0_n1_global_tensor_iterator_hacks);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
|
||||
__device__ void Run(const AGlobalDesc& a_k_m_global_desc,
|
||||
const FloatAB* __restrict__ p_a_global,
|
||||
const BGlobalDesc& b_k_n_global_desc,
|
||||
const FloatAB* __restrict__ p_b_global,
|
||||
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
const CBlockClusterDesc& c_block_cluster_desc,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>) const
|
||||
__device__ static void Run(const FloatAB* __restrict__ p_a_global,
|
||||
const FloatAB* __restrict__ p_b_global,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
const AGlobalDesc& a_k_m_global_desc,
|
||||
const BGlobalDesc& b_k_n_global_desc,
|
||||
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
|
||||
const CBlockClusterDesc& c_block_cluster_desc,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>)
|
||||
{
|
||||
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
Run(a_k_m_global_desc,
|
||||
p_a_global,
|
||||
b_k_n_global_desc,
|
||||
Run(p_a_global,
|
||||
p_b_global,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
p_c_global,
|
||||
a_k_m_global_desc,
|
||||
b_k_n_global_desc,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
c_block_cluster_desc,
|
||||
p_shared_block,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
|
||||
@@ -84,6 +84,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
p_a_global, a_e_k_global_desc.GetElementSpaceSize());
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
p_b_global, b_e_n_ho_wo_global_desc.GetElementSpaceSize());
|
||||
auto c_global_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
p_c_global, c_k_n_ho_wo_global_desc.GetElementSpaceSize());
|
||||
|
||||
constexpr auto E = EPerBlock * 3 * 3;
|
||||
|
||||
// const auto E = a_e_k_global_desc.GetLength(I0);
|
||||
@@ -192,8 +199,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
|
||||
1,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K,
|
||||
AddressSpace::Global,
|
||||
AddressSpace::Lds,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
@@ -216,19 +221,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
AddressSpace::Global,
|
||||
AddressSpace::Vgpr,
|
||||
InMemoryDataOperation::Set,
|
||||
1,
|
||||
true>(b_e_n_ho_wo_global_desc,
|
||||
make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global));
|
||||
|
||||
FloatAB* p_a_block = p_shared_block;
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer(p_a_block);
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpace::Lds>(p_shared_block,
|
||||
a_e_k_desc.GetElementSpaceSize());
|
||||
|
||||
// register allocation for output
|
||||
StaticBuffer<FloatAcc, c_k_n_ho_wo_thread_desc.GetElementSpaceSize()> c_thread_buf;
|
||||
StaticBuffer<AddressSpace::Vgpr, FloatAcc, c_k_n_ho_wo_thread_desc.GetElementSpaceSize()>
|
||||
c_thread_buf;
|
||||
|
||||
// initialize output thread tensor
|
||||
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
|
||||
@@ -250,21 +253,21 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
|
||||
BGlobalMoveSliceWindowIteratorHacks{};
|
||||
|
||||
// double regsiter buffer for b
|
||||
StaticBuffer<FloatAB, b_e_n_ho_wo_thread_desc.GetElementSpaceSize()> b_thread_even_buf,
|
||||
b_thread_odd_buf;
|
||||
StaticBuffer<AddressSpace::Vgpr, FloatAB, b_e_n_ho_wo_thread_desc.GetElementSpaceSize()>
|
||||
b_thread_even_buf, b_thread_odd_buf;
|
||||
|
||||
// LDS double buffer: preload data
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_e_k_global_desc, p_a_global, a_e_k_global_iterator_hacks);
|
||||
a_blockwise_copy.RunRead(a_e_k_global_desc, a_global_buf, a_e_k_global_iterator_hacks);
|
||||
|
||||
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
|
||||
p_b_global,
|
||||
b_global_buf,
|
||||
b_e_n_ho_wo_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_even_buf,
|
||||
b_e_n_ho_wo_global_iterator_hacks);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_e_k_desc, p_a_block);
|
||||
a_blockwise_copy.RunWrite(a_e_k_desc, a_block_buf);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
@@ -282,7 +285,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
|
||||
b_thread_slice_copy_step);
|
||||
|
||||
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
|
||||
p_b_global,
|
||||
b_global_buf,
|
||||
b_e_n_ho_wo_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_odd_buf,
|
||||
@@ -298,7 +301,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
|
||||
b_thread_slice_copy_step);
|
||||
|
||||
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
|
||||
p_b_global,
|
||||
b_global_buf,
|
||||
b_e_n_ho_wo_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_even_buf,
|
||||
@@ -321,7 +324,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
|
||||
b_thread_slice_copy_step);
|
||||
|
||||
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
|
||||
p_b_global,
|
||||
b_global_buf,
|
||||
b_e_n_ho_wo_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_odd_buf,
|
||||
@@ -358,8 +361,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AddressSpace::Vgpr,
|
||||
AddressSpace::Global,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
true>(
|
||||
@@ -370,7 +371,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
c_k_n_ho_wo_global_desc,
|
||||
p_c_global,
|
||||
c_global_buf,
|
||||
c_k_n_ho_wo_global_tensor_iterator_hacks);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,8 +54,6 @@ template <typename SrcData,
|
||||
typename DimAccessOrder,
|
||||
index_t DstVectorDim,
|
||||
index_t DstScalarPerVector,
|
||||
AddressSpace SrcAddressSpace,
|
||||
AddressSpace DstAddressSpace,
|
||||
InMemoryDataOperation DstInMemOp,
|
||||
index_t DstScalarStrideInVector,
|
||||
bool DstResetCoordinateAfterRun,
|
||||
@@ -72,7 +70,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
|
||||
|
||||
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v1r3(
|
||||
const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
|
||||
: dst_slice_origin_coord_(make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx))
|
||||
: dst_coord_(make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx))
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc need to known at compile-time");
|
||||
@@ -80,15 +78,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
|
||||
|
||||
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
|
||||
{
|
||||
dst_slice_origin_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
||||
dst_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
||||
}
|
||||
|
||||
template <typename SrcSliceOriginIdx, typename SrcBuffer, typename DstIteratorHacks>
|
||||
template <typename SrcSliceOriginIdx,
|
||||
typename SrcBuffer,
|
||||
typename DstBuffer,
|
||||
typename DstIteratorHacks>
|
||||
__device__ void Run(const SrcDesc&,
|
||||
const SrcSliceOriginIdx&,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc& dst_desc,
|
||||
DstData* p_dst,
|
||||
DstBuffer& dst_buf,
|
||||
const DstIteratorHacks& dst_iterator_hacks)
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime(),
|
||||
@@ -191,12 +192,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
|
||||
return dst_data_idx;
|
||||
}();
|
||||
|
||||
// copy data
|
||||
typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector;
|
||||
|
||||
using dst_vector_t =
|
||||
typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
|
||||
|
||||
// copy data from src_buf into dst_vector
|
||||
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
|
||||
constexpr index_t src_offset = src_desc.CalculateOffset(
|
||||
src_slice_origin_idx + dst_data_idx + i * dst_scalar_step_in_vector);
|
||||
@@ -205,37 +206,14 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
|
||||
type_convert<DstData>{}(src_buf[Number<src_offset>{}]);
|
||||
});
|
||||
|
||||
const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
|
||||
dst_desc, dst_slice_origin_coord_);
|
||||
const bool is_dst_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
|
||||
|
||||
if constexpr(SrcAddressSpace == AddressSpace::Vgpr &&
|
||||
DstAddressSpace == AddressSpace::Global)
|
||||
{
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
amd_buffer_store_v2<DstData, DstScalarPerVector>(
|
||||
dst_vector.template AsType<dst_vector_t>()(Number<0>{}),
|
||||
p_dst,
|
||||
dst_slice_origin_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_desc.GetElementSpaceSize());
|
||||
#else
|
||||
if(is_dst_valid)
|
||||
{
|
||||
*reinterpret_cast<dst_vector_t*>(
|
||||
&(p_dst[dst_slice_origin_coord_.GetOffset()])) =
|
||||
dst_vector.template AsType<dst_vector_t>()[Number<0>{}];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
if(is_dst_valid)
|
||||
{
|
||||
*reinterpret_cast<dst_vector_t*>(
|
||||
&(p_dst[dst_slice_origin_coord_.GetOffset()])) =
|
||||
dst_vector.template AsType<dst_vector_t>()[Number<0>{}];
|
||||
}
|
||||
}
|
||||
// copy data from dst_vector into dst_buf
|
||||
dst_buf.template Set<dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
@@ -259,15 +237,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_dynamic_tensor_coordinate(dst_desc,
|
||||
dst_slice_origin_coord_,
|
||||
dst_forward_iterators[dim_access_order[i]]);
|
||||
move_dynamic_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_forward_iterators[dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_dynamic_tensor_coordinate(dst_desc,
|
||||
dst_slice_origin_coord_,
|
||||
dst_backward_iterators[dim_access_order[i]]);
|
||||
move_dynamic_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_backward_iterators[dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -279,11 +255,16 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
|
||||
const auto dst_reset_iterator =
|
||||
make_dynamic_tensor_coordinate_iterator(dst_desc, GetDstCoordinateResetStep());
|
||||
|
||||
move_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_coord_, dst_reset_iterator);
|
||||
move_dynamic_tensor_coordinate(dst_desc, dst_coord_, dst_reset_iterator);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void Run(const SrcData* p_src, const DstDesc& dst_desc, DstData* p_dst)
|
||||
template <typename SrcSliceOriginIdx, typename SrcBuffer, typename DstBuffer>
|
||||
__device__ void Run(const SrcDesc&,
|
||||
const SrcSliceOriginIdx&,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf)
|
||||
{
|
||||
constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform();
|
||||
|
||||
@@ -293,7 +274,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
|
||||
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
||||
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
||||
|
||||
Run(p_src, dst_desc, p_dst, dst_iterator_hacks);
|
||||
Run(SrcDesc{}, SrcSliceOriginIdx{}, src_buf, dst_desc, dst_buf, dst_iterator_hacks);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetDstCoordinateResetStep()
|
||||
@@ -371,18 +352,22 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
|
||||
const auto adjusted_step =
|
||||
make_dynamic_tensor_coordinate_iterator(dst_desc, adjusted_step_idx);
|
||||
|
||||
move_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_coord_, adjusted_step);
|
||||
move_dynamic_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
private:
|
||||
DstCoord dst_slice_origin_coord_;
|
||||
DstCoord dst_coord_;
|
||||
}; // namespace ck
|
||||
|
||||
// Assume:
|
||||
// 1. src_desc is not known at compile-time
|
||||
// 2. dst_desc is known at compile-time
|
||||
// 3. src_slice_origin_idx is not known at compile-time
|
||||
// 4. dst_slice_origin_idx is known at compile-time and it's 0
|
||||
// 1. src:
|
||||
// 1. SrcDesc is not known at compile-time
|
||||
// 2. SrcBuffer is DynamicBuffer
|
||||
// 3. src_slice_origin_idx is not known at compile-time
|
||||
// 2. dst:
|
||||
// 1. DstDesc is known at compile-time
|
||||
// 2. DstBuffer is StaticBuffer
|
||||
// 3. dst_slice_origin_idx is known at compile-time
|
||||
template <typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
@@ -391,8 +376,6 @@ template <typename SrcData,
|
||||
typename DimAccessOrder,
|
||||
index_t SrcVectorDim,
|
||||
index_t SrcScalarPerVector,
|
||||
AddressSpace SrcAddressSpace,
|
||||
AddressSpace DstAddressSpace,
|
||||
index_t SrcScalarStrideInVector,
|
||||
bool SrcResetCoordinateAfterRun,
|
||||
typename std::enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false>
|
||||
@@ -408,7 +391,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
|
||||
|
||||
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v2(const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin_idx)
|
||||
: src_slice_origin_coord_(make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx))
|
||||
: src_coord_(make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx))
|
||||
{
|
||||
static_assert(DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc need to known at compile-time");
|
||||
@@ -416,12 +399,15 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
|
||||
|
||||
__device__ void SetDstSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
|
||||
{
|
||||
src_slice_origin_coord_ = make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx);
|
||||
src_coord_ = make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx);
|
||||
}
|
||||
|
||||
template <typename DstBuffer, typename DstSliceOriginIdx, typename SrcIteratorHacks>
|
||||
template <typename SrcBuffer,
|
||||
typename DstBuffer,
|
||||
typename DstSliceOriginIdx,
|
||||
typename SrcIteratorHacks>
|
||||
__device__ void Run(const SrcDesc& src_desc,
|
||||
const SrcData* p_src,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc&,
|
||||
const DstSliceOriginIdx&,
|
||||
DstBuffer& dst_buf,
|
||||
@@ -525,41 +511,19 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
|
||||
return src_data_idx;
|
||||
}();
|
||||
|
||||
// copy data
|
||||
static_assert(DstAddressSpace == AddressSpace::Vgpr, "wrong! hardcode for vgpr dst");
|
||||
|
||||
typename vector_type_maker<SrcData, SrcScalarPerVector>::type src_vector;
|
||||
|
||||
using src_vector_t =
|
||||
typename vector_type_maker<SrcData, SrcScalarPerVector>::type::type;
|
||||
|
||||
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
|
||||
src_desc, src_slice_origin_coord_);
|
||||
const bool is_src_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
|
||||
|
||||
if constexpr(SrcAddressSpace == AddressSpace::Global)
|
||||
{
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
src_vector.template AsType<src_vector_t>()(Number<0>{}) =
|
||||
amd_buffer_load_v2<SrcData, SrcScalarPerVector>(
|
||||
p_src,
|
||||
src_slice_origin_coord_.GetOffset(),
|
||||
is_src_valid,
|
||||
src_desc.GetElementSpaceSize());
|
||||
#else
|
||||
src_vector.template AsType<src_vector_t>()(Number<0>{}) =
|
||||
is_src_valid ? *reinterpret_cast<const src_vector_t*>(
|
||||
&p_src[src_slice_origin_coord_.GetOffset()])
|
||||
: src_vector_t{0};
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
src_vector.template AsType<src_vector_t>()(Number<0>{}) =
|
||||
is_src_valid ? *reinterpret_cast<const src_vector_t*>(
|
||||
&p_src[src_slice_origin_coord_.GetOffset()])
|
||||
: src_vector_t{0};
|
||||
}
|
||||
// copy data from src_buf into src_vector
|
||||
src_vector.template AsType<src_vector_t>()(Number<0>{}) =
|
||||
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid);
|
||||
|
||||
// copy data from src_vector into dst_buf
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
|
||||
constexpr index_t dst_offset =
|
||||
dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx +
|
||||
@@ -590,15 +554,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_dynamic_tensor_coordinate(src_desc,
|
||||
src_slice_origin_coord_,
|
||||
src_forward_iterators[dim_access_order[i]]);
|
||||
move_dynamic_tensor_coordinate(
|
||||
src_desc, src_coord_, src_forward_iterators[dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_dynamic_tensor_coordinate(src_desc,
|
||||
src_slice_origin_coord_,
|
||||
src_backward_iterators[dim_access_order[i]]);
|
||||
move_dynamic_tensor_coordinate(
|
||||
src_desc, src_coord_, src_backward_iterators[dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -610,13 +572,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
|
||||
const auto src_reset_iterator =
|
||||
make_dynamic_tensor_coordinate_iterator(src_desc, GetSrcCoordinateResetStep());
|
||||
|
||||
move_dynamic_tensor_coordinate(src_desc, src_slice_origin_coord_, src_reset_iterator);
|
||||
move_dynamic_tensor_coordinate(src_desc, src_coord_, src_reset_iterator);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DstBuffer, typename DstSliceOriginIdx>
|
||||
template <typename SrcBuffer, typename DstBuffer, typename DstSliceOriginIdx>
|
||||
__device__ void Run(const SrcDesc& src_desc,
|
||||
const SrcData* p_src,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc&,
|
||||
const DstSliceOriginIdx&,
|
||||
DstBuffer& dst_buf)
|
||||
@@ -629,7 +591,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
|
||||
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
||||
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
||||
|
||||
Run(src_desc, p_src, DstDesc{}, DstSliceOriginIdx{}, dst_buf, src_iterator_hacks);
|
||||
Run(src_desc, src_buf, DstDesc{}, DstSliceOriginIdx{}, dst_buf, src_iterator_hacks);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetSrcCoordinateResetStep()
|
||||
@@ -707,17 +669,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
|
||||
const auto adjusted_step =
|
||||
make_dynamic_tensor_coordinate_iterator(src_desc, adjusted_step_idx);
|
||||
|
||||
move_dynamic_tensor_coordinate(src_desc, src_slice_origin_coord_, adjusted_step);
|
||||
move_dynamic_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
private:
|
||||
SrcCoord src_slice_origin_coord_;
|
||||
SrcCoord src_coord_;
|
||||
}; // namespace ck
|
||||
|
||||
// Assume:
|
||||
// 1. src_desc and dst_desc are not known at compile-time
|
||||
// 2. src_slice_origin and dst_slice_origin are not known at compile-time,
|
||||
// 3. Use thread buffer
|
||||
// 2. SrcBuffer and DstBuffer are DynamicBuffer
|
||||
// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
|
||||
// 4. Use thread buffer
|
||||
template <typename SliceLengths,
|
||||
InMemoryDataOperation DstInMemOp,
|
||||
typename SrcData,
|
||||
@@ -732,8 +695,6 @@ template <typename SliceLengths,
|
||||
index_t DstScalarPerVector,
|
||||
index_t SrcScalarStrideInVector,
|
||||
index_t DstScalarStrideInVector,
|
||||
AddressSpace SrcAddressSpace,
|
||||
AddressSpace DstAddressSpace,
|
||||
bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each
|
||||
// RunRead(), will be fused with MoveSrcSliceWindow to
|
||||
// save addr computation
|
||||
@@ -755,16 +716,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
const Index& src_slice_origin,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin)
|
||||
: src_slice_origin_coord_(make_dynamic_tensor_coordinate(src_desc, src_slice_origin)),
|
||||
dst_slice_origin_coord_(make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin))
|
||||
: src_coord_(make_dynamic_tensor_coordinate(src_desc, src_slice_origin)),
|
||||
dst_coord_(make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin))
|
||||
{
|
||||
static_assert(SrcAddressSpace == AddressSpace::Global or
|
||||
SrcAddressSpace == AddressSpace::Lds,
|
||||
"wrong!");
|
||||
static_assert(DstAddressSpace == AddressSpace::Global or
|
||||
DstAddressSpace == AddressSpace::Lds,
|
||||
"wrong!");
|
||||
|
||||
// TODO: fix this
|
||||
static_assert(is_same<SrcData, DstData>::value,
|
||||
"wrong! current implementation assume SrcData and DstData are same type");
|
||||
@@ -772,19 +726,27 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
|
||||
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
|
||||
{
|
||||
src_slice_origin_coord_ = make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx);
|
||||
src_coord_ = make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx);
|
||||
}
|
||||
|
||||
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
|
||||
{
|
||||
dst_slice_origin_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
||||
dst_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
||||
}
|
||||
|
||||
template <typename SrcIteratorHacks>
|
||||
template <typename SrcBuffer, typename SrcIteratorHacks>
|
||||
__device__ void RunRead(const SrcDesc& src_desc,
|
||||
const SrcData* p_src,
|
||||
const SrcBuffer& src_buf,
|
||||
const SrcIteratorHacks& src_iterator_hacks)
|
||||
{
|
||||
static_assert(SrcBuffer::GetAddressSpace() == AddressSpace::Global or
|
||||
SrcBuffer::GetAddressSpace() == AddressSpace::Lds,
|
||||
"wrong!");
|
||||
|
||||
static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<SrcData>>>::value,
|
||||
"wrong! SrcBuffer and SrcData data type are inconsistent");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
@@ -869,37 +831,16 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
return src_data_idx;
|
||||
}();
|
||||
|
||||
// copy data from src_buf to src_tmp_vector
|
||||
vector_type_maker_t<SrcData, SrcScalarPerVector> src_tmp_vector;
|
||||
|
||||
using src_vector_t = typename decltype(src_tmp_vector)::type;
|
||||
|
||||
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
|
||||
src_desc, src_slice_origin_coord_);
|
||||
const bool is_src_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
|
||||
|
||||
if constexpr(SrcAddressSpace == AddressSpace::Global)
|
||||
{
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
|
||||
amd_buffer_load_v2<SrcData, SrcScalarPerVector>(
|
||||
p_src,
|
||||
src_slice_origin_coord_.GetOffset(),
|
||||
is_src_valid,
|
||||
src_desc.GetElementSpaceSize());
|
||||
#else
|
||||
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
|
||||
is_src_valid ? *reinterpret_cast<const src_vector_t*>(
|
||||
&p_src[src_slice_origin_coord_.GetOffset()])
|
||||
: src_vector_t{0};
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
|
||||
is_src_valid ? *reinterpret_cast<const src_vector_t*>(
|
||||
&p_src[src_slice_origin_coord_.GetOffset()])
|
||||
: src_vector_t{0};
|
||||
}
|
||||
// copy data from src_buf to src_tmp_vector
|
||||
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
|
||||
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid);
|
||||
|
||||
// copy data from src_tmp_vector to buffer_
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
|
||||
@@ -933,16 +874,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_dynamic_tensor_coordinate(
|
||||
src_desc,
|
||||
src_slice_origin_coord_,
|
||||
src_forward_iterators[src_dim_access_order[i]]);
|
||||
src_desc, src_coord_, src_forward_iterators[src_dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_dynamic_tensor_coordinate(
|
||||
src_desc,
|
||||
src_slice_origin_coord_,
|
||||
src_backward_iterators[src_dim_access_order[i]]);
|
||||
src_desc, src_coord_, src_backward_iterators[src_dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -954,14 +891,23 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
const auto src_reset_iterator =
|
||||
make_dynamic_tensor_coordinate_iterator(src_desc, GetSrcCoordinateResetStep());
|
||||
|
||||
move_dynamic_tensor_coordinate(src_desc, src_slice_origin_coord_, src_reset_iterator);
|
||||
move_dynamic_tensor_coordinate(src_desc, src_coord_, src_reset_iterator);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DstIteratorHacks>
|
||||
__device__ void
|
||||
RunWrite(const DstDesc& dst_desc, DstData* p_dst, const DstIteratorHacks& dst_iterator_hacks)
|
||||
template <typename DstBuffer, typename DstIteratorHacks>
|
||||
__device__ void RunWrite(const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf,
|
||||
const DstIteratorHacks& dst_iterator_hacks)
|
||||
{
|
||||
static_assert(DstBuffer::GetAddressSpace() == AddressSpace::Global or
|
||||
DstBuffer::GetAddressSpace() == AddressSpace::Lds,
|
||||
"wrong!");
|
||||
|
||||
static_assert(is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<DstData>>>::value,
|
||||
"wrong! SrcBuffer or DstBuffer data type is wrong");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
@@ -1050,13 +996,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
return dst_data_idx;
|
||||
}();
|
||||
|
||||
// copy data
|
||||
// hardcoding for ds_write
|
||||
// TODO refactor transfer_data() to encapsulate this
|
||||
static_assert(DstAddressSpace == AddressSpace::Lds &&
|
||||
DstInMemOp == InMemoryDataOperation::Set,
|
||||
"wrong! hardcoded for ds_write");
|
||||
|
||||
vector_type_maker_t<DstData, DstScalarPerVector> dst_tmp_vector;
|
||||
|
||||
// copy data from buffer_ to dst_tmp_vector
|
||||
@@ -1070,8 +1009,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
using dst_vector_t = typename decltype(dst_tmp_vector)::type;
|
||||
|
||||
// copy data from dst_tmp_vector to dst_buf
|
||||
*reinterpret_cast<dst_vector_t*>(p_dst + dst_slice_origin_coord_.GetOffset()) =
|
||||
dst_tmp_vector.template AsType<dst_vector_t>()[Number<0>{}];
|
||||
const bool is_dst_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
|
||||
|
||||
dst_buf.template Set<dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_tmp_vector.template AsType<dst_vector_t>()[Number<0>{}]);
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
@@ -1097,16 +1041,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_dynamic_tensor_coordinate(
|
||||
dst_desc,
|
||||
dst_slice_origin_coord_,
|
||||
dst_forward_iterators[dst_dim_access_order[i]]);
|
||||
dst_desc, dst_coord_, dst_forward_iterators[dst_dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_dynamic_tensor_coordinate(
|
||||
dst_desc,
|
||||
dst_slice_origin_coord_,
|
||||
dst_backward_iterators[dst_dim_access_order[i]]);
|
||||
dst_desc, dst_coord_, dst_backward_iterators[dst_dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -1118,11 +1058,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
const auto dst_reset_iterator =
|
||||
make_dynamic_tensor_coordinate_iterator(dst_desc, GetDstCoordinateResetStep());
|
||||
|
||||
move_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_coord_, dst_reset_iterator);
|
||||
move_dynamic_tensor_coordinate(dst_desc, dst_coord_, dst_reset_iterator);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void RunRead(const SrcDesc& src_desc, const SrcData* p_src)
|
||||
template <typename SrcBuffer>
|
||||
__device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf)
|
||||
{
|
||||
constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform();
|
||||
|
||||
@@ -1132,10 +1073,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
||||
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
||||
|
||||
RunRead(src_desc, p_src, src_iterator_hacks);
|
||||
RunRead(src_desc, src_buf, src_iterator_hacks);
|
||||
}
|
||||
|
||||
__device__ void RunWrite(const DstDesc& dst_desc, DstData* p_dst)
|
||||
template <typename DstBuffer>
|
||||
__device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf)
|
||||
{
|
||||
constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform();
|
||||
|
||||
@@ -1145,7 +1087,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
||||
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
||||
|
||||
RunWrite(dst_desc, p_dst, dst_iterator_hacks);
|
||||
RunWrite(dst_desc, dst_buf, dst_iterator_hacks);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetSrcCoordinateResetStep()
|
||||
@@ -1285,7 +1227,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
const auto adjusted_step =
|
||||
make_dynamic_tensor_coordinate_iterator(src_desc, adjusted_step_idx);
|
||||
|
||||
move_dynamic_tensor_coordinate(src_desc, src_slice_origin_coord_, adjusted_step);
|
||||
move_dynamic_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
@@ -1304,7 +1246,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
const auto adjusted_step = make_dynamic_tensor_coordinate_iterator(
|
||||
src_desc, adjusted_step_idx, src_move_slice_window_iterator_hack);
|
||||
|
||||
move_dynamic_tensor_coordinate(src_desc, src_slice_origin_coord_, adjusted_step);
|
||||
move_dynamic_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||
}
|
||||
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
|
||||
@@ -1319,7 +1261,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
const auto adjusted_step =
|
||||
make_dynamic_tensor_coordinate_iterator(dst_desc, adjusted_step_idx);
|
||||
|
||||
move_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_coord_, adjusted_step);
|
||||
move_dynamic_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -1328,10 +1270,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
|
||||
static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
|
||||
|
||||
StaticBuffer<SrcData, buffer_size_> buffer_;
|
||||
StaticBuffer<AddressSpace::Vgpr, SrcData, buffer_size_> buffer_;
|
||||
|
||||
SrcCoord src_slice_origin_coord_;
|
||||
DstCoord dst_slice_origin_coord_;
|
||||
SrcCoord src_coord_;
|
||||
DstCoord dst_coord_;
|
||||
};
|
||||
|
||||
// Assume:
|
||||
@@ -1356,8 +1298,6 @@ template <
|
||||
typename DimAccessOrder,
|
||||
index_t SrcVectorDim,
|
||||
index_t SrcScalarPerVector,
|
||||
AddressSpace SrcAddressSpace,
|
||||
AddressSpace DstAddressSpace,
|
||||
index_t SrcScalarStrideInVector,
|
||||
typename std::enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
@@ -1480,7 +1420,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
|
||||
move_dynamic_tensor_coordinate(
|
||||
src_desc, src_data_coord, src_ref_to_data_disp_coord_iterator);
|
||||
|
||||
// copy data from src_buf into src_tmp_buffer
|
||||
vector_type_maker_t<SrcData, SrcScalarPerVector> src_tmp_vector;
|
||||
|
||||
using src_vector_t = typename decltype(src_tmp_vector)::type;
|
||||
@@ -1488,9 +1427,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
|
||||
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
|
||||
src_desc, src_data_coord);
|
||||
|
||||
// copy data from src_buf into src_tmp_vector
|
||||
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
|
||||
is_src_valid ? src_buf.template Get<src_vector_t>(src_data_coord.GetOffset())
|
||||
: src_vector_t{0};
|
||||
src_buf.template Get<src_vector_t>(src_data_coord.GetOffset(), is_src_valid);
|
||||
|
||||
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
|
||||
// DstData)
|
||||
|
||||
@@ -323,7 +323,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
|
||||
return __llvm_amdgcn_raw_buffer_load_i8x2(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
#else
|
||||
@@ -335,7 +335,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
|
||||
return __llvm_amdgcn_raw_buffer_load_i8x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
#else
|
||||
@@ -347,7 +347,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
|
||||
vector_type<int8_t, 8> tmp;
|
||||
|
||||
tmp.AsType<int8x4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i8x4(
|
||||
@@ -369,7 +369,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
|
||||
}
|
||||
else if constexpr(N == 16)
|
||||
{
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
|
||||
vector_type<int8_t, 16> tmp;
|
||||
|
||||
tmp.AsType<int8x4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i8x4(
|
||||
@@ -483,7 +483,7 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
|
||||
__llvm_amdgcn_raw_buffer_store_i8x2(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
@@ -499,7 +499,7 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
|
||||
__llvm_amdgcn_raw_buffer_store_i8x4(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
#ifndef CK_BUFFER_HPP
|
||||
#define CK_BUFFER_HPP
|
||||
|
||||
#include "statically_indexed_array.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename T, index_t N>
|
||||
struct StaticBuffer : public StaticallyIndexedArray<T, N>
|
||||
{
|
||||
using type = T;
|
||||
using base = StaticallyIndexedArray<T, N>;
|
||||
|
||||
__host__ __device__ constexpr StaticBuffer() : base{} {}
|
||||
|
||||
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
|
||||
};
|
||||
|
||||
template <typename T, index_t N>
|
||||
__host__ __device__ constexpr auto make_static_buffer(Number<N>)
|
||||
{
|
||||
return StaticBuffer<T, N>{};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct DynamicBuffer
|
||||
{
|
||||
using type = T;
|
||||
|
||||
T* p_data_;
|
||||
|
||||
__host__ __device__ constexpr DynamicBuffer(T* p_data) : p_data_{p_data} {}
|
||||
|
||||
__host__ __device__ constexpr const T& operator[](index_t i) const { return p_data_[i]; }
|
||||
|
||||
__host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; }
|
||||
|
||||
template <typename X,
|
||||
typename std::enable_if<
|
||||
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
|
||||
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
|
||||
bool>::type = false>
|
||||
__host__ __device__ constexpr const auto Get(index_t i) const
|
||||
{
|
||||
return *reinterpret_cast<const X*>(&p_data_[i]);
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename std::enable_if<
|
||||
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
|
||||
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
|
||||
bool>::type = false>
|
||||
__host__ __device__ void Set(index_t i, const X& x)
|
||||
{
|
||||
*reinterpret_cast<X*>(&p_data_[i]) = x;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsStaticBuffer() { return false; }
|
||||
|
||||
__host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr auto make_dynamic_buffer(T* p)
|
||||
{
|
||||
return DynamicBuffer<T>{p};
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -8,7 +8,6 @@
|
||||
#include "container_element_picker.hpp"
|
||||
#include "data_type.hpp"
|
||||
#include "float_type.hpp"
|
||||
#include "buffer.hpp"
|
||||
#include "functional.hpp"
|
||||
#include "functional2.hpp"
|
||||
#include "functional3.hpp"
|
||||
@@ -25,6 +24,8 @@
|
||||
#include "type.hpp"
|
||||
#include "utility.hpp"
|
||||
#include "magic_division.hpp"
|
||||
#include "static_buffer.hpp"
|
||||
#include "dynamic_buffer.hpp"
|
||||
|
||||
#if CK_USE_AMD_INLINE_ASM
|
||||
#include "amd_inline_asm.hpp"
|
||||
|
||||
@@ -143,8 +143,13 @@
|
||||
#endif
|
||||
|
||||
// workaround for compiler crash when using buffer load/store for i8
|
||||
#ifndef CK_WORKAROUND_SWDEV_XXXXXX
|
||||
#define CK_WORKAROUND_SWDEV_XXXXXX 1
|
||||
#ifndef CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
|
||||
#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE 1
|
||||
#endif
|
||||
|
||||
// workaround for compiler crash when using buffer load/store for i8
|
||||
#ifndef CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
|
||||
#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
@@ -154,6 +159,7 @@ enum AddressSpace
|
||||
Generic,
|
||||
Global,
|
||||
Lds,
|
||||
Sgpr,
|
||||
Vgpr
|
||||
};
|
||||
|
||||
|
||||
173
composable_kernel/include/utility/dynamic_buffer.hpp
Normal file
173
composable_kernel/include/utility/dynamic_buffer.hpp
Normal file
@@ -0,0 +1,173 @@
|
||||
#ifndef CK_DYNAMIC_BUFFER_HPP
|
||||
#define CK_DYNAMIC_BUFFER_HPP
|
||||
|
||||
namespace ck {
|
||||
|
||||
#include "amd_buffer_addressing_v2.hpp"
|
||||
|
||||
template <AddressSpace BufferAddressSpace, typename T, typename ElementSpaceSize>
|
||||
struct DynamicBuffer
|
||||
{
|
||||
using type = T;
|
||||
|
||||
T* p_data_;
|
||||
ElementSpaceSize element_space_size_;
|
||||
|
||||
__host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size)
|
||||
: p_data_{p_data}, element_space_size_{element_space_size}
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr AddressSpace GetAddressSpace()
|
||||
{
|
||||
return BufferAddressSpace;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr const T& operator[](index_t i) const { return p_data_[i]; }
|
||||
|
||||
__host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; }
|
||||
|
||||
template <typename X,
|
||||
typename std::enable_if<
|
||||
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
|
||||
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
|
||||
bool>::type = false>
|
||||
__host__ __device__ constexpr const auto Get(index_t i, bool is_valid_offset) const
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector =
|
||||
scalar_type<remove_cv_t<remove_reference_t<T>>>::vector_size;
|
||||
|
||||
constexpr index_t scalar_per_x_vector =
|
||||
scalar_type<remove_cv_t<remove_reference_t<X>>>::vector_size;
|
||||
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X need to be multiple T");
|
||||
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
if constexpr(GetAddressSpace() == AddressSpace::Global)
|
||||
{
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
return amd_buffer_load_v2<remove_cv_t<remove_reference_t<T>>, t_per_x>(
|
||||
p_data_, i, is_valid_offset, element_space_size_);
|
||||
#else
|
||||
return is_valid_offset ? *reinterpret_cast<const X*>(&p_data_[i]) : X{0};
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
return is_valid_offset ? *reinterpret_cast<const X*>(&p_data_[i]) : X{0};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename std::enable_if<
|
||||
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
|
||||
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
|
||||
bool>::type = false>
|
||||
__host__ __device__ void Set(index_t i, bool is_valid_offset, const X& x)
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector =
|
||||
scalar_type<remove_cv_t<remove_reference_t<T>>>::vector_size;
|
||||
|
||||
constexpr index_t scalar_per_x_vector =
|
||||
scalar_type<remove_cv_t<remove_reference_t<X>>>::vector_size;
|
||||
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X need to be multiple T");
|
||||
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
if constexpr(GetAddressSpace() == AddressSpace::Global)
|
||||
{
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
amd_buffer_store_v2<remove_cv_t<remove_reference_t<T>>, t_per_x>(
|
||||
x, p_data_, i, is_valid_offset, element_space_size_);
|
||||
#else
|
||||
if(is_valid_offset)
|
||||
{
|
||||
*reinterpret_cast<X*>(&p_data_[i]) = x;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if constexpr(GetAddressSpace() == AddressSpace::Lds)
|
||||
{
|
||||
if(is_valid_offset)
|
||||
{
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
|
||||
*reinterpret_cast<X*>(&p_data_[i]) = x;
|
||||
#else
|
||||
// HACK: compiler would lower IR "store<i8, 16> address_space(3)" into inefficient
|
||||
// ISA, so I try to let compiler emit use IR "store<i32, 4>" which would be lower to
|
||||
// ds_write_b128
|
||||
// TODO: remove this after compiler fix
|
||||
if constexpr(is_same<typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type,
|
||||
int8_t>::value)
|
||||
{
|
||||
static_assert(
|
||||
(is_same<remove_cv_t<remove_reference_t<T>>, int8x4_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value) ||
|
||||
(is_same<remove_cv_t<remove_reference_t<T>>, int8x8_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x8_t>::value) ||
|
||||
(is_same<remove_cv_t<remove_reference_t<T>>, int8x16_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value),
|
||||
"wrong! not implemented for this combination, please add implementation");
|
||||
|
||||
if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8x4_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value)
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*reinterpret_cast<int32_t*>(&p_data_[i]) =
|
||||
*reinterpret_cast<const int32_t*>(&x);
|
||||
}
|
||||
if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8x8_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x8_t>::value)
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*reinterpret_cast<int32x2_t*>(&p_data_[i]) =
|
||||
*reinterpret_cast<const int32x2_t*>(&x);
|
||||
}
|
||||
if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8x16_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value)
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*reinterpret_cast<int32x4_t*>(&p_data_[i]) =
|
||||
*reinterpret_cast<const int32x4_t*>(&x);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
*reinterpret_cast<X*>(&p_data_[i]) = x;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(is_valid_offset)
|
||||
{
|
||||
*reinterpret_cast<X*>(&p_data_[i]) = x;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsStaticBuffer() { return false; }
|
||||
|
||||
__host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
|
||||
};
|
||||
|
||||
template <AddressSpace BufferAddressSpace = AddressSpace::Generic,
|
||||
typename T,
|
||||
typename ElementSpaceSize>
|
||||
__host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size)
|
||||
{
|
||||
return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize>{p, element_space_size};
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
33
composable_kernel/include/utility/static_buffer.hpp
Normal file
33
composable_kernel/include/utility/static_buffer.hpp
Normal file
@@ -0,0 +1,33 @@
|
||||
#ifndef CK_STATIC_BUFFER_HPP
|
||||
#define CK_STATIC_BUFFER_HPP
|
||||
|
||||
#include "statically_indexed_array.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <AddressSpace BufferAddressSpace, typename T, index_t N>
|
||||
struct StaticBuffer : public StaticallyIndexedArray<T, N>
|
||||
{
|
||||
using type = T;
|
||||
using base = StaticallyIndexedArray<T, N>;
|
||||
|
||||
__host__ __device__ constexpr StaticBuffer() : base{} {}
|
||||
|
||||
__host__ __device__ static constexpr AddressSpace GetAddressSpace()
|
||||
{
|
||||
return BufferAddressSpace;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
|
||||
};
|
||||
|
||||
template <AddressSpace BufferAddressSpace = AddressSpace::Generic, typename T, index_t N>
|
||||
__host__ __device__ constexpr auto make_static_buffer(Number<N>)
|
||||
{
|
||||
return StaticBuffer<BufferAddressSpace, T, N>{};
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
Reference in New Issue
Block a user