mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
Merge branch 'develop' into amd-develop
This commit is contained in:
@@ -1,2 +1,2 @@
|
||||
rocm-docs-core==1.1.1
|
||||
rocm-docs-core==1.1.2
|
||||
sphinxcontrib-bibtex==2.6.2
|
||||
|
||||
@@ -103,7 +103,7 @@ requests==2.31.0
|
||||
# via
|
||||
# pygithub
|
||||
# sphinx
|
||||
rocm-docs-core==1.1.1
|
||||
rocm-docs-core==1.1.2
|
||||
# via -r requirements.in
|
||||
six==1.16.0
|
||||
# via
|
||||
|
||||
@@ -117,7 +117,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
|
||||
#define MEDIAN 1
|
||||
if(stream_config.time_kernel_)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n",
|
||||
__func__,
|
||||
@@ -142,7 +142,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
|
||||
{
|
||||
return 0.0;
|
||||
}
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
printf("Start running %d times...\n", nrepeat);
|
||||
}
|
||||
@@ -186,7 +186,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
|
||||
total_time += cur_time;
|
||||
#endif
|
||||
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "i: " << i << " cur_time: " << cur_time << std::endl;
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
|
||||
#if CK_TIME_KERNEL
|
||||
if(stream_config.time_kernel_)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n",
|
||||
__func__,
|
||||
@@ -41,7 +41,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
|
||||
}
|
||||
|
||||
const int nrepeat = stream_config.nrepeat_;
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
printf("Start running %d times...\n", nrepeat);
|
||||
}
|
||||
@@ -95,7 +95,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
|
||||
#if CK_TIME_KERNEL
|
||||
if(stream_config.time_kernel_)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n",
|
||||
__func__,
|
||||
@@ -117,7 +117,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
|
||||
}
|
||||
|
||||
const int nrepeat = stream_config.nrepeat_;
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
printf("Start running %d times...\n", nrepeat);
|
||||
}
|
||||
|
||||
@@ -795,11 +795,6 @@ struct BlockwiseGemmXdlops_v2
|
||||
"wrong!");
|
||||
}
|
||||
|
||||
__host__ __device__ BlockwiseGemmXdlops_v2(const BlockwiseGemmXdlops_v2& other)
|
||||
: a_thread_copy_(other.a_origin), b_thread_copy_(other.b_origin)
|
||||
{
|
||||
}
|
||||
|
||||
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
|
||||
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
|
||||
{
|
||||
|
||||
@@ -587,7 +587,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
|
||||
BatchStrideD1s,
|
||||
BatchStrideE1}
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "a0_grid_desc_m_k_{" << a0_grid_desc_m_k_.GetLength(I0) << ", "
|
||||
<< a0_grid_desc_m_k_.GetLength(I1) << "}" << std::endl;
|
||||
|
||||
@@ -658,7 +658,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
{
|
||||
std::cout << "arg.Batch_ = " << arg.Batch_ << std::endl;
|
||||
|
||||
@@ -719,7 +719,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
arg.Print();
|
||||
}
|
||||
|
||||
@@ -516,7 +516,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
float ave_time = 0;
|
||||
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_container_{"
|
||||
|
||||
@@ -644,7 +644,7 @@ struct
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << DeviceOp{}.GetTypeString() << std::endl;
|
||||
std::cout << "N " << arg.Conv_N_ << ", "
|
||||
|
||||
@@ -614,7 +614,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << DeviceOp{}.GetTypeString() << std::endl;
|
||||
std::cout << "N " << arg.Conv_N_ << ", "
|
||||
|
||||
@@ -579,7 +579,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << DeviceOp{}.GetTypeString() << std::endl;
|
||||
std::cout << "N " << arg.Conv_N_ << ", "
|
||||
|
||||
@@ -431,7 +431,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
|
||||
@@ -401,7 +401,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "num_batches_of_GEMM = " << arg.num_subbatches_ << std::endl;
|
||||
std::cout << "a_grid_desc_k0_m_k1{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
|
||||
@@ -1272,7 +1272,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
|
||||
float ave_time = 0;
|
||||
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_container_{"
|
||||
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", "
|
||||
|
||||
@@ -1220,7 +1220,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
|
||||
float ave_time = 0;
|
||||
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1{"
|
||||
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", "
|
||||
|
||||
@@ -334,7 +334,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m0_m1_k1_{"
|
||||
<< arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
|
||||
|
||||
@@ -510,7 +510,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
|
||||
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
|
||||
|
||||
@@ -514,7 +514,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
|
||||
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
|
||||
|
||||
@@ -299,7 +299,7 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
|
||||
@@ -553,7 +553,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
|
||||
|
||||
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{"
|
||||
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
|
||||
@@ -337,6 +337,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
elementwise_d_grid_descs_m_n_.reserve(group_count_);
|
||||
ds_grid_pointer_.reserve(group_count_);
|
||||
group_grid_size_.reserve(group_count_);
|
||||
e_ptrs_.reserve(group_count_);
|
||||
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); ++i)
|
||||
{
|
||||
@@ -380,7 +381,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
const index_t block_end = grid_size_ + grid_size_grp;
|
||||
|
||||
grid_size_ += grid_size_grp;
|
||||
group_grid_size_[i] = grid_size_grp;
|
||||
group_grid_size_.push_back(grid_size_grp);
|
||||
// block-to-e-tile map
|
||||
auto grouped_block_2_ctile_map =
|
||||
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
|
||||
@@ -421,9 +422,9 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
elementwise_c_grid_descs_m_n_.push_back(c_grid_desc_m_n);
|
||||
elementwise_d_grid_descs_m_n_.push_back(ds_grid_desc_m_n);
|
||||
ds_grid_pointer_.push_back(p_ds_grid);
|
||||
// Store a copy of E pointers for elementwise kernel destination
|
||||
e_ptrs_.push_back(p_Es[i]);
|
||||
}
|
||||
// Store a copy of E pointers for elementwise kernel destination
|
||||
e_ptrs_ = p_Es;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -467,7 +468,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
gemm_kernel_args_[i].block_start_ = block_start;
|
||||
gemm_kernel_args_[i].block_end_ = block_end;
|
||||
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
index_t tiles = (block_end - block_start) / K_BATCH;
|
||||
std::cout << "block_start: " << block_start << "\n"
|
||||
@@ -494,7 +495,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
arg.karg_.p_c_grid = p_workspace + offset;
|
||||
index_t tiles = (arg.block_end_ - arg.block_start_) / arg.karg_.k_batch;
|
||||
offset += tiles * MPerBlock * NPerBlock;
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "block_start: " << arg.block_start_ << "\n"
|
||||
<< "block_end: " << arg.block_end_ << "\n"
|
||||
@@ -774,13 +775,13 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
cast_pointer_to_constant_address_space(dev_gemm_args),
|
||||
arg.group_count_,
|
||||
arg.gemm_kernel_args_.size(),
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
PassThrough{});
|
||||
|
||||
// Elementwise kernels
|
||||
for(int i = 0; i < arg.group_count_; ++i)
|
||||
for(size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
|
||||
{
|
||||
time += launch_and_time_kernel(
|
||||
stream_config,
|
||||
@@ -818,7 +819,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
if((ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) +
|
||||
arg.skipped_group_count_) != arg.group_count_)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "The group count is not equal to sum of skipped groups "
|
||||
"and kernel args size!"
|
||||
@@ -835,7 +836,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
bool group_arg_valid = GridwiseGemm::CheckValidity(gemm_arg);
|
||||
if(not group_arg_valid)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "[" << __func__ << "] group id: " << i
|
||||
<< " has invalid GridwiseGemm settings!" << std::endl;
|
||||
|
||||
@@ -620,7 +620,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
GridwiseGemm::template CheckTensorTransfersValidity<ALayout, BLayout, ELayout>(
|
||||
M, N, K)))
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "The provided GEMM problem size (M,N,K) [" << M << "," << N << ","
|
||||
<< K << "] are not supported by current template parameters!"
|
||||
|
||||
@@ -514,7 +514,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
|
||||
|
||||
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "group: " << i << " arg.a_grid_desc_ak0_m_ak1_{"
|
||||
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0)
|
||||
|
||||
@@ -529,7 +529,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
if((ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) +
|
||||
arg.skipped_group_count_) != arg.group_count_)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "The group count is not equal to sum of skipped groups "
|
||||
"and kernel args size!"
|
||||
@@ -545,7 +545,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
bool group_arg_valid = GridwiseGemm::CheckValidity(a);
|
||||
if(not group_arg_valid)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "[" << __func__ << "] group id: " << i
|
||||
<< " has invalid GridwiseGemm settings!" << std::endl;
|
||||
|
||||
@@ -935,7 +935,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(!(karg.M % MPerBlock == 0))
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
@@ -952,7 +952,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(!(karg.N % NPerBlock == 0))
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
@@ -971,7 +971,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
auto K_t = karg.KBatch * KPerBlock;
|
||||
if(!(karg.K % K_t == 0))
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
|
||||
<< karg.K << " " << __FILE__ << ":" << __LINE__
|
||||
@@ -995,7 +995,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(karg.K % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg K (" << karg.K
|
||||
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
|
||||
@@ -1009,7 +1009,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(karg.M % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg M (" << karg.M
|
||||
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
|
||||
@@ -1024,7 +1024,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(karg.N % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg N (" << karg.N
|
||||
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
|
||||
@@ -1038,7 +1038,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(karg.K % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg K (" << karg.K
|
||||
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
|
||||
@@ -1053,7 +1053,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg N (" << karg.N
|
||||
<< ") value is not a multiple of "
|
||||
@@ -1069,7 +1069,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg M (" << karg.M
|
||||
<< ") value is not a multiple of "
|
||||
@@ -1084,7 +1084,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
|
||||
if constexpr(is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__
|
||||
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
|
||||
@@ -1113,7 +1113,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(!(karg.M % MPerBlock == 0))
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
@@ -1130,7 +1130,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(!(karg.N % NPerBlock == 0))
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
@@ -1149,7 +1149,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
auto K_t = karg.KBatch * KPerBlock;
|
||||
if(!(karg.K % K_t == 0))
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
|
||||
<< karg.K << " " << __FILE__ << ":" << __LINE__
|
||||
@@ -1173,7 +1173,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(karg.K % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg K (" << karg.K
|
||||
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
|
||||
@@ -1187,7 +1187,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(karg.M % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg M (" << karg.M
|
||||
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
|
||||
@@ -1202,7 +1202,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(karg.N % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg N (" << karg.N
|
||||
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
|
||||
@@ -1216,7 +1216,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(karg.K % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg K (" << karg.K
|
||||
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
|
||||
@@ -1231,7 +1231,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg N (" << karg.N
|
||||
<< ") value is not a multiple of "
|
||||
@@ -1247,7 +1247,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg M (" << karg.M
|
||||
<< ") value is not a multiple of "
|
||||
|
||||
@@ -446,7 +446,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
{
|
||||
if(!(karg.M % MPerBlock == 0))
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
@@ -463,7 +463,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
{
|
||||
if(!(karg.N % NPerBlock == 0))
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
@@ -482,7 +482,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
auto K_t = karg.k_batch * K0PerBlock * K1;
|
||||
if(!(karg.K % K_t == 0))
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
|
||||
<< karg.K << " " << __FILE__ << ":" << __LINE__
|
||||
@@ -496,7 +496,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
{
|
||||
if(karg.K % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg K (" << karg.K
|
||||
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
|
||||
@@ -510,7 +510,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
{
|
||||
if(karg.M % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg M (" << karg.M
|
||||
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
|
||||
@@ -525,7 +525,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
{
|
||||
if(karg.N % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg N (" << karg.N
|
||||
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
|
||||
@@ -539,7 +539,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
{
|
||||
if(karg.K % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg K (" << karg.K
|
||||
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
|
||||
@@ -554,7 +554,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
{
|
||||
if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg N (" << karg.N
|
||||
<< ") value is not a multiple of "
|
||||
@@ -569,7 +569,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
{
|
||||
if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg M (" << karg.M
|
||||
<< ") value is not a multiple of "
|
||||
@@ -584,7 +584,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
const auto num_k_loop = karg.K0Padded / K0PerBlock;
|
||||
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "The number of k loops (" << num_k_loop
|
||||
<< ") value is not supported by GridwiseGemm Pipeline."
|
||||
|
||||
@@ -124,7 +124,7 @@ struct EnvVar
|
||||
|
||||
#define CK_DECLARE_ENV_VAR_STR(name) CK_DECLARE_ENV_VAR(name, std::string, "")
|
||||
|
||||
#define ENV(name) \
|
||||
#define CK_ENV(name) \
|
||||
ck::env::name {}
|
||||
|
||||
template <class EnvVar>
|
||||
|
||||
@@ -129,8 +129,8 @@ constexpr double fp16_to_double_hip(const fp16_hip_t& x)
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr fp16_hip_t float_to_fp16_hip(const float& x)
|
||||
{
|
||||
return __float2half(x);
|
||||
// return static_cast<fp16_hip_t>(x);
|
||||
// return __float2half(x);
|
||||
return static_cast<fp16_hip_t>(x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
|
||||
@@ -56,7 +56,6 @@ CK_TILE_LEFT_UNARY_OP(+)
|
||||
CK_TILE_LEFT_UNARY_OP(-)
|
||||
CK_TILE_LEFT_UNARY_OP(~)
|
||||
CK_TILE_LEFT_UNARY_OP(!)
|
||||
CK_TILE_LEFT_UNARY_OP(*)
|
||||
|
||||
CK_TILE_BINARY_OP(+)
|
||||
CK_TILE_BINARY_OP(-)
|
||||
|
||||
@@ -88,7 +88,7 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification,
|
||||
|
||||
c_m_n_host_results.push_back(
|
||||
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n["
|
||||
<< i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i
|
||||
|
||||
@@ -87,7 +87,7 @@ bool profile_grouped_gemm_impl(int do_verification,
|
||||
|
||||
c_m_n_host_results.push_back(
|
||||
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n["
|
||||
<< i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i
|
||||
|
||||
@@ -82,7 +82,7 @@ bool profile_grouped_gemm_tile_loop_impl(int do_verification,
|
||||
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
|
||||
c_m_n_host_results.push_back(
|
||||
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n["
|
||||
<< i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i
|
||||
|
||||
@@ -88,7 +88,7 @@ bool profile_grouped_gemm_two_stage_impl(int do_verification,
|
||||
|
||||
c_m_n_host_results.push_back(
|
||||
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n["
|
||||
<< i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i
|
||||
|
||||
@@ -6,6 +6,12 @@ if(result EQUAL 0)
|
||||
add_dependencies(test_grouped_gemm test_grouped_gemm_splitk)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_grouped_gemm_two_stage_splitk test_grouped_gemm_two_stage_multiple_d_splitk_xdl.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_grouped_gemm_two_stage_splitk PRIVATE utility device_grouped_gemm_instance)
|
||||
add_dependencies(test_grouped_gemm test_grouped_gemm_two_stage_splitk)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance)
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test_grouped_gemm_util.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using I8 = int8_t;
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using RRR_F16_F16_F16 = ck::test::TestGroupedGemmTwoStage<std::tuple<Row, Row, Row, F16, F16, F16>>;
|
||||
using RCR_F16_F16_F16 = ck::test::TestGroupedGemmTwoStage<std::tuple<Row, Col, Row, F16, F16, F16>>;
|
||||
using RRR_F16_F16_F16_LargeK =
|
||||
ck::test::TestGroupedGemmTwoStage<std::tuple<Row, Row, Row, F16, F16, F16>>;
|
||||
using RCR_F16_F16_F16_LargeK =
|
||||
ck::test::TestGroupedGemmTwoStage<std::tuple<Row, Col, Row, F16, F16, F16>>;
|
||||
using RRR_BF16_BF16_BF16 =
|
||||
ck::test::TestGroupedGemmTwoStage<std::tuple<Row, Row, Row, BF16, BF16, BF16>>;
|
||||
using RCR_BF16_BF16_BF16 =
|
||||
ck::test::TestGroupedGemmTwoStage<std::tuple<Row, Col, Row, BF16, BF16, BF16>>;
|
||||
using RRR_BF16_I8_BF16 =
|
||||
ck::test::TestGroupedGemmTwoStage<std::tuple<Row, Row, Row, BF16, I8, BF16>>;
|
||||
using RCR_BF16_I8_BF16 =
|
||||
ck::test::TestGroupedGemmTwoStage<std::tuple<Row, Col, Row, BF16, I8, BF16>>;
|
||||
|
||||
const std::vector<int> KBATCH{1, 2, 3, 5, 8};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_KN,
|
||||
RRR_F16_F16_F16,
|
||||
testing::ValuesIn(KBATCH));
|
||||
INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_NK,
|
||||
RCR_F16_F16_F16,
|
||||
testing::ValuesIn(KBATCH));
|
||||
INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_KN_BF16,
|
||||
RRR_BF16_BF16_BF16,
|
||||
testing::ValuesIn(KBATCH));
|
||||
INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_NK_BF16,
|
||||
RCR_BF16_BF16_BF16,
|
||||
testing::ValuesIn(KBATCH));
|
||||
INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_KN_BF16_INT8,
|
||||
RRR_BF16_I8_BF16,
|
||||
testing::ValuesIn(KBATCH));
|
||||
INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_NK_BF16_INT8,
|
||||
RCR_BF16_I8_BF16,
|
||||
testing::ValuesIn(KBATCH));
|
||||
INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_LargeK_MK_KN,
|
||||
RRR_F16_F16_F16_LargeK,
|
||||
testing::Values(32, 64));
|
||||
INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_LargeK_MK_NK,
|
||||
RCR_F16_F16_F16_LargeK,
|
||||
testing::Values(32, 64));
|
||||
|
||||
#include "test_grouped_gemm_ut_cases.inc"
|
||||
#include "test_grouped_gemm_two_stage_ut_cases.inc"
|
||||
61
test/grouped_gemm/test_grouped_gemm_two_stage_ut_cases.inc
Normal file
61
test/grouped_gemm/test_grouped_gemm_two_stage_ut_cases.inc
Normal file
@@ -0,0 +1,61 @@
|
||||
#pragma once
|
||||
|
||||
TEST_P(RRR_BF16_BF16_BF16, MNKPadded)
|
||||
{
|
||||
const std::vector<int> Ms{127, 150, 188, 210};
|
||||
constexpr int N = 136;
|
||||
constexpr int K = 280;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
const std::vector<int> StrideAs(Ms.size(), K);
|
||||
const std::vector<int> StrideBs(Ms.size(), N);
|
||||
const std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
|
||||
}
|
||||
|
||||
TEST_P(RCR_BF16_BF16_BF16, MNKPadded)
|
||||
{
|
||||
const std::vector<int> Ms{127, 150, 188, 210};
|
||||
constexpr int N = 136;
|
||||
constexpr int K = 280;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
const std::vector<int> StrideAs(Ms.size(), K);
|
||||
const std::vector<int> StrideBs(Ms.size(), K);
|
||||
const std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
|
||||
}
|
||||
|
||||
TEST_P(RRR_BF16_I8_BF16, MNKPadded)
|
||||
{
|
||||
const std::vector<int> Ms{127, 150, 188, 210};
|
||||
constexpr int N = 136;
|
||||
constexpr int K = 280;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
const std::vector<int> StrideAs(Ms.size(), K);
|
||||
const std::vector<int> StrideBs(Ms.size(), N);
|
||||
const std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
|
||||
}
|
||||
|
||||
TEST_P(RCR_BF16_I8_BF16, MNKPadded)
|
||||
{
|
||||
const std::vector<int> Ms{127, 150, 188, 210};
|
||||
constexpr int N = 136;
|
||||
constexpr int K = 280;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
const std::vector<int> StrideAs(Ms.size(), K);
|
||||
const std::vector<int> StrideBs(Ms.size(), K);
|
||||
const std::vector<int> StrideCs(Ms.size(), N);
|
||||
|
||||
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck/utility/number.hpp"
|
||||
#include "profiler/profile_grouped_gemm_impl.hpp"
|
||||
#include "profiler/profile_grouped_gemm_two_stage_impl.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace test {
|
||||
@@ -90,6 +91,58 @@ class TestGroupedGemm : public testing::TestWithParam<int>
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedGemmTwoStage : public testing::TestWithParam<int>
|
||||
{
|
||||
protected:
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using ELayout = std::tuple_element_t<2, Tuple>;
|
||||
using ADataType = std::tuple_element_t<3, Tuple>;
|
||||
using BDataType = std::tuple_element_t<4, Tuple>;
|
||||
using EDataType = std::tuple_element_t<5, Tuple>;
|
||||
|
||||
public:
|
||||
static constexpr bool verify_ = true;
|
||||
static constexpr int init_method_ = 1; // decimal value initialization
|
||||
static constexpr bool log_ = false;
|
||||
static constexpr bool bench_ = false; // measure kernel performance
|
||||
|
||||
void SetUp() override {}
|
||||
|
||||
void Run(const std::vector<int>& Ms,
|
||||
const std::vector<int>& Ns,
|
||||
const std::vector<int>& Ks,
|
||||
const std::vector<int>& StrideAs,
|
||||
const std::vector<int>& StrideBs,
|
||||
const std::vector<int>& StrideCs,
|
||||
int kbatch = 1,
|
||||
int n_warmup = 1,
|
||||
int n_iter = 10)
|
||||
{
|
||||
bool pass = ck::profiler::profile_grouped_gemm_two_stage_impl<ADataType,
|
||||
BDataType,
|
||||
EDataType,
|
||||
float,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout>(verify_,
|
||||
init_method_,
|
||||
log_,
|
||||
bench_,
|
||||
Ms,
|
||||
Ns,
|
||||
Ks,
|
||||
StrideAs,
|
||||
StrideBs,
|
||||
StrideCs,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_iter);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
|
||||
Reference in New Issue
Block a user