mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
External Interface (#304)
* add client example * clean * clean * reorg * clean up profiler * reorg * clea * fix profiler * function for getinstances * update client example * update client example * update client example * update * update example * update Jenkins file * update cmake * update Jenkins
This commit is contained in:
@@ -0,0 +1,45 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceBatchedGemm : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
ck::index_t Batch) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
using DeviceBatchedGemmPtr = std::unique_ptr<
|
||||
DeviceBatchedGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,54 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include "device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsReduceAccElementwiseOperation>
|
||||
struct DeviceBatchedGemmReduce : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
void* p_dxs,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op,
|
||||
ck::index_t Batch) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename DxsInElementwiseOperation,
|
||||
typename DxsReduceAccElementwiseOperation>
|
||||
using DeviceBatchedGemmReducePtr =
|
||||
std::unique_ptr<DeviceBatchedGemmReduce<AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsReduceAccElementwiseOperation>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -10,7 +10,7 @@
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_reduce.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp"
|
||||
#include "ck/device_utility/device_prop.hpp"
|
||||
@@ -111,7 +111,7 @@ __global__ void
|
||||
ignore = d_grid_desc_mblock_mperblock;
|
||||
ignore = compute_base_ptr_of_batch_;
|
||||
ignore = block_2_ctile_map;
|
||||
#endif // end of if defined (defined(__gfx908__) || defined(__gfx90a__))
|
||||
#endif
|
||||
}
|
||||
|
||||
// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
|
||||
@@ -169,11 +169,11 @@ template <typename ALayout,
|
||||
index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
struct DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
: public DeviceGemmReduce<AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsReduceAccElementwiseOperation>
|
||||
: public DeviceBatchedGemmReduce<AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
DxsInElementwiseOperation,
|
||||
DxsReduceAccElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle;
|
||||
|
||||
@@ -594,12 +594,12 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
CElementwiseOperation c_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op,
|
||||
index_t BatchCount)
|
||||
index_t Batch)
|
||||
: p_a_grid_{p_a_grid},
|
||||
p_b_grid_{p_b_grid},
|
||||
p_c_grid_{p_c_grid},
|
||||
p_ds_grid_{p_ds_grid},
|
||||
BatchCount_(BatchCount),
|
||||
Batch_(Batch),
|
||||
a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)},
|
||||
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)},
|
||||
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)},
|
||||
@@ -637,7 +637,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
const BDataType* p_b_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
DPtrsGlobal p_ds_grid_;
|
||||
index_t BatchCount_;
|
||||
index_t Batch_;
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
@@ -663,7 +663,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
{
|
||||
#if 0
|
||||
{
|
||||
std::cout << "arg.BatchCount_ = " << arg.BatchCount_ << std::endl;
|
||||
std::cout << "arg.Batch_ = " << arg.Batch_ << std::endl;
|
||||
|
||||
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
|
||||
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
|
||||
@@ -692,7 +692,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
}
|
||||
|
||||
const index_t grid_size =
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_;
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Batch_;
|
||||
|
||||
const auto K =
|
||||
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
|
||||
@@ -728,7 +728,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.BatchCount_,
|
||||
arg.Batch_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
@@ -771,7 +771,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.BatchCount_,
|
||||
arg.Batch_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
@@ -839,7 +839,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
CElementwiseOperation c_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op,
|
||||
index_t BatchCount)
|
||||
index_t Batch)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
@@ -856,7 +856,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
c_element_op,
|
||||
dxs_in_element_op,
|
||||
dxs_out_element_op,
|
||||
BatchCount};
|
||||
Batch};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
@@ -878,7 +878,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
CElementwiseOperation c_element_op,
|
||||
DxsInElementwiseOperation dxs_in_element_op,
|
||||
DxsReduceAccElementwiseOperation dxs_out_element_op,
|
||||
index_t BatchCount) override
|
||||
index_t Batch) override
|
||||
{
|
||||
DPtrsGlobal dxs_tuple = *(static_cast<DPtrsGlobal*>(p_dxs));
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
@@ -896,7 +896,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
|
||||
c_element_op,
|
||||
dxs_in_element_op,
|
||||
dxs_out_element_op,
|
||||
BatchCount);
|
||||
Batch);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp"
|
||||
#include "ck/device_utility/device_prop.hpp"
|
||||
@@ -152,7 +152,7 @@ template <typename ADataType,
|
||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||
ck::index_t CThreadTransferDstScalarPerVector>
|
||||
struct DeviceBatchedGemmXdl
|
||||
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
|
||||
: public DeviceBatchedGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -339,11 +339,11 @@ struct DeviceBatchedGemmXdl
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
index_t BatchCount)
|
||||
index_t Batch)
|
||||
: p_a_grid_{p_a_grid},
|
||||
p_b_grid_{p_b_grid},
|
||||
p_c_grid_{p_c_grid},
|
||||
BatchCount_(BatchCount),
|
||||
Batch_(Batch),
|
||||
a_grid_desc_k0_m_k1_{
|
||||
DeviceBatchedGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA)},
|
||||
b_grid_desc_k0_n_k1_{
|
||||
@@ -376,7 +376,7 @@ struct DeviceBatchedGemmXdl
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
index_t BatchCount_;
|
||||
index_t Batch_;
|
||||
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
|
||||
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
@@ -420,7 +420,7 @@ struct DeviceBatchedGemmXdl
|
||||
}
|
||||
|
||||
const index_t grid_size =
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_;
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Batch_;
|
||||
|
||||
const auto K =
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
|
||||
@@ -451,7 +451,7 @@ struct DeviceBatchedGemmXdl
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.BatchCount_,
|
||||
arg.Batch_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
|
||||
@@ -485,7 +485,7 @@ struct DeviceBatchedGemmXdl
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.BatchCount_,
|
||||
arg.Batch_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
|
||||
@@ -539,7 +539,7 @@ struct DeviceBatchedGemmXdl
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
index_t BatchCount)
|
||||
index_t Batch)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
@@ -555,7 +555,7 @@ struct DeviceBatchedGemmXdl
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
BatchCount};
|
||||
Batch};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
@@ -573,7 +573,7 @@ struct DeviceBatchedGemmXdl
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
index_t BatchCount) override
|
||||
index_t Batch) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
@@ -589,7 +589,7 @@ struct DeviceBatchedGemmXdl
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
BatchCount);
|
||||
Batch);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceGemmSplitK : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
ck::index_t KBatch) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
using DeviceGemmSplitKPtr = std::unique_ptr<
|
||||
DeviceGemmSplitK<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -10,7 +10,7 @@
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp"
|
||||
#include "ck/device_utility/device_prop.hpp"
|
||||
@@ -57,7 +57,7 @@ template <typename ADataType,
|
||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||
ck::index_t CThreadTransferDstScalarPerVector>
|
||||
struct DeviceGemmXdlSplitK
|
||||
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
|
||||
: public DeviceGemmSplitK<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp"
|
||||
#include "ck/device_utility/device_prop.hpp"
|
||||
@@ -59,7 +59,7 @@ template <typename ADataType,
|
||||
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CBlockTransferScalarPerVector_NWaveNPerXDL>
|
||||
struct DeviceGemmXdlSplitKCShuffle
|
||||
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
|
||||
: public DeviceGemmSplitK<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -420,21 +420,22 @@ struct DeviceGemmXdlSplitKCShuffle
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
|
||||
sizeof(CDataType)));
|
||||
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
ave_time =
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
};
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
|
||||
Reference in New Issue
Block a user