Grouped conv bwd data skip B LDS

This commit is contained in:
Bartlomiej Kocot
2025-09-01 14:21:30 +00:00
parent 508e7912f9
commit dd310e435a
2 changed files with 909 additions and 54 deletions

View File

@@ -18,6 +18,7 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_multiple_d_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
@@ -77,7 +78,8 @@ template <typename GridwiseGemm,
InMemoryDataOperationEnum OutElementOp,
bool HasMainKBlockLoopInAllGemm,
bool NoMainKBlockLoopInAllGemm,
bool CTranspose>
bool CTranspose,
bool SkipBLds>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
@@ -101,7 +103,6 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
const index_t block_args_id = __builtin_amdgcn_readfirstlane(blockIdx.x);
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / KBatch);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.z - n_idx * KBatch);
const long_index_t a_batch_offset =
CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))
@@ -149,30 +150,70 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
group_id = index_t((left + right) / 2);
}
if constexpr(HasMainKBlockLoopInAllGemm || NoMainKBlockLoopInAllGemm)
// If constexpr to be compatible with skip LDS gridwise gemm
if constexpr(SkipBLds)
{
GridwiseGemm::template Run<HasMainKBlockLoopInAllGemm, OutElementOp>(
p_a_grid + a_batch_offset + a_n_offset,
p_b_grid + b_batch_offset + b_n_offset,
p_ds_grid_grp,
p_e_grid + e_batch_offset + e_n_offset,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_kernel_args[group_id].block_2_ctile_map_,
KBatch,
k_idx);
if constexpr(HasMainKBlockLoopInAllGemm || NoMainKBlockLoopInAllGemm)
{
GridwiseGemm::template Run<HasMainKBlockLoopInAllGemm>(
p_a_grid + a_batch_offset + a_n_offset,
p_b_grid + b_batch_offset + b_n_offset,
p_ds_grid_grp,
p_e_grid + e_batch_offset + e_n_offset,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_kernel_args[group_id].block_2_ctile_map_);
}
else
{
if(gemm_kernel_args[group_id].HasMainKBlockLoop_)
{
GridwiseGemm::template Run<true>(
p_a_grid + a_batch_offset + a_n_offset,
p_b_grid + b_batch_offset + b_n_offset,
p_ds_grid_grp,
p_e_grid + e_batch_offset + e_n_offset,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_kernel_args[group_id].block_2_ctile_map_);
}
else
{
GridwiseGemm::template Run<false>(
p_a_grid + a_batch_offset + a_n_offset,
p_b_grid + b_batch_offset + b_n_offset,
p_ds_grid_grp,
p_e_grid + e_batch_offset + e_n_offset,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_kernel_args[group_id].block_2_ctile_map_);
}
}
}
else
{
if(gemm_kernel_args[group_id].HasMainKBlockLoop_)
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.z - n_idx * KBatch);
if constexpr(HasMainKBlockLoopInAllGemm || NoMainKBlockLoopInAllGemm)
{
GridwiseGemm::template Run<true, OutElementOp>(
GridwiseGemm::template Run<HasMainKBlockLoopInAllGemm, OutElementOp>(
p_a_grid + a_batch_offset + a_n_offset,
p_b_grid + b_batch_offset + b_n_offset,
p_ds_grid_grp,
@@ -191,22 +232,44 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
}
else
{
GridwiseGemm::template Run<false, OutElementOp>(
p_a_grid + a_batch_offset + a_n_offset,
p_b_grid + b_batch_offset + b_n_offset,
p_ds_grid_grp,
p_e_grid + e_batch_offset + e_n_offset,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_kernel_args[group_id].block_2_ctile_map_,
KBatch,
k_idx);
if(gemm_kernel_args[group_id].HasMainKBlockLoop_)
{
GridwiseGemm::template Run<true, OutElementOp>(
p_a_grid + a_batch_offset + a_n_offset,
p_b_grid + b_batch_offset + b_n_offset,
p_ds_grid_grp,
p_e_grid + e_batch_offset + e_n_offset,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_kernel_args[group_id].block_2_ctile_map_,
KBatch,
k_idx);
}
else
{
GridwiseGemm::template Run<false, OutElementOp>(
p_a_grid + a_batch_offset + a_n_offset,
p_b_grid + b_batch_offset + b_n_offset,
p_ds_grid_grp,
p_e_grid + e_batch_offset + e_n_offset,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_kernel_args[group_id].block_2_ctile_map_,
KBatch,
k_idx);
}
}
}
#else
@@ -284,7 +347,8 @@ template <index_t NDimSpatial,
typename AComputeType = ADataType,
typename BComputeType = AComputeType,
index_t MaxTransposeTransferInScalarPerVector = 1,
index_t MaxTransposeTransferOutScalarPerVector = 1>
index_t MaxTransposeTransferOutScalarPerVector = 1,
bool SkipBLds = false>
struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
: public DeviceGroupedConvBwdDataMultipleD<NDimSpatial,
ALayout, // output image
@@ -321,7 +385,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
static constexpr GemmSpecialization GemmSpec = GemmSpecialization::MNKPadding;
static constexpr bool IsSplitKSupported =
(CDEBlockTransferScalarPerVector_NPerBlock % 2 == 0 || sizeof(EDataType) % 4 == 0) &&
std::is_same_v<remove_cvref_t<CDEElementwiseOp>, element_wise::PassThrough>;
std::is_same_v<remove_cvref_t<CDEElementwiseOp>, element_wise::PassThrough> && !SkipBLds;
// TODO: Add support for different A and B data types.
using ABDataType = ADataType;
@@ -342,9 +406,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
(isATensorColMajor == false) && (is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_NGKDHW<ELayout, BLayout, ALayout>());
static constexpr bool CTranspose =
(NeedTransposeKernel == false) && (is_same_v<ELayout, tensor_layout::convolution::NGCHW> ||
is_same_v<ELayout, tensor_layout::convolution::NGCDHW>);
static constexpr bool CTranspose = (NeedTransposeKernel == false) &&
(is_same_v<ELayout, tensor_layout::convolution::NGCHW> ||
is_same_v<ELayout, tensor_layout::convolution::NGCDHW>) &&
!SkipBLds;
using ALayoutAfterTranspose = std::conditional_t<
is_NGCHW_NGKHW<ELayout, BLayout, ALayout>() && NeedTransposeKernel,
@@ -463,7 +528,26 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmMultiDTemplateParams>;
static constexpr index_t BBlockBufferSize = 1;
#define GridwiseGemmMultiDSkipBLdsTemplateParams \
BlockSize, ABDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, \
InMemoryDataOperationEnum::Set, element_wise::PassThrough, element_wise::PassThrough, \
element_wise::PassThrough, MPerBlock, NPerBlock, KPerBlock / AK1, MPerXDL, NPerXDL, AK1, \
MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, \
ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \
ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \
ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \
BBlockTransferSrcScalarPerVector, false, BBlockBufferSize, CShuffleMXdlPerWavePerShuffle, \
CShuffleNXdlPerWavePerShuffle, \
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CDEBlockTransferScalarPerVector_NPerBlock
using GridwiseGemm =
std::conditional_t<SkipBLds,
GridwiseGemm_xdlops_skip_b_lds_multiple_d_cshuffle<
GridwiseGemmMultiDSkipBLdsTemplateParams>,
GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmMultiDTemplateParams>>;
using GridwiseGemmCTranspose = std::conditional_t<
CTranspose,
GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmCTransposeTemplateParameters>,
@@ -1199,7 +1283,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
ElementOp,
has_main_loop,
no_main_loop,
CTranspose>;
CTranspose,
SkipBLds>;
return launch_and_time_kernel_with_preprocess(
stream_config,
@@ -1238,7 +1323,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
ElementOp,
has_main_loop,
no_main_loop,
CTranspose>;
CTranspose,
SkipBLds>;
return launch_and_time_kernel_with_preprocess(
stream_config,
@@ -1600,16 +1686,35 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// Gridwise GEMM size
for(std::size_t i = 0; i < arg.a_grid_desc_m_k_container_.size(); i++)
{
if(!GridwiseGemmCTranspose::CheckValidity(
arg.a_grid_desc_m_k_container_[i],
arg.b_grid_desc_n_k_container_[i],
arg.ds_grid_desc_m_n_container_[i],
arg.e_grid_desc_m_n_container_[i],
arg.gemm_kernel_args_[i / MaxGroupedGemmGroupsNum][i % MaxGroupedGemmGroupsNum]
.block_2_ctile_map_,
arg.k_batch_))
if constexpr(SkipBLds)
{
return false;
if(!GridwiseGemmCTranspose::CheckValidity(
arg.gemm_kernel_args_[i / MaxGroupedGemmGroupsNum]
[i % MaxGroupedGemmGroupsNum]
.a_grid_desc_ak0_m_ak1_,
arg.gemm_kernel_args_[i / MaxGroupedGemmGroupsNum]
[i % MaxGroupedGemmGroupsNum]
.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_m_n_container_[i],
arg.e_grid_desc_m_n_container_[i]))
{
return false;
}
}
else
{
if(!GridwiseGemmCTranspose::CheckValidity(
arg.a_grid_desc_m_k_container_[i],
arg.b_grid_desc_n_k_container_[i],
arg.ds_grid_desc_m_n_container_[i],
arg.e_grid_desc_m_n_container_[i],
arg.gemm_kernel_args_[i / MaxGroupedGemmGroupsNum]
[i % MaxGroupedGemmGroupsNum]
.block_2_ctile_map_,
arg.k_batch_))
{
return false;
}
}
}

View File

@@ -0,0 +1,750 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
namespace ck {
template <index_t BlockSize,
typename ABDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t MPerXdl,
index_t NPerXdl,
index_t K1Value,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K1,
bool AThreadTransferSrcResetCoordinateAfterRun,
bool ABlockLdsExtraM,
index_t BBlockTransferSrcScalarPerVector,
bool BThreadTransferSrcResetCoordinateAfterRun,
index_t BBlockBufferSize,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
bool DoElementwiseBeforeCShuffle = false>
struct GridwiseGemm_xdlops_skip_b_lds_multiple_d_cshuffle
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
// K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{};
static constexpr index_t WaveSize = 64;
static constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
static constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
static constexpr auto xdlops_gemm = XdlopsGemm<ABDataType, MPerXdl, NPerXdl, K1>{};
static constexpr index_t K0PerThread = K0PerBlock / xdlops_gemm.K0PerXdlops;
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
static constexpr auto MakeDsGridPointer()
{
return generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
return static_cast<const DDataType*>(nullptr);
},
Number<DsDataType::Size()>{});
}
using DsGridPointer = decltype(MakeDsGridPointer());
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k0_m_k1 = [&]() {
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock * BBlockBufferSize>{}, Number<MPerBlock>{}, K1),
make_tuple(Number<MPerBlock + 1>{} * K1, K1, I1));
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock * BBlockBufferSize>{}, Number<MPerBlock>{}, K1),
max_lds_align);
}
}();
return a_block_desc_k0_m_k1;
}
template <typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename DsGridDesc_M_N,
typename EGridDesc_M_N>
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const DsGridDesc_M_N& ds_grid_desc_m_n,
const EGridDesc_M_N& e_grid_desc_m_n)
{
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time");
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) &&
K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
K1 == b_grid_desc_k0_n_k1.GetLength(I2)))
return false;
bool valid = true;
static_for<0, DsGridDesc_M_N::Size(), 1>{}([&](auto i) {
valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) &&
N == ds_grid_desc_m_n[i].GetLength(I1));
});
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
return false;
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if(!((K0 / K0PerBlock) % BBlockBufferSize == 0))
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return valid;
}
// TODO move this function into GEMM-pipeline class
__host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0)
{
const bool has_main_k0_block_loop = (K0 / (BBlockBufferSize * K0PerBlock)) > 1;
return has_main_k0_block_loop;
}
template <typename BGridDesc_K0_N_K1>
__host__ __device__ static constexpr auto
MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1)
{
const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0);
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
const auto b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1 = transform_tensor_descriptor(
b_grid_desc_k0_n_k1,
make_tuple(make_unmerge_transform(
make_tuple(K0 / K0PerBlock, xdlops_gemm.K0PerXdlops, K0PerThread)),
make_unmerge_transform(make_tuple(
N / (NXdlPerWave * NWaves * NPerXdl), NXdlPerWave, NWaves, NPerXdl)),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{}, Sequence<7>{}));
return b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1;
}
__device__ static auto GetWaveIdx()
{
const index_t thread_id = get_thread_local_1d_id();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto GetWaveKNIdx(const index_t thread_id)
{
constexpr auto wave_threadid_to_nk_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(xdlops_gemm.K0PerXdlops, NPerXdl))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return wave_threadid_to_nk_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{},
I1,
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{}));
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
constexpr auto max_lds_align = K1;
constexpr auto a_block_space_size_aligned =
math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned) * sizeof(ABDataType),
c_block_size * sizeof(EDataType));
}
template <bool HasMainK0BlockLoop,
typename AGridDesc_K0_M_K1,
typename BGridDesc_K0_N_K1,
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2CTileMap>
__device__ static void Run(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsGridPointer p_ds_grid,
EDataType* __restrict__ p_c_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op,
const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap& block_2_ctile_map)
{
constexpr index_t NumDTensor = DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock::Size();
const auto b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3 =
MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(b_grid_desc_k0_n_k1);
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetElementSpaceSize());
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i],
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
// divide block work by [M, N]
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
// A matrix blockwise copy
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<K0PerBlock * BBlockBufferSize, MPerBlock, K1>,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABDataType,
ABDataType,
decltype(a_grid_desc_k0_m_k1),
decltype(a_block_desc_k0_m_k1),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
1>(a_grid_desc_k0_m_k1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_k0_m_k1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
ignore = b_element_op;
// B matrix threadwise copy
constexpr auto b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3 =
make_naive_tensor_descriptor_packed(make_tuple(I1,
I1,
Number<K0PerThread>{}, // K0PerThread
I1, // NBlockId
Number<NXdlPerWave>{}, // repeat
I1, // waves
I1, // NPerXdlops
Number<K1>{}));
auto b_thread_buf = generate_tuple(
[&](auto i) {
ignore = i;
return StaticBuffer<AddressSpaceEnum::Vgpr,
ABDataType,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetElementSpaceSize(),
true>{};
},
Number<BBlockBufferSize>{});
const auto wave_id = GetWaveIdx();
const auto wave_k_n_id = GetWaveKNIdx(wave_id[I2]);
#if 0
const index_t block_id = get_block_1d_id();
const index_t thread_id = get_thread_local_1d_id();
printf("block id: %d m blockid: %d n block id: %d ,thread id: %d, wave id :{%d %d %d} "
"kn id: {%d %d}\n",
block_id,
block_work_idx[I0],
block_work_idx[I1],
thread_id,
wave_id[I0],
wave_id[I1],
wave_id[I2],
wave_k_n_id[I0],
wave_k_n_id[I1]);
printf("mfma thread k per xdlops: %d K0PerThread: %d HasMainK0BlockLoop: %d K0: %d \t",
xdlops_gemm.K0PerXdlops, K0PerThread, HasMainK0BlockLoop, b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetLength(I0));
#endif
auto b_threadwise_copy =
ThreadwiseTensorSliceTransfer_v2<ABDataType,
ABDataType,
decltype(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3),
decltype(b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3),
Sequence<I1,
I1,
Number<K0PerThread>{},
I1,
Number<NXdlPerWave>{},
I1,
I1,
Number<K1>{}>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
BBlockTransferSrcScalarPerVector,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_multi_index(
0, wave_k_n_id[I0], 0, block_work_idx[I1], 0, wave_id[I1], wave_k_n_id[I1], 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1<
BlockSize,
ABDataType,
AccDataType,
decltype(a_block_desc_k0_m_k1),
decltype(b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3),
MPerBlock,
NPerBlock,
K0PerBlock,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
K1>{};
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ABDataType*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
// gridwise GEMM pipeline
constexpr auto a_block_slice_copy_step =
make_multi_index(K0PerBlock * BBlockBufferSize, 0, 0);
constexpr auto b_thread_slice_copy_step = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0);
// preload data to regiester and LDS
{
// Read
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step);
static_for<0, BBlockBufferSize, 1>{}([&](auto ii) {
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<ii>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
});
// Initialize C
c_thread_buf.Clear();
// a data write to lds
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf);
// main body
if constexpr(HasMainK0BlockLoop)
{
index_t K0BlockMainLoop =
__builtin_amdgcn_readfirstlane(K0 / (BBlockBufferSize * K0PerBlock));
index_t i = 0;
do
{
a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf);
blockwise_gemm.ResetABlockStartWindow();
block_sync_lds();
static_for<0, BBlockBufferSize, 1>{}([&](auto ii) {
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<ii>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
s_nop();
b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_grid_buf,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
b_thread_buf(Number<ii>{}));
b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3,
b_thread_slice_copy_step);
});
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf);
// move a and b window
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1,
a_block_slice_copy_step);
i += 1;
} while(i < (K0BlockMainLoop - 1));
}
// tail
{
block_sync_lds();
blockwise_gemm.ResetABlockStartWindow();
static_for<0, BBlockBufferSize, 1>{}([&](auto ii) {
blockwise_gemm.Run(a_block_buf, b_thread_buf(Number<ii>{}), c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow();
});
}
}
// shuffle C and write out
{
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
// TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<EDataType*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
M1, // M1 = MWave
M2, // M2 * M3 * M4 = MPerXdl
M3,
M4)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
N1, // N1 = NWave
N2))), // N2 = NPerXdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
tensor_operation::element_wise::PassThrough pass_through{};
const auto& vpgr_to_lds_element_op = [&] {
if constexpr(DoElementwiseBeforeCShuffle)
{
return cde_element_op;
}
else
{
return pass_through;
}
};
const auto& lds_to_global_element_op = [&] {
if constexpr(!DoElementwiseBeforeCShuffle)
{
return cde_element_op;
}
else
{
return pass_through;
}
};
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_desc_refs = concat_tuple_of_reference(
tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
generate_tie([&](auto i) -> const auto& // return type should be reference
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumDTensor>{}));
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_buf_refs = concat_tuple_of_reference(
tie(c_shuffle_block_buf),
generate_tie([&](auto i) -> const auto& // return type should be reference
{ return ds_grid_buf[i]; },
Number<NumDTensor>{}));
// tuple of starting index of C/Ds blockwise copy
const auto idx_c_ds_block_begin = container_concat(
make_tuple(make_multi_index(0, 0, 0, 0)),
generate_tuple(
[&](auto) {
return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
},
Number<NumDTensor>{}));
// blockwise copy C/D/E between LDS and global
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
Tuple<EDataType>,
decltype(c_ds_desc_refs),
decltype(tie(c_grid_desc_mblock_mperblock_nblock_nperblock)),
conditional_t<!DoElementwiseBeforeCShuffle,
CDEElementwiseOperation,
tensor_operation::element_wise::PassThrough>,
Sequence<static_cast<index_t>(CGlobalMemoryDataOperation)>, // FIXME: make Sequence
// support arbitray type
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
sequence_merge_t<
Sequence<true>,
uniform_sequence_gen_t<NumDTensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
{c_ds_desc_refs,
idx_c_ds_block_begin,
tie(c_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
lds_to_global_element_op()};
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
1,
1,
M2,
1,
M4,
1>>{};
// space filling curve for shuffled blockwise C/D/E
constexpr auto sfc_cde_block =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
cde_block_copy_lds_and_global.Run(
c_ds_desc_refs,
c_ds_buf_refs,
tie(c_grid_desc_mblock_mperblock_nblock_nperblock),
tie(c_grid_buf));
if constexpr(access_id < num_access - 1)
{
constexpr auto cde_lds_and_global_step =
sfc_cde_block.GetForwardStep(access_id);
// move on Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
cde_block_copy_lds_and_global.MoveSrcSliceWindow(
c_ds_desc_refs, i + I1, cde_lds_and_global_step);
});
// move on E
cde_block_copy_lds_and_global.MoveDstSliceWindow(
tie(c_grid_desc_mblock_mperblock_nblock_nperblock),
I0,
cde_lds_and_global_step);
}
});
}
}
};
} // namespace ck